@@ -3433,9 +3433,9 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
34333433 raise ValueError (f"na_action={ na_action } not supported" )
34343434
34353435 # TODO(shobs): Support **kwargs
3436- # Reproject as workaround to applying filter too late. This forces the filter
3437- # to be applied before passing data to remote function, protecting from bad
3438- # inputs causing errors.
3436+ # Reproject as workaround to applying filter too late. This forces the
3437+ # filter to be applied before passing data to remote function,
3438+ # protecting from bad inputs causing errors.
34393439 reprojected_df = DataFrame (self ._block ._force_reproject ())
34403440 return reprojected_df ._apply_unary_op (
34413441 ops .RemoteFunctionOp (func = func , apply_on_null = (na_action is None ))
@@ -3448,65 +3448,99 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
34483448 category = bigframes .exceptions .PreviewWarning ,
34493449 )
34503450
3451- # Early check whether the dataframe dtypes are currently supported
3452- # in the remote function
3453- # NOTE: Keep in sync with the value converters used in the gcf code
3454- # generated in remote_function_template.py
3455- remote_function_supported_dtypes = (
3456- bigframes .dtypes .INT_DTYPE ,
3457- bigframes .dtypes .FLOAT_DTYPE ,
3458- bigframes .dtypes .BOOL_DTYPE ,
3459- bigframes .dtypes .BYTES_DTYPE ,
3460- bigframes .dtypes .STRING_DTYPE ,
3461- )
3462- supported_dtypes_types = tuple (
3463- type (dtype )
3464- for dtype in remote_function_supported_dtypes
3465- if not isinstance (dtype , pandas .ArrowDtype )
3466- )
3467- # Check ArrowDtype separately since multiple BigQuery types map to
3468- # ArrowDtype, including BYTES and TIMESTAMP.
3469- supported_arrow_types = tuple (
3470- dtype .pyarrow_dtype
3471- for dtype in remote_function_supported_dtypes
3472- if isinstance (dtype , pandas .ArrowDtype )
3473- )
3474- supported_dtypes_hints = tuple (
3475- str (dtype ) for dtype in remote_function_supported_dtypes
3476- )
3477-
3478- for dtype in self .dtypes :
3479- if (
3480- # Not one of the pandas/numpy types.
3481- not isinstance (dtype , supported_dtypes_types )
3482- # And not one of the arrow types.
3483- and not (
3484- isinstance (dtype , pandas .ArrowDtype )
3485- and any (
3486- dtype .pyarrow_dtype .equals (arrow_type )
3487- for arrow_type in supported_arrow_types
3488- )
3489- )
3490- ):
3491- raise NotImplementedError (
3492- f"DataFrame has a column of dtype '{ dtype } ' which is not supported with axis=1."
3493- f" Supported dtypes are { supported_dtypes_hints } ."
3494- )
3495-
34963451 # Check if the function is a remote function
34973452 if not hasattr (func , "bigframes_remote_function" ):
34983453 raise ValueError ("For axis=1 a remote function must be used." )
34993454
3500- # Serialize the rows as json values
3501- block = self ._get_block ()
3502- rows_as_json_series = bigframes .series .Series (
3503- block ._get_rows_as_json_values ()
3504- )
3455+ is_row_processor = getattr (func , "is_row_processor" )
3456+ if is_row_processor :
3457+ # Early check whether the dataframe dtypes are currently supported
3458+ # in the remote function
3459+ # NOTE: Keep in sync with the value converters used in the gcf code
3460+ # generated in remote_function_template.py
3461+ remote_function_supported_dtypes = (
3462+ bigframes .dtypes .INT_DTYPE ,
3463+ bigframes .dtypes .FLOAT_DTYPE ,
3464+ bigframes .dtypes .BOOL_DTYPE ,
3465+ bigframes .dtypes .BYTES_DTYPE ,
3466+ bigframes .dtypes .STRING_DTYPE ,
3467+ )
3468+ supported_dtypes_types = tuple (
3469+ type (dtype )
3470+ for dtype in remote_function_supported_dtypes
3471+ if not isinstance (dtype , pandas .ArrowDtype )
3472+ )
3473+ # Check ArrowDtype separately since multiple BigQuery types map to
3474+ # ArrowDtype, including BYTES and TIMESTAMP.
3475+ supported_arrow_types = tuple (
3476+ dtype .pyarrow_dtype
3477+ for dtype in remote_function_supported_dtypes
3478+ if isinstance (dtype , pandas .ArrowDtype )
3479+ )
3480+ supported_dtypes_hints = tuple (
3481+ str (dtype ) for dtype in remote_function_supported_dtypes
3482+ )
35053483
3506- # Apply the function
3507- result_series = rows_as_json_series ._apply_unary_op (
3508- ops .RemoteFunctionOp (func = func , apply_on_null = True )
3509- )
3484+ for dtype in self .dtypes :
3485+ if (
3486+ # Not one of the pandas/numpy types.
3487+ not isinstance (dtype , supported_dtypes_types )
3488+ # And not one of the arrow types.
3489+ and not (
3490+ isinstance (dtype , pandas .ArrowDtype )
3491+ and any (
3492+ dtype .pyarrow_dtype .equals (arrow_type )
3493+ for arrow_type in supported_arrow_types
3494+ )
3495+ )
3496+ ):
3497+ raise NotImplementedError (
3498+ f"DataFrame has a column of dtype '{ dtype } ' which is not supported with axis=1."
3499+ f" Supported dtypes are { supported_dtypes_hints } ."
3500+ )
3501+
3502+ # Serialize the rows as json values
3503+ block = self ._get_block ()
3504+ rows_as_json_series = bigframes .series .Series (
3505+ block ._get_rows_as_json_values ()
3506+ )
3507+
3508+ # Apply the function
3509+ result_series = rows_as_json_series ._apply_unary_op (
3510+ ops .RemoteFunctionOp (func = func , apply_on_null = True )
3511+ )
3512+ else :
3513+ # This is a special case where we are providing not-pandas-like
3514+ # extension. If the remote function can take one or more params
3515+ # then we assume that here the user intention is to use the
3516+ # column values of the dataframe as arguments to the function.
3517+ # For this to work the following condition must be true:
3518+ # 1. The number or input params in the function must be same
3519+ # as the number of columns in the dataframe
3520+ # 2. The dtypes of the columns in the dataframe must be
3521+ # compatible with the data types of the input params
3522+ # 3. The order of the columns in the dataframe must correspond
3523+ # to the order of the input params in the function
3524+ udf_input_dtypes = getattr (func , "input_dtypes" )
3525+ if len (udf_input_dtypes ) != len (self .columns ):
3526+ raise ValueError (
3527+ f"Remote function takes { len (udf_input_dtypes )} arguments but DataFrame has { len (self .columns )} columns."
3528+ )
3529+ if udf_input_dtypes != tuple (self .dtypes .to_list ()):
3530+ raise ValueError (
3531+ f"Remote function takes arguments of types { udf_input_dtypes } but DataFrame dtypes are { tuple (self .dtypes )} ."
3532+ )
3533+
3534+ series_list = [self [col ] for col in self .columns ]
3535+ # Reproject as workaround to applying filter too late. This forces the
3536+ # filter to be applied before passing data to remote function,
3537+ # protecting from bad inputs causing errors.
3538+ reprojected_series = bigframes .series .Series (
3539+ series_list [0 ]._block ._force_reproject ()
3540+ )
3541+ result_series = reprojected_series ._apply_nary_op (
3542+ ops .NaryRemoteFunctionOp (func = func ), series_list [1 :]
3543+ )
35103544 result_series .name = None
35113545
35123546 # Return Series with materialized result so that any error in the remote
0 commit comments