Skip to content

Commit ac43ee5

Browse files
refactor: WindowOpNode can create multiple cols (#2284)
1 parent f7fd2d2 commit ac43ee5

File tree

16 files changed

+284
-236
lines changed

16 files changed

+284
-236
lines changed

bigframes/core/array_value.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ def compute_values(self, assignments: Sequence[ex.Expression]):
268268

269269
def compute_general_expression(self, assignments: Sequence[ex.Expression]):
270270
named_exprs = [
271-
expression_factoring.NamedExpression(expr, ids.ColumnId.unique())
272-
for expr in assignments
271+
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
273272
]
274273
# TODO: Push this to rewrite later to go from block expression to planning form
275274
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
@@ -279,7 +278,7 @@ def compute_general_expression(self, assignments: Sequence[ex.Expression]):
279278
for expr in named_exprs
280279
)
281280
)
282-
target_ids = tuple(named_expr.name for named_expr in named_exprs)
281+
target_ids = tuple(named_expr.id for named_expr in named_exprs)
283282
new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids)
284283
return (ArrayValue(new_root), target_ids)
285284

@@ -403,22 +402,24 @@ def aggregate(
403402

404403
def project_window_expr(
405404
self,
406-
expression: agg_expressions.Aggregation,
405+
expressions: Sequence[agg_expressions.Aggregation],
407406
window: WindowSpec,
408-
skip_reproject_unsafe: bool = False,
409407
):
410-
output_name = self._gen_namespaced_uid()
408+
id_strings = [self._gen_namespaced_uid() for _ in expressions]
409+
agg_exprs = tuple(
410+
nodes.ColumnDef(expression, ids.ColumnId(id_str))
411+
for expression, id_str in zip(expressions, id_strings)
412+
)
413+
411414
return (
412415
ArrayValue(
413416
nodes.WindowOpNode(
414417
child=self.node,
415-
expression=expression,
418+
agg_exprs=agg_exprs,
416419
window_spec=window,
417-
output_name=ids.ColumnId(output_name),
418-
skip_reproject_unsafe=skip_reproject_unsafe,
419420
)
420421
),
421-
output_name,
422+
id_strings,
422423
)
423424

424425
def isin(

bigframes/core/block_transforms.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,11 @@ def _interpolate_column(
232232
masked_offsets,
233233
agg_ops.LastNonNullOp(),
234234
backwards_window,
235-
skip_reproject_unsafe=True,
236235
)
237236
block, next_value_offset = block.apply_window_op(
238237
masked_offsets,
239238
agg_ops.FirstNonNullOp(),
240239
forwards_window,
241-
skip_reproject_unsafe=True,
242240
)
243241

244242
if interpolate_method == "linear":

bigframes/core/blocks.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,20 +1091,14 @@ def multi_apply_window_op(
10911091
*,
10921092
skip_null_groups: bool = False,
10931093
) -> typing.Tuple[Block, typing.Sequence[str]]:
1094-
block = self
1095-
result_ids = []
1096-
for i, col_id in enumerate(columns):
1097-
label = self.col_id_to_label[col_id]
1098-
block, result_id = block.apply_window_op(
1099-
col_id,
1100-
op,
1101-
window_spec=window_spec,
1102-
skip_reproject_unsafe=(i + 1) < len(columns),
1103-
result_label=label,
1104-
skip_null_groups=skip_null_groups,
1105-
)
1106-
result_ids.append(result_id)
1107-
return block, result_ids
1094+
return self.apply_analytic(
1095+
agg_exprs=(
1096+
agg_expressions.UnaryAggregation(op, ex.deref(col)) for col in columns
1097+
),
1098+
window=window_spec,
1099+
result_labels=self._get_labels_for_columns(columns),
1100+
skip_null_groups=skip_null_groups,
1101+
)
11081102

11091103
def multi_apply_unary_op(
11101104
self,
@@ -1181,44 +1175,39 @@ def apply_window_op(
11811175
*,
11821176
result_label: Label = None,
11831177
skip_null_groups: bool = False,
1184-
skip_reproject_unsafe: bool = False,
11851178
) -> typing.Tuple[Block, str]:
11861179
agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column))
1187-
return self.apply_analytic(
1188-
agg_expr,
1180+
block, ids = self.apply_analytic(
1181+
[agg_expr],
11891182
window_spec,
1190-
result_label,
1191-
skip_reproject_unsafe=skip_reproject_unsafe,
1183+
[result_label],
11921184
skip_null_groups=skip_null_groups,
11931185
)
1186+
return block, ids[0]
11941187

11951188
def apply_analytic(
11961189
self,
1197-
agg_expr: agg_expressions.Aggregation,
1190+
agg_exprs: Iterable[agg_expressions.Aggregation],
11981191
window: windows.WindowSpec,
1199-
result_label: Label,
1192+
result_labels: Iterable[Label],
12001193
*,
1201-
skip_reproject_unsafe: bool = False,
12021194
skip_null_groups: bool = False,
1203-
) -> typing.Tuple[Block, str]:
1195+
) -> typing.Tuple[Block, Sequence[str]]:
12041196
block = self
12051197
if skip_null_groups:
12061198
for key in window.grouping_keys:
12071199
block = block.filter(ops.notnull_op.as_expr(key))
1208-
expr, result_id = block._expr.project_window_expr(
1209-
agg_expr,
1200+
expr, result_ids = block._expr.project_window_expr(
1201+
tuple(agg_exprs),
12101202
window,
1211-
skip_reproject_unsafe=skip_reproject_unsafe,
12121203
)
12131204
block = Block(
12141205
expr,
12151206
index_columns=self.index_columns,
1216-
column_labels=self.column_labels.insert(
1217-
len(self.column_labels), result_label
1218-
),
1207+
column_labels=self.column_labels.append(pd.Index(result_labels)),
12191208
index_labels=self._index_labels,
12201209
)
1221-
return (block, result_id)
1210+
return (block, result_ids)
12221211

12231212
def copy_values(self, source_column_id: str, destination_column_id: str) -> Block:
12241213
expr = self.expr.assign(source_column_id, destination_column_id)

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,13 @@ def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR):
265265

266266
@_compile_node.register
267267
def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR):
268-
result = child.project_window_op(
269-
node.expression,
270-
node.window_spec,
271-
node.output_name.sql,
272-
)
268+
result = child
269+
for cdef in node.agg_exprs:
270+
result = result.project_window_op(
271+
cdef.expression, # type: ignore
272+
node.window_spec,
273+
cdef.id.sql,
274+
)
273275
return result
274276

