Skip to content

Commit

Permalink
Remove dtype safeguards
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684100032
  • Loading branch information
vroulet authored and OptaxDev committed Oct 10, 2024
1 parent 6e23da8 commit ef05704
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 146 deletions.
39 changes: 11 additions & 28 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,32 +234,15 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):

params_dtype = jax.dtypes.canonicalize_dtype(params_dtype)
params = jnp.array([0.0, 0.0], dtype=params_dtype)
state_has_lower_dtype = (
jnp.promote_types(params_dtype, jnp.dtype(state_dtype)) == params_dtype
)
if state_dtype is None or state_has_lower_dtype:
state = opt.init(params)

with self.subTest('Test that attribute dtype is correct'):
if state_dtype is None:
expected_dtype = params_dtype
else:
expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype)
attribute = otu.tree_get(state, attribute_name)
self.assertEqual(expected_dtype, attribute.dtype)

with self.subTest(
'Verifies that the updates keep the same type as params'
):
updates, _ = opt.update(jnp.ones_like(params), state, params)
self.assertEqual(updates.dtype, params.dtype)
else:
with self.subTest(
'Test that we forbid setting dtype s.t. updates dtype get promoted to'
' the state dtype'
):
with self.assertRaises(ValueError):
opt.init(params)
state = opt.init(params)

with self.subTest('Test that attribute dtype is correct'):
if state_dtype is None:
expected_dtype = params_dtype
else:
expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype)
attribute = otu.tree_get(state, attribute_name)
self.assertEqual(expected_dtype, attribute.dtype)

# Not testing with `without_device=True` because without_device set the
# variables to the host which appears to convert then the dtype, so we
Expand All @@ -269,7 +252,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):
)
@parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32'))
def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers return updates of same dtype as params."""
"""Test that the optimizers return updates of same dtype as gradients."""
# When debugging this test, note that operations like
# x = 0.5**jnp.asarray(1, dtype=jnp.int32)
# (appearing in e.g. optax.tree_utils.tree_bias_correction)
Expand All @@ -291,7 +274,7 @@ def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
self.assertEqual(updates.dtype, params.dtype)
self.assertEqual(updates.dtype, grads.dtype)

@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
Expand Down
6 changes: 0 additions & 6 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,11 @@ def scale_by_adam(
Returns:
A `GradientTransformation` object.
Raises:
ValueError: If the selected ``mu_dtype`` induces a dtype promotion of the
dtypes of the parameters.
"""

mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
if mu_dtype is not None:
otu.tree_assert_dtype_preserved(params, mu_dtype)
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
nu = otu.tree_zeros_like(params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
Expand Down
1 change: 0 additions & 1 deletion optax/contrib/_schedule_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def init_fn(params: base.Params) -> ScheduleFreeState:
# parameters and updates.
params_dtype = otu.tree_dtype(params, 'lowest')
if state_dtype is not None:
otu.tree_assert_dtype_preserved(params, state_dtype)
z = otu.tree_cast(params, dtype=state_dtype)
else:
z = params
Expand Down
31 changes: 7 additions & 24 deletions optax/contrib/_schedule_free_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,31 +141,14 @@ def test_explicit_dtype(self, params_dtype, state_dtype):

params_dtype = jax.dtypes.canonicalize_dtype(params_dtype)
params = jnp.array([0.0, 0.0], dtype=params_dtype)
state_has_lower_dtype = (
jnp.promote_types(params_dtype, state_dtype) == params_dtype
)
if state_dtype is None or state_has_lower_dtype:
state = opt.init(params)
state = opt.init(params)

with self.subTest('Test that attribute dtype is correct'):
if state_dtype is None:
expected_dtype = params_dtype
else:
expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype)
self.assertEqual(expected_dtype, getattr(state, 'z').dtype)

with self.subTest(
'Verifies that the updates keep the same type as params'
):
updates, _ = opt.update(jnp.ones_like(params), state, params)
self.assertEqual(getattr(updates, 'dtype'), params.dtype)
else:
with self.subTest(
'Test that we forbid setting dtype s.t. updates dtype get promoted to'
' the state dtype'
):
with self.assertRaises(ValueError):
opt.init(params)
with self.subTest('Test that attribute dtype is correct'):
if state_dtype is None:
expected_dtype = params_dtype
else:
expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype)
self.assertEqual(expected_dtype, getattr(state, 'z').dtype)


