Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4188e9f
ENH: Implement EA unary ops
sinhrks Oct 24, 2018
6e5cf3d
Fixed test for older NumPy ver
sinhrks Jan 15, 2019
3b3003f
Fixed tests for latest numpy
sinhrks Jan 15, 2019
e0b37ae
updated doc
sinhrks Jan 15, 2019
6476d0d
Additional fixes
sinhrks Jan 17, 2019
528e4d1
Remove unary method from DatetimeArray
sinhrks Jan 22, 2019
dbdce23
move __pos__ logic to ExtensionArray
sinhrks Jan 23, 2019
af84599
Merge remote-tracking branch 'upstream/master' into unary
WillAyd Jun 8, 2019
b0975f4
Reverted bad merge of 24 whatsnew
WillAyd Jun 8, 2019
8e3638c
Removed Panel-specific code
WillAyd Jun 8, 2019
63430e7
Added stdlib operator import
WillAyd Jun 10, 2019
9fe6085
Merge remote-tracking branch 'upstream/master' into unary
WillAyd Jun 10, 2019
b6430f1
Removed old np13 check
WillAyd Jun 10, 2019
736fc27
Fixed wrong error msg expectation
WillAyd Jun 10, 2019
87b8cab
lint fixup
WillAyd Jun 10, 2019
c6437ba
Merge remote-tracking branch 'upstream/master' into unary
WillAyd Jun 11, 2019
91a7fec
Added warnings catch for operator pos
WillAyd Jun 11, 2019
b221fd0
lint fixup
WillAyd Jun 11, 2019
0afef3c
Reverted warnings catch
WillAyd Jun 11, 2019
4c322b6
Reimplemented warnings catch for np dev
WillAyd Jun 11, 2019
2b480be
test fixup
WillAyd Jun 11, 2019
ad62830
Used OrderedDict for py35 compat
WillAyd Jun 11, 2019
0d6e0c5
Changed to tm.assert_produces_warning
WillAyd Jun 11, 2019
4f6d656
Merge remote-tracking branch 'upstream/master' into unary
WillAyd Jun 30, 2019
536d002
Fixed bad imports
WillAyd Jun 30, 2019
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ the output will truncate, if it's wider than :attr:`options.display.width`
(default: 80 characters).


.. _whatsnew_0250.enhancements.extension_array_operators:

``ExtensionArray`` operator support
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A ``Series`` based on an ``ExtensionArray`` now supports arithmetic, comparison and unary
operators (:issue:`19577`, :issue:`23313`). There are two approaches for providing operator support for an ``ExtensionArray``:

1. Define each of the operators on your ``ExtensionArray`` subclass.
2. Use an operator implementation from pandas that depends on operators that are already defined
on the underlying elements (scalars) of the ``ExtensionArray``.

See the :ref:`ExtensionArray Operator Support
<extending.extension.operator>` documentation section for details on both
ways of adding operator support.

.. _whatsnew_0250.enhancements.other:

Other enhancements
Expand Down
60 changes: 60 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ def all_compare_operators(request):
return request.param


@pytest.fixture(params=['__pos__', '__neg__', '__inv__', '__invert__',
'__abs__'])
def all_unary_operators(request):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this have a name reflecting the fact that these are the strings and not operator.pos, operator.neg, etc?

"""
Fixture for dunder names for common unary operations

* +
* -
* ~
* abs
"""
return request.param


@pytest.fixture(params=['__le__', '__lt__', '__ge__', '__gt__'])
def compare_operators_no_eq_ne(request):
"""
Expand Down Expand Up @@ -413,6 +427,7 @@ def tz_aware_fixture(request):
SIGNED_EA_INT_DTYPES = ["Int8", "Int16", "Int32", "Int64"]
ALL_INT_DTYPES = UNSIGNED_INT_DTYPES + SIGNED_INT_DTYPES
ALL_EA_INT_DTYPES = UNSIGNED_EA_INT_DTYPES + SIGNED_EA_INT_DTYPES
ALL_NUMPY_EA_INT_DTYPES = ALL_INT_DTYPES + ALL_EA_INT_DTYPES
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine


FLOAT_DTYPES = [float, "float32", "float64"]
COMPLEX_DTYPES = [complex, "complex64", "complex128"]
Expand Down Expand Up @@ -556,6 +571,51 @@ def any_int_dtype(request):
return request.param


@pytest.fixture(params=ALL_EA_INT_DTYPES)
def any_ea_int_dtype(request):
"""
Parameterized fixture for any integer dtype.