275277

bigframes/core/compile/polars/compiler.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -853,20 +853,23 @@ def compile_window(self, node: nodes.WindowOpNode):
853853
"min_period not yet supported for polars engine"
854854
)
855855

856-
if (window.bounds is None) or (window.is_unbounded):
857-
# polars will automatically broadcast the aggregate to the matching input rows
858-
agg_pl = self.agg_compiler.compile_agg_expr(node.expression)
859-
if window.grouping_keys:
860-
agg_pl = agg_pl.over(
861-
self.expr_compiler.compile_expression(key)
862-
for key in window.grouping_keys
856+
result = df
857+
for cdef in node.agg_exprs:
858+
assert isinstance(cdef.expression, agg_expressions.Aggregation)
859+
if (window.bounds is None) or (window.is_unbounded):
860+
# polars will automatically broadcast the aggregate to the matching input rows
861+
agg_pl = self.agg_compiler.compile_agg_expr(cdef.expression)
862+
if window.grouping_keys:
863+
agg_pl = agg_pl.over(
864+
self.expr_compiler.compile_expression(key)
865+
for key in window.grouping_keys
866+
)
867+
result = result.with_columns(agg_pl.alias(cdef.id.sql))
868+
else: # row-bounded window
869+
window_result = self._calc_row_analytic_func(
870+
result, cdef.expression, node.window_spec, cdef.id.sql
863871
)
864-
result = df.with_columns(agg_pl.alias(node.output_name.sql))
865-
else: # row-bounded window
866-
window_result = self._calc_row_analytic_func(
867-
df, node.expression, node.window_spec, node.output_name.sql
868-
)
869-
result = pl.concat([df, window_result], how="horizontal")
872+
result = pl.concat([result, window_result], how="horizontal")
870873
return result
871874

872875
def _calc_row_analytic_func(

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919

2020
import sqlglot.expressions as sge
2121

22-
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
22+
from bigframes.core import (
23+
agg_expressions,
24+
expression,
25+
guid,
26+
identifiers,
27+
nodes,
28+
pyarrow_utils,
29+
rewrite,
30+
)
2331
from bigframes.core.compile import configs
2432
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
2533
from bigframes.core.compile.sqlglot.aggregations import windows
@@ -310,67 +318,71 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG
310318
@_compile_node.register
311319
def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
312320
window_spec = node.window_spec
313-
if node.expression.op.order_independent and window_spec.is_unbounded:
314-
# notably percentile_cont does not support ordering clause
315-
window_spec = window_spec.without_order()
321+
result = child
322+
for cdef in node.agg_exprs:
323+
assert isinstance(cdef.expression, agg_expressions.Aggregation)
324+
if cdef.expression.op.order_independent and window_spec.is_unbounded:
325+
# notably percentile_cont does not support ordering clause
326+
window_spec = window_spec.without_order()
316327

317-
window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)
328+
window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec)
318329

