Skip to content
28 changes: 28 additions & 0 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,34 @@ def fit(
return self._fit(X, y)


class SupervisedTrainableWithIdColPredictor(SupervisedTrainablePredictor):
"""Inherits from SupervisedTrainablePredictor,
but adds an optional id_col parameter to fit()."""

def __init__(self):
super().__init__()
self.id_col = None

def _fit(
self,
X: utils.ArrayType,
y: utils.ArrayType,
transforms=None,
id_col: Optional[utils.ArrayType] = None,
):
return self

def fit(
self,
X: utils.ArrayType,
y: utils.ArrayType,
transforms=None,
id_col: Optional[utils.ArrayType] = None,
):
self.id_col = id_col
return self._fit(X, y, transforms=transforms, id_col=self.id_col)


class TrainableWithEvaluationPredictor(TrainablePredictor):
"""A BigQuery DataFrames ML Model base class that can be used to fit and predict outputs.

Expand Down
27 changes: 22 additions & 5 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,23 @@ def detect_anomalies(

def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()
timestamp_col_name = "forecast_timestamp"
index_cols = [timestamp_col_name]
first_col_name = self._session.read_gbq(sql).columns.values[0]
if timestamp_col_name != first_col_name:
index_cols.append(first_col_name)
return self._session.read_gbq(sql, index_col=index_cols).reset_index()

def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
struct_options=options
)
return self._session.read_gbq(
sql, index_col="time_series_timestamp"
).reset_index()
timestamp_col_name = "time_series_timestamp"
index_cols = [timestamp_col_name]
first_col_name = self._session.read_gbq(sql).columns.values[0]
if timestamp_col_name != first_col_name:
index_cols.append(first_col_name)
return self._session.read_gbq(sql, index_col=index_cols).reset_index()

def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
sql = self._model_manipulation_sql_generator.ml_evaluate(
Expand Down Expand Up @@ -390,6 +398,7 @@ def create_time_series_model(
self,
X_train: bpd.DataFrame,
y_train: bpd.DataFrame,
id_col: Optional[bpd.DataFrame] = None,
transforms: Optional[Iterable[str]] = None,
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
) -> BqmlModel:
Expand All @@ -399,13 +408,21 @@ def create_time_series_model(
assert (
y_train.columns.size == 1
), "Time stamp data input must only contain 1 column."
assert id_col is None or (
id_col is not None and id_col.columns.size == 1
), "Time series id input is either None or must only contain 1 column."

options = dict(options)
# Cache dataframes to make sure base table is not a snapshot
# cached dataframe creates a full copy, never uses snapshot
input_data = X_train.join(y_train, how="outer").cache()
input_data = X_train.join(y_train, how="outer")
if id_col is not None:
input_data = input_data.join(id_col, how="outer")
input_data = input_data.cache()
options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]})
options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]})
if id_col is not None:
options.update({"TIME_SERIES_ID_COL": id_col.columns.tolist()[0]})

session = X_train._session
model_ref = self._create_model_ref(session._anonymous_dataset)
Expand Down
58 changes: 44 additions & 14 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


@log_adapter.class_logger
class ARIMAPlus(base.SupervisedTrainablePredictor):
class ARIMAPlus(base.SupervisedTrainableWithIdColPredictor):
"""Time Series ARIMA Plus model.

Args:
Expand Down Expand Up @@ -183,37 +183,53 @@ def _fit(
X: utils.ArrayType,
y: utils.ArrayType,
transforms: Optional[List[str]] = None,
):
id_col: Optional[utils.ArrayType] = None,
) -> ARIMAPlus:
"""Fit the model to training data.

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
A dataframe of training timestamp.

y (bigframes.dataframe.DataFrame or bigframes.series.Series):
X (bigframes.dataframe.DataFrame or bigframes.series.Series,
or pandas.core.frame.DataFrame or pandas.core.series.Series):
A dataframe or series of trainging timestamp.
y (bigframes.dataframe.DataFrame, or bigframes.series.Series,
or pandas.core.frame.DataFrame, or pandas.core.series.Series):
Target values for training.
transforms (Optional[List[str]], default None):
Do not use. Internal param to be deprecated.
Use bigframes.ml.pipeline instead.
id_col (Optional[bigframes.dataframe.DataFrame]
or Optional[bigframes.series.Series]
or Optional[pandas.core.frame.DataFrame]
or Optional[pandas.core.frame.Series]
or None, default None):
An optional dataframe or series of training id col.

Returns:
ARIMAPlus: Fitted estimator.
"""
X, y = utils.batch_convert_to_dataframe(X, y)

if X.columns.size != 1:
raise ValueError(
"Time series timestamp input X must only contain 1 column."
)
raise ValueError("Time series timestamp input X contain at least 1 column.")
if y.columns.size != 1:
raise ValueError("Time series data input y must only contain 1 column.")

if id_col is not None:
(id_col,) = utils.batch_convert_to_dataframe(id_col)

if id_col.columns.size != 1:
raise ValueError(
"Time series id input id_col must only contain 1 column."
)

self._bqml_model = self._bqml_model_factory.create_time_series_model(
X,
y,
id_col=id_col,
transforms=transforms,
options=self._bqml_options,
)
return self

def predict(
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
Expand All @@ -237,7 +253,7 @@ def predict(

Returns:
bigframes.dataframe.DataFrame: The predicted DataFrames. Which
contains 2 columns: "forecast_timestamp" and "forecast_value".
contains 2 columns: "forecast_timestamp", "id" as optional, and "forecast_value".
"""
if horizon < 1 or horizon > 1000:
raise ValueError(f"horizon must be [1, 1000], but is {horizon}.")
Expand Down Expand Up @@ -345,6 +361,7 @@ def score(
self,
X: utils.ArrayType,
y: utils.ArrayType,
id_col: Optional[utils.ArrayType] = None,
) -> bpd.DataFrame:
"""Calculate evaluation metrics of the model.

Expand All @@ -355,13 +372,22 @@ def score(
for the outputs relevant to this model type.

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
A BigQuery DataFrame only contains 1 column as
X (bigframes.dataframe.DataFrame or bigframes.series.Series
or pandas.core.frame.DataFrame or pandas.core.series.Series):
A dataframe or series only contains 1 column as
evaluation timestamp. The timestamp must be within the horizon
of the model, which by default is 1000 data points.
y (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
A BigQuery DataFrame only contains 1 column as
y (bigframes.dataframe.DataFrame or bigframes.series.Series
or pandas.core.frame.DataFrame or pandas.core.series.Series):
A dataframe or series only contains 1 column as
evaluation numeric values.
id_col (Optional[bigframes.dataframe.DataFrame]
or Optional[bigframes.series.Series]
or Optional[pandas.core.frame.DataFrame]
or Optional[pandas.core.series.Series]
or None, default None):
An optional dataframe or series contains at least 1 column as
evaluation id column.

Returns:
bigframes.dataframe.DataFrame: A DataFrame as evaluation result.
Expand All @@ -371,6 +397,10 @@ def score(
X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session)

input_data = X.join(y, how="outer")
if id_col is not None:
(id_col,) = utils.batch_convert_to_dataframe(id_col)
input_data = input_data.join(id_col, how="outer")

return self._bqml_model.evaluate(input_data)

def summary(
Expand Down
Loading