* 'Int8'
* 'UInt8'
* 'Int16'
* 'UInt16'
* 'Int32'
* 'UInt32'
* 'Int64'
* 'UInt64'
"""

return request.param


@pytest.fixture(params=ALL_NUMPY_EA_INT_DTYPES)
def any_numpy_ea_int_dtype(request):
"""
Parameterized fixture for any integer dtype.

* int
* 'int8'
* 'uint8'
* 'int16'
* 'uint16'
* 'int32'
* 'uint32'
* 'int64'
* 'uint64'
* 'Int8'
* 'UInt8'
* 'Int16'
* 'UInt16'
* 'Int32'
* 'UInt32'
* 'Int64'
* 'UInt64'
"""

return request.param


@pytest.fixture(params=ALL_REAL_DTYPES)
def any_real_dtype(request):
"""
Expand Down
90 changes: 74 additions & 16 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,14 @@ def _add_comparison_ops(cls):
cls.__le__ = cls._create_comparison_method(operator.le)
cls.__ge__ = cls._create_comparison_method(operator.ge)

@classmethod
def _add_unary_ops(cls):
cls.__pos__ = cls._create_unary_method(operator.pos)
cls.__neg__ = cls._create_unary_method(operator.neg)
cls.__inv__ = cls._create_unary_method(operator.invert)
cls.__invert__ = cls._create_unary_method(operator.invert)
cls.__abs__ = cls._create_unary_method(operator.abs)


class ExtensionScalarOpsMixin(ExtensionOpsMixin):
"""
Expand Down Expand Up @@ -1039,7 +1047,7 @@ def _create_method(cls, op, coerce_to_dtype=True):
Parameters
----------
op : function
An operator that takes arguments op(a, b)
An operator that takes binary arguments op(a, b)
coerce_to_dtype : bool, default True
boolean indicating whether to attempt to convert
the result to the underlying ExtensionArray dtype.
Expand Down Expand Up @@ -1087,24 +1095,12 @@ def convert_values(param):
# a TypeError should be raised
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]

def _maybe_convert(arr):
if coerce_to_dtype:
# https://github.com/pandas-dev/pandas/issues/22850
# We catch all regular exceptions here, and fall back
# to an ndarray.
try:
res = self._from_sequence(arr)
except Exception:
res = np.asarray(arr)
else:
res = np.asarray(arr)
return res

if op.__name__ in {'divmod', 'rdivmod'}:
a, b = zip(*res)
res = _maybe_convert(a), _maybe_convert(b)
res = (self._maybe_convert(a, coerce_to_dtype),
self._maybe_convert(b, coerce_to_dtype))
else:
res = _maybe_convert(res)
res = self._maybe_convert(res, coerce_to_dtype)
return res

op_name = ops._get_op_name(op, True)
Expand All @@ -1117,3 +1113,65 @@ def _create_arithmetic_method(cls, op):
@classmethod
def _create_comparison_method(cls, op):
return cls._create_method(op, coerce_to_dtype=False)

@classmethod
def _create_unary_method(cls, op, coerce_to_dtype=True):
"""
A class method that returns a method that will correspond to an
operator for an ExtensionArray subclass, by dispatching to the
relevant operator defined on the individual elements of the
ExtensionArray.

Parameters
----------
op : function
An operator that takes unary argument op(a)
coerce_to_dtype : bool, default True
boolean indicating whether to attempt to convert
the result to the underlying ExtensionArray dtype.
If it's not possible to create a new ExtensionArray with the
values, an ndarray is returned instead.

