diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4a2887045fb1..b4a4568f08b9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -428,8 +428,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None, if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - # Note: don't canonicalize old_dtype because x64 context might - # cause un-canonicalized operands to be passed in. + # Don't canonicalize old_dtype because x64 context might cause + # un-canonicalized operands to be passed in. old_dtype = np.result_type(operand) old_weak_type = dtypes.is_weakly_typed(operand) @@ -441,6 +441,14 @@ def convert_element_type(operand: Array, new_dtype: DType = None, msg = "Casting complex values to real discards the imaginary part" warnings.warn(msg, np.ComplexWarning, stacklevel=2) + # Python has big integers, but convert_element_type(2 ** 100, np.float32) need + # not be an error since the target dtype fits the value. Handle this case by + # converting to a NumPy array before calling bind. Without this step, we'd + # first canonicalize the input to a value of dtype int32 or int64, leading to + # an overflow error. + if type(operand) is int: + operand = np.asarray(operand, new_dtype) + if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type) and isinstance(operand, (core.Tracer, xla.DeviceArray))): return operand diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index 734957cdd1b0..57c75a0774c8 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -68,9 +68,8 @@ def _zeros_like_python_scalar(t, x): return np.array(0, dtypes.python_scalar_dtypes[t]) def _make_concrete_python_scalar(t, x): - return ConcreteArray( - np.array(x, dtype=dtypes.python_scalar_dtypes[t]), - weak_type=True) + return ConcreteArray(np.array(x, dtype=dtypes.python_scalar_dtypes[t]), + weak_type=True) for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) diff --git a/tests/api_test.py b/tests/api_test.py index 4b38afd46111..3bf5a9c31852 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2393,6 +2393,12 @@ def f(_): expected = jnp.arange(1) + 1 self.assertAllClose(ans, expected) + def test_large_python_int_to_float(self): + # https://github.com/google/jax/pull/6165 + jnp.multiply(2 ** 100, 3.) # doesn't crash + out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash + self.assertArraysEqual(out, np.float32(2 ** 100)) + class RematTest(jtu.JaxTestCase): diff --git a/tests/random_test.py b/tests/random_test.py index 41fa226bc71a..3eb3e6b54bf5 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -746,6 +746,8 @@ def f(x): grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2)) def testNoOpByOpUnderHash(self): + if not config.omnistaging_enabled: + raise SkipTest("test requires omnistaging") def fail(*args, **kwargs): assert False apply_primitive, xla.apply_primitive = xla.apply_primitive, fail try: