|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
3 | 3 |
|
| 4 | +from pandas.core.dtypes.common import pandas_dtype |
| 5 | + |
4 | 6 | from pandas import ( |
5 | 7 | NA, |
6 | 8 | DataFrame, |
|
19 | 21 | def get_dtype(dtype, coerce_int=None): |
20 | 22 | if coerce_int is False and "int" in dtype: |
21 | 23 | return None |
22 | | - if dtype != "category": |
23 | | - return np.dtype(dtype) |
24 | | - return dtype |
| 24 | + return pandas_dtype(dtype) |
25 | 25 |
|
26 | 26 |
|
27 | 27 | @pytest.mark.parametrize( |
@@ -66,21 +66,23 @@ def get_dtype(dtype, coerce_int=None): |
66 | 66 | ], |
67 | 67 | ) |
68 | 68 | def test_series_dtypes(method, data, expected_data, coerce_int, dtypes, min_periods): |
69 | | - s = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int)) |
70 | | - if dtypes in ("m8[ns]", "M8[ns]") and method != "count": |
| 69 | + ser = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int)) |
| 70 | + rolled = ser.rolling(2, min_periods=min_periods) |
| 71 | + |
| 72 | + if dtypes in ("m8[ns]", "M8[ns]", "datetime64[ns, UTC]") and method != "count": |
71 | 73 | msg = "No numeric types to aggregate" |
72 | 74 | with pytest.raises(DataError, match=msg): |
73 | | - getattr(s.rolling(2, min_periods=min_periods), method)() |
| 75 | + getattr(rolled, method)() |
74 | 76 | else: |
75 | | - result = getattr(s.rolling(2, min_periods=min_periods), method)() |
| 77 | + result = getattr(rolled, method)() |
76 | 78 | expected = Series(expected_data, dtype="float64") |
77 | 79 | tm.assert_almost_equal(result, expected) |
78 | 80 |
|
79 | 81 |
|
80 | 82 | def test_series_nullable_int(any_signed_int_ea_dtype): |
81 | 83 | # GH 43016 |
82 | | - s = Series([0, 1, NA], dtype=any_signed_int_ea_dtype) |
83 | | - result = s.rolling(2).mean() |
| 84 | + ser = Series([0, 1, NA], dtype=any_signed_int_ea_dtype) |
| 85 | + result = ser.rolling(2).mean() |
84 | 86 | expected = Series([np.nan, 0.5, np.nan]) |
85 | 87 | tm.assert_series_equal(result, expected) |
86 | 88 |
|
@@ -130,14 +132,15 @@ def test_series_nullable_int(any_signed_int_ea_dtype): |
130 | 132 | ], |
131 | 133 | ) |
132 | 134 | def test_dataframe_dtypes(method, expected_data, dtypes, min_periods): |
133 | | - if dtypes == "category": |
134 | | - pytest.skip("Category dataframe testing not implemented.") |
| 135 | + |
135 | 136 | df = DataFrame(np.arange(10).reshape((5, 2)), dtype=get_dtype(dtypes)) |
136 | | - if dtypes in ("m8[ns]", "M8[ns]") and method != "count": |
| 137 | + rolled = df.rolling(2, min_periods=min_periods) |
| 138 | + |
| 139 | + if dtypes in ("m8[ns]", "M8[ns]", "datetime64[ns, UTC]") and method != "count": |
137 | 140 | msg = "No numeric types to aggregate" |
138 | 141 | with pytest.raises(DataError, match=msg): |
139 | | - getattr(df.rolling(2, min_periods=min_periods), method)() |
| 142 | + getattr(rolled, method)() |
140 | 143 | else: |
141 | | - result = getattr(df.rolling(2, min_periods=min_periods), method)() |
| 144 | + result = getattr(rolled, method)() |
142 | 145 | expected = DataFrame(expected_data, dtype="float64") |
143 | 146 | tm.assert_frame_equal(result, expected) |
0 commit comments