@@ -35,7 +35,17 @@ class BaseBqml:
3535
3636 def __init__ (self , session : bigframes .session .Session ):
3737 self ._session = session
38- self ._base_sql_generator = ml_sql .BaseSqlGenerator ()
38+ self ._sql_generator = ml_sql .BaseSqlGenerator ()
39+
40+ def ai_forecast (
41+ self ,
42+ input_data : bpd .DataFrame ,
43+ options : Mapping [str , Union [str , int , float , Iterable [str ]]],
44+ ) -> bpd .DataFrame :
45+ result_sql = self ._sql_generator .ai_forecast (
46+ source_sql = input_data .sql , options = options
47+ )
48+ return self ._session .read_gbq (result_sql )
3949
4050
4151class BqmlModel (BaseBqml ):
@@ -55,8 +65,8 @@ def __init__(self, session: bigframes.Session, model: bigquery.Model):
5565 self ._model = model
5666 model_ref = self ._model .reference
5767 assert model_ref is not None
58- self ._model_manipulation_sql_generator = ml_sql .ModelManipulationSqlGenerator (
59- model_ref
68+ self ._sql_generator : ml_sql .ModelManipulationSqlGenerator = (
69+ ml_sql . ModelManipulationSqlGenerator ( model_ref )
6070 )
6171
6272 def _apply_ml_tvf (
@@ -126,30 +136,28 @@ def model(self) -> bigquery.Model:
126136 def recommend (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
127137 return self ._apply_ml_tvf (
128138 input_data ,
129- self ._model_manipulation_sql_generator .ml_recommend ,
139+ self ._sql_generator .ml_recommend ,
130140 )
131141
132142 def predict (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
133143 return self ._apply_ml_tvf (
134144 input_data ,
135- self ._model_manipulation_sql_generator .ml_predict ,
145+ self ._sql_generator .ml_predict ,
136146 )
137147
138148 def explain_predict (
139149 self , input_data : bpd .DataFrame , options : Mapping [str , int | float ]
140150 ) -> bpd .DataFrame :
141151 return self ._apply_ml_tvf (
142152 input_data ,
143- lambda source_sql : self ._model_manipulation_sql_generator .ml_explain_predict (
153+ lambda source_sql : self ._sql_generator .ml_explain_predict (
144154 source_sql = source_sql ,
145155 struct_options = options ,
146156 ),
147157 )
148158
149159 def global_explain (self , options : Mapping [str , bool ]) -> bpd .DataFrame :
150- sql = self ._model_manipulation_sql_generator .ml_global_explain (
151- struct_options = options
152- )
160+ sql = self ._sql_generator .ml_global_explain (struct_options = options )
153161 return (
154162 self ._session .read_gbq (sql )
155163 .sort_values (by = "attribution" , ascending = False )
@@ -159,7 +167,7 @@ def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
159167 def transform (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
160168 return self ._apply_ml_tvf (
161169 input_data ,
162- self ._model_manipulation_sql_generator .ml_transform ,
170+ self ._sql_generator .ml_transform ,
163171 )
164172
165173 def generate_text (
@@ -170,7 +178,7 @@ def generate_text(
170178 options ["flatten_json_output" ] = True
171179 return self ._apply_ml_tvf (
172180 input_data ,
173- lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_text (
181+ lambda source_sql : self ._sql_generator .ml_generate_text (
174182 source_sql = source_sql ,
175183 struct_options = options ,
176184 ),
@@ -186,7 +194,7 @@ def generate_embedding(
186194 options ["flatten_json_output" ] = True
187195 return self ._apply_ml_tvf (
188196 input_data ,
189- lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_embedding (
197+ lambda source_sql : self ._sql_generator .ml_generate_embedding (
190198 source_sql = source_sql ,
191199 struct_options = options ,
192200 ),
@@ -201,7 +209,7 @@ def generate_table(
201209 ) -> bpd .DataFrame :
202210 return self ._apply_ml_tvf (
203211 input_data ,
204- lambda source_sql : self ._model_manipulation_sql_generator .ai_generate_table (
212+ lambda source_sql : self ._sql_generator .ai_generate_table (
205213 source_sql = source_sql ,
206214 struct_options = options ,
207215 ),
@@ -216,14 +224,14 @@ def detect_anomalies(
216224
217225 return self ._apply_ml_tvf (
218226 input_data ,
219- lambda source_sql : self ._model_manipulation_sql_generator .ml_detect_anomalies (
227+ lambda source_sql : self ._sql_generator .ml_detect_anomalies (
220228 source_sql = source_sql ,
221229 struct_options = options ,
222230 ),
223231 )
224232
225233 def forecast (self , options : Mapping [str , int | float ]) -> bpd .DataFrame :
226- sql = self ._model_manipulation_sql_generator .ml_forecast (struct_options = options )
234+ sql = self ._sql_generator .ml_forecast (struct_options = options )
227235 timestamp_col_name = "forecast_timestamp"
228236 index_cols = [timestamp_col_name ]
229237 first_col_name = self ._session .read_gbq (sql ).columns .values [0 ]
@@ -232,9 +240,7 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
232240 return self ._session .read_gbq (sql , index_col = index_cols ).reset_index ()
233241
234242 def explain_forecast (self , options : Mapping [str , int | float ]) -> bpd .DataFrame :
235- sql = self ._model_manipulation_sql_generator .ml_explain_forecast (
236- struct_options = options
237- )
243+ sql = self ._sql_generator .ml_explain_forecast (struct_options = options )
238244 timestamp_col_name = "time_series_timestamp"
239245 index_cols = [timestamp_col_name ]
240246 first_col_name = self ._session .read_gbq (sql ).columns .values [0 ]
@@ -243,7 +249,7 @@ def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
243249 return self ._session .read_gbq (sql , index_col = index_cols ).reset_index ()
244250
245251 def evaluate (self , input_data : Optional [bpd .DataFrame ] = None ):
246- sql = self ._model_manipulation_sql_generator .ml_evaluate (
252+ sql = self ._sql_generator .ml_evaluate (
247253 input_data .sql if (input_data is not None ) else None
248254 )
249255
@@ -254,28 +260,24 @@ def llm_evaluate(
254260 input_data : bpd .DataFrame ,
255261 task_type : Optional [str ] = None ,
256262 ):
257- sql = self ._model_manipulation_sql_generator .ml_llm_evaluate (
258- input_data .sql , task_type
259- )
263+ sql = self ._sql_generator .ml_llm_evaluate (input_data .sql , task_type )
260264
261265 return self ._session .read_gbq (sql )
262266
263267 def arima_evaluate (self , show_all_candidate_models : bool = False ):
264- sql = self ._model_manipulation_sql_generator .ml_arima_evaluate (
265- show_all_candidate_models
266- )
268+ sql = self ._sql_generator .ml_arima_evaluate (show_all_candidate_models )
267269
268270 return self ._session .read_gbq (sql )
269271
270272 def arima_coefficients (self ) -> bpd .DataFrame :
271- sql = self ._model_manipulation_sql_generator .ml_arima_coefficients ()
273+ sql = self ._sql_generator .ml_arima_coefficients ()
272274
273275 return self ._session .read_gbq (sql )
274276
275277 def centroids (self ) -> bpd .DataFrame :
276278 assert self ._model .model_type == "KMEANS"
277279
278- sql = self ._model_manipulation_sql_generator .ml_centroids ()
280+ sql = self ._sql_generator .ml_centroids ()
279281
280282 return self ._session .read_gbq (
281283 sql , index_col = ["centroid_id" , "feature" ]
@@ -284,7 +286,7 @@ def centroids(self) -> bpd.DataFrame:
284286 def principal_components (self ) -> bpd .DataFrame :
285287 assert self ._model .model_type == "PCA"
286288
287- sql = self ._model_manipulation_sql_generator .ml_principal_components ()
289+ sql = self ._sql_generator .ml_principal_components ()
288290
289291 return self ._session .read_gbq (
290292 sql , index_col = ["principal_component_id" , "feature" ]
@@ -293,7 +295,7 @@ def principal_components(self) -> bpd.DataFrame:
293295 def principal_component_info (self ) -> bpd .DataFrame :
294296 assert self ._model .model_type == "PCA"
295297
296- sql = self ._model_manipulation_sql_generator .ml_principal_component_info ()
298+ sql = self ._sql_generator .ml_principal_component_info ()
297299
298300 return self ._session .read_gbq (sql )
299301
@@ -319,7 +321,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
319321 # truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
320322 # The possibility of conflicts should be low.
321323 vertex_ai_model_id = vertex_ai_model_id [:63 ]
322- sql = self ._model_manipulation_sql_generator .alter_model (
324+ sql = self ._sql_generator .alter_model (
323325 options = {"vertex_ai_model_id" : vertex_ai_model_id }
324326 )
325327 # Register the model and wait it to finish
0 commit comments