Skip to content

Commit

Permalink
Add support for the use case of laplace_tail_mass being a tf.Tensor, …
Browse files Browse the repository at this point in the history
…e.g., as set by a schedule during model training.

PiperOrigin-RevId: 467336716
Change-Id: I415b914b53dc27d8d009a8bff142d3e89440dc8c
  • Loading branch information
Googler authored and copybara-github committed Aug 13, 2022
1 parent c2ae0e1 commit b724663
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 17 deletions.
39 changes: 28 additions & 11 deletions tensorflow_compression/python/entropy_models/continuous_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def __init__(self,
Elias gamma code embedded into the range coder.
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
laplace_tail_mass: Float. If non-zero, will augment the prior with a
`NoisyLaplace` mixture component for training stability. (experimental)
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
will augment the prior with a `NoisyLaplace` mixture component for
training stability. (experimental)
"""
super().__init__()
self._prior = None # This will be set by subclasses, if appropriate.
Expand All @@ -83,14 +84,12 @@ def __init__(self,
if bottleneck_dtype is None:
bottleneck_dtype = tf.keras.backend.floatx()
self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype)
self._laplace_tail_mass = float(laplace_tail_mass)
self._laplace_tail_mass = laplace_tail_mass

if self.coding_rank < 0:
raise ValueError("`coding_rank` must be at least 0.")
if not 0 < self.tail_mass < 1:
raise ValueError("`tail_mass` must be between 0 and 1.")
if not 0 <= self.laplace_tail_mass < 1:
raise ValueError("`laplace_tail_mass` must be between 0 and 1.")

def _check_compression(self):
if not self.compression:
Expand Down Expand Up @@ -299,23 +298,41 @@ def loop_body(i, cdf):
def _log_prob(self, prior, bottleneck_perturbed):
"""Evaluates prior.log_prob(bottleneck + noise)."""
bottleneck_perturbed = tf.cast(bottleneck_perturbed, prior.dtype)
if self.laplace_tail_mass:
laplace_tail_mass = self.laplace_tail_mass

def mixture_log_prob_fn():
tf.debugging.assert_less(
laplace_tail_mass,
tf.constant(1.0, prior.dtype),
message="`laplace_tail_mass` must be less than 1.")
laplace_prior = uniform_noise.NoisyLaplace(
loc=tf.constant(0, dtype=prior.dtype),
scale=tf.constant(1, dtype=prior.dtype))
probs = prior.prob(bottleneck_perturbed)
probs = ((1 - self.laplace_tail_mass) * probs +
self.laplace_tail_mass *
probs = ((1 - laplace_tail_mass) * probs +
laplace_tail_mass *
laplace_prior.prob(bottleneck_perturbed))
probs_too_small = probs < 1e-10
probs_bounded = tf.maximum(probs, 1e-10)
return tf.where(
probs_too_small,
tf.math.log(self.laplace_tail_mass) +
tf.math.log(laplace_tail_mass) +
laplace_prior.log_prob(bottleneck_perturbed),
tf.math.log(probs_bounded))

prior_log_prob_fn = lambda: prior.log_prob(bottleneck_perturbed)

if isinstance(laplace_tail_mass, tf.Tensor):
# Do all the computation in tf (graph mode compatible).
laplace_tail_mass = tf.cast(laplace_tail_mass, prior.dtype)
use_laplace_tail_mass = tf.greater(laplace_tail_mass, 0.0)
return tf.cond(use_laplace_tail_mass, mixture_log_prob_fn,
prior_log_prob_fn)
else:
return prior.log_prob(bottleneck_perturbed)
if laplace_tail_mass > 0:
return mixture_log_prob_fn()
else:
return prior_log_prob_fn()

@abc.abstractmethod
def get_config(self):
Expand All @@ -340,7 +357,7 @@ def get_config(self):
tail_mass=self.tail_mass,
cdf_shapes=(self.cdf.shape[0], self.cdf_offset.shape[0]),
bottleneck_dtype=self.bottleneck_dtype.name,
laplace_tail_mass=self.laplace_tail_mass,
laplace_tail_mass=float(self.laplace_tail_mass),
)

def get_weights(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ def __init__(self,
use. If provided (not `None`), then `offset_heuristic` is ineffective.
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
strings passed into `decompress` are not completely decoded.
laplace_tail_mass: Float. If non-zero, will augment the prior with a
`NoisyLaplace` mixture component for training stability. (experimental)
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
will augment the prior with a `NoisyLaplace` mixture component for
training stability. (experimental)
"""
if (prior is None) == (prior_shape is None):
raise ValueError("Either `prior` or `prior_shape` must be provided.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,37 @@ def test_small_bitcost_for_dirac_prior(self):
# Quantization noise should be between -.5 and .5
self.assertAllClose(x, x_decoded, rtol=0., atol=.5)

def test_laplace_tail_mass(self):
noisy = uniform_noise.NoisyNormal(loc=0., scale=1.)
em = ContinuousBatchedEntropyModel(noisy, 1, laplace_tail_mass=0.0)
self.assertEqual(em.laplace_tail_mass, 0.0)
em = ContinuousBatchedEntropyModel(noisy, 1,
laplace_tail_mass=tf.constant(1e-3))
self.assertEqual(em.laplace_tail_mass, tf.constant(1e-3))
log_prob = em._log_prob(noisy, tf.constant(0.0))
self.assertEqual(log_prob.dtype, tf.float32)

def test_laplace_tail_mass_works_in_tf_function(self):
noisy = uniform_noise.NoisyNormal(loc=0., scale=1.)
samples = noisy.base.sample([100])

# Since tf.function traces each function twice, and only allows variable
# creation in the first call, we need to have a stateful object in which we
# create the entropy model only the first time the function is called, and
# store it for the second time.

class EntropyModel:

def log_prob(self, values):
if not hasattr(self, "em"):
self.em = ContinuousBatchedEntropyModel(
noisy, 1, laplace_tail_mass=tf.constant(1e-3))
return self.em._log_prob(noisy, values)

values_eager = EntropyModel().log_prob(samples)
values_function = tf.function(EntropyModel().log_prob)(samples)
self.assertAllEqual(values_eager, values_function)


if __name__ == "__main__":
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def __init__(self,
computations. Defaults to `tf.float32`.
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
strings passed into `decompress` are not completely decoded.
laplace_tail_mass: Float. If non-zero, will augment the prior with a
`NoisyLaplace` mixture component for training stability. (experimental)
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
will augment the prior with a `NoisyLaplace` mixture component for
training stability. (experimental)
"""
if not callable(prior_fn):
raise TypeError("`prior_fn` must be a class or factory function.")
Expand Down Expand Up @@ -496,8 +497,9 @@ def __init__(self,
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
prior_dtype: `tf.dtypes.DType`. Data type of prior and probability
computations. Defaults to `tf.float32`.
laplace_tail_mass: Float. If non-zero, will augment the prior with a
`NoisyLaplace` mixture component for training stability. (experimental)
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
will augment the prior with a `NoisyLaplace` mixture component for
training stability. (experimental)
"""
num_scales = int(num_scales)
super().__init__(
Expand Down

0 comments on commit b724663

Please sign in to comment.