1818from __future__ import annotations
1919
2020import typing
21- from typing import Any , cast , List , Literal , Optional , Tuple , Union
21+ from typing import cast , Iterable , List , Literal , Optional , Tuple , Union
2222
2323import bigframes_vendored .sklearn .preprocessing ._data
2424import bigframes_vendored .sklearn .preprocessing ._discretization
@@ -43,11 +43,10 @@ def __init__(self):
4343 self ._bqml_model_factory = globals .bqml_model_factory ()
4444 self ._base_sql_generator = globals .base_sql_generator ()
4545
46- # TODO(garrettwu): implement __hash__
47- def __eq__ (self , other : Any ) -> bool :
48- return type (other ) is StandardScaler and self ._bqml_model == other ._bqml_model
46+ def _keys (self ):
47+ return (self ._bqml_model ,)
4948
50- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
49+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
5150 """Compile this transformer to a list of SQL expressions that can be included in
5251 a BQML TRANSFORM clause
5352
@@ -125,11 +124,10 @@ def __init__(self):
125124 self ._bqml_model_factory = globals .bqml_model_factory ()
126125 self ._base_sql_generator = globals .base_sql_generator ()
127126
128- # TODO(garrettwu): implement __hash__
129- def __eq__ (self , other : Any ) -> bool :
130- return type (other ) is MaxAbsScaler and self ._bqml_model == other ._bqml_model
127+ def _keys (self ):
128+ return (self ._bqml_model ,)
131129
132- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
130+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
133131 """Compile this transformer to a list of SQL expressions that can be included in
134132 a BQML TRANSFORM clause
135133
@@ -207,11 +205,10 @@ def __init__(self):
207205 self ._bqml_model_factory = globals .bqml_model_factory ()
208206 self ._base_sql_generator = globals .base_sql_generator ()
209207
210- # TODO(garrettwu): implement __hash__
211- def __eq__ (self , other : Any ) -> bool :
212- return type (other ) is MinMaxScaler and self ._bqml_model == other ._bqml_model
208+ def _keys (self ):
209+ return (self ._bqml_model ,)
213210
214- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
211+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
215212 """Compile this transformer to a list of SQL expressions that can be included in
216213 a BQML TRANSFORM clause
217214
@@ -301,18 +298,12 @@ def __init__(
301298 self ._bqml_model_factory = globals .bqml_model_factory ()
302299 self ._base_sql_generator = globals .base_sql_generator ()
303300
304- # TODO(garrettwu): implement __hash__
305- def __eq__ (self , other : Any ) -> bool :
306- return (
307- type (other ) is KBinsDiscretizer
308- and self .n_bins == other .n_bins
309- and self .strategy == other .strategy
310- and self ._bqml_model == other ._bqml_model
311- )
301+ def _keys (self ):
302+ return (self ._bqml_model , self .n_bins , self .strategy )
312303
313304 def _compile_to_sql (
314305 self ,
315- columns : List [str ],
306+ columns : Iterable [str ],
316307 X : bpd .DataFrame ,
317308 ) -> List [Tuple [str , str ]]:
318309 """Compile this transformer to a list of SQL expressions that can be included in
@@ -446,17 +437,10 @@ def __init__(
446437 self ._bqml_model_factory = globals .bqml_model_factory ()
447438 self ._base_sql_generator = globals .base_sql_generator ()
448439
449- # TODO(garrettwu): implement __hash__
450- def __eq__ (self , other : Any ) -> bool :
451- return (
452- type (other ) is OneHotEncoder
453- and self ._bqml_model == other ._bqml_model
454- and self .drop == other .drop
455- and self .min_frequency == other .min_frequency
456- and self .max_categories == other .max_categories
457- )
440+ def _keys (self ):
441+ return (self ._bqml_model , self .drop , self .min_frequency , self .max_categories )
458442
459- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
443+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
460444 """Compile this transformer to a list of SQL expressions that can be included in
461445 a BQML TRANSFORM clause
462446
@@ -572,16 +556,10 @@ def __init__(
572556 self ._bqml_model_factory = globals .bqml_model_factory ()
573557 self ._base_sql_generator = globals .base_sql_generator ()
574558
575- # TODO(garrettwu): implement __hash__
576- def __eq__ (self , other : Any ) -> bool :
577- return (
578- type (other ) is LabelEncoder
579- and self ._bqml_model == other ._bqml_model
580- and self .min_frequency == other .min_frequency
581- and self .max_categories == other .max_categories
582- )
559+ def _keys (self ):
560+ return (self ._bqml_model , self .min_frequency , self .max_categories )
583561
584- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
562+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
585563 """Compile this transformer to a list of SQL expressions that can be included in
586564 a BQML TRANSFORM clause
587565
@@ -672,18 +650,17 @@ class PolynomialFeatures(
672650 )
673651
674652 def __init__ (self , degree : int = 2 ):
653+ if degree not in range (1 , 5 ):
654+ raise ValueError (f"degree has to be [1, 4], input is { degree } ." )
675655 self .degree = degree
676656 self ._bqml_model : Optional [core .BqmlModel ] = None
677657 self ._bqml_model_factory = globals .bqml_model_factory ()
678658 self ._base_sql_generator = globals .base_sql_generator ()
679659
680- # TODO(garrettwu): implement __hash__
681- def __eq__ (self , other : Any ) -> bool :
682- return (
683- type (other ) is PolynomialFeatures and self ._bqml_model == other ._bqml_model
684- )
660+ def _keys (self ):
661+ return (self ._bqml_model , self .degree )
685662
686- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
663+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
687664 """Compile this transformer to a list of SQL expressions that can be included in
688665 a BQML TRANSFORM clause
689666
@@ -705,17 +682,18 @@ def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
705682 ]
706683
707684 @classmethod
708- def _parse_from_sql (cls , sql : str ) -> tuple [PolynomialFeatures , str ]:
709- """Parse SQL to tuple(PolynomialFeatures, column_label ).
685+ def _parse_from_sql (cls , sql : str ) -> tuple [PolynomialFeatures , tuple [ str , ...] ]:
686+ """Parse SQL to tuple(PolynomialFeatures, column_labels ).
710687
711688 Args:
712689 sql: SQL string of format "ML.POLYNOMIAL_EXPAND(STRUCT(col_label0, col_label1, ...), degree)"
713690
714691 Returns:
715692 tuple(MaxAbsScaler, column_label)"""
716- col_label = sql [sql .find ("STRUCT(" ) + 7 : sql .find (")" )]
693+ col_labels = sql [sql .find ("STRUCT(" ) + 7 : sql .find (")" )].split ("," )
694+ col_labels = [label .strip () for label in col_labels ]
717695 degree = int (sql [sql .rfind ("," ) + 1 : sql .rfind (")" )])
718- return cls (degree ), col_label
696+ return cls (degree ), tuple ( col_labels )
719697
720698 def fit (
721699 self ,
@@ -762,8 +740,6 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
762740 df [self ._output_names ],
763741 )
764742
765- # TODO(garrettwu): to_gbq()
766-
767743
768744PreprocessingType = Union [
769745 OneHotEncoder ,
@@ -772,4 +748,5 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
772748 MinMaxScaler ,
773749 KBinsDiscretizer ,
774750 LabelEncoder ,
751+ PolynomialFeatures ,
775752]
0 commit comments