Skip to content

Commit 6eab121

Browse files
authored
refactor: add parenthesization for binary operations (#2193)
1 parent 4f568b1 commit 6eab121

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ class ScalarOpCompiler:
3131
typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression],
3232
] = {}
3333

34+
# A set of SQLGlot classes that may need to be parenthesized
35+
SQLGLOT_NEEDS_PARENS = {
36+
# Numeric operations
37+
sge.Add,
38+
sge.Sub,
39+
sge.Mul,
40+
sge.Div,
41+
sge.Mod,
42+
sge.Pow,
43+
# Comparison operations
44+
sge.GTE,
45+
sge.GT,
46+
sge.LTE,
47+
sge.LT,
48+
sge.EQ,
49+
sge.NEQ,
50+
# Logical operations
51+
sge.And,
52+
sge.Or,
53+
sge.Xor,
54+
# Bitwise operations
55+
sge.BitwiseAnd,
56+
sge.BitwiseOr,
57+
sge.BitwiseXor,
58+
sge.BitwiseLeftShift,
59+
sge.BitwiseRightShift,
60+
sge.BitwiseNot,
61+
# Other operations
62+
sge.Is,
63+
}
64+
3465
@functools.singledispatchmethod
3566
def compile_expression(
3667
self,
@@ -110,10 +141,12 @@ def register_binary_op(
110141

111142
def decorator(impl: typing.Callable[..., sge.Expression]):
112143
def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp):
144+
left = self._add_parentheses(args[0])
145+
right = self._add_parentheses(args[1])
113146
if pass_op:
114-
return impl(args[0], args[1], op)
147+
return impl(left, right, op)
115148
else:
116-
return impl(args[0], args[1])
149+
return impl(left, right)
117150

118151
self._register(key, normalized_impl)
119152
return impl
@@ -177,6 +210,12 @@ def _register(
177210
raise ValueError(f"Operation name {op_name} already registered")
178211
self._registry[op_name] = impl
179212

213+
@classmethod
214+
def _add_parentheses(cls, expr: TypedExpr) -> TypedExpr:
215+
if type(expr.expr) in cls.SQLGLOT_NEEDS_PARENS:
216+
return TypedExpr(sge.paren(expr.expr, copy=False), expr.dtype)
217+
return expr
218+
180219

181220
# Singleton compiler
182221
scalar_op_compiler = ScalarOpCompiler()

tests/unit/core/compile/sqlglot/test_scalar_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,43 @@ def _(*args: TypedExpr, op: ops.NaryOp) -> sge.Expression:
170170
mock_impl.assert_called_once_with(arg1, arg2, arg3, arg4, op=mock_op)
171171

172172

173+
def test_binary_op_parentheses():
174+
compiler = scalar_compiler.ScalarOpCompiler()
175+
176+
class MockAddOp(ops.BinaryOp):
177+
name = "mock_add_op"
178+
179+
class MockMulOp(ops.BinaryOp):
180+
name = "mock_mul_op"
181+
182+
add_op = MockAddOp()
183+
mul_op = MockMulOp()
184+
185+
@compiler.register_binary_op(add_op)
186+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
187+
return sge.Add(this=left.expr, expression=right.expr)
188+
189+
@compiler.register_binary_op(mul_op)
190+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
191+
return sge.Mul(this=left.expr, expression=right.expr)
192+
193+
a = TypedExpr(sge.Identifier(this="a"), "int")
194+
b = TypedExpr(sge.Identifier(this="b"), "int")
195+
c = TypedExpr(sge.Identifier(this="c"), "int")
196+
197+
# (a + b) * c
198+
add_expr = compiler.compile_row_op(add_op, [a, b])
199+
add_typed_expr = TypedExpr(add_expr, "int")
200+
result1 = compiler.compile_row_op(mul_op, [add_typed_expr, c])
201+
assert result1.sql() == "(a + b) * c"
202+
203+
# a * (b + c)
204+
add_expr_2 = compiler.compile_row_op(add_op, [b, c])
205+
add_typed_expr_2 = TypedExpr(add_expr_2, "int")
206+
result2 = compiler.compile_row_op(mul_op, [a, add_typed_expr_2])
207+
assert result2.sql() == "a * (b + c)"
208+
209+
173210
def test_register_duplicate_op_raises():
174211
compiler = scalar_compiler.ScalarOpCompiler()
175212

0 commit comments

Comments
 (0)