From a0d8cc8fa19365cbda81a70df2e865c245e9787a Mon Sep 17 00:00:00 2001 From: Sergey B Kirpichev Date: Thu, 17 Oct 2024 09:47:13 +0300 Subject: [PATCH] Avoid OverflowError in mpz.__rshift__() Closes #524 --- src/gmpy2_mpz_bitops.c | 31 +++++++++++++++++++++++++++++-- test/test_mpz.py | 6 ++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/gmpy2_mpz_bitops.c b/src/gmpy2_mpz_bitops.c index c18e6098..801ade42 100644 --- a/src/gmpy2_mpz_bitops.c +++ b/src/gmpy2_mpz_bitops.c @@ -592,8 +592,35 @@ GMPy_MPZ_Rshift_Slot(PyObject *self, PyObject *other) MPZ_Object *result, *tempx; count = GMPy_Integer_AsMpBitCnt(other); - if ((count == (mp_bitcnt_t)(-1)) && PyErr_Occurred()) - return NULL; + if ((count == (mp_bitcnt_t)(-1)) && PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) { + /* LCOV_EXCL_START */ + return NULL; + /* LCOV_EXCL_STOP */ + } + PyErr_Clear(); + + PyObject *tmp = PyNumber_Long(other); + + if (!tmp) { + /* LCOV_EXCL_START */ + return NULL; + /* LCOV_EXCL_STOP */ + } + if (PyLong_IsNegative(tmp)) { + VALUE_ERROR("negative shift count"); + Py_DECREF(tmp); + return NULL; + } + Py_DECREF(tmp); + if (!(result = GMPy_MPZ_New(NULL))) { + /* LCOV_EXCL_START */ + return NULL; + /* LCOV_EXCL_STOP */ + } + mpz_set_si(result->z, mpz_sgn(MPZ(self)) < 0 ? -1 : 0); + return (PyObject*)result; + } if (!(result = GMPy_MPZ_New(NULL))) return NULL; diff --git a/test/test_mpz.py b/test/test_mpz.py index 261151a5..0222ebed 100644 --- a/test/test_mpz.py +++ b/test/test_mpz.py @@ -1383,8 +1383,10 @@ def test_mpz_rshift(): assert a>>1 == mpz(61) assert int(a)>>mpz(1) == mpz(61) + assert a>>111111111111111111111 == mpz(0) + assert (-a)>>111111111111111111111 == mpz(-1) - raises(OverflowError, lambda: a>>-2) + raises(ValueError, lambda: a>>-2) assert a>>0 == mpz(123) @@ -1426,7 +1428,7 @@ def test_mpz_ilshift_irshift(): x >>= mpfr(2) with raises(TypeError): x <<= mpfr(2) - with raises(OverflowError): + with raises(ValueError): x >>= -1 with raises(OverflowError): x <<= -5