1616 time ,
1717 timedelta ,
1818)
19+ from decimal import Decimal
1920from io import (
2021 BytesIO ,
2122 StringIO ,
@@ -79,6 +80,14 @@ def data(dtype):
7980 data = [1 , 0 ] * 4 + [None ] + [- 2 , - 1 ] * 44 + [None ] + [1 , 99 ]
8081 elif pa .types .is_unsigned_integer (pa_dtype ):
8182 data = [1 , 0 ] * 4 + [None ] + [2 , 1 ] * 44 + [None ] + [1 , 99 ]
83+ elif pa .types .is_decimal (pa_dtype ):
84+ data = (
85+ [Decimal ("1" ), Decimal ("0.0" )] * 4
86+ + [None ]
87+ + [Decimal ("-2.0" ), Decimal ("-1.0" )] * 44
88+ + [None ]
89+ + [Decimal ("0.5" ), Decimal ("33.123" )]
90+ )
8291 elif pa .types .is_date (pa_dtype ):
8392 data = (
8493 [date (2022 , 1 , 1 ), date (1999 , 12 , 31 )] * 4
@@ -188,6 +197,10 @@ def data_for_grouping(dtype):
188197 A = b"a"
189198 B = b"b"
190199 C = b"c"
200+ elif pa .types .is_decimal (pa_dtype ):
201+ A = Decimal ("-1.1" )
202+ B = Decimal ("0.0" )
203+ C = Decimal ("1.1" )
191204 else :
192205 raise NotImplementedError
193206 return pd .array ([B , B , None , None , A , A , B , C ], dtype = dtype )
@@ -250,17 +263,20 @@ def test_astype_str(self, data, request):
250263class TestConstructors (base .BaseConstructorsTests ):
251264 def test_from_dtype (self , data , request ):
252265 pa_dtype = data .dtype .pyarrow_dtype
266+ if pa .types .is_string (pa_dtype ) or pa .types .is_decimal (pa_dtype ):
267+ if pa .types .is_string (pa_dtype ):
268+ reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
269+ else :
270+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } "
253271
254- if pa .types .is_string (pa_dtype ):
255- reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
256272 request .node .add_marker (
257273 pytest .mark .xfail (
258274 reason = reason ,
259275 )
260276 )
261277 super ().test_from_dtype (data )
262278
263- def test_from_sequence_pa_array (self , data , request ):
279+ def test_from_sequence_pa_array (self , data ):
264280 # https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
265281 # data._data = pa.ChunkedArray
266282 result = type (data )._from_sequence (data ._data )
@@ -285,7 +301,9 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
285301 reason = "Nanosecond time parsing not supported." ,
286302 )
287303 )
288- elif pa_version_under11p0 and pa .types .is_duration (pa_dtype ):
304+ elif pa_version_under11p0 and (
305+ pa .types .is_duration (pa_dtype ) or pa .types .is_decimal (pa_dtype )
306+ ):
289307 request .node .add_marker (
290308 pytest .mark .xfail (
291309 raises = pa .ArrowNotImplementedError ,
@@ -392,7 +410,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
392410 raises = NotImplementedError ,
393411 )
394412 )
395- elif all_numeric_accumulations == "cumsum" and (pa .types .is_boolean (pa_type )):
413+ elif all_numeric_accumulations == "cumsum" and (
414+ pa .types .is_boolean (pa_type ) or pa .types .is_decimal (pa_type )
415+ ):
396416 request .node .add_marker (
397417 pytest .mark .xfail (
398418 reason = f"{ all_numeric_accumulations } not implemented for { pa_type } " ,
@@ -476,6 +496,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
476496 )
477497 if all_numeric_reductions in {"skew" , "kurt" }:
478498 request .node .add_marker (xfail_mark )
499+ elif (
500+ all_numeric_reductions in {"var" , "std" , "median" }
501+ and pa_version_under7p0
502+ and pa .types .is_decimal (pa_dtype )
503+ ):
504+ request .node .add_marker (xfail_mark )
479505 elif all_numeric_reductions == "sem" and pa_version_under8p0 :
480506 request .node .add_marker (xfail_mark )
481507
@@ -598,8 +624,26 @@ def test_in_numeric_groupby(self, data_for_grouping):
598624
599625
600626class TestBaseDtype (base .BaseDtypeTests ):
627+ def test_check_dtype (self , data , request ):
628+ pa_dtype = data .dtype .pyarrow_dtype
629+ if pa .types .is_decimal (pa_dtype ) and pa_version_under8p0 :
630+ request .node .add_marker (
631+ pytest .mark .xfail (
632+ raises = ValueError ,
633+ reason = "decimal string repr affects numpy comparison" ,
634+ )
635+ )
636+ super ().test_check_dtype (data )
637+
601638 def test_construct_from_string_own_name (self , dtype , request ):
602639 pa_dtype = dtype .pyarrow_dtype
640+ if pa .types .is_decimal (pa_dtype ):
641+ request .node .add_marker (
642+ pytest .mark .xfail (
643+ raises = NotImplementedError ,
644+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } " ,
645+ )
646+ )
603647
604648 if pa .types .is_string (pa_dtype ):
605649 # We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
@@ -617,6 +661,13 @@ def test_is_dtype_from_name(self, dtype, request):
617661 # We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
618662 assert not type (dtype ).is_dtype (dtype .name )
619663 else :
664+ if pa .types .is_decimal (pa_dtype ):
665+ request .node .add_marker (
666+ pytest .mark .xfail (
667+ raises = NotImplementedError ,
668+ reason = f"pyarrow.type_for_alias cannot infer { pa_dtype } " ,
669+ )
670+ )
620671 super ().test_is_dtype_from_name (dtype )
621672
622673 def test_construct_from_string_another_type_raises (self , dtype ):
@@ -635,6 +686,7 @@ def test_get_common_dtype(self, dtype, request):
635686 )
636687 or (pa .types .is_duration (pa_dtype ) and pa_dtype .unit != "ns" )
637688 or pa .types .is_binary (pa_dtype )
689+ or pa .types .is_decimal (pa_dtype )
638690 ):
639691 request .node .add_marker (
640692 pytest .mark .xfail (
@@ -708,6 +760,13 @@ def test_EA_types(self, engine, data, request):
708760 request .node .add_marker (
709761 pytest .mark .xfail (raises = TypeError , reason = "GH 47534" )
710762 )
763+ elif pa .types .is_decimal (pa_dtype ):
764+ request .node .add_marker (
765+ pytest .mark .xfail (
766+ raises = NotImplementedError ,
767+ reason = f"Parameterized types { pa_dtype } not supported." ,
768+ )
769+ )
711770 elif pa .types .is_timestamp (pa_dtype ) and pa_dtype .unit in ("us" , "ns" ):
712771 request .node .add_marker (
713772 pytest .mark .xfail (
@@ -790,6 +849,13 @@ def test_argmin_argmax(
790849 reason = f"{ pa_dtype } only has 2 unique possible values" ,
791850 )
792851 )
852+ elif pa .types .is_decimal (pa_dtype ) and pa_version_under7p0 :
853+ request .node .add_marker (
854+ pytest .mark .xfail (
855+ reason = f"No pyarrow kernel for { pa_dtype } " ,
856+ raises = pa .ArrowNotImplementedError ,
857+ )
858+ )
793859 super ().test_argmin_argmax (data_for_sorting , data_missing_for_sorting , na_value )
794860
795861 @pytest .mark .parametrize (
@@ -808,6 +874,14 @@ def test_argmin_argmax(
808874 def test_argreduce_series (
809875 self , data_missing_for_sorting , op_name , skipna , expected , request
810876 ):
877+ pa_dtype = data_missing_for_sorting .dtype .pyarrow_dtype
878+ if pa .types .is_decimal (pa_dtype ) and pa_version_under7p0 and skipna :
879+ request .node .add_marker (
880+ pytest .mark .xfail (
881+ reason = f"No pyarrow kernel for { pa_dtype } " ,
882+ raises = pa .ArrowNotImplementedError ,
883+ )
884+ )
811885 super ().test_argreduce_series (
812886 data_missing_for_sorting , op_name , skipna , expected
813887 )
@@ -906,6 +980,21 @@ def test_basic_equals(self, data):
906980class TestBaseArithmeticOps (base .BaseArithmeticOpsTests ):
907981 divmod_exc = NotImplementedError
908982
983+ @classmethod
984+ def assert_equal (cls , left , right , ** kwargs ):
985+ if isinstance (left , pd .DataFrame ):
986+ left_pa_type = left .iloc [:, 0 ].dtype .pyarrow_dtype
987+ right_pa_type = right .iloc [:, 0 ].dtype .pyarrow_dtype
988+ else :
989+ left_pa_type = left .dtype .pyarrow_dtype
990+ right_pa_type = right .dtype .pyarrow_dtype
991+ if pa .types .is_decimal (left_pa_type ) or pa .types .is_decimal (right_pa_type ):
992+ # decimal precision can resize in the result type depending on data
993+ # just compare the float values
994+ left = left .astype ("float[pyarrow]" )
995+ right = right .astype ("float[pyarrow]" )
996+ tm .assert_equal (left , right , ** kwargs )
997+
909998 def get_op_from_name (self , op_name ):
910999 short_opname = op_name .strip ("_" )
9111000 if short_opname == "rtruediv" :
@@ -975,7 +1064,11 @@ def _get_scalar_exception(self, opname, pa_dtype):
9751064 pa .types .is_string (pa_dtype ) or pa .types .is_binary (pa_dtype )
9761065 ):
9771066 exc = None
978- elif not (pa .types .is_floating (pa_dtype ) or pa .types .is_integer (pa_dtype )):
1067+ elif not (
1068+ pa .types .is_floating (pa_dtype )
1069+ or pa .types .is_integer (pa_dtype )
1070+ or pa .types .is_decimal (pa_dtype )
1071+ ):
9791072 exc = pa .ArrowNotImplementedError
9801073 else :
9811074 exc = None
@@ -988,7 +1081,11 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
9881081
9891082 if (
9901083 opname == "__rpow__"
991- and (pa .types .is_floating (pa_dtype ) or pa .types .is_integer (pa_dtype ))
1084+ and (
1085+ pa .types .is_floating (pa_dtype )
1086+ or pa .types .is_integer (pa_dtype )
1087+ or pa .types .is_decimal (pa_dtype )
1088+ )
9921089 and not pa_version_under7p0
9931090 ):
9941091 mark = pytest .mark .xfail (
@@ -1006,14 +1103,32 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
10061103 ),
10071104 )
10081105 elif (
1009- opname in { "__rfloordiv__" }
1010- and pa .types .is_integer (pa_dtype )
1106+ opname == "__rfloordiv__"
1107+ and ( pa .types .is_integer (pa_dtype ) or pa . types . is_decimal ( pa_dtype ) )
10111108 and not pa_version_under7p0
10121109 ):
10131110 mark = pytest .mark .xfail (
10141111 raises = pa .ArrowInvalid ,
10151112 reason = "divide by 0" ,
10161113 )
1114+ elif (
1115+ opname == "__rtruediv__"
1116+ and pa .types .is_decimal (pa_dtype )
1117+ and not pa_version_under7p0
1118+ ):
1119+ mark = pytest .mark .xfail (
1120+ raises = pa .ArrowInvalid ,
1121+ reason = "divide by 0" ,
1122+ )
1123+ elif (
1124+ opname == "__pow__"
1125+ and pa .types .is_decimal (pa_dtype )
1126+ and pa_version_under7p0
1127+ ):
1128+ mark = pytest .mark .xfail (
1129+ raises = pa .ArrowInvalid ,
1130+ reason = "Invalid decimal function: power_checked" ,
1131+ )
10171132
10181133 return mark
10191134
@@ -1231,6 +1346,9 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
12311346 expected = ArrowDtype (pa .timestamp ("s" , "UTC" ))
12321347 assert dtype == expected
12331348
1349+ with pytest .raises (NotImplementedError , match = "Passing pyarrow type" ):
1350+ ArrowDtype .construct_from_string ("decimal(7, 2)[pyarrow]" )
1351+
12341352
12351353@pytest .mark .parametrize (
12361354 "interpolation" , ["linear" , "lower" , "higher" , "nearest" , "midpoint" ]
@@ -1257,7 +1375,11 @@ def test_quantile(data, interpolation, quantile, request):
12571375 ser .quantile (q = quantile , interpolation = interpolation )
12581376 return
12591377
1260- if pa .types .is_integer (pa_dtype ) or pa .types .is_floating (pa_dtype ):
1378+ if (
1379+ pa .types .is_integer (pa_dtype )
1380+ or pa .types .is_floating (pa_dtype )
1381+ or (pa .types .is_decimal (pa_dtype ) and not pa_version_under7p0 )
1382+ ):
12611383 pass
12621384 elif pa .types .is_temporal (data ._data .type ):
12631385 pass
@@ -1298,7 +1420,11 @@ def test_quantile(data, interpolation, quantile, request):
12981420 else :
12991421 # Just check the values
13001422 expected = pd .Series (data .take ([0 , 0 ]), index = [0.5 , 0.5 ])
1301- if pa .types .is_integer (pa_dtype ) or pa .types .is_floating (pa_dtype ):
1423+ if (
1424+ pa .types .is_integer (pa_dtype )
1425+ or pa .types .is_floating (pa_dtype )
1426+ or pa .types .is_decimal (pa_dtype )
1427+ ):
13021428 expected = expected .astype ("float64[pyarrow]" )
13031429 result = result .astype ("float64[pyarrow]" )
13041430 tm .assert_series_equal (result , expected )
0 commit comments