Skip to content

Commit b0f16e6

Browse files
authored
refactor: add ToArrayOp and ArrayReduceOp to the sqlglot compiler (#2263)
Fixes internal issue 446726636 🦕
1 parent 95a83f7 commit b0f16e6

File tree

10 files changed

+216
-11
lines changed

10 files changed

+216
-11
lines changed

bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def compile(
2727
op: agg_ops.WindowOp,
2828
column: typed_expr.TypedExpr,
2929
*,
30-
order_by: tuple[sge.Expression, ...],
30+
order_by: tuple[sge.Expression, ...] = (),
3131
) -> sge.Expression:
3232
return ORDERED_UNARY_OP_REGISTRATION[op](op, column, order_by=order_by)
3333

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ def _(
4949
return sge.func("IFNULL", result, sge.true())
5050

5151

52+
@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
53+
def _(
54+
op: agg_ops.AnyOp,
55+
column: typed_expr.TypedExpr,
56+
window: typing.Optional[window_spec.WindowSpec] = None,
57+
) -> sge.Expression:
58+
expr = column.expr
59+
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)
60+
61+
# BQ will return null for empty column, result would be false in pandas.
62+
return sge.func("COALESCE", expr, sge.convert(False))
63+
64+
5265
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
5366
def _(
5467
op: agg_ops.ApproxQuartilesOp,

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,16 @@
1616

1717
import typing
1818

19-
import sqlglot
19+
import sqlglot as sg
2020
import sqlglot.expressions as sge
2121

2222
from bigframes import operations as ops
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
25+
import bigframes.dtypes as dtypes
2526

2627
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
27-
28-
29-
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
30-
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
31-
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
28+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
3229

3330

3431
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
@@ -41,17 +38,45 @@ def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
4138
)
4239

4340

41+
@register_unary_op(ops.ArrayReduceOp, pass_op=True)
42+
def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
43+
sub_expr = sg.to_identifier("bf_arr_reduce_uid")
44+
sub_type = dtypes.get_array_inner_type(expr.dtype)
45+
46+
if op.aggregation.order_independent:
47+
from bigframes.core.compile.sqlglot.aggregations import unary_compiler
48+
49+
agg_expr = unary_compiler.compile(op.aggregation, TypedExpr(sub_expr, sub_type))
50+
else:
51+
from bigframes.core.compile.sqlglot.aggregations import ordered_unary_compiler
52+
53+
agg_expr = ordered_unary_compiler.compile(
54+
op.aggregation, TypedExpr(sub_expr, sub_type)
55+
)
56+
57+
return (
58+
sge.select(agg_expr)
59+
.from_(
60+
sge.Unnest(
61+
expressions=[expr.expr],
62+
alias=sge.TableAlias(columns=[sub_expr]),
63+
)
64+
)
65+
.subquery()
66+
)
67+
68+
4469
@register_unary_op(ops.ArraySliceOp, pass_op=True)
4570
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
46-
slice_idx = sqlglot.to_identifier("slice_idx")
71+
slice_idx = sg.to_identifier("slice_idx")
4772

4873
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
4974

5075
if op.stop is not None:
5176
conditions.append(slice_idx < op.stop)
5277

5378
# local name for each element in the array
54-
el = sqlglot.to_identifier("el")
79+
el = sg.to_identifier("el")
5580

5681
selected_elements = (
5782
sge.select(el)
@@ -66,3 +91,27 @@ def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
6691
)
6792

6893
return sge.array(selected_elements)
94+
95+
96+
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
97+
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
98+
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
99+
100+
101+
@register_nary_op(ops.ToArrayOp)
102+
def _(*exprs: TypedExpr) -> sge.Expression:
103+
do_upcast_bool = any(
104+
dtypes.is_numeric(expr.dtype, include_bool=False) for expr in exprs
105+
)
106+
if do_upcast_bool:
107+
sg_exprs = [_coerce_bool_to_int(expr) for expr in exprs]
108+
else:
109+
sg_exprs = [expr.expr for expr in exprs]
110+
return sge.Array(expressions=sg_exprs)
111+
112+
113+
def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
114+
"""Coerce boolean expression to integer."""
115+
if typed_expr.dtype == dtypes.BOOL_DTYPE:
116+
return sge.Cast(this=typed_expr.expr, to="INT64")
117+
return typed_expr.expr

tests/system/small/engines/test_array_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2727

2828

29-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
29+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
3030
def test_engines_to_array_op(scalars_array_value: array_value.ArrayValue, engine):
3131
# Bigquery won't allow you to materialize arrays with null, so use non-nullable
3232
int64_non_null = ops.coalesce_op.as_expr("int64_col", expression.const(0))
@@ -46,7 +46,7 @@ def test_engines_to_array_op(scalars_array_value: array_value.ArrayValue, engine
4646
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
4747

4848

49-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
49+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5050
def test_engines_array_reduce_op(arrays_array_value: array_value.ArrayValue, engine):
5151
arr, _ = arrays_array_value.compute_values(
5252
[
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `bool_col`
12+
FROM `bfcte_1`
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `bool_col` IS NULL
10+
THEN NULL
11+
ELSE COALESCE(LOGICAL_OR(`bool_col`) OVER (), FALSE)
12+
END AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `agg_bool`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ def test_all(scalar_types_df: bpd.DataFrame, snapshot):
8888
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
8989

9090

91+
def test_any(scalar_types_df: bpd.DataFrame, snapshot):
92+
col_name = "bool_col"
93+
bf_df = scalar_types_df[[col_name]]
94+
agg_expr = agg_ops.AnyOp().as_expr(col_name)
95+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
96+
97+
snapshot.assert_match(sql, "out.sql")
98+
99+
# Window tests
100+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
101+
sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool")
102+
snapshot.assert_match(sql_window, "window_out.sql")
103+
104+
91105
def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
92106
col_name = "int64_col"
93107
bf_df = scalar_types_df[[col_name]]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_list_col`,
4+
`float_list_col`,
5+
`string_list_col`
6+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
(
11+
SELECT
12+
COALESCE(SUM(bf_arr_reduce_uid), 0)
13+
FROM UNNEST(`float_list_col`) AS bf_arr_reduce_uid
14+
) AS `bfcol_3`,
15+
(
16+
SELECT
17+
STDDEV(bf_arr_reduce_uid)
18+
FROM UNNEST(`float_list_col`) AS bf_arr_reduce_uid
19+
) AS `bfcol_4`,
20+
(
21+
SELECT
22+
COUNT(bf_arr_reduce_uid)
23+
FROM UNNEST(`string_list_col`) AS bf_arr_reduce_uid
24+
) AS `bfcol_5`,
25+
(
26+
SELECT
27+
COALESCE(LOGICAL_OR(bf_arr_reduce_uid), FALSE)
28+
FROM UNNEST(`bool_list_col`) AS bf_arr_reduce_uid
29+
) AS `bfcol_6`
30+
FROM `bfcte_0`
31+
)
32+
SELECT
33+
`bfcol_3` AS `sum_float`,
34+
`bfcol_4` AS `std_float`,
35+
`bfcol_5` AS `count_str`,
36+
`bfcol_6` AS `any_bool`
37+
FROM `bfcte_1`
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`,
4+
`float64_col`,
5+
`int64_col`,
6+
`string_col`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
[COALESCE(`bool_col`, FALSE)] AS `bfcol_8`,
12+
[COALESCE(`int64_col`, 0)] AS `bfcol_9`,
13+
[COALESCE(`string_col`, ''), COALESCE(`string_col`, '')] AS `bfcol_10`,
14+
[
15+
COALESCE(`int64_col`, 0),
16+
CAST(COALESCE(`bool_col`, FALSE) AS INT64),
17+
COALESCE(`float64_col`, 0.0)
18+
] AS `bfcol_11`
19+
FROM `bfcte_0`
20+
)
21+
SELECT
22+
`bfcol_8` AS `bool_col`,
23+
`bfcol_9` AS `int64_col`,
24+
`bfcol_10` AS `strs_col`,
25+
`bfcol_11` AS `numeric_col`
26+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_array_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import pytest
1616

1717
from bigframes import operations as ops
18+
from bigframes.core import expression
1819
from bigframes.operations._op_converters import convert_index, convert_slice
20+
import bigframes.operations.aggregations as agg_ops
1921
import bigframes.pandas as bpd
2022
from bigframes.testing import utils
2123

@@ -42,6 +44,20 @@ def test_array_index(repeated_types_df: bpd.DataFrame, snapshot):
4244
snapshot.assert_match(sql, "out.sql")
4345

4446

47+
def test_array_reduce_op(repeated_types_df: bpd.DataFrame, snapshot):
48+
ops_map = {
49+
"sum_float": ops.ArrayReduceOp(agg_ops.SumOp()).as_expr("float_list_col"),
50+
"std_float": ops.ArrayReduceOp(agg_ops.StdOp()).as_expr("float_list_col"),
51+
"count_str": ops.ArrayReduceOp(agg_ops.CountOp()).as_expr("string_list_col"),
52+
"any_bool": ops.ArrayReduceOp(agg_ops.AnyOp()).as_expr("bool_list_col"),
53+
}
54+
55+
sql = utils._apply_ops_to_sql(
56+
repeated_types_df, list(ops_map.values()), list(ops_map.keys())
57+
)
58+
snapshot.assert_match(sql, "out.sql")
59+
60+
4561
def test_array_slice_with_only_start(repeated_types_df: bpd.DataFrame, snapshot):
4662
col_name = "string_list_col"
4763
bf_df = repeated_types_df[[col_name]]
@@ -60,3 +76,24 @@ def test_array_slice_with_start_and_stop(repeated_types_df: bpd.DataFrame, snaps
6076
)
6177

6278
snapshot.assert_match(sql, "out.sql")
79+
80+
81+
def test_to_array_op(scalar_types_df: bpd.DataFrame, snapshot):
82+
bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col", "string_col"]]
83+
# Bigquery won't allow you to materialize arrays with null, so use non-nullable
84+
int64_non_null = ops.coalesce_op.as_expr("int64_col", expression.const(0))
85+
bool_col_non_null = ops.coalesce_op.as_expr("bool_col", expression.const(False))
86+
float_col_non_null = ops.coalesce_op.as_expr("float64_col", expression.const(0.0))
87+
string_col_non_null = ops.coalesce_op.as_expr("string_col", expression.const(""))
88+
89+
ops_map = {
90+
"bool_col": ops.ToArrayOp().as_expr(bool_col_non_null),
91+
"int64_col": ops.ToArrayOp().as_expr(int64_non_null),
92+
"strs_col": ops.ToArrayOp().as_expr(string_col_non_null, string_col_non_null),
93+
"numeric_col": ops.ToArrayOp().as_expr(
94+
int64_non_null, bool_col_non_null, float_col_non_null
95+
),
96+
}
97+
98+
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
99+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)