|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | 19 | from typing import cast, Literal, Optional, Union |
| 20 | +import warnings |
20 | 21 |
|
21 | 22 | import bigframes |
22 | 23 | from bigframes import clients, constants |
23 | 24 | from bigframes.core import blocks |
24 | 25 | from bigframes.ml import base, core, globals, utils |
25 | 26 | import bigframes.pandas as bpd |
26 | 27 |
|
27 | | -_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT = "text-bison" |
28 | | -_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT = "text-bison-32k" |
29 | | -_TEXT_GENERATE_RESULT_COLUMN = "ml_generate_text_llm_result" |
| 28 | +_TEXT_GENERATOR_BISON_ENDPOINT = "text-bison" |
| 29 | +_TEXT_GENERATOR_BISON_32K_ENDPOINT = "text-bison-32k" |
| 30 | +_TEXT_GENERATOR_ENDPOINTS = ( |
| 31 | + _TEXT_GENERATOR_BISON_ENDPOINT, |
| 32 | + _TEXT_GENERATOR_BISON_32K_ENDPOINT, |
| 33 | +) |
30 | 34 |
|
31 | | -_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT = "textembedding-gecko" |
32 | | -_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT = ( |
33 | | - "textembedding-gecko-multilingual" |
| 35 | +_EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko" |
| 36 | +_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual" |
| 37 | +_EMBEDDING_GENERATOR_ENDPOINTS = ( |
| 38 | + _EMBEDDING_GENERATOR_GECKO_ENDPOINT, |
| 39 | + _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT, |
34 | 40 | ) |
35 | | -_EMBED_TEXT_RESULT_COLUMN = "text_embedding" |
| 41 | + |
| 42 | +_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" |
| 43 | +_ML_EMBED_TEXT_STATUS = "ml_embed_text_status" |
36 | 44 |
|
37 | 45 |
|
38 | 46 | class PaLM2TextGenerator(base.Predictor): |
@@ -90,18 +98,16 @@ def _create_bqml_model(self): |
90 | 98 | connection_id=connection_name_parts[2], |
91 | 99 | iam_role="aiplatform.user", |
92 | 100 | ) |
93 | | - if self.model_name == _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT: |
94 | | - options = { |
95 | | - "endpoint": _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT, |
96 | | - } |
97 | | - elif self.model_name == _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT: |
98 | | - options = { |
99 | | - "endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT, |
100 | | - } |
101 | | - else: |
| 101 | + |
| 102 | + if self.model_name not in _TEXT_GENERATOR_ENDPOINTS: |
102 | 103 | raise ValueError( |
103 | | - f"Model name {self.model_name} is not supported. We only support {_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT}." |
| 104 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_GENERATOR_ENDPOINTS)}." |
104 | 105 | ) |
| 106 | + |
| 107 | + options = { |
| 108 | + "endpoint": self.model_name, |
| 109 | + } |
| 110 | + |
105 | 111 | return self._bqml_model_factory.create_remote_model( |
106 | 112 | session=self.session, connection_name=self.connection_name, options=options |
107 | 113 | ) |
@@ -182,7 +188,16 @@ def predict( |
182 | 188 | "top_p": top_p, |
183 | 189 | "flatten_json_output": True, |
184 | 190 | } |
185 | | - return self._bqml_model.generate_text(X, options) |
| 191 | + |
| 192 | + df = self._bqml_model.generate_text(X, options) |
| 193 | + |
| 194 | + if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): |
| 195 | + warnings.warn( |
| 196 | + f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", |
| 197 | + RuntimeWarning, |
| 198 | + ) |
| 199 | + |
| 200 | + return df |
186 | 201 |
|
187 | 202 |
|
188 | 203 | class PaLM2TextEmbeddingGenerator(base.Predictor): |
@@ -241,19 +256,15 @@ def _create_bqml_model(self): |
241 | 256 | connection_id=connection_name_parts[2], |
242 | 257 | iam_role="aiplatform.user", |
243 | 258 | ) |
244 | | - if self.model_name == "textembedding-gecko": |
245 | | - options = { |
246 | | - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT, |
247 | | - } |
248 | | - elif self.model_name == _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT: |
249 | | - options = { |
250 | | - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT, |
251 | | - } |
252 | | - else: |
| 259 | + |
| 260 | + if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: |
253 | 261 | raise ValueError( |
254 | | - f"Model name {self.model_name} is not supported. We only support {_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT}." |
| 262 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}." |
255 | 263 | ) |
256 | 264 |
|
| 265 | + options = { |
| 266 | + "endpoint": self.model_name, |
| 267 | + } |
257 | 268 | return self._bqml_model_factory.create_remote_model( |
258 | 269 | session=self.session, connection_name=self.connection_name, options=options |
259 | 270 | ) |
@@ -284,4 +295,13 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: |
284 | 295 | options = { |
285 | 296 | "flatten_json_output": True, |
286 | 297 | } |
287 | | - return self._bqml_model.generate_text_embedding(X, options) |
| 298 | + |
| 299 | + df = self._bqml_model.generate_text_embedding(X, options) |
| 300 | + |
| 301 | + if (df[_ML_EMBED_TEXT_STATUS] != "").any(): |
| 302 | + warnings.warn( |
| 303 | + f"Some predictions failed. Check column {_ML_EMBED_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", |
| 304 | + RuntimeWarning, |
| 305 | + ) |
| 306 | + |
| 307 | + return df |
0 commit comments