Returns
-------
Callable[[Any], Union[ndarray, ExtensionArray]]
A method that can be bound to a class. When used, the method
receives the instance of this class, and should return an
ExtensionArray or an ndarray.

Returning an ndarray may be necessary when the result of the
`op` cannot be stored in the ExtensionArray. The dtype of the
ndarray uses NumPy's normal inference rules.

Examples
--------
Given an ExtensionArray subclass called MyExtensionArray, use
>>> __neg__ = cls._create_method(operator.neg)

in the class definition of MyExtensionArray to create the operator
for negative, that will be based on the operator implementation
of the underlying elements of the ExtensionArray
"""

def _unaryop(self):
# If the operator is not defined for the underlying objects,
# a TypeError should be raised
res = [op(a) for a in self]
res = self._maybe_convert(res, coerce_to_dtype)
return res

op_name = ops._get_op_name(op, True)
return set_function_name(_unaryop, op_name, cls)

def _maybe_convert(self, arr, coerce_to_dtype):
if coerce_to_dtype:
# https://github.com/pandas-dev/pandas/issues/22850
# We catch all regular exceptions here, and fall back
# to an ndarray.
try:
res = self._from_sequence(arr)
except Exception:
res = np.asarray(arr)
else:
res = np.asarray(arr)
return res
4 changes: 4 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,10 @@ def _add_delta(self, delta):
new_values = super()._add_delta(delta)
return type(self)._from_sequence(new_values, tz=self.tz, freq='infer')

def __pos__(self):
raise TypeError("Unary plus expects numeric dtype, not {}"
.format(self.dtype))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for PeriodArray?

# -----------------------------------------------------------------
# Timezone Conversion and Localization Methods

Expand Down
11 changes: 11 additions & 0 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,20 @@ def integer_arithmetic_method(self, other):
name = '__{name}__'.format(name=op.__name__)
return set_function_name(integer_arithmetic_method, name, cls)

@classmethod
def _create_unary_method(cls, op):
def integer_unary_method(self):
with np.errstate(all='ignore'):
result = op(self._data)
return type(self)(result, self._mask)

name = '__{name}__'.format(name=op.__name__)
return set_function_name(integer_unary_method, name, cls)


IntegerArray._add_arithmetic_ops()
IntegerArray._add_comparison_ops()
IntegerArray._add_unary_ops()


module = sys.modules[__name__]
Expand Down
4 changes: 1 addition & 3 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,9 +1723,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
return type(self)(result)

def __abs__(self):
return np.abs(self)

# ------------------------------------------------------------------------
# Ops
# ------------------------------------------------------------------------
Expand Down Expand Up @@ -1828,6 +1825,7 @@ def _add_unary_ops(cls):
cls.__pos__ = cls._create_unary_method(operator.pos)
cls.__neg__ = cls._create_unary_method(operator.neg)
cls.__invert__ = cls._create_unary_method(operator.invert)
cls.__abs__ = cls._create_unary_method(operator.abs)

@classmethod
def _add_comparison_ops(cls):
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,9 @@ def __rdivmod__(self, other):
return res1, res2

# Note: TimedeltaIndex overrides this in call to cls._add_numeric_methods
def __pos__(self):
return self.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a copy or view?


def __neg__(self):
if self.freq is not None:
return type(self)(-self._data, freq=-self.freq)
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3302,7 +3302,7 @@ def _setitem_frame(self, key, value):

self._check_inplace_setting(value)
self._check_setitem_copy()
self._where(-key, value, inplace=True)
self._where(~key, value, inplace=True)

def _ensure_valid_index(self, value):
"""
Expand Down Expand Up @@ -4587,11 +4587,11 @@ def drop_duplicates(self, subset=None, keep='first', inplace=False):
duplicated = self.duplicated(subset, keep=keep)

if inplace:
inds, = (-duplicated)._ndarray_values.nonzero()
inds, = (~duplicated)._ndarray_values.nonzero()
new_data = self._data.take(inds)
self._update_inplace(new_data)
else:
return self[-duplicated]
return self[~duplicated]

def duplicated(self, subset=None, keep='first'):
"""
Expand Down
Loading