@@ -963,3 +963,115 @@ def float_parser(row):
963963 cleanup_function_assets (
964964 float_parser_mf , session .bqclient , ignore_failures = False
965965 )
966+
967+
968+ def test_managed_function_df_where (session , dataset_id , scalars_dfs ):
969+ try :
970+
971+ # The return type has to be bool type for callable where condition.
972+ def is_sum_positive (a , b ):
973+ return a + b > 0
974+
975+ is_sum_positive_mf = session .udf (
976+ input_types = [int , int ],
977+ output_type = bool ,
978+ dataset = dataset_id ,
979+ name = prefixer .create_prefix (),
980+ )(is_sum_positive )
981+
982+ scalars_df , scalars_pandas_df = scalars_dfs
983+ int64_cols = ["int64_col" , "int64_too" ]
984+
985+ bf_int64_df = scalars_df [int64_cols ]
986+ bf_int64_df_filtered = bf_int64_df .dropna ()
987+ pd_int64_df = scalars_pandas_df [int64_cols ]
988+ pd_int64_df_filtered = pd_int64_df .dropna ()
989+
990+ # Use callable condition in dataframe.where method.
991+ bf_result = bf_int64_df_filtered .where (is_sum_positive_mf ).to_pandas ()
992+ # Pandas doesn't support such case, use following as workaround.
993+ pd_result = pd_int64_df_filtered .where (pd_int64_df_filtered .sum (axis = 1 ) > 0 )
994+
995+ # Ignore any dtype difference.
996+ pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
997+
998+ # Make sure the read_gbq_function path works for this function.
999+ is_sum_positive_ref = session .read_gbq_function (
1000+ function_name = is_sum_positive_mf .bigframes_bigquery_function
1001+ )
1002+
1003+ bf_result_gbq = bf_int64_df_filtered .where (
1004+ is_sum_positive_ref , - bf_int64_df_filtered
1005+ ).to_pandas ()
1006+ pd_result_gbq = pd_int64_df_filtered .where (
1007+ pd_int64_df_filtered .sum (axis = 1 ) > 0 , - pd_int64_df_filtered
1008+ )
1009+
1010+ # Ignore any dtype difference.
1011+ pandas .testing .assert_frame_equal (
1012+ bf_result_gbq , pd_result_gbq , check_dtype = False
1013+ )
1014+
1015+ finally :
1016+ # Clean up the gcp assets created for the managed function.
1017+ cleanup_function_assets (
1018+ is_sum_positive_mf , session .bqclient , ignore_failures = False
1019+ )
1020+
1021+
1022+ def test_managed_function_df_where_series (session , dataset_id , scalars_dfs ):
1023+ try :
1024+
1025+ # The return type has to be bool type for callable where condition.
1026+ def is_sum_positive_series (s ):
1027+ return s ["int64_col" ] + s ["int64_too" ] > 0
1028+
1029+ is_sum_positive_series_mf = session .udf (
1030+ input_types = bigframes .series .Series ,
1031+ output_type = bool ,
1032+ dataset = dataset_id ,
1033+ name = prefixer .create_prefix (),
1034+ )(is_sum_positive_series )
1035+
1036+ scalars_df , scalars_pandas_df = scalars_dfs
1037+ int64_cols = ["int64_col" , "int64_too" ]
1038+
1039+ bf_int64_df = scalars_df [int64_cols ]
1040+ bf_int64_df_filtered = bf_int64_df .dropna ()
1041+ pd_int64_df = scalars_pandas_df [int64_cols ]
1042+ pd_int64_df_filtered = pd_int64_df .dropna ()
1043+
1044+ # Use callable condition in dataframe.where method.
1045+ bf_result = bf_int64_df_filtered .where (is_sum_positive_series ).to_pandas ()
1046+ pd_result = pd_int64_df_filtered .where (is_sum_positive_series )
1047+
1048+ # Ignore any dtype difference.
1049+ pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
1050+
1051+ # Make sure the read_gbq_function path works for this function.
1052+ is_sum_positive_series_ref = session .read_gbq_function (
1053+ function_name = is_sum_positive_series_mf .bigframes_bigquery_function ,
1054+ is_row_processor = True ,
1055+ )
1056+
1057+ # This is for callable `other` arg in dataframe.where method.
1058+ def func_for_other (x ):
1059+ return - x
1060+
1061+ bf_result_gbq = bf_int64_df_filtered .where (
1062+ is_sum_positive_series_ref , func_for_other
1063+ ).to_pandas ()
1064+ pd_result_gbq = pd_int64_df_filtered .where (
1065+ is_sum_positive_series , func_for_other
1066+ )
1067+
1068+ # Ignore any dtype difference.
1069+ pandas .testing .assert_frame_equal (
1070+ bf_result_gbq , pd_result_gbq , check_dtype = False
1071+ )
1072+
1073+ finally :
1074+ # Clean up the gcp assets created for the managed function.
1075+ cleanup_function_assets (
1076+ is_sum_positive_series_mf , session .bqclient , ignore_failures = False
1077+ )
0 commit comments