Skip to content

Commit

Permalink
Merge pull request #24070 from jakevdp:issubdtype-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681502075
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
2 parents 152a873 + 4495dae commit 78b65dd
Showing 1 changed file with 84 additions and 3 deletions.
87 changes: 84 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
save = np.save
savez = np.savez

@util.implements(np.dtype)

def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False,
copy: bool = False) -> DType:
"""Similar to np.dtype, but respects JAX dtype defaults."""
Expand Down Expand Up @@ -436,8 +436,50 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array:
"""
return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2)

@util.implements(np.issubdtype)

def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
"""Return True if arg1 is equal or lower than arg2 in the type hierarchy.
JAX implementation of :func:`numpy.issubdtype`.
The main difference in JAX's implementation is that it properly handles
dtype extensions such as :code:`bfloat16`.
Args:
arg1: dtype-like object. In typical usage, this will be a dtype specifier,
such as ``"float32"`` (i.e. a string), ``np.dtype('int32')`` (i.e. an
instance of :class:`numpy.dtype`), ``jnp.complex64`` (i.e. a JAX scalar
constructor), or ``np.uint8`` (i.e. a NumPy scalar type).
arg2: dtype-like object. In typical usage, this will be a generic scalar
type, such as ``jnp.integer``, ``jnp.floating``, or ``jnp.complexfloating``.
Returns:
True if arg1 represents a dtype that is equal or lower in the type
hierarchy than arg2.
See also:
- :func:`jax.numpy.isdtype`: similar function aligning with the array API standard.
Examples:
>>> jnp.issubdtype('uint32', jnp.unsignedinteger)
True
>>> jnp.issubdtype(np.int32, jnp.integer)
True
>>> jnp.issubdtype(jnp.bfloat16, jnp.floating)
True
>>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating)
True
>>> jnp.issubdtype('complex64', jnp.integer)
False
Be aware that while this is very similar to :func:`numpy.issubdtype`, the
results of these differ in the case of JAX's custom floating point types:
>>> np.issubdtype('bfloat16', np.floating)
False
>>> jnp.issubdtype('bfloat16', jnp.floating)
True
"""
return dtypes.issubdtype(arg1, arg2)

@util.implements(np.isscalar)
Expand All @@ -448,8 +490,47 @@ def isscalar(element: Any) -> bool:

iterable = np.iterable

@util.implements(np.result_type)

def result_type(*args: Any) -> DType:
"""Return the result of applying JAX promotion rules to the inputs.
JAX implementation of :func:`numpy.result_type`.
JAX's dtype promotion behavior is described in :ref:`type-promotion`.
Args:
args: one or more arrays or dtype-like objects.
Returns:
A :class:`numpy.dtype` instance representing the result of type
promotion for the inputs.
Examples:
Inputs can be dtype specifiers:
>>> jnp.result_type('int32', 'float32')
dtype('float32')
>>> jnp.result_type(np.uint16, np.dtype('int32'))
dtype('int32')
Inputs may also be scalars or arrays:
>>> jnp.result_type(1.0, jnp.bfloat16(2))
dtype(bfloat16)
>>> jnp.result_type(jnp.arange(4), jnp.zeros(4))
dtype('float32')
Be aware that the result type will be canonicalized based on the state
of the ``jax_enable_x64`` configuration flag, meaning that 64-bit types
may be downcast to 32-bit:
>>> jnp.result_type('float64')
dtype('float32')
For details on 64-bit values, refer to `Sharp bits - double precision`_:
.. _Sharp bits - double precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
"""
return dtypes.result_type(*args)


Expand Down

0 comments on commit 78b65dd

Please sign in to comment.