@@ -76,18 +76,7 @@ def _apply_aggregate(
7676 self ,
7777 op : agg_ops .UnaryAggregateOp ,
7878 ):
79- agg_col_ids = [
80- col_id
81- for col_id in self ._value_column_ids
82- if col_id != self ._skip_agg_column_id
83- ]
84- agg_block = self ._aggregate_block (op , agg_col_ids )
85-
86- if self ._skip_agg_column_id is not None :
87- # Concat the skipped column to the result.
88- agg_block , _ = agg_block .join (
89- self ._block .select_column (self ._skip_agg_column_id ), how = "outer"
90- )
79+ agg_block = self ._aggregate_block (op )
9180
9281 if self ._is_series :
9382 from bigframes .series import Series
@@ -102,9 +91,12 @@ def _apply_aggregate(
10291 ]
10392 return DataFrame (agg_block )._reindex_columns (column_labels )
10493
105- def _aggregate_block (
106- self , op : agg_ops .UnaryAggregateOp , agg_col_ids : typing .List [str ]
107- ) -> blocks .Block :
94+ def _aggregate_block (self , op : agg_ops .UnaryAggregateOp ) -> blocks .Block :
95+ agg_col_ids = [
96+ col_id
97+ for col_id in self ._value_column_ids
98+ if col_id != self ._skip_agg_column_id
99+ ]
108100 block , result_ids = self ._block .multi_apply_window_op (
109101 agg_col_ids ,
110102 op ,
@@ -123,39 +115,71 @@ def _aggregate_block(
123115 block = block .set_index (col_ids = index_ids )
124116
125117 labels = [self ._block .col_id_to_label [col ] for col in agg_col_ids ]
118+ if self ._skip_agg_column_id is not None :
119+ result_ids = [self ._skip_agg_column_id , * result_ids ]
120+ labels .insert (0 , self ._block .col_id_to_label [self ._skip_agg_column_id ])
121+
126122 return block .select_columns (result_ids ).with_column_labels (labels )
127123
128124
129125def create_range_window (
130126 block : blocks .Block ,
131127 window : pandas .Timedelta | numpy .timedelta64 | datetime .timedelta | str ,
128+ * ,
129+ value_column_ids : typing .Sequence [str ] = tuple (),
132130 min_periods : int | None ,
131+ on : str | None = None ,
133132 closed : typing .Literal ["right" , "left" , "both" , "neither" ],
134133 is_series : bool ,
134+ grouping_keys : typing .Sequence [str ] = tuple (),
135+ drop_null_groups : bool = True ,
135136) -> Window :
136137
137- index_dtypes = block .index .dtypes
138- if len (index_dtypes ) > 1 :
139- raise ValueError ("Range rolling on MultiIndex is not supported" )
140- if index_dtypes [0 ] != dtypes .TIMESTAMP_DTYPE :
141- raise ValueError ("Index type should be timestamps with timezones" )
138+ if on is None :
139+ # Rolling on index
140+ index_dtypes = block .index .dtypes
141+ if len (index_dtypes ) > 1 :
142+ raise ValueError ("Range rolling on MultiIndex is not supported" )
143+ if index_dtypes [0 ] != dtypes .TIMESTAMP_DTYPE :
144+ raise ValueError ("Index type should be timestamps with timezones" )
145+ rolling_key_col_id = block .index_columns [0 ]
146+ else :
147+ # Rolling on a specific column
148+ rolling_key_col_id = block .resolve_label_exact_or_error (on )
149+ if block .expr .get_column_type (rolling_key_col_id ) != dtypes .TIMESTAMP_DTYPE :
150+ raise ValueError (f"Column { on } type should be timestamps with timezones" )
142151
143152 order_direction = window_ordering .find_order_direction (
144- block .expr .node , block . index_columns [ 0 ]
153+ block .expr .node , rolling_key_col_id
145154 )
146155 if order_direction is None :
156+ target_str = "index" if on is None else f"column { on } "
147157 raise ValueError (
148- "The index might not be in a monotonic order. Please sort the index before rolling."
158+ f "The { target_str } might not be in a monotonic order. Please sort by { target_str } before rolling."
149159 )
150160 if isinstance (window , str ):
151161 window = pandas .Timedelta (window )
152162 spec = window_spec .WindowSpec (
153163 bounds = window_spec .RangeWindowBounds .from_timedelta_window (window , closed ),
154164 min_periods = 1 if min_periods is None else min_periods ,
155165 ordering = (
156- ordering .OrderingExpression (
157- ex .deref (block .index_columns [0 ]), order_direction
158- ),
166+ ordering .OrderingExpression (ex .deref (rolling_key_col_id ), order_direction ),
159167 ),
168+ grouping_keys = tuple (ex .deref (col ) for col in grouping_keys ),
169+ )
170+
171+ selected_value_col_ids = (
172+ value_column_ids if value_column_ids else block .value_columns
173+ )
174+ # This step must be done after finding the order direction of the window key.
175+ if grouping_keys :
176+ block = block .order_by ([ordering .ascending_over (col ) for col in grouping_keys ])
177+
178+ return Window (
179+ block ,
180+ spec ,
181+ value_column_ids = selected_value_col_ids ,
182+ is_series = is_series ,
183+ skip_agg_column_id = None if on is None else rolling_key_col_id ,
184+ drop_null_groups = drop_null_groups ,
160185 )
161- return Window (block , spec , block .value_columns , is_series = is_series )
0 commit comments