Skip to content
15 changes: 6 additions & 9 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ def _cross_join_w_labels(
conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross"
)
if join_side == "left":
joined_array = self.join(labels_array, join_def=join)
joined_array = self.relational_join(labels_array, join_def=join)
else:
joined_array = labels_array.join(self, join_def=join)
joined_array = labels_array.relational_join(self, join_def=join)
return joined_array

def _create_unpivot_labels_array(
Expand All @@ -485,30 +485,27 @@ def _create_unpivot_labels_array(

return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session)

def join(
def relational_join(
self,
other: ArrayValue,
join_def: join_def.JoinDefinition,
allow_row_identity_join: bool = False,
):
) -> ArrayValue:
join_node = nodes.JoinNode(
left_child=self.node,
right_child=other.node,
join=join_def,
allow_row_identity_join=allow_row_identity_join,
)
if allow_row_identity_join:
return ArrayValue(bigframes.core.rewrite.maybe_rewrite_join(join_node))
return ArrayValue(join_node)

def try_align_as_projection(
self,
other: ArrayValue,
join_type: join_def.JoinType,
join_keys: typing.Tuple[join_def.CoalescedColumnMapping, ...],
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
) -> typing.Optional[ArrayValue]:
result = bigframes.core.rewrite.join_as_projection(
self.node, other.node, mappings, join_type
self.node, other.node, join_keys, mappings, join_type
)
if result is not None:
return ArrayValue(result)
Expand Down
79 changes: 45 additions & 34 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2008,7 +2008,7 @@ def merge(
mappings=(*left_mappings, *right_mappings),
type=how,
)
joined_expr = self.expr.join(other.expr, join_def=join_def)
joined_expr = self.expr.relational_join(other.expr, join_def=join_def)
result_columns = []
matching_join_labels = []

Expand Down Expand Up @@ -2267,25 +2267,33 @@ def join(
raise NotImplementedError(
f"Only how='outer','left','right','inner' currently supported. {constants.FEEDBACK_LINK}"
)
# Special case for null index,
# Handle null index, which only supports row join
if (self.index.nlevels == other.index.nlevels == 0) and not block_identity_join:
if not block_identity_join:
result = try_row_join(self, other, how=how)
if result is not None:
return result
raise bigframes.exceptions.NullIndexError(
"Cannot implicitly align objects. Set an explicit index using set_index."
)

# Oddly, pandas row-wise join ignores right index names
if (
(self.index.nlevels == other.index.nlevels == 0)
and not sort
and not block_identity_join
not block_identity_join
and (self.index.nlevels == other.index.nlevels)
and (self.index.dtypes == other.index.dtypes)
):
return join_indexless(self, other, how=how)
result = try_row_join(self, other, how=how)
if result is not None:
return result

self._throw_if_null_index("join")
other._throw_if_null_index("join")
if self.index.nlevels == other.index.nlevels == 1:
return join_mono_indexed(
self, other, how=how, sort=sort, block_identity_join=block_identity_join
)
else:
return join_mono_indexed(self, other, how=how, sort=sort)
else: # Handles cases where one or both sides are multi-indexed
# Always sort mult-index join
return join_multi_indexed(
self, other, how=how, sort=sort, block_identity_join=block_identity_join
)
return join_multi_indexed(self, other, how=how, sort=sort)

def _force_reproject(self) -> Block:
"""Forces a reprojection of the underlying tables expression. Used to force predicate/order application before subsequent operations."""
Expand Down Expand Up @@ -2623,46 +2631,55 @@ def is_uniquely_named(self: BlockIndexProperties):
return len(set(self.names)) == len(self.names)


def join_indexless(
def try_row_join(
left: Block,
right: Block,
*,
how="left",
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
"""Joins two blocks"""
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:
"""Joins two blocks that have a common root expression by merging the projections."""
left_expr = left.expr
right_expr = right.expr
# Create a new array value, mapping from both, then left, and then right
join_keys = tuple(
join_defs.CoalescedColumnMapping(
left_source_id=left_id,
right_source_id=right_id,
destination_id=guid.generate_guid(),
)
for left_id, right_id in zip(left.index_columns, right.index_columns)
)
left_mappings = [
join_defs.JoinColumnMapping(
source_table=join_defs.JoinSide.LEFT,
source_id=id,
destination_id=guid.generate_guid(),
)
for id in left_expr.column_ids
for id in left.value_columns
]
right_mappings = [
join_defs.JoinColumnMapping(
source_table=join_defs.JoinSide.RIGHT,
source_id=id,
destination_id=guid.generate_guid(),
)
for id in right_expr.column_ids
for id in right.value_columns
]
combined_expr = left_expr.try_align_as_projection(
right_expr,
join_type=how,
join_keys=join_keys,
mappings=(*left_mappings, *right_mappings),
)
if combined_expr is None:
raise bigframes.exceptions.NullIndexError(
"Cannot implicitly align objects. Set an explicit index using set_index."
)
return None
get_column_left = {m.source_id: m.destination_id for m in left_mappings}
get_column_right = {m.source_id: m.destination_id for m in right_mappings}
block = Block(
combined_expr,
column_labels=[*left.column_labels, *right.column_labels],
index_columns=(),
index_columns=(key.destination_id for key in join_keys),
index_labels=left.index.names,
)
return (
block,
Expand Down Expand Up @@ -2704,7 +2721,7 @@ def join_with_single_row(
mappings=(*left_mappings, *right_mappings),
type="cross",
)
combined_expr = left_expr.join(
combined_expr = left_expr.relational_join(
right_expr,
join_def=join_def,
)
Expand All @@ -2731,7 +2748,6 @@ def join_mono_indexed(
*,
how="left",
sort=False,
block_identity_join: bool = False,
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
left_expr = left.expr
right_expr = right.expr
Expand Down Expand Up @@ -2759,14 +2775,14 @@ def join_mono_indexed(
mappings=(*left_mappings, *right_mappings),
type=how,
)
combined_expr = left_expr.join(

combined_expr = left_expr.relational_join(
right_expr,
join_def=join_def,
allow_row_identity_join=(not block_identity_join),
)

get_column_left = join_def.get_left_mapping()
get_column_right = join_def.get_right_mapping()
# Drop original indices from each side. and used the coalesced combination generated by the join.
left_index = get_column_left[left.index_columns[0]]
right_index = get_column_right[right.index_columns[0]]
# Drop original indices from each side. and used the coalesced combination generated by the join.
Expand Down Expand Up @@ -2800,7 +2816,6 @@ def join_multi_indexed(
*,
how="left",
sort=False,
block_identity_join: bool = False,
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
if not (left.index.is_uniquely_named() and right.index.is_uniquely_named()):
raise ValueError("Joins not supported on indices with non-unique level names")
Expand All @@ -2819,8 +2834,6 @@ def join_multi_indexed(
left_join_ids = [left.index.resolve_level_exact(name) for name in common_names]
right_join_ids = [right.index.resolve_level_exact(name) for name in common_names]

names_fully_match = len(left_only_names) == 0 and len(right_only_names) == 0

left_expr = left.expr
right_expr = right.expr

Expand Down Expand Up @@ -2850,13 +2863,11 @@ def join_multi_indexed(
type=how,
)

combined_expr = left_expr.join(
combined_expr = left_expr.relational_join(
right_expr,
join_def=join_def,
# If we're only joining on a subset of the index columns, we need to
# perform a true join.
allow_row_identity_join=(names_fully_match and not block_identity_join),
)

get_column_left = join_def.get_left_mapping()
get_column_right = join_def.get_right_mapping()
left_ids_post_join = [get_column_left[id] for id in left_join_ids]
Expand Down
9 changes: 9 additions & 0 deletions bigframes/core/join_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ class JoinColumnMapping:
destination_id: str


@dataclasses.dataclass(frozen=True)
class CoalescedColumnMapping:
"""Special column mapping used only by implicit joiner only"""

left_source_id: str
right_source_id: str
destination_id: str


@dataclasses.dataclass(frozen=True)
class JoinDefinition:
conditions: Tuple[JoinCondition, ...]
Expand Down
1 change: 0 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ class JoinNode(BigFrameNode):
left_child: BigFrameNode
right_child: BigFrameNode
join: JoinDefinition
allow_row_identity_join: bool = False

@property
def row_preserving(self) -> bool:
Expand Down
74 changes: 38 additions & 36 deletions bigframes/core/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,33 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
)

def can_merge(
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
self,
right: SquashedSelect,
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
) -> bool:
"""Determines whether the two selections can be merged into a single selection."""
if join_def.type == "cross":
# Cannot convert cross join to projection
return False

r_exprs_by_id = {id: expr for expr, id in right.columns}
l_exprs_by_id = {id: expr for expr, id in self.columns}
l_join_exprs = [l_exprs_by_id[cond.left_id] for cond in join_def.conditions]
r_join_exprs = [r_exprs_by_id[cond.right_id] for cond in join_def.conditions]
l_join_exprs = [
l_exprs_by_id[join_key.left_source_id] for join_key in join_keys
]
r_join_exprs = [
r_exprs_by_id[join_key.right_source_id] for join_key in join_keys
]

if (self.root != right.root) or any(
l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)
):
if self.root != right.root:
return False
if len(l_join_exprs) != len(r_join_exprs):
return False
if any(l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)):
return False
return True

def merge(
self,
right: SquashedSelect,
join_type: join_defs.JoinType,
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
mappings: Tuple[join_defs.JoinColumnMapping, ...],
) -> SquashedSelect:
if self.root != right.root:
Expand All @@ -147,11 +152,9 @@ def merge(
l_relative, r_relative = relative_predicates(self.predicate, right.predicate)
lmask = l_relative if join_type in {"right", "outer"} else None
rmask = r_relative if join_type in {"left", "outer"} else None
if lmask is not None:
lselection = tuple((apply_mask(expr, lmask), id) for expr, id in lselection)
if rmask is not None:
rselection = tuple((apply_mask(expr, rmask), id) for expr, id in rselection)
new_columns = remap_names(mappings, lselection, rselection)
new_columns = merge_expressions(
join_keys, mappings, lselection, rselection, lmask, rmask
)

# Reconstruct ordering
reverse_root = self.reverse_root
Expand Down Expand Up @@ -204,34 +207,21 @@ def expand(self) -> nodes.BigFrameNode:
return nodes.ProjectionNode(child=root, assignments=self.columns)


def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode:
rewrite_common_node = common_selection_root(
join_node.left_child, join_node.right_child
)
if rewrite_common_node is None:
return join_node
left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node)
right_side = SquashedSelect.from_node_span(
join_node.right_child, rewrite_common_node
)
if left_side.can_merge(right_side, join_node.join):
return left_side.merge(
right_side, join_node.join.type, join_node.join.mappings
).expand()
return join_node


def join_as_projection(
l_node: nodes.BigFrameNode,
r_node: nodes.BigFrameNode,
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
mappings: Tuple[join_defs.JoinColumnMapping, ...],
how: join_defs.JoinType,
) -> Optional[nodes.BigFrameNode]:
rewrite_common_node = common_selection_root(l_node, r_node)
if rewrite_common_node is not None:
left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node)
right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node)
merged = left_side.merge(right_side, how, mappings)
if not left_side.can_merge(right_side, join_keys):
# Most likely because join keys didn't match
return None
merged = left_side.merge(right_side, how, join_keys, mappings)
assert (
merged is not None
), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at bigframes-feedback@google.com."
Expand All @@ -240,21 +230,33 @@ def join_as_projection(
return None


def remap_names(
def merge_expressions(
join_keys: Tuple[join_defs.CoalescedColumnMapping, ...],
mappings: Tuple[join_defs.JoinColumnMapping, ...],
lselection: Selection,
rselection: Selection,
lmask: Optional[scalar_exprs.Expression],
rmask: Optional[scalar_exprs.Expression],
) -> Selection:
new_selection: Selection = tuple()
l_exprs_by_id = {id: expr for expr, id in lselection}
r_exprs_by_id = {id: expr for expr, id in rselection}
for key in join_keys:
# Join keys expressions are equivalent on both sides, so can choose either left or right key
assert l_exprs_by_id[key.left_source_id] == r_exprs_by_id[key.right_source_id]
expr = l_exprs_by_id[key.left_source_id]
id = key.destination_id
new_selection = (*new_selection, (expr, id))
for mapping in mappings:
if mapping.source_table == join_defs.JoinSide.LEFT:
expr = l_exprs_by_id[mapping.source_id]
if lmask is not None:
expr = apply_mask(expr, lmask)
else: # Right
expr = r_exprs_by_id[mapping.source_id]
id = mapping.destination_id
new_selection = (*new_selection, (expr, id))
if rmask is not None:
expr = apply_mask(expr, rmask)
new_selection = (*new_selection, (expr, mapping.destination_id))
return new_selection


Expand Down
Loading