if __name__ == '__main__':
Expand Down
6 changes: 0 additions & 6 deletions optax/transforms/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,11 @@ def trace(
Returns:
A `GradientTransformation` object.
Raises:
ValueError: If the selected ``accumulator_dtype`` induces a dtype promotion
of the dtypes of the parameters.
"""

accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)

def init_fn(params):
if accumulator_dtype is not None:
otu.tree_assert_dtype_preserved(params, accumulator_dtype)
return TraceState(
trace=otu.tree_zeros_like(params, dtype=accumulator_dtype))

Expand Down
1 change: 0 additions & 1 deletion optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""The tree_utils sub-package."""

# pylint: disable=g-importing-member
from optax.tree_utils._casting import tree_assert_dtype_preserved
from optax.tree_utils._casting import tree_cast
from optax.tree_utils._casting import tree_dtype
from optax.tree_utils._random import tree_random_like
Expand Down
39 changes: 0 additions & 39 deletions optax/tree_utils/_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,45 +142,6 @@ def tree_dtype(
)


def tree_assert_dtype_preserved(
tree: chex.ArrayTree,
dtype: chex.ArrayDType,
) -> None:
"""Checks whether some elements of tree may be promoted to dtype.
Some transformations like :func:`optax.scale_by_adam`, :func:`optax.trace`
allow the user to specify a dtype for some of the state's parameters (e.g. the
momentum term). This function checks that the specified dtype of the state's
parameters does not induce a dtype promotion of any of the parameters. That
way we can ensure that the dtype of the updates are consistent with the dtype
of the parameters.
Args:
tree: the tree to check.
dtype: the dtype to check against.
Raises:
ValueError: If any element of the tree is promoted to dtype.
.. versionadded:: 0.2.4
"""

def _assert_dtype_preserved(path, x):
x_dtype = jnp.asarray(x).dtype
if jnp.promote_types(x_dtype, dtype) != x_dtype:
err_msg = (
f'{dtype=} induces dtype promotion for {path} with dtype {x_dtype}.'
)
return err_msg

err_msgs = jax.tree.leaves(
jax.tree_util.tree_map_with_path(_assert_dtype_preserved, tree)
)
err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None]
if err_msgs:
raise ValueError('\n'.join(err_msgs))


def _tree_assert_all_dtypes_equal(
tree: chex.ArrayTree, dtype: chex.ArrayDType
) -> None:
Expand Down
41 changes: 0 additions & 41 deletions optax/tree_utils/_casting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,40 +88,6 @@ def test_tree_dtype(self):
self.assertRaises(ValueError, otu.tree_dtype, tree, 'lowest')
self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest')

def test_tree_assert_dtype_preserved(self):
"""Test asserting no promotion of dtypes in a tree for given dtype."""
tree = {
'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)},
'c': jnp.array(2.0, dtype=jnp.float32),
}

with self.subTest(
'Check that it raises an error if given dtype induces promotion of at'
' least one element.'
):
with self.assertRaises(ValueError):
otu.tree_assert_dtype_preserved(tree, jnp.float32)

with self.subTest(
'Check that it runs fine if no element gets promoted by given dtype.'
):
otu.tree_assert_dtype_preserved(tree, jnp.bfloat16)

with self.subTest(
'Check that it naturally succeeds when considering lowest common dtype.'
):
otu.tree_assert_dtype_preserved(tree, otu.tree_dtype(tree, 'lowest'))

with self.subTest(
'Check that it naturally fails when considering highest common dtype.'
):
with self.assertRaises(ValueError):
otu.tree_assert_dtype_preserved(tree, otu.tree_dtype(tree, 'highest'))

with self.subTest('Check that it works with empty trees.'):
for tree in [(), {}, None]:
otu.tree_assert_dtype_preserved(tree, jnp.float32)

@parameterized.named_parameters(
dict(testcase_name='empty_dict', tree={}),
dict(testcase_name='empty_list', tree=[]),
Expand All @@ -136,13 +102,6 @@ def test_tree_dtype_utilities_with_empty_trees(self, tree):
dtype = otu.tree_dtype(tree)
self.assertEqual(dtype, default_dtype)

with self.subTest(
'Check tree_assert_dtype_preserved succeeds with any dtype for'
' empty trees.'
):
# There is no array in the tree to check, so it should succeed.
otu.tree_assert_dtype_preserved(tree, jnp.complex64)


if __name__ == '__main__':
absltest.main()

0 comments on commit ef05704

Please sign in to comment.