diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index 1bc1a28ef0fe9..a6e32e574ef7d 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -1319,6 +1319,7 @@ Datetimelike - Bug in :class:`DatetimeIndex` and :class:`TimedeltaIndex` where indexing with ``Ellipsis`` would incorrectly lose the index's ``freq`` attribute (:issue:`21282`) - Clarified error message produced when passing an incorrect ``freq`` argument to :class:`DatetimeIndex` with ``NaT`` as the first entry in the passed data (:issue:`11587`) - Bug in :func:`to_datetime` where ``box`` and ``utc`` arguments were ignored when passing a :class:`DataFrame` or ``dict`` of unit mappings (:issue:`23760`) +- Bug in :class:`PeriodIndex` when comparing indexes of different lengths, ValueError is not raised (:issue:`23078`) Timedelta ^^^^^^^^^ diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 5311d6b8d9d90..58c8671472c0a 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -917,6 +917,10 @@ def _add_comparison_ops(cls): cls.__le__ = cls._create_comparison_method(operator.le) cls.__ge__ = cls._create_comparison_method(operator.ge) + def _validate_shape(self, other): + if len(self) != len(other): + raise ValueError('Lengths must match to compare') + class ExtensionScalarOpsMixin(ExtensionOpsMixin): """ diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index 5f4d98a81e5f2..4d84736244fdc 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -67,6 +67,8 @@ def wrapper(self, other): elif isinstance(other, cls): self._check_compatible_with(other) + self._validate_shape(other) + if not_implemented: return NotImplemented result = op(other.asi8) diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index cd5e55d9871b2..fd8053841d78e 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -119,6 +119,12 @@ def test_direct_arith_with_series_returns_not_implemented(self, data): "{} does not implement add".format(data.__class__.__name__) ) + def test_arith_diff_lengths(self, data, all_arithmetic_operators): + op = self.get_op_from_name(all_arithmetic_operators) + other = data[:3] + with pytest.raises(ValueError): + op(data, other) + class BaseComparisonOpsTests(BaseOpsUtil): """Various Series and DataFrame comparison ops methods.""" @@ -164,3 +170,9 @@ def test_direct_arith_with_series_returns_not_implemented(self, data): raise pytest.skip( "{} does not implement __eq__".format(data.__class__.__name__) ) + + def test_compare_diff_lengths(self, data, all_compare_operators): + op = self.get_op_from_name(all_compare_operators) + other = data[:3] + with pytest.raises(ValueError): + op(data, other) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 6281c5360cd03..338549b5ee032 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -299,6 +299,13 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError): def test_error(self): pass + # TODO + # Raise ValueError when carrying out arithmetic operation + # on two decimal arrays of different lengths + @pytest.mark.xfail(reason="raise of ValueError not implemented") + def test_arith_diff_lengths(self, data, all_compare_operators): + super().test_arith_diff_lengths(data, all_compare_operators) + class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests): @@ -324,6 +331,12 @@ def test_compare_array(self, data, all_compare_operators): for i in alter] self._compare_other(s, data, op_name, other) + # TODO: + # Raise ValueError when comparing decimal arrays of different lenghts + @pytest.mark.xfail(reason="raise of ValueError not implemented") + def test_compare_diff_lengths(self, data, all_compare_operators): + super().test_compare_diff_lenths(data, all_compare_operators) + class DecimalArrayWithoutFromSequence(DecimalArray): """Helper class for testing error handling in _from_sequence.""" diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index a35997b07fd83..209a85e300b54 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -291,9 +291,13 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError): s, op, other, exc=TypeError ) + def test_arith_diff_lengths(self): + pass + class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): - pass + def test_compare_diff_lengths(self): + pass class TestPrinting(BaseJSON, base.BasePrintingTests): diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index 6106bc3d58620..c8079babc9e91 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -215,6 +215,9 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError): s, op, other, exc=TypeError ) + def test_arith_diff_lengths(self): + pass + class TestComparisonOps(base.BaseComparisonOpsTests): @@ -233,3 +236,11 @@ def _compare_other(self, s, data, op_name, other): else: with pytest.raises(TypeError): op(data, other) + + @pytest.mark.parametrize('op_name', + ['__eq__', '__ne__']) + def test_compare_diff_lengths(self, data, op_name): + op = self.get_op_from_name(op_name) + other = data[:3] + with pytest.raises(ValueError): + op(data, other) diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index e21ca81bcf5c3..a88778b70c788 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -153,6 +153,13 @@ def check_opname(self, s, op_name, other, exc=None): def _compare_other(self, s, data, op_name, other): self.check_opname(s, op_name, other) + @pytest.mark.filterwarnings("ignore:elementwise:DeprecationWarning") + def test_compare_diff_lengths(self, data, all_compare_operators): + op = self.get_op_from_name(all_compare_operators) + other = data[:3] + with pytest.raises(ValueError): + op(data, other) + class TestInterface(base.BaseInterfaceTests): pass diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py index 08e21fc30ad10..a43f7aab93bd7 100644 --- a/pandas/tests/extension/test_period.py +++ b/pandas/tests/extension/test_period.py @@ -119,6 +119,12 @@ def test_add_series_with_extension_array(self, data): def test_error(self): pass + def test_arith_diff_lengths(self, data): + op = self.get_op_from_name('__sub__') + other = data[:3] + with pytest.raises(ValueError): + op(data, other) + def test_direct_arith_with_series_returns_not_implemented(self, data): # Override to use __sub__ instead of __add__ other = pd.Series(data) diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index ea849a78cda12..da269cbb4d036 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -321,6 +321,17 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators): all_arithmetic_operators ) + def test_arith_diff_lengths(self, data, all_arithmetic_operators): + from pandas.core.dtypes.common import is_float_dtype + + if is_float_dtype(data): + op = self.get_op_from_name(all_arithmetic_operators) + other = data[:3] + with pytest.raises(ValueError): + op(data, other) + else: + pass + class TestComparisonOps(BaseSparseTests, base.BaseComparisonOpsTests): @@ -348,6 +359,17 @@ def _compare_other(self, s, data, op_name, other): result = op(s, other) tm.assert_series_equal(result, expected) + def test_compare_diff_lengths(self, data, all_compare_operators): + from pandas.core.dtypes.common import is_float_dtype + + if is_float_dtype(data): + op = self.get_op_from_name(all_compare_operators) + other = data[:3] + with pytest.raises(ValueError): + op(data, other) + else: + pass + class TestPrinting(BaseSparseTests, base.BasePrintingTests): @pytest.mark.xfail(reason='Different repr', strict=True) diff --git a/pandas/tests/indexes/period/test_period.py b/pandas/tests/indexes/period/test_period.py index 37bfb9c0606a3..9f4a87352d5ff 100644 --- a/pandas/tests/indexes/period/test_period.py +++ b/pandas/tests/indexes/period/test_period.py @@ -555,6 +555,12 @@ def test_insert(self): result = period_range('2017Q1', periods=4, freq='Q').insert(1, na) tm.assert_index_equal(result, expected) + def test_comp_op(self): + # GH 23078 + index = period_range('2017', periods=12, freq="A-DEC") + with pytest.raises(ValueError, match="Lengths must match"): + index <= index[[0]] + def test_maybe_convert_timedelta(): pi = PeriodIndex(['2000', '2001'], freq='D')