Skip to content

Commit fc17398

Browse files
fix: Fewer relation joins from df self-operations
1 parent bb1b1e3 commit fc17398

File tree

6 files changed

+115
-80
lines changed

6 files changed

+115
-80
lines changed

bigframes/core/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,9 @@ def _cross_join_w_labels(
460460
conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross"
461461
)
462462
if join_side == "left":
463-
joined_array = self.join(labels_array, join_def=join)
463+
joined_array = self.relational_join(labels_array, join_def=join)
464464
else:
465-
joined_array = labels_array.join(self, join_def=join)
465+
joined_array = labels_array.relational_join(self, join_def=join)
466466
return joined_array
467467

468468
def _create_unpivot_labels_array(
@@ -485,30 +485,27 @@ def _create_unpivot_labels_array(
485485

486486
return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session)
487487

488-
def join(
488+
def relational_join(
489489
self,
490490
other: ArrayValue,
491491
join_def: join_def.JoinDefinition,
492-
allow_row_identity_join: bool = False,
493-
):
492+
) -> ArrayValue:
494493
join_node = nodes.JoinNode(
495494
left_child=self.node,
496495
right_child=other.node,
497496
join=join_def,
498-
allow_row_identity_join=allow_row_identity_join,
499497
)
500-
if allow_row_identity_join:
501-
return ArrayValue(bigframes.core.rewrite.maybe_rewrite_join(join_node))
502498
return ArrayValue(join_node)
503499

504500
def try_align_as_projection(
505501
self,
506502
other: ArrayValue,
507503
join_type: join_def.JoinType,
504+
join_keys: typing.Tuple[join_def.CoalescedColumnMapping, ...],
508505
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
509506
) -> typing.Optional[ArrayValue]:
510507
result = bigframes.core.rewrite.join_as_projection(
511-
self.node, other.node, mappings, join_type
508+
self.node, other.node, join_keys, mappings, join_type
512509
)
513510
if result is not None:
514511
return ArrayValue(result)

bigframes/core/blocks.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,7 +2010,7 @@ def merge(
20102010
mappings=(*left_mappings, *right_mappings),
20112011
type=how,
20122012
)
2013-
joined_expr = self.expr.join(other.expr, join_def=join_def)
2013+
joined_expr = self.expr.relational_join(other.expr, join_def=join_def)
20142014
result_columns = []
20152015
matching_join_labels = []
20162016

@@ -2269,25 +2269,33 @@ def join(
22692269
raise NotImplementedError(
22702270
f"Only how='outer','left','right','inner' currently supported. {constants.FEEDBACK_LINK}"
22712271
)
2272-
# Special case for null index,
2272+
# Handle null index, which only supports row join
2273+
if (self.index.nlevels == other.index.nlevels == 0) and not block_identity_join:
2274+
if not block_identity_join:
2275+
result = try_row_join(self, other, how=how)
2276+
if result is not None:
2277+
return result
2278+
raise bigframes.exceptions.NullIndexError(
2279+
"Cannot implicitly align objects. Set an explicit index using set_index."
2280+
)
2281+
2282+
# Oddly, pandas row-wise join ignores right index names
22732283
if (
2274-
(self.index.nlevels == other.index.nlevels == 0)
2275-
and not sort
2276-
and not block_identity_join
2284+
not block_identity_join
2285+
and (self.index.nlevels == other.index.nlevels)
2286+
and (self.index.dtypes == other.index.dtypes)
22772287
):
2278-
return join_indexless(self, other, how=how)
2288+
result = try_row_join(self, other, how=how)
2289+
if result is not None:
2290+
return result
22792291

22802292
self._throw_if_null_index("join")
22812293
other._throw_if_null_index("join")
22822294
if self.index.nlevels == other.index.nlevels == 1:
2283-
return join_mono_indexed(
2284-
self, other, how=how, sort=sort, block_identity_join=block_identity_join
2285-
)
2286-
else:
2295+
return join_mono_indexed(self, other, how=how, sort=sort)
2296+
else: # Handles cases where one or both sides are multi-indexed
22872297
# Always sort mult-index join
2288-
return join_multi_indexed(
2289-
self, other, how=how, sort=sort, block_identity_join=block_identity_join
2290-
)
2298+
return join_multi_indexed(self, other, how=how, sort=sort)
22912299

22922300
def _force_reproject(self) -> Block:
22932301
"""Forces a reprojection of the underlying tables expression. Used to force predicate/order application before subsequent operations."""
@@ -2626,46 +2634,55 @@ def is_uniquely_named(self: BlockIndexProperties):
26262634
return len(set(self.names)) == len(self.names)
26272635

26282636

2629-
def join_indexless(
2637+
def try_row_join(
26302638
left: Block,
26312639
right: Block,
26322640
*,
26332641
how="left",
2634-
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
2635-
"""Joins two blocks"""
2642+
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:
2643+
"""Joins two blocks that have a common root expression by merging the projections."""
26362644
left_expr = left.expr
26372645
right_expr = right.expr
2646+
# Create a new array value, mapping from both, then left, and then right
2647+
join_keys = tuple(
2648+
join_defs.CoalescedColumnMapping(
2649+
left_source_id=left_id,
2650+
right_source_id=right_id,
2651+
destination_id=guid.generate_guid(),
2652+
)
2653+
for left_id, right_id in zip(left.index_columns, right.index_columns)
2654+
)
26382655
left_mappings = [
26392656
join_defs.JoinColumnMapping(
26402657
source_table=join_defs.JoinSide.LEFT,
26412658
source_id=id,
26422659
destination_id=guid.generate_guid(),
26432660
)
2644-
for id in left_expr.column_ids
2661+
for id in left.value_columns
26452662
]
26462663
right_mappings = [
26472664
join_defs.JoinColumnMapping(
26482665
source_table=join_defs.JoinSide.RIGHT,
26492666
source_id=id,
26502667
destination_id=guid.generate_guid(),
26512668
)
2652-
for id in right_expr.column_ids
2669+
for id in right.value_columns
26532670
]
26542671
combined_expr = left_expr.try_align_as_projection(
26552672
right_expr,
26562673
join_type=how,
2674+
join_keys=join_keys,
26572675
mappings=(*left_mappings, *right_mappings),
26582676
)
26592677
if combined_expr is None:
2660-
raise bigframes.exceptions.NullIndexError(
2661-
"Cannot implicitly align objects. Set an explicit index using set_index."
2662-
)
2678+
return None
26632679
get_column_left = {m.source_id: m.destination_id for m in left_mappings}
26642680
get_column_right = {m.source_id: m.destination_id for m in right_mappings}
26652681
block = Block(
26662682
combined_expr,
26672683
column_labels=[*left.column_labels, *right.column_labels],
2668-
index_columns=(),
2684+
index_columns=(key.destination_id for key in join_keys),
2685+
index_labels=left.index.names,
26692686
)
26702687
return (
26712688
block,
@@ -2707,7 +2724,7 @@ def join_with_single_row(
27072724
mappings=(*left_mappings, *right_mappings),
27082725
type="cross",
27092726
)
2710-
combined_expr = left_expr.join(
2727+
combined_expr = left_expr.relational_join(
27112728
right_expr,
27122729
join_def=join_def,
27132730
)
@@ -2734,7 +2751,6 @@ def join_mono_indexed(
27342751
*,
27352752
how="left",
27362753
sort=False,
2737-
block_identity_join: bool = False,
27382754
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
27392755
left_expr = left.expr
27402756
right_expr = right.expr
@@ -2762,14 +2778,14 @@ def join_mono_indexed(
27622778
mappings=(*left_mappings, *right_mappings),
27632779
type=how,
27642780
)
2765-
combined_expr = left_expr.join(
2781+
2782+
combined_expr = left_expr.relational_join(
27662783
right_expr,
27672784
join_def=join_def,
2768-
allow_row_identity_join=(not block_identity_join),
27692785
)
2786+
27702787
get_column_left = join_def.get_left_mapping()
27712788
get_column_right = join_def.get_right_mapping()
2772-
# Drop original indices from each side. and used the coalesced combination generated by the join.
27732789
left_index = get_column_left[left.index_columns[0]]
27742790
right_index = get_column_right[right.index_columns[0]]
27752791
# Drop original indices from each side. and used the coalesced combination generated by the join.
@@ -2803,7 +2819,6 @@ def join_multi_indexed(
28032819
*,
28042820
how="left",
28052821
sort=False,
2806-
block_identity_join: bool = False,
28072822
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
28082823
if not (left.index.is_uniquely_named() and right.index.is_uniquely_named()):
28092824
raise ValueError("Joins not supported on indices with non-unique level names")
@@ -2822,8 +2837,6 @@ def join_multi_indexed(
28222837
left_join_ids = [left.index.resolve_level_exact(name) for name in common_names]
28232838
right_join_ids = [right.index.resolve_level_exact(name) for name in common_names]
28242839

2825-
names_fully_match = len(left_only_names) == 0 and len(right_only_names) == 0
2826-
28272840
left_expr = left.expr
28282841
right_expr = right.expr
28292842

@@ -2853,13 +2866,11 @@ def join_multi_indexed(
28532866
type=how,
28542867
)
28552868

2856-
combined_expr = left_expr.join(
2869+
combined_expr = left_expr.relational_join(
28572870
right_expr,
28582871
join_def=join_def,
2859-
# If we're only joining on a subset of the index columns, we need to
2860-
# perform a true join.
2861-
allow_row_identity_join=(names_fully_match and not block_identity_join),
28622872
)
2873+
28632874
get_column_left = join_def.get_left_mapping()
28642875
get_column_right = join_def.get_right_mapping()
28652876
left_ids_post_join = [get_column_left[id] for id in left_join_ids]

bigframes/core/join_def.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ class JoinColumnMapping:
4343
destination_id: str
4444

4545

46+
@dataclasses.dataclass(frozen=True)
47+
class CoalescedColumnMapping:
48+
"""Special column mapping used only by implicit joiner only"""
49+
50+
left_source_id: str
51+
right_source_id: str
52+
destination_id: str
53+
54+
4655
@dataclasses.dataclass(frozen=True)
4756
class JoinDefinition:
4857
conditions: Tuple[JoinCondition, ...]

bigframes/core/nodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ class JoinNode(BigFrameNode):
183183
left_child: BigFrameNode
184184
right_child: BigFrameNode
185185
join: JoinDefinition
186-
allow_row_identity_join: bool = False
187186

188187
@property
189188
def row_preserving(self) -> bool:

bigframes/core/rewrite.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -106,28 +106,33 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
106106
)
107107

108108
def can_merge(
109-
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
109+
self,
110+
right: SquashedSelect,
111+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
110112
) -> bool:
111113
"""Determines whether the two selections can be merged into a single selection."""
112-
if join_def.type == "cross":
113-
# Cannot convert cross join to projection
114-
return False
115-
116114
r_exprs_by_id = {id: expr for expr, id in right.columns}
117115
l_exprs_by_id = {id: expr for expr, id in self.columns}
118-
l_join_exprs = [l_exprs_by_id[cond.left_id] for cond in join_def.conditions]
119-
r_join_exprs = [r_exprs_by_id[cond.right_id] for cond in join_def.conditions]
116+
l_join_exprs = [
117+
l_exprs_by_id[join_key.left_source_id] for join_key in join_keys
118+
]
119+
r_join_exprs = [
120+
r_exprs_by_id[join_key.right_source_id] for join_key in join_keys
121+
]
120122

121-
if (self.root != right.root) or any(
122-
l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)
123-
):
123+
if self.root != right.root:
124+
return False
125+
if len(l_join_exprs) != len(r_join_exprs):
126+
return False
127+
if any(l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)):
124128
return False
125129
return True
126130

127131
def merge(
128132
self,
129133
right: SquashedSelect,
130134
join_type: join_defs.JoinType,
135+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
131136
mappings: Tuple[join_defs.JoinColumnMapping, ...],
132137
) -> SquashedSelect:
133138
if self.root != right.root:
@@ -147,11 +152,9 @@ def merge(
147152
l_relative, r_relative = relative_predicates(self.predicate, right.predicate)
148153
lmask = l_relative if join_type in {"right", "outer"} else None
149154
rmask = r_relative if join_type in {"left", "outer"} else None
150-
if lmask is not None:
151-
lselection = tuple((apply_mask(expr, lmask), id) for expr, id in lselection)
152-
if rmask is not None:
153-
rselection = tuple((apply_mask(expr, rmask), id) for expr, id in rselection)
154-
new_columns = remap_names(mappings, lselection, rselection)
155+
new_columns = merge_expressions(
156+
join_keys, mappings, lselection, rselection, lmask, rmask
157+
)
155158

156159
# Reconstruct ordering
157160
reverse_root = self.reverse_root
@@ -204,34 +207,18 @@ def expand(self) -> nodes.BigFrameNode:
204207
return nodes.ProjectionNode(child=root, assignments=self.columns)
205208

206209

207-
def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode:
208-
rewrite_common_node = common_selection_root(
209-
join_node.left_child, join_node.right_child
210-
)
211-
if rewrite_common_node is None:
212-
return join_node
213-
left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node)
214-
right_side = SquashedSelect.from_node_span(
215-
join_node.right_child, rewrite_common_node
216-
)
217-
if left_side.can_merge(right_side, join_node.join):
218-
return left_side.merge(
219-
right_side, join_node.join.type, join_node.join.mappings
220-
).expand()
221-
return join_node
222-
223-
224210
def join_as_projection(
225211
l_node: nodes.BigFrameNode,
226212
r_node: nodes.BigFrameNode,
213+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
227214
mappings: Tuple[join_defs.JoinColumnMapping, ...],
228215
how: join_defs.JoinType,
229216
) -> Optional[nodes.BigFrameNode]:
230217
rewrite_common_node = common_selection_root(l_node, r_node)
231218
if rewrite_common_node is not None:
232219
left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node)
233220
right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node)
234-
merged = left_side.merge(right_side, how, mappings)
221+
merged = left_side.merge(right_side, how, join_keys, mappings)
235222
assert (
236223
merged is not None
237224
), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at bigframes-feedback@google.com."
@@ -240,21 +227,33 @@ def join_as_projection(
240227
return None
241228

242229

243-
def remap_names(
230+
def merge_expressions(
231+
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
244232
mappings: Tuple[join_defs.JoinColumnMapping, ...],
245233
lselection: Selection,
246234
rselection: Selection,
235+
lmask: Optional[scalar_exprs.Expression],
236+
rmask: Optional[scalar_exprs.Expression],
247237
) -> Selection:
248238
new_selection: Selection = tuple()
249239
l_exprs_by_id = {id: expr for expr, id in lselection}
250240
r_exprs_by_id = {id: expr for expr, id in rselection}
241+
for key in join_keys:
242+
# Join keys expressions are equivalent on both sides, so can choose either left or right key
243+
assert l_exprs_by_id[key.left_source_id] == r_exprs_by_id[key.right_source_id]
244+
expr = l_exprs_by_id[key.left_source_id]
245+
id = key.destination_id
246+
new_selection = (*new_selection, (expr, id))
251247
for mapping in mappings:
252248
if mapping.source_table == join_defs.JoinSide.LEFT:
253249
expr = l_exprs_by_id[mapping.source_id]
250+
if lmask is not None:
251+
expr = apply_mask(expr, lmask)
254252
else: # Right
255253
expr = r_exprs_by_id[mapping.source_id]
256-
id = mapping.destination_id
257-
new_selection = (*new_selection, (expr, id))
254+
if rmask is not None:
255+
expr = apply_mask(expr, rmask)
256+
new_selection = (*new_selection, (expr, mapping.destination_id))
258257
return new_selection
259258

260259

0 commit comments

Comments
 (0)