Skip to content

Commit

Permalink
implement masked_arith_op to de-duplicate ops code (pandas-dev#22182)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and victor committed Sep 30, 2018
1 parent 6281f75 commit afe0011
Showing 1 changed file with 56 additions and 50 deletions.
106 changes: 56 additions & 50 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pandas.core.dtypes.common import (
needs_i8_conversion,
is_datetimelike_v_numeric,
is_period_dtype,
is_integer_dtype, is_categorical_dtype,
is_object_dtype, is_timedelta64_dtype,
is_datetime64_dtype, is_datetime64tz_dtype,
Expand All @@ -41,7 +42,7 @@
from pandas.core.dtypes.generic import (
ABCSeries,
ABCDataFrame, ABCPanel,
ABCIndex,
ABCIndex, ABCIndexClass,
ABCSparseSeries, ABCSparseArray)


Expand Down Expand Up @@ -788,6 +789,57 @@ def mask_cmp_op(x, y, op, allowed_types):
return result


def masked_arith_op(x, y, op):
"""
If the given arithmetic operation fails, attempt it again on
only the non-null elements of the input array(s).
Parameters
----------
x : np.ndarray
y : np.ndarray, Series, Index
op : binary operator
"""
# For Series `x` is 1D so ravel() is a no-op; calling it anyway makes
# the logic valid for both Series and DataFrame ops.
xrav = x.ravel()
assert isinstance(x, (np.ndarray, ABCSeries)), type(x)
if isinstance(y, (np.ndarray, ABCSeries, ABCIndexClass)):
dtype = find_common_type([x.dtype, y.dtype])
result = np.empty(x.size, dtype=dtype)

# PeriodIndex.ravel() returns int64 dtype, so we have
# to work around that case. See GH#19956
yrav = y if is_period_dtype(y) else y.ravel()
mask = notna(xrav) & notna(yrav)

if yrav.shape != mask.shape:
# FIXME: GH#5284, GH#5035, GH#19448
# Without specifically raising here we get mismatched
# errors in Py3 (TypeError) vs Py2 (ValueError)
# Note: Only = an issue in DataFrame case
raise ValueError('Cannot broadcast operands together.')

if mask.any():
with np.errstate(all='ignore'):
result[mask] = op(xrav[mask],
com.values_from_object(yrav[mask]))

else:
assert is_scalar(y), type(y)
assert isinstance(x, np.ndarray), type(x)
# mask is only meaningful for x
result = np.empty(x.size, dtype=x.dtype)
mask = notna(xrav)
if mask.any():
with np.errstate(all='ignore'):
result[mask] = op(xrav[mask], y)

result, changed = maybe_upcast_putmask(result, ~mask, np.nan)
result = result.reshape(x.shape) # 2D compat
return result


def invalid_comparison(left, right, op):
"""
If a comparison has mismatched types and is not necessarily meaningful,
Expand Down Expand Up @@ -880,8 +932,7 @@ def _get_method_wrappers(cls):
return arith_flex, comp_flex, arith_special, comp_special, bool_special


def _create_methods(cls, arith_method, comp_method, bool_method,
special=False):
def _create_methods(cls, arith_method, comp_method, bool_method, special):
# creates actual methods based upon arithmetic, comp and bool method
# constructors.

Expand Down Expand Up @@ -1136,19 +1187,7 @@ def na_op(x, y):
try:
result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs)
except TypeError:
if isinstance(y, (np.ndarray, ABCSeries, pd.Index)):
dtype = find_common_type([x.dtype, y.dtype])
result = np.empty(x.size, dtype=dtype)
mask = notna(x) & notna(y)
result[mask] = op(x[mask], com.values_from_object(y[mask]))
else:
assert isinstance(x, np.ndarray)
assert is_scalar(y)
result = np.empty(len(x), dtype=x.dtype)
mask = notna(x)
result[mask] = op(x[mask], y)

result, changed = maybe_upcast_putmask(result, ~mask, np.nan)
result = masked_arith_op(x, y, op)

result = missing.fill_zeros(result, x, y, op_name, fill_zeros)
return result
Expand Down Expand Up @@ -1675,40 +1714,7 @@ def na_op(x, y):
try:
result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs)
except TypeError:
xrav = x.ravel()
if isinstance(y, (np.ndarray, ABCSeries)):
dtype = find_common_type([x.dtype, y.dtype])
result = np.empty(x.size, dtype=dtype)
yrav = y.ravel()
mask = notna(xrav) & notna(yrav)
xrav = xrav[mask]

if yrav.shape != mask.shape:
# FIXME: GH#5284, GH#5035, GH#19448
# Without specifically raising here we get mismatched
# errors in Py3 (TypeError) vs Py2 (ValueError)
raise ValueError('Cannot broadcast operands together.')

yrav = yrav[mask]
if xrav.size:
with np.errstate(all='ignore'):
result[mask] = op(xrav, yrav)

elif isinstance(x, np.ndarray):
# mask is only meaningful for x
result = np.empty(x.size, dtype=x.dtype)
mask = notna(xrav)
xrav = xrav[mask]
if xrav.size:
with np.errstate(all='ignore'):
result[mask] = op(xrav, y)
else:
raise TypeError("cannot perform operation {op} between "
"objects of type {x} and {y}"
.format(op=op_name, x=type(x), y=type(y)))

result, changed = maybe_upcast_putmask(result, ~mask, np.nan)
result = result.reshape(x.shape)
result = masked_arith_op(x, y, op)

result = missing.fill_zeros(result, x, y, op_name, fill_zeros)

Expand Down

0 comments on commit afe0011

Please sign in to comment.