Skip to content

Commit 027d406

Browse files
committed
feat: add GroupBy.size() to get number of rows in each group
1 parent 21b2188 commit 027d406

File tree

7 files changed

+181
-24
lines changed

7 files changed

+181
-24
lines changed

bigframes/core/blocks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,35 @@ def aggregate_all_and_stack(
933933
index_labels=self.index.names,
934934
)
935935

936+
def aggregate_size(
937+
self,
938+
by_column_ids: typing.Sequence[str] = (),
939+
*,
940+
dropna: bool = True,
941+
):
942+
"""Returns a block object to compute the size(s) of groups."""
943+
agg_specs = [
944+
(ex.NullaryAggregation(agg_ops.SizeOp()), guid.generate_guid()),
945+
]
946+
output_col_ids = [agg_spec[1] for agg_spec in agg_specs]
947+
result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna)
948+
aggregate_labels = self._get_labels_for_columns(["size"])
949+
names: typing.List[Label] = []
950+
for by_col_id in by_column_ids:
951+
if by_col_id in self.value_columns:
952+
names.append(self.col_id_to_label[by_col_id])
953+
else:
954+
names.append(self.col_id_to_index_name[by_col_id])
955+
return (
956+
Block(
957+
result_expr,
958+
index_columns=by_column_ids,
959+
column_labels=aggregate_labels,
960+
index_labels=names,
961+
),
962+
output_col_ids,
963+
)
964+
936965
def select_column(self, id: str) -> Block:
937966
return self.select_columns([id])
938967

