|
19 | 19 |
|
20 | 20 | import sqlglot.expressions as sge |
21 | 21 |
|
22 | | -from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite |
| 22 | +from bigframes.core import ( |
| 23 | + agg_expressions, |
| 24 | + expression, |
| 25 | + guid, |
| 26 | + identifiers, |
| 27 | + nodes, |
| 28 | + pyarrow_utils, |
| 29 | + rewrite, |
| 30 | +) |
23 | 31 | from bigframes.core.compile import configs |
24 | 32 | import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler |
25 | 33 | from bigframes.core.compile.sqlglot.aggregations import windows |
@@ -310,67 +318,71 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG |
310 | 318 | @_compile_node.register |
311 | 319 | def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: |
312 | 320 | window_spec = node.window_spec |
313 | | - if node.expression.op.order_independent and window_spec.is_unbounded: |
314 | | - # notably percentile_cont does not support ordering clause |
315 | | - window_spec = window_spec.without_order() |
| 321 | + result = child |
| 322 | + for cdef in node.agg_exprs: |
| 323 | + assert isinstance(cdef.expression, agg_expressions.Aggregation) |
| 324 | + if cdef.expression.op.order_independent and window_spec.is_unbounded: |
| 325 | + # notably percentile_cont does not support ordering clause |
| 326 | + window_spec = window_spec.without_order() |
316 | 327 |
|
317 | | - window_op = aggregate_compiler.compile_analytic(node.expression, window_spec) |
| 328 | + window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec) |
318 | 329 |
|
319 | | - inputs: tuple[sge.Expression, ...] = tuple( |
320 | | - scalar_compiler.scalar_op_compiler.compile_expression( |
321 | | - expression.DerefOp(column) |
| 330 | + inputs: tuple[sge.Expression, ...] = tuple( |
| 331 | + scalar_compiler.scalar_op_compiler.compile_expression( |
| 332 | + expression.DerefOp(column) |
| 333 | + ) |
| 334 | + for column in cdef.expression.column_references |
322 | 335 | ) |
323 | | - for column in node.expression.column_references |
324 | | - ) |
325 | 336 |
|
326 | | - clauses: list[tuple[sge.Expression, sge.Expression]] = [] |
327 | | - if window_spec.min_periods and len(inputs) > 0: |
328 | | - if not node.expression.op.nulls_count_for_min_values: |
329 | | - # Most operations do not count NULL values towards min_periods |
330 | | - not_null_columns = [ |
331 | | - sge.Not(this=sge.Is(this=column, expression=sge.Null())) |
332 | | - for column in inputs |
333 | | - ] |
334 | | - # All inputs must be non-null for observation to count |
335 | | - if not not_null_columns: |
336 | | - is_observation_expr: sge.Expression = sge.convert(True) |
| 337 | + clauses: list[tuple[sge.Expression, sge.Expression]] = [] |
| 338 | + if window_spec.min_periods and len(inputs) > 0: |
| 339 | + if not cdef.expression.op.nulls_count_for_min_values: |
| 340 | + # Most operations do not count NULL values towards min_periods |
| 341 | + not_null_columns = [ |
| 342 | + sge.Not(this=sge.Is(this=column, expression=sge.Null())) |
| 343 | + for column in inputs |
| 344 | + ] |
| 345 | + # All inputs must be non-null for observation to count |
| 346 | + if not not_null_columns: |
| 347 | + is_observation_expr: sge.Expression = sge.convert(True) |
| 348 | + else: |
| 349 | + is_observation_expr = not_null_columns[0] |
| 350 | + for expr in not_null_columns[1:]: |
| 351 | + is_observation_expr = sge.And( |
| 352 | + this=is_observation_expr, expression=expr |
| 353 | + ) |
| 354 | + is_observation = ir._cast(is_observation_expr, "INT64") |
| 355 | + observation_count = windows.apply_window_if_present( |
| 356 | + sge.func("SUM", is_observation), window_spec |
| 357 | + ) |
337 | 358 | else: |
338 | | - is_observation_expr = not_null_columns[0] |
339 | | - for expr in not_null_columns[1:]: |
340 | | - is_observation_expr = sge.And( |
341 | | - this=is_observation_expr, expression=expr |
342 | | - ) |
343 | | - is_observation = ir._cast(is_observation_expr, "INT64") |
344 | | - observation_count = windows.apply_window_if_present( |
345 | | - sge.func("SUM", is_observation), window_spec |
346 | | - ) |
347 | | - else: |
348 | | - # Operations like count treat even NULLs as valid observations |
349 | | - # for the sake of min_periods notnull is just used to convert |
350 | | - # null values to non-null (FALSE) values to be counted. |
351 | | - is_observation = ir._cast( |
352 | | - sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), |
353 | | - "INT64", |
354 | | - ) |
355 | | - observation_count = windows.apply_window_if_present( |
356 | | - sge.func("COUNT", is_observation), window_spec |
357 | | - ) |
358 | | - |
359 | | - clauses.append( |
360 | | - ( |
361 | | - observation_count < sge.convert(window_spec.min_periods), |
362 | | - sge.Null(), |
| 359 | + # Operations like count treat even NULLs as valid observations |
| 360 | + # for the sake of min_periods notnull is just used to convert |
| 361 | + # null values to non-null (FALSE) values to be counted. |
| 362 | + is_observation = ir._cast( |
| 363 | + sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), |
| 364 | + "INT64", |
| 365 | + ) |
| 366 | + observation_count = windows.apply_window_if_present( |
| 367 | + sge.func("COUNT", is_observation), window_spec |
| 368 | + ) |
| 369 | + |
| 370 | + clauses.append( |
| 371 | + ( |
| 372 | + observation_count < sge.convert(window_spec.min_periods), |
| 373 | + sge.Null(), |
| 374 | + ) |
363 | 375 | ) |
| 376 | + if clauses: |
| 377 | + when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] |
| 378 | + window_op = sge.Case(ifs=when_expressions, default=window_op) |
| 379 | + |
| 380 | + # TODO: check if we can directly window the expression. |
| 381 | + result = child.window( |
| 382 | + window_op=window_op, |
| 383 | + output_column_id=cdef.id.sql, |
364 | 384 | ) |
365 | | - if clauses: |
366 | | - when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] |
367 | | - window_op = sge.Case(ifs=when_expressions, default=window_op) |
368 | | - |
369 | | - # TODO: check if we can directly window the expression. |
370 | | - return child.window( |
371 | | - window_op=window_op, |
372 | | - output_column_id=node.output_name.sql, |
373 | | - ) |
| 385 | + return result |
374 | 386 |
|
375 | 387 |
|
376 | 388 | def _replace_unsupported_ops(node: nodes.BigFrameNode): |
|
0 commit comments