319-
inputs: tuple[sge.Expression, ...] = tuple(
320-
scalar_compiler.scalar_op_compiler.compile_expression(
321-
expression.DerefOp(column)
330+
inputs: tuple[sge.Expression, ...] = tuple(
331+
scalar_compiler.scalar_op_compiler.compile_expression(
332+
expression.DerefOp(column)
333+
)
334+
for column in cdef.expression.column_references
322335
)
323-
for column in node.expression.column_references
324-
)
325336

326-
clauses: list[tuple[sge.Expression, sge.Expression]] = []
327-
if window_spec.min_periods and len(inputs) > 0:
328-
if not node.expression.op.nulls_count_for_min_values:
329-
# Most operations do not count NULL values towards min_periods
330-
not_null_columns = [
331-
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
332-
for column in inputs
333-
]
334-
# All inputs must be non-null for observation to count
335-
if not not_null_columns:
336-
is_observation_expr: sge.Expression = sge.convert(True)
337+
clauses: list[tuple[sge.Expression, sge.Expression]] = []
338+
if window_spec.min_periods and len(inputs) > 0:
339+
if not cdef.expression.op.nulls_count_for_min_values:
340+
# Most operations do not count NULL values towards min_periods
341+
not_null_columns = [
342+
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
343+
for column in inputs
344+
]
345+
# All inputs must be non-null for observation to count
346+
if not not_null_columns:
347+
is_observation_expr: sge.Expression = sge.convert(True)
348+
else:
349+
is_observation_expr = not_null_columns[0]
350+
for expr in not_null_columns[1:]:
351+
is_observation_expr = sge.And(
352+
this=is_observation_expr, expression=expr
353+
)
354+
is_observation = ir._cast(is_observation_expr, "INT64")
355+
observation_count = windows.apply_window_if_present(
356+
sge.func("SUM", is_observation), window_spec
357+
)
337358
else:
338-
is_observation_expr = not_null_columns[0]
339-
for expr in not_null_columns[1:]:
340-
is_observation_expr = sge.And(
341-
this=is_observation_expr, expression=expr
342-
)
343-
is_observation = ir._cast(is_observation_expr, "INT64")
344-
observation_count = windows.apply_window_if_present(
345-
sge.func("SUM", is_observation), window_spec
346-
)
347-
else:
348-
# Operations like count treat even NULLs as valid observations
349-
# for the sake of min_periods notnull is just used to convert
350-
# null values to non-null (FALSE) values to be counted.
351-
is_observation = ir._cast(
352-
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
353-
"INT64",
354-
)
355-
observation_count = windows.apply_window_if_present(
356-
sge.func("COUNT", is_observation), window_spec
357-
)
358-
359-
clauses.append(
360-
(
361-
observation_count < sge.convert(window_spec.min_periods),
362-
sge.Null(),
359+
# Operations like count treat even NULLs as valid observations
360+
# for the sake of min_periods notnull is just used to convert
361+
# null values to non-null (FALSE) values to be counted.
362+
is_observation = ir._cast(
363+
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
364+
"INT64",
365+
)
366+
observation_count = windows.apply_window_if_present(
367+
sge.func("COUNT", is_observation), window_spec
368+
)
369+
370+
clauses.append(
371+
(
372+
observation_count < sge.convert(window_spec.min_periods),
373+
sge.Null(),
374+
)
363375
)
376+
if clauses:
377+
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
378+
window_op = sge.Case(ifs=when_expressions, default=window_op)
379+
380+
# TODO: check if we can directly window the expression.
381+
result = child.window(
382+
window_op=window_op,
383+
output_column_id=cdef.id.sql,
364384
)
365-
if clauses:
366-
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
367-
window_op = sge.Case(ifs=when_expressions, default=window_op)
368-
369-
# TODO: check if we can directly window the expression.
370-
return child.window(
371-
window_op=window_op,
372-
output_column_id=node.output_name.sql,
373-
)
385+
return result
374386

375387

376388
def _replace_unsupported_ops(node: nodes.BigFrameNode):

0 commit comments

Comments
 (0)