@@ -155,14 +155,15 @@ def _fit(
155155 def predict (self , X : utils .ArrayType ) -> bpd .DataFrame :
156156 if not self ._bqml_model :
157157 raise RuntimeError ("A model must be fitted before predict" )
158-
159- (X ,) = utils .batch_convert_to_dataframe (X )
158+ (X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
160159
161160 return self ._bqml_model .predict (X )
162161
163162 def predict_explain (
164163 self ,
165164 X : utils .ArrayType ,
165+ * ,
166+ top_k_features : int = 5 ,
166167 ) -> bpd .DataFrame :
167168 """
168169 Explain predictions for a linear regression model.
@@ -175,18 +176,32 @@ def predict_explain(
175176 X (bigframes.dataframe.DataFrame or bigframes.series.Series or
176177 pandas.core.frame.DataFrame or pandas.core.series.Series):
177178 Series or a DataFrame to explain its predictions.
179+ top_k_features (int, default 5):
180+ an INT64 value that specifies how many top feature attribution
181+ pairs are generated for each row of input data. The features are
182+ ranked by the absolute values of their attributions.
183+
184+ By default, top_k_features is set to 5. If its value is greater
185+ than the number of features in the training data, the
186+ attributions of all features are returned.
178187
179188 Returns:
180189 bigframes.pandas.DataFrame:
181190 The predicted DataFrames with explanation columns.
182191 """
183- # TODO(b/377366612): Add support for `top_k_features` parameter
192+ if top_k_features < 1 :
193+ raise ValueError (
194+ f"top_k_features must be at least 1, but is { top_k_features } ."
195+ )
196+
184197 if not self ._bqml_model :
185198 raise RuntimeError ("A model must be fitted before predict" )
186199
187200 (X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
188201
189- return self ._bqml_model .explain_predict (X )
202+ return self ._bqml_model .explain_predict (
203+ X , options = {"top_k_features" : top_k_features }
204+ )
190205
191206 def score (
192207 self ,
@@ -356,6 +371,8 @@ def predict(
356371 def predict_explain (
357372 self ,
358373 X : utils .ArrayType ,
374+ * ,
375+ top_k_features : int = 5 ,
359376 ) -> bpd .DataFrame :
360377 """
361378 Explain predictions for a logistic regression model.
@@ -368,18 +385,32 @@ def predict_explain(
368385 X (bigframes.dataframe.DataFrame or bigframes.series.Series or
369386 pandas.core.frame.DataFrame or pandas.core.series.Series):
370387 Series or a DataFrame to explain its predictions.
388+ top_k_features (int, default 5):
389+ an INT64 value that specifies how many top feature attribution
390+ pairs are generated for each row of input data. The features are
391+ ranked by the absolute values of their attributions.
392+
393+ By default, top_k_features is set to 5. If its value is greater
394+ than the number of features in the training data, the
395+ attributions of all features are returned.
371396
372397 Returns:
373398 bigframes.pandas.DataFrame:
374399 The predicted DataFrames with explanation columns.
375400 """
376- # TODO(b/377366612): Add support for `top_k_features` parameter
401+ if top_k_features < 1 :
402+ raise ValueError (
403+ f"top_k_features must be at least 1, but is { top_k_features } ."
404+ )
405+
377406 if not self ._bqml_model :
378407 raise RuntimeError ("A model must be fitted before predict" )
379408
380409 (X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
381410
382- return self ._bqml_model .explain_predict (X )
411+ return self ._bqml_model .explain_predict (
412+ X , options = {"top_k_features" : top_k_features }
413+ )
383414
384415 def score (
385416 self ,
0 commit comments