Skip to content

Commit 597eb55

Browse files
refactor: Simplify InNode definition (#2264)
1 parent f73fb98 commit 597eb55

File tree

12 files changed

+32
-67
lines changed

12 files changed

+32
-67
lines changed

bigframes/core/array_value.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,15 @@ def project_window_expr(
450450
)
451451

452452
def isin(
453-
self, other: ArrayValue, lcol: str, rcol: str
453+
self,
454+
other: ArrayValue,
455+
lcol: str,
454456
) -> typing.Tuple[ArrayValue, str]:
457+
assert len(other.column_ids) == 1
455458
node = nodes.InNode(
456459
self.node,
457460
other.node,
458461
ex.deref(lcol),
459-
ex.deref(rcol),
460462
indicator_col=ids.ColumnId.unique(),
461463
)
462464
return ArrayValue(node), node.indicator_col.name

bigframes/core/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2326,7 +2326,7 @@ def isin(self, other: Block):
23262326
return block
23272327

23282328
def _isin_inner(self: Block, col: str, unique_values: core.ArrayValue) -> Block:
2329-
expr, matches = self._expr.isin(unique_values, col, unique_values.column_ids[0])
2329+
expr, matches = self._expr.isin(unique_values, col)
23302330

23312331
new_value_cols = tuple(
23322332
val_col if val_col != col else matches for val_col in self.value_columns

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def compile_isin(
128128
return left.isin_join(
129129
right=right,
130130
indicator_col=node.indicator_col.sql,
131-
conditions=(node.left_col.id.sql, node.right_col.id.sql),
131+
conditions=(node.left_col.id.sql, list(node.right_child.ids)[0].sql),
132132
join_nulls=node.joins_nulls,
133133
)
134134

bigframes/core/compile/polars/compiler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -700,12 +700,11 @@ def compile_join(self, node: nodes.JoinNode):
700700
@compile_node.register
701701
def compile_isin(self, node: nodes.InNode):
702702
left = self.compile_node(node.left_child)
703-
right = self.compile_node(node.right_child).unique(node.right_col.id.sql)
703+
right = self.compile_node(node.right_child).unique()
704704
right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql))
705705

706-
left_ex, right_ex = lowering._coerce_comparables(
707-
node.left_col, node.right_col
708-
)
706+
right_col = ex.ResolvedDerefOp.from_field(node.right_child.fields[0])
707+
left_ex, right_ex = lowering._coerce_comparables(node.left_col, right_col)
709708

710709
left_pl_ex = self.expr_compiler.compile_expression(left_ex)
711710
right_pl_ex = self.expr_compiler.compile_expression(right_ex)

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,17 @@ def compile_join(
230230
def compile_isin_join(
231231
node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
232232
) -> ir.SQLGlotIR:
233+
right_field = node.right_child.fields[0]
233234
conditions = (
234235
typed_expr.TypedExpr(
235236
scalar_compiler.scalar_op_compiler.compile_expression(node.left_col),
236237
node.left_col.output_type,
237238
),
238239
typed_expr.TypedExpr(
239-
scalar_compiler.scalar_op_compiler.compile_expression(node.right_col),
240-
node.right_col.output_type,
240+
scalar_compiler.scalar_op_compiler.compile_expression(
241+
expression.DerefOp(right_field.id)
242+
),
243+
right_field.dtype,
241244
),
242245
)
243246

bigframes/core/nodes.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,10 @@ class InNode(BigFrameNode, AdditiveNode):
200200
left_child: BigFrameNode
201201
right_child: BigFrameNode
202202
left_col: ex.DerefOp
203-
right_col: ex.DerefOp
204203
indicator_col: identifiers.ColumnId
205204

206205
def _validate(self):
207-
assert not (
208-
set(self.left_child.ids) & set(self.right_child.ids)
209-
), "Join ids collide"
206+
assert len(self.right_child.fields) == 1
210207

211208
@property
212209
def row_preserving(self) -> bool:
@@ -259,7 +256,11 @@ def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
259256

260257
@property
261258
def referenced_ids(self) -> COLUMN_SET:
262-
return frozenset({self.left_col.id, self.right_col.id})
259+
return frozenset(
260+
{
261+
self.left_col.id,
262+
}
263+
)
263264

264265
@property
265266
def additive_base(self) -> BigFrameNode:
@@ -268,12 +269,13 @@ def additive_base(self) -> BigFrameNode:
268269
@property
269270
def joins_nulls(self) -> bool:
270271
left_nullable = self.left_child.field_by_id[self.left_col.id].nullable
271-
right_nullable = self.right_child.field_by_id[self.right_col.id].nullable
272+
# assumption: right side has one column
273+
right_nullable = self.right_child.fields[0].nullable
272274
return left_nullable or right_nullable
273275

274276
@property
275277
def _node_expressions(self):
276-
return (self.left_col, self.right_col)
278+
return (self.left_col,)
277279

278280
def replace_additive_base(self, node: BigFrameNode):
279281
return dataclasses.replace(self, left_child=node)
@@ -302,9 +304,6 @@ def remap_refs(
302304
left_col=self.left_col.remap_column_refs(
303305
mappings, allow_partial_bindings=True
304306
),
305-
right_col=self.right_col.remap_column_refs(
306-
mappings, allow_partial_bindings=True
307-
),
308307
) # type: ignore
309308

310309

bigframes/core/rewrite/identifiers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ def remap_variables(
6969
left_col=new_root.left_col.remap_column_refs(
7070
new_child_mappings[0], allow_partial_bindings=True
7171
),
72-
right_col=new_root.right_col.remap_column_refs(
73-
new_child_mappings[1], allow_partial_bindings=True
74-
),
7572
)
7673
else:
7774
new_root = new_root.remap_refs(downstream_mappings)

bigframes/core/rewrite/implicit_align.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import dataclasses
1717
import itertools
18-
from typing import cast, Optional, Sequence, Set, Tuple
18+
from typing import Optional, Sequence, Set, Tuple
1919

2020
import bigframes.core.expression
2121
import bigframes.core.identifiers
@@ -152,35 +152,6 @@ def pull_up_selection(
152152
return node, tuple(
153153
bigframes.core.nodes.AliasedRef.identity(field.id) for field in node.fields
154154
)
155-
# InNode needs special handling, as its a binary node, but row identity is from left side only.
156-
# TODO: Merge code with unary op paths
157-
if isinstance(node, bigframes.core.nodes.InNode):
158-
child_node, child_selections = pull_up_selection(
159-
node.left_child, stop=stop, rename_vars=rename_vars
160-
)
161-
mapping = {out: ref.id for ref, out in child_selections}
162-
163-
new_in_node: bigframes.core.nodes.InNode = dataclasses.replace(
164-
node, left_child=child_node
165-
)
166-
new_in_node = new_in_node.remap_refs(mapping)
167-
if rename_vars:
168-
new_in_node = cast(
169-
bigframes.core.nodes.InNode,
170-
new_in_node.remap_vars(
171-
{node.indicator_col: bigframes.core.identifiers.ColumnId.unique()}
172-
),
173-
)
174-
added_selection = tuple(
175-
(
176-
bigframes.core.nodes.AliasedRef(
177-
bigframes.core.expression.DerefOp(new_in_node.indicator_col),
178-
node.indicator_col,
179-
),
180-
)
181-
)
182-
new_selection = child_selections + added_selection
183-
return new_in_node, new_selection
184155

185156
if isinstance(node, bigframes.core.nodes.AdditiveNode):
186157
child_node, child_selections = pull_up_selection(

bigframes/core/rewrite/pruning.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def prune_columns(node: nodes.BigFrameNode):
5555
result = node.replace_child(prune_node(node.child, node.consumed_ids))
5656
elif isinstance(node, nodes.AggregateNode):
5757
result = node.replace_child(prune_node(node.child, node.consumed_ids))
58-
elif isinstance(node, nodes.InNode):
59-
result = dataclasses.replace(
60-
node,
61-
right_child=prune_node(node.right_child, frozenset([node.right_col.id])),
62-
)
6358
else:
6459
result = node
6560
return result

bigframes/core/rewrite/schema_binding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ def bind_schema_to_node(
7171
left_col=ex.ResolvedDerefOp.from_field(
7272
node.left_child.field_by_id[node.left_col.id]
7373
),
74-
right_col=ex.ResolvedDerefOp.from_field(
75-
node.right_child.field_by_id[node.right_col.id]
76-
),
7774
)
7875

7976
if isinstance(node, nodes.AggregateNode):

0 commit comments

Comments
 (0)