bigframes/core/compile/aggregate_compiler.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def compile_aggregate(
3535
aggregate: ex.Aggregation,
3636
bindings: typing.Dict[str, ibis_types.Value],
3737
) -> ibis_types.Value:
38+
if isinstance(aggregate, ex.NullaryAggregation):
39+
return compile_nullary_agg(aggregate.op)
3840
if isinstance(aggregate, ex.UnaryAggregation):
3941
input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings)
4042
return compile_unary_agg(
@@ -54,7 +56,9 @@ def compile_analytic(
5456
window: window_spec.WindowSpec,
5557
bindings: typing.Dict[str, ibis_types.Value],
5658
) -> ibis_types.Value:
57-
if isinstance(aggregate, ex.UnaryAggregation):
59+
if isinstance(aggregate, ex.NullaryAggregation):
60+
return compile_nullary_agg(aggregate.op, window)
61+
elif isinstance(aggregate, ex.UnaryAggregation):
5862
input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings)
5963
return compile_unary_agg(aggregate.op, input, window)
6064
elif isinstance(aggregate, ex.BinaryAggregation):
@@ -81,6 +85,14 @@ def compile_unary_agg(
8185
raise ValueError(f"Can't compile unrecognized operation: {op}")
8286

8387

88+
@functools.singledispatch
89+
def compile_nullary_agg(
90+
op: agg_ops.WindowOp,
91+
window: Optional[window_spec.WindowSpec] = None,
92+
) -> ibis_types.Value:
93+
raise ValueError(f"Can't compile unrecognized operation: {op}")
94+
95+
8496
def numeric_op(operation):
8597
@functools.wraps(operation)
8698
def constrained_op(op, column: ibis_types.Column, window=None):
@@ -101,6 +113,11 @@ def constrained_op(op, column: ibis_types.Column, window=None):
101113
### Specific Op implementations Below
102114

103115

116+
@compile_nullary_agg.register
117+
def _(op: agg_ops.SizeOp, window=None) -> ibis_types.NumericValue:
118+
return _apply_window_if_present(vendored_ibis_ops.count(1), window)
119+
120+
104121
@compile_unary_agg.register
105122
@numeric_op
106123
def _(

bigframes/core/expression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ class Aggregation(abc.ABC):
4040
op: agg_ops.WindowOp = dataclasses.field()
4141

4242

43+
@dataclasses.dataclass(frozen=True)
44+
class NullaryAggregation(Aggregation):
45+
op: agg_ops.NullaryWindowOp = dataclasses.field()
46+
47+
4348
@dataclasses.dataclass(frozen=True)
4449
class UnaryAggregation(Aggregation):
4550
op: agg_ops.UnaryWindowOp = dataclasses.field()

bigframes/core/groupby/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ def __getitem__(
102102
dropna=self._dropna,
103103
)
104104

105+
def size(self) -> typing.Union[df.DataFrame, series.Series]:
106+
agg_block, _ = self._block.aggregate_size(
107+
by_column_ids=self._by_col_ids,
108+
dropna=self._dropna,
109+
)
110+
agg_block = agg_block.with_column_labels(pd.Index(["size"]))
111+
dataframe = df.DataFrame(agg_block)
112+
113+
if self._as_index:
114+
series = dataframe["size"]
115+
return series.rename(None)
116+
else:
117+
return self._convert_index(dataframe)
118+
105119
def sum(self, numeric_only: bool = False, *args) -> df.DataFrame:
106120
if not numeric_only:
107121
self._raise_on_non_numeric("sum")
@@ -475,6 +489,13 @@ def std(self, *args, **kwargs) -> series.Series:
475489
def var(self, *args, **kwargs) -> series.Series:
476490
return self._aggregate(agg_ops.var_op)
477491

492+
def size(self) -> series.Series:
493+
agg_block, _ = self._block.aggregate_size(
494+
by_column_ids=self._by_col_ids,
495+
dropna=self._dropna,
496+
)
497+
return series.Series(agg_block, name=self._value_name)
498+
478499
def skew(self, *args, **kwargs) -> series.Series:
479500
block = block_ops.skew(self._block, [self._value_column], self._by_col_ids)
480501
return series.Series(block)

bigframes/operations/aggregations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def handles_ties(self):
3333
return False
3434

3535

36+
@dataclasses.dataclass(frozen=True)
37+
class NullaryWindowOp(WindowOp):
38+
@property
39+
def arguments(self) -> int:
40+
return 0
41+
42+
3643
@dataclasses.dataclass(frozen=True)
3744
class UnaryWindowOp(WindowOp):
3845
@property
@@ -55,6 +62,13 @@ def arguments(self) -> int:
5562
...
5663

5764

65+
@dataclasses.dataclass(frozen=True)
66+
class NullaryAggregateOp(AggregateOp, NullaryWindowOp):
67+
@property
68+
def arguments(self) -> int:
69+
return 0
70+
71+
5872
@dataclasses.dataclass(frozen=True)
5973
class UnaryAggregateOp(AggregateOp, UnaryWindowOp):
6074
@property
@@ -69,6 +83,11 @@ def arguments(self) -> int:
6983
return 2
7084

7185

86+
@dataclasses.dataclass(frozen=True)
87+
class SizeOp(NullaryAggregateOp):
88+
name: ClassVar[str] = "size"
89+
90+
7291
@dataclasses.dataclass(frozen=True)
7392
class SumOp(UnaryAggregateOp):
7493
name: ClassVar[str] = "sum"
@@ -270,6 +289,7 @@ class CovOp(BinaryAggregateOp):
270289
name: ClassVar[str] = "cov"
271290

272291

292+
size_op = SizeOp()
273293
sum_op = SumOp()
274294
mean_op = MeanOp()
275295
median_op = MedianOp()

tests/system/small/test_groupby.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from tests.system.utils import assert_pandas_df_equal
2020

2121

22+
# =================
23+
# DataFrame.groupby
24+
# =================
25+
2226
@pytest.mark.parametrize(
2327
("operator"),
2428
[
@@ -250,21 +254,26 @@ def test_dataframe_groupby_analytic(
250254
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
251255

252256

253-
def test_series_groupby_skew(scalars_df_index, scalars_pandas_df_index):
254-
bf_result = scalars_df_index.groupby("bool_col")["int64_too"].skew().to_pandas()
255-
pd_result = scalars_pandas_df_index.groupby("bool_col")["int64_too"].skew()
257+
def test_dataframe_groupby_size_as_index_false(
258+
scalars_df_index, scalars_pandas_df_index
259+
):
260+
bf_result = scalars_df_index.groupby("string_col", as_index=False).size()
261+
bf_result_computed = bf_result.to_pandas()
262+
pd_result = scalars_pandas_df_index.groupby("string_col", as_index=False).size()
256263

257-
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
264+
pd.testing.assert_frame_equal(
265+
pd_result, bf_result_computed, check_dtype=False, check_index_type=False
266+
)
258267

259268

260-
def test_series_groupby_kurt(scalars_df_index, scalars_pandas_df_index):
261-
bf_result = scalars_df_index.groupby("bool_col")["int64_too"].kurt().to_pandas()
262-
# Pandas doesn't have groupby.kurt yet: https://github.com/pandas-dev/pandas/issues/40139
263-
pd_result = scalars_pandas_df_index.groupby("bool_col")["int64_too"].apply(
264-
pd.Series.kurt
265-
)
269+
def test_dataframe_groupby_size_as_index_true(
270+
scalars_df_index, scalars_pandas_df_index
271+
):
272+
bf_result = scalars_df_index.groupby("string_col", as_index=True).size()
273+
pd_result = scalars_pandas_df_index.groupby("string_col", as_index=True).size()
274+
bf_result_computed = bf_result.to_pandas()
266275

267-
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
276+
pd.testing.assert_series_equal(pd_result, bf_result_computed, check_dtype=False)
268277

269278

270279
def test_dataframe_groupby_skew(scalars_df_index, scalars_pandas_df_index):
@@ -337,6 +346,26 @@ def test_dataframe_groupby_getitem_list(
337346
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
338347

339348

349+
def test_dataframe_groupby_nonnumeric_with_mean():
350+
df = pd.DataFrame(
351+
{
352+
"key1": ["a", "a", "a", "b"],
353+
"key2": ["a", "a", "c", "c"],
354+
"key3": [1, 2, 3, 4],
355+
"key4": [1.6, 2, 3, 4],
356+
}
357+
)
358+
pd_result = df.groupby(["key1", "key2"]).mean()
359+
bf_result = bpd.DataFrame(df).groupby(["key1", "key2"]).mean().to_pandas()
360+
361+
pd.testing.assert_frame_equal(
362+
pd_result, bf_result, check_index_type=False, check_dtype=False
363+
)
364+
365+
# ==============
366+
# Series.groupby
367+
# ==============
368+
340369
def test_series_groupby_agg_string(scalars_df_index, scalars_pandas_df_index):
341370
bf_result = (
342371
scalars_df_index["int64_col"]
@@ -373,18 +402,46 @@ def test_series_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
373402
)
374403

375404

376-
def test_dataframe_groupby_nonnumeric_with_mean():
377-
df = pd.DataFrame(
378-
{
379-
"key1": ["a", "a", "a", "b"],
380-
"key2": ["a", "a", "c", "c"],
381-
"key3": [1, 2, 3, 4],
382-
"key4": [1.6, 2, 3, 4],
383-
}
405+
def test_series_groupby_kurt(scalars_df_index, scalars_pandas_df_index):
406+
bf_result = (
407+
scalars_df_index["int64_too"]
408+
.groupby(scalars_df_index["bool_col"])
409+
.kurt()
410+
.to_pandas()
411+
)
412+
# Pandas doesn't have groupby.kurt yet: https://github.com/pandas-dev/pandas/issues/40139
413+
pd_result = scalars_pandas_df_index.groupby("bool_col")["int64_too"].apply(
414+
pd.Series.kurt
384415
)
385-
pd_result = df.groupby(["key1", "key2"]).mean()
386-
bf_result = bpd.DataFrame(df).groupby(["key1", "key2"]).mean().to_pandas()
387416

388-
pd.testing.assert_frame_equal(
389-
pd_result, bf_result, check_index_type=False, check_dtype=False
417+
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
418+
419+
420+
def test_series_groupby_size(scalars_df_index, scalars_pandas_df_index):
421+
bf_result = (
422+
scalars_df_index["int64_too"].groupby(scalars_df_index["bool_col"]).size()
390423
)
424+
pd_result = (
425+
scalars_pandas_df_index["int64_too"]
426+
.groupby(scalars_pandas_df_index["bool_col"])
427+
.size()
428+
)
429+
bf_result_computed = bf_result.to_pandas()
430+
431+
pd.testing.assert_series_equal(pd_result, bf_result_computed, check_dtype=False)
432+
433+
434+
def test_series_groupby_skew(scalars_df_index, scalars_pandas_df_index):
435+
bf_result = (
436+
scalars_df_index["int64_too"]
437+
.groupby(scalars_df_index["bool_col"])
438+
.skew()
439+
.to_pandas()
440+
)
441+
pd_result = (
442+
scalars_pandas_df_index["int64_too"]
443+
.groupby(scalars_pandas_df_index["bool_col"])
444+
.skew()
445+
)
446+
447+
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

third_party/bigframes_vendored/ibis/expr/operations/analytic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
from __future__ import annotations
44

5+
import ibis
56
import ibis.expr.operations as ops
67
import ibis.expr.rules as rlz
78

89

10+
@ibis.udf.agg.builtin
11+
def count(value: int) -> int:
12+
"""Count of a scalar."""
13+
return 0 # pragma: NO COVER
14+
15+
916
class FirstNonNullValue(ops.Analytic):
1017
"""Retrieve the first element."""
1118

@@ -21,6 +28,7 @@ class LastNonNullValue(ops.Analytic):
2128

2229

2330
__all__ = [
31+
"count",
2432
"FirstNonNullValue",
2533
"LastNonNullValue",
2634
]

0 commit comments

Comments
 (0)