1515from __future__ import annotations
1616
1717import typing
18- from typing import Sequence , Union
18+ from typing import Sequence , Tuple , Union
1919
2020import bigframes_vendored .constants as constants
2121import bigframes_vendored .pandas .core .groupby as vendored_pandas_groupby
2626import bigframes .core as core
2727import bigframes .core .block_transforms as block_ops
2828import bigframes .core .blocks as blocks
29+ import bigframes .core .expression
2930import bigframes .core .ordering as order
3031import bigframes .core .utils as utils
3132import bigframes .core .validations as validations
@@ -334,24 +335,19 @@ def agg(self, func=None, **kwargs) -> typing.Union[df.DataFrame, series.Series]:
334335 return self ._agg_named (** kwargs )
335336
336337 def _agg_string (self , func : str ) -> df .DataFrame :
337- aggregations = [
338- (col_id , agg_ops .lookup_agg_func (func ))
339- for col_id in self ._aggregated_columns ()
340- ]
338+ ids , labels = self ._aggregated_columns ()
339+ aggregations = [agg (col_id , agg_ops .lookup_agg_func (func )) for col_id in ids ]
341340 agg_block , _ = self ._block .aggregate (
342341 by_column_ids = self ._by_col_ids ,
343342 aggregations = aggregations ,
344343 dropna = self ._dropna ,
344+ column_labels = labels ,
345345 )
346346 dataframe = df .DataFrame (agg_block )
347347 return dataframe if self ._as_index else self ._convert_index (dataframe )
348348
349349 def _agg_dict (self , func : typing .Mapping ) -> df .DataFrame :
350- aggregations : typing .List [
351- typing .Tuple [
352- str , typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
353- ]
354- ] = []
350+ aggregations : typing .List [bigframes .core .expression .Aggregation ] = []
355351 column_labels = []
356352
357353 want_aggfunc_level = any (utils .is_list_like (aggs ) for aggs in func .values ())
@@ -362,7 +358,7 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
362358 funcs_for_id if utils .is_list_like (funcs_for_id ) else [funcs_for_id ]
363359 )
364360 for f in func_list :
365- aggregations .append ((col_id , agg_ops .lookup_agg_func (f )))
361+ aggregations .append (agg (col_id , agg_ops .lookup_agg_func (f )))
366362 column_labels .append (label )
367363 agg_block , _ = self ._block .aggregate (
368364 by_column_ids = self ._by_col_ids ,
@@ -373,7 +369,10 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
373369 agg_block = agg_block .with_column_labels (
374370 utils .combine_indices (
375371 pd .Index (column_labels ),
376- pd .Index (agg [1 ].name for agg in aggregations ),
372+ pd .Index (
373+ typing .cast (agg_ops .AggregateOp , agg .op ).name
374+ for agg in aggregations
375+ ),
377376 )
378377 )
379378 else :
@@ -382,34 +381,21 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
382381 return dataframe if self ._as_index else self ._convert_index (dataframe )
383382
384383 def _agg_list (self , func : typing .Sequence ) -> df .DataFrame :
384+ ids , labels = self ._aggregated_columns ()
385385 aggregations = [
386- (col_id , agg_ops .lookup_agg_func (f ))
387- for col_id in self ._aggregated_columns ()
388- for f in func
386+ agg (col_id , agg_ops .lookup_agg_func (f )) for col_id in ids for f in func
389387 ]
390388
391389 if self ._block .column_labels .nlevels > 1 :
392390 # Restructure MultiIndex for proper format: (idx1, idx2, func)
393391 # rather than ((idx1, idx2), func).
394- aggregated_columns = pd .MultiIndex .from_tuples (
395- [
396- self ._block .col_id_to_label [col_id ]
397- for col_id in self ._aggregated_columns ()
398- ],
399- names = [* self ._block .column_labels .names ],
400- ).to_frame (index = False )
401-
402392 column_labels = [
403- tuple (col_id ) + (f ,)
404- for col_id in aggregated_columns .to_numpy ()
405- for f in func
406- ]
407- else :
408- column_labels = [
409- (self ._block .col_id_to_label [col_id ], f )
410- for col_id in self ._aggregated_columns ()
393+ tuple (label ) + (f ,)
394+ for label in labels .to_frame (index = False ).to_numpy ()
411395 for f in func
412396 ]
397+ else : # Single-level index
398+ column_labels = [(label , f ) for label in labels for f in func ]
413399
414400 agg_block , _ = self ._block .aggregate (
415401 by_column_ids = self ._by_col_ids ,
@@ -435,7 +421,7 @@ def _agg_named(self, **kwargs) -> df.DataFrame:
435421 if not isinstance (v , tuple ) or (len (v ) != 2 ):
436422 raise TypeError ("kwargs values must be 2-tuples of column, aggfunc" )
437423 col_id = self ._resolve_label (v [0 ])
438- aggregations .append ((col_id , agg_ops .lookup_agg_func (v [1 ])))
424+ aggregations .append (agg (col_id , agg_ops .lookup_agg_func (v [1 ])))
439425 column_labels .append (k )
440426 agg_block , _ = self ._block .aggregate (
441427 by_column_ids = self ._by_col_ids ,
@@ -470,15 +456,19 @@ def _raise_on_non_numeric(self, op: str):
470456 )
471457 return self
472458
473- def _aggregated_columns (self , numeric_only : bool = False ) -> typing .Sequence [str ]:
459+ def _aggregated_columns (
460+ self , numeric_only : bool = False
461+ ) -> Tuple [typing .Sequence [str ], pd .Index ]:
474462 valid_agg_cols : list [str ] = []
475- for col_id in self ._selected_cols :
463+ offsets : list [int ] = []
464+ for i , col_id in enumerate (self ._block .value_columns ):
476465 is_numeric = (
477466 self ._column_type (col_id ) in dtypes .NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
478467 )
479- if is_numeric or not numeric_only :
468+ if (col_id in self ._selected_cols ) and (is_numeric or not numeric_only ):
469+ offsets .append (i )
480470 valid_agg_cols .append (col_id )
481- return valid_agg_cols
471+ return valid_agg_cols , self . _block . column_labels . take ( offsets )
482472
483473 def _column_type (self , col_id : str ) -> dtypes .Dtype :
484474 col_offset = self ._block .value_columns .index (col_id )
@@ -488,11 +478,12 @@ def _column_type(self, col_id: str) -> dtypes.Dtype:
488478 def _aggregate_all (
489479 self , aggregate_op : agg_ops .UnaryAggregateOp , numeric_only : bool = False
490480 ) -> df .DataFrame :
491- aggregated_col_ids = self ._aggregated_columns (numeric_only = numeric_only )
492- aggregations = [(col_id , aggregate_op ) for col_id in aggregated_col_ids ]
481+ aggregated_col_ids , labels = self ._aggregated_columns (numeric_only = numeric_only )
482+ aggregations = [agg (col_id , aggregate_op ) for col_id in aggregated_col_ids ]
493483 result_block , _ = self ._block .aggregate (
494484 by_column_ids = self ._by_col_ids ,
495485 aggregations = aggregations ,
486+ column_labels = labels ,
496487 dropna = self ._dropna ,
497488 )
498489 dataframe = df .DataFrame (result_block )
@@ -508,7 +499,7 @@ def _apply_window_op(
508499 window_spec = window or window_specs .cumulative_rows (
509500 grouping_keys = tuple (self ._by_col_ids )
510501 )
511- columns = self ._aggregated_columns (numeric_only = numeric_only )
502+ columns , _ = self ._aggregated_columns (numeric_only = numeric_only )
512503 block , result_ids = self ._block .multi_apply_window_op (
513504 columns , op , window_spec = window_spec
514505 )
@@ -639,11 +630,11 @@ def prod(self, *args) -> series.Series:
639630 def agg (self , func = None ) -> typing .Union [df .DataFrame , series .Series ]:
640631 column_names : list [str ] = []
641632 if isinstance (func , str ):
642- aggregations = [(self ._value_column , agg_ops .lookup_agg_func (func ))]
633+ aggregations = [agg (self ._value_column , agg_ops .lookup_agg_func (func ))]
643634 column_names = [func ]
644635 elif utils .is_list_like (func ):
645636 aggregations = [
646- (self ._value_column , agg_ops .lookup_agg_func (f )) for f in func
637+ agg (self ._value_column , agg_ops .lookup_agg_func (f )) for f in func
647638 ]
648639 column_names = list (func )
649640 else :
@@ -756,7 +747,7 @@ def expanding(self, min_periods: int = 1) -> windows.Window:
756747 def _aggregate (self , aggregate_op : agg_ops .UnaryAggregateOp ) -> series .Series :
757748 result_block , _ = self ._block .aggregate (
758749 self ._by_col_ids ,
759- ((self ._value_column , aggregate_op ),),
750+ (agg (self ._value_column , aggregate_op ),),
760751 dropna = self ._dropna ,
761752 )
762753
@@ -781,3 +772,13 @@ def _apply_window_op(
781772 window_spec = window_spec ,
782773 )
783774 return series .Series (block .select_column (result_id ))
775+
776+
777+ def agg (input : str , op : agg_ops .AggregateOp ) -> bigframes .core .expression .Aggregation :
778+ if isinstance (op , agg_ops .UnaryAggregateOp ):
779+ return bigframes .core .expression .UnaryAggregation (
780+ op , bigframes .core .expression .deref (input )
781+ )
782+ else :
783+ assert isinstance (op , agg_ops .NullaryAggregateOp )
784+ return bigframes .core .expression .NullaryAggregation (op )
0 commit comments