Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion bigframes/session/polars_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@
bigframes.operations.ge_op,
bigframes.operations.le_op,
)
_COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp)
_COMPATIBLE_AGG_OPS = (
agg_ops.SizeOp,
agg_ops.SizeUnaryOp,
agg_ops.MinOp,
agg_ops.MaxOp,
agg_ops.SumOp,
agg_ops.MeanOp,
agg_ops.CountOp,
)


def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]:
Expand Down
9 changes: 6 additions & 3 deletions bigframes/testing/engine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas.testing

from bigframes.core import nodes
from bigframes.session import semi_executor

Expand All @@ -26,6 +28,7 @@ def assert_equivalence_execution(
assert e1_result is not None
assert e2_result is not None
# Schemas might have extra nullity markers, normalize to node expected schema, which should be looser
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this comment no longer applies since you aren't casting here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

e1_table = e1_result.to_arrow_table().cast(node.schema.to_pyarrow())
e2_table = e2_result.to_arrow_table().cast(node.schema.to_pyarrow())
assert e1_table.equals(e2_table), f"{e1_table} is not equal to {e2_table}"
assert e1_result.schema == e2_result.schema
e1_table = e1_result.to_pandas()
e2_table = e2_result.to_pandas()
pandas.testing.assert_frame_equal(e1_table, e2_table, rtol=1e-10)
33 changes: 33 additions & 0 deletions tests/system/small/engines/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
REFERENCE_ENGINE = polars_executor.PolarsExecutor()


def apply_agg_to_all_valid(
array: array_value.ArrayValue, op: agg_ops.UnaryAggregateOp, excluded_cols=[]
) -> array_value.ArrayValue:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make some assumptions as to this helper function's purpose from the name, but a docstring would be lovely to clarify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added simple docstring

exprs_by_name = []
for arg in array.column_ids:
if arg in excluded_cols:
continue
try:
_ = op.output_type(array.get_column_type(arg))
expr = expression.UnaryAggregation(op, expression.deref(arg))
name = f"{arg}-{op.name}"
exprs_by_name.append((expr, name))
except TypeError:
continue
assert len(exprs_by_name) > 0
new_arr = array.aggregate(exprs_by_name)
return new_arr


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_aggregate_size(
scalars_array_value: array_value.ArrayValue,
Expand All @@ -48,6 +67,20 @@ def test_engines_aggregate_size(
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize(
"op",
[agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op],
)
def test_engines_unary_aggregates(
scalars_array_value: array_value.ArrayValue,
engine,
op,
):
node = apply_agg_to_all_valid(scalars_array_value, op).node
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize(
"grouping_cols",
Expand Down