Skip to content

Commit

Permalink
Move dtype setting logic to the Policy class to make dtype policy awa…
Browse files Browse the repository at this point in the history
…reness more self-contained.

PiperOrigin-RevId: 522094585
  • Loading branch information
fchollet authored and tensorflower-gardener committed Apr 5, 2023
1 parent c225ac7 commit 5a77d20
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 45 deletions.
33 changes: 1 addition & 32 deletions keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from keras.engine import keras_tensor
from keras.engine import node as node_module
from keras.mixed_precision import autocast_variable
from keras.mixed_precision import loss_scale_optimizer
from keras.mixed_precision import policy
from keras.saving import serialization_lib
from keras.saving.legacy.saved_model import layer_serialization
Expand Down Expand Up @@ -2705,37 +2704,7 @@ def _outbound_nodes(self, value):

def _set_dtype_policy(self, dtype):
"""Sets self._dtype_policy."""
if isinstance(dtype, policy.Policy):
self._dtype_policy = dtype
elif isinstance(dtype, dict):
self._dtype_policy = policy.deserialize(dtype)
elif isinstance(dtype, str) and dtype in (
"mixed_float16",
"mixed_bfloat16",
):
# The isinstance check is required since np.dtype raises an error if
# compared to a non-dtype string.
self._dtype_policy = policy.Policy(dtype)
elif dtype:
self._dtype_policy = policy.Policy(tf.as_dtype(dtype).name)
else:
self._dtype_policy = policy.global_policy()
if (
self._dtype_policy.name == "mixed_float16"
and not loss_scale_optimizer.strategy_supports_loss_scaling()
):
# Although only loss scaling doesn't support certain strategies, to
# avoid confusion, we disallow the 'mixed_float16' policy with
# unsupported strategies. This is because 'mixed_float16' requires
# loss scaling for numeric stability.
strategy = tf.distribute.get_strategy()
raise ValueError(
"Mixed precision is not supported with the "
"tf.distribute.Strategy: %s. Either stop using mixed "
'precision by removing the use of the "%s" policy or '
"use a different Strategy, e.g. a MirroredStrategy."
% (strategy.__class__.__name__, self._dtype_policy.name)
)
self._dtype_policy = policy.get_policy(dtype)

# Performance optimization: cache the compute dtype as a Dtype object or
# None, so that str to Dtype conversion doesn't happen in
Expand Down
5 changes: 1 addition & 4 deletions keras/mixed_precision/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,7 @@ def test_unsupported_strategy(self):
with strategy.scope(), self.assertRaisesRegex(
ValueError,
"Mixed precision is not supported with the "
"tf.distribute.Strategy: CentralStorageStrategy. Either "
"stop using mixed precision by removing the use of the "
'"mixed_float16" policy or use a different Strategy, e.g. '
"a MirroredStrategy.",
"tf.distribute.Strategy: CentralStorageStrategy.",
):
mp_test_util.MultiplyLayer(dtype="mixed_float16")
# Non-mixed policies are fine
Expand Down
47 changes: 41 additions & 6 deletions keras/mixed_precision/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras import backend
from keras.engine import base_layer_utils
from keras.mixed_precision import device_compatibility_check
from keras.mixed_precision import loss_scale_optimizer
from keras.saving import serialization_lib

# isort: off
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(self, name):
if isinstance(name, tf.DType):
raise TypeError(
"'name' must be a string, not a DType. "
"Instead, pass DType.name. Got: %s" % (name.name,)
f"Instead, pass DType.name. Received: name={name.name}"
)
elif not isinstance(name, str):
raise TypeError(f"'name' must be a string, but got: {name}")
Expand Down Expand Up @@ -246,12 +247,11 @@ def _parse_name(self, name):
try:
dtype = tf.as_dtype(name).name
except TypeError:
error = (
"Cannot convert value %s to a mixed precision Policy. "
raise ValueError(
f"Cannot convert value {name} to a mixed precision Policy. "
"Valid policies include 'mixed_float16', 'mixed_bfloat16', "
"and the name of any dtype such as 'float32'." % (name,)
"and the name of any dtype such as 'float32'."
)
raise ValueError(error)
return dtype, dtype

@property
Expand Down Expand Up @@ -440,7 +440,7 @@ def set_global_policy(policy):
raise ValueError(
"set_global_policy can only be used to set the global "
'policy to floating-point policies, such as "float32" and '
'"mixed_float16", but got policy: %s' % (policy.name,)
f'"mixed_float16", but got policy: {policy.name}'
)
_global_policy = policy
tf.__internal__.train.set_using_mixed_precision_policy(is_mixed_policy)
Expand All @@ -465,6 +465,41 @@ def policy_scope(policy):
set_global_policy(old_policy)


def get_policy(identifier):
if isinstance(identifier, Policy):
dtype_policy = identifier
elif isinstance(identifier, dict):
dtype_policy = deserialize(identifier)
elif isinstance(identifier, str) and identifier in (
"mixed_float16",
"mixed_bfloat16",
):
# The isinstance check is required since np.dtype raises an error if
# compared to a non-dtype string.
dtype_policy = Policy(identifier)
elif identifier:
dtype_policy = Policy(tf.as_dtype(identifier).name)
else:
dtype_policy = global_policy()
if (
dtype_policy.name == "mixed_float16"
and not loss_scale_optimizer.strategy_supports_loss_scaling()
):
# Although only loss scaling doesn't support certain strategies, to
# avoid confusion, we disallow the 'mixed_float16' policy with
# unsupported strategies. This is because 'mixed_float16' requires
# loss scaling for numeric stability.
strategy = tf.distribute.get_strategy()
raise ValueError(
"Mixed precision is not supported with the "
f"tf.distribute.Strategy: {strategy.__class__.__name__}. "
"Either stop using mixed precision by removing the use of "
f"the {dtype_policy.name} policy or "
"use a different Strategy, e.g. a MirroredStrategy."
)
return dtype_policy


def _is_convertible_to_dtype(dtype):
try:
tf.as_dtype(dtype)
Expand Down
4 changes: 1 addition & 3 deletions keras/mixed_precision/policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def test_policy_errors(self):

# Test passing a DType
with self.assertRaisesRegex(
TypeError,
"'name' must be a string, not a DType. "
"Instead, pass DType.name. Got: float16",
TypeError, "'name' must be a string, not a DType. "
):
mp_policy.Policy(tf.float16)

Expand Down

0 comments on commit 5a77d20

Please sign in to comment.