Skip to content

Commit c35e21f

Browse files
chelsea-linsycai
andauthored
refactor: add agg_ops.CutOp to the sqlglot compiler (#2268)
Fixes internal issue 445774480🦕 --------- Co-authored-by: Shenyang Cai <sycai@users.noreply.github.com>
1 parent e39dfe2 commit c35e21f

File tree

6 files changed

+278
-0
lines changed

6 files changed

+278
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,140 @@ def _(
111111
return apply_window_if_present(sge.func("COUNT", column.expr), window)
112112

113113

114+
@UNARY_OP_REGISTRATION.register(agg_ops.CutOp)
115+
def _(
116+
op: agg_ops.CutOp,
117+
column: typed_expr.TypedExpr,
118+
window: typing.Optional[window_spec.WindowSpec] = None,
119+
) -> sge.Expression:
120+
if isinstance(op.bins, int):
121+
case_expr = _cut_ops_w_int_bins(op, column, op.bins, window)
122+
else: # Interpret as intervals
123+
case_expr = _cut_ops_w_intervals(op, column, op.bins, window)
124+
return case_expr
125+
126+
127+
def _cut_ops_w_int_bins(
128+
op: agg_ops.CutOp,
129+
column: typed_expr.TypedExpr,
130+
bins: int,
131+
window: typing.Optional[window_spec.WindowSpec] = None,
132+
) -> sge.Case:
133+
case_expr = sge.Case()
134+
col_min = apply_window_if_present(
135+
sge.func("MIN", column.expr), window or window_spec.WindowSpec()
136+
)
137+
col_max = apply_window_if_present(
138+
sge.func("MAX", column.expr), window or window_spec.WindowSpec()
139+
)
140+
adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001)
141+
bin_width: sge.Expression = sge.func(
142+
"IEEE_DIVIDE",
143+
sge.Sub(this=col_max, expression=col_min),
144+
sge.convert(bins),
145+
)
146+
147+
for this_bin in range(bins):
148+
value: sge.Expression
149+
if op.labels is False:
150+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
151+
elif isinstance(op.labels, typing.Iterable):
152+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
153+
else:
154+
left_adj: sge.Expression = (
155+
adj if this_bin == 0 and op.right else sge.convert(0)
156+
)
157+
right_adj: sge.Expression = (
158+
adj if this_bin == bins - 1 and not op.right else sge.convert(0)
159+
)
160+
161+
left: sge.Expression = (
162+
col_min + sge.convert(this_bin) * bin_width - left_adj
163+
)
164+
right: sge.Expression = (
165+
col_min + sge.convert(this_bin + 1) * bin_width + right_adj
166+
)
167+
if op.right:
168+
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
169+
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
170+
else:
171+
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
172+
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)
173+
174+
value = sge.Struct(
175+
expressions=[
176+
sge.PropertyEQ(this=left_identifier, expression=left),
177+
sge.PropertyEQ(this=right_identifier, expression=right),
178+
]
179+
)
180+
181+
condition: sge.Expression
182+
if this_bin == bins - 1:
183+
condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null()))
184+
else:
185+
if op.right:
186+
condition = sge.LTE(
187+
this=column.expr,
188+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
189+
)
190+
else:
191+
condition = sge.LT(
192+
this=column.expr,
193+
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
194+
)
195+
case_expr = case_expr.when(condition, value)
196+
return case_expr
197+
198+
199+
def _cut_ops_w_intervals(
200+
op: agg_ops.CutOp,
201+
column: typed_expr.TypedExpr,
202+
bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
203+
window: typing.Optional[window_spec.WindowSpec] = None,
204+
) -> sge.Case:
205+
case_expr = sge.Case()
206+
for this_bin, interval in enumerate(bins):
207+
left: sge.Expression = ir._literal(
208+
interval[0], dtypes.infer_literal_type(interval[0])
209+
)
210+
right: sge.Expression = ir._literal(
211+
interval[1], dtypes.infer_literal_type(interval[1])
212+
)
213+
condition: sge.Expression
214+
if op.right:
215+
condition = sge.And(
216+
this=sge.GT(this=column.expr, expression=left),
217+
expression=sge.LTE(this=column.expr, expression=right),
218+
)
219+
else:
220+
condition = sge.And(
221+
this=sge.GTE(this=column.expr, expression=left),
222+
expression=sge.LT(this=column.expr, expression=right),
223+
)
224+
225+
value: sge.Expression
226+
if op.labels is False:
227+
value = ir._literal(this_bin, dtypes.INT_DTYPE)
228+
elif isinstance(op.labels, typing.Iterable):
229+
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
230+
else:
231+
if op.right:
232+
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
233+
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
234+
else:
235+
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
236+
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)
237+
238+
value = sge.Struct(
239+
expressions=[
240+
sge.PropertyEQ(this=left_identifier, expression=left),
241+
sge.PropertyEQ(this=right_identifier, expression=right),
242+
]
243+
)
244+
case_expr = case_expr.when(condition, value)
245+
return case_expr
246+
247+
114248
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
115249
def _(
116250
op: agg_ops.DateSeriesDiffOp,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
10+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
11+
)
12+
THEN STRUCT(
13+
(
14+
MIN(`int64_col`) OVER () + (
15+
0 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
16+
)
17+
) - (
18+
(
19+
MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER ()
20+
) * 0.001
21+
) AS `left_exclusive`,
22+
MIN(`int64_col`) OVER () + (
23+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
24+
) + 0 AS `right_inclusive`
25+
)
26+
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
27+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
28+
)
29+
THEN STRUCT(
30+
(
31+
MIN(`int64_col`) OVER () + (
32+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
33+
)
34+
) - 0 AS `left_exclusive`,
35+
MIN(`int64_col`) OVER () + (
36+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
37+
) + 0 AS `right_inclusive`
38+
)
39+
WHEN `int64_col` IS NOT NULL
40+
THEN STRUCT(
41+
(
42+
MIN(`int64_col`) OVER () + (
43+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
44+
)
45+
) - 0 AS `left_exclusive`,
46+
MIN(`int64_col`) OVER () + (
47+
3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
48+
) + 0 AS `right_inclusive`
49+
)
50+
END AS `bfcol_1`
51+
FROM `bfcte_0`
52+
)
53+
SELECT
54+
`bfcol_1` AS `int_bins`
55+
FROM `bfcte_1`
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
10+
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
11+
)
12+
THEN 'a'
13+
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
14+
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
15+
)
16+
THEN 'b'
17+
WHEN `int64_col` IS NOT NULL
18+
THEN 'c'
19+
END AS `bfcol_1`
20+
FROM `bfcte_0`
21+
)
22+
SELECT
23+
`bfcol_1` AS `int_bins_labels`
24+
FROM `bfcte_1`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `int64_col` > 0 AND `int64_col` <= 1
10+
THEN STRUCT(0 AS `left_exclusive`, 1 AS `right_inclusive`)
11+
WHEN `int64_col` > 1 AND `int64_col` <= 2
12+
THEN STRUCT(1 AS `left_exclusive`, 2 AS `right_inclusive`)
13+
END AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `interval_bins`
18+
FROM `bfcte_1`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `int64_col` > 0 AND `int64_col` <= 1
10+
THEN 0
11+
WHEN `int64_col` > 1 AND `int64_col` <= 2
12+
THEN 1
13+
END AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `interval_bins_labels`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,35 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot):
174174
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
175175

176176

177+
def test_cut(scalar_types_df: bpd.DataFrame, snapshot):
178+
col_name = "int64_col"
179+
bf_df = scalar_types_df[[col_name]]
180+
agg_ops_map = {
181+
"int_bins": agg_exprs.UnaryAggregation(
182+
agg_ops.CutOp(bins=3, right=True, labels=None), expression.deref(col_name)
183+
),
184+
"interval_bins": agg_exprs.UnaryAggregation(
185+
agg_ops.CutOp(bins=((0, 1), (1, 2)), right=True, labels=None),
186+
expression.deref(col_name),
187+
),
188+
"int_bins_labels": agg_exprs.UnaryAggregation(
189+
agg_ops.CutOp(bins=3, labels=("a", "b", "c"), right=False),
190+
expression.deref(col_name),
191+
),
192+
"interval_bins_labels": agg_exprs.UnaryAggregation(
193+
agg_ops.CutOp(bins=((0, 1), (1, 2)), labels=False, right=True),
194+
expression.deref(col_name),
195+
),
196+
}
197+
window = window_spec.WindowSpec()
198+
199+
# Loop through the aggregation map items
200+
for test_name, agg_expr in agg_ops_map.items():
201+
sql = _apply_unary_window_op(bf_df, agg_expr, window, test_name)
202+
203+
snapshot.assert_match(sql, f"{test_name}.sql")
204+
205+
177206
def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
178207
col_name = "int64_col"
179208
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)