diff --git a/tensorflow_compression/python/entropy_models/continuous_base.py b/tensorflow_compression/python/entropy_models/continuous_base.py index 932efcb..ddc5b13 100644 --- a/tensorflow_compression/python/entropy_models/continuous_base.py +++ b/tensorflow_compression/python/entropy_models/continuous_base.py @@ -68,8 +68,9 @@ def __init__(self, Elias gamma code embedded into the range coder. bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor. Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`. - laplace_tail_mass: Float. If non-zero, will augment the prior with a - `NoisyLaplace` mixture component for training stability. (experimental) + laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive, + will augment the prior with a `NoisyLaplace` mixture component for + training stability. (experimental) """ super().__init__() self._prior = None # This will be set by subclasses, if appropriate. @@ -83,14 +84,12 @@ def __init__(self, if bottleneck_dtype is None: bottleneck_dtype = tf.keras.backend.floatx() self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype) - self._laplace_tail_mass = float(laplace_tail_mass) + self._laplace_tail_mass = laplace_tail_mass if self.coding_rank < 0: raise ValueError("`coding_rank` must be at least 0.") if not 0 < self.tail_mass < 1: raise ValueError("`tail_mass` must be between 0 and 1.") - if not 0 <= self.laplace_tail_mass < 1: - raise ValueError("`laplace_tail_mass` must be between 0 and 1.") def _check_compression(self): if not self.compression: @@ -299,23 +298,41 @@ def loop_body(i, cdf): def _log_prob(self, prior, bottleneck_perturbed): """Evaluates prior.log_prob(bottleneck + noise).""" bottleneck_perturbed = tf.cast(bottleneck_perturbed, prior.dtype) - if self.laplace_tail_mass: + laplace_tail_mass = self.laplace_tail_mass + + def mixture_log_prob_fn(): + tf.debugging.assert_less( + laplace_tail_mass, + tf.constant(1.0, prior.dtype), + message="`laplace_tail_mass` must be less than 1.") laplace_prior = uniform_noise.NoisyLaplace( loc=tf.constant(0, dtype=prior.dtype), scale=tf.constant(1, dtype=prior.dtype)) probs = prior.prob(bottleneck_perturbed) - probs = ((1 - self.laplace_tail_mass) * probs + - self.laplace_tail_mass * + probs = ((1 - laplace_tail_mass) * probs + + laplace_tail_mass * laplace_prior.prob(bottleneck_perturbed)) probs_too_small = probs < 1e-10 probs_bounded = tf.maximum(probs, 1e-10) return tf.where( probs_too_small, - tf.math.log(self.laplace_tail_mass) + + tf.math.log(laplace_tail_mass) + laplace_prior.log_prob(bottleneck_perturbed), tf.math.log(probs_bounded)) + + prior_log_prob_fn = lambda: prior.log_prob(bottleneck_perturbed) + + if isinstance(laplace_tail_mass, tf.Tensor): + # Do all the computation in tf (graph mode compatible). + laplace_tail_mass = tf.cast(laplace_tail_mass, prior.dtype) + use_laplace_tail_mass = tf.greater(laplace_tail_mass, 0.0) + return tf.cond(use_laplace_tail_mass, mixture_log_prob_fn, + prior_log_prob_fn) else: - return prior.log_prob(bottleneck_perturbed) + if laplace_tail_mass > 0: + return mixture_log_prob_fn() + else: + return prior_log_prob_fn() @abc.abstractmethod def get_config(self): @@ -340,7 +357,7 @@ def get_config(self): tail_mass=self.tail_mass, cdf_shapes=(self.cdf.shape[0], self.cdf_offset.shape[0]), bottleneck_dtype=self.bottleneck_dtype.name, - laplace_tail_mass=self.laplace_tail_mass, + laplace_tail_mass=float(self.laplace_tail_mass), ) def get_weights(self): diff --git a/tensorflow_compression/python/entropy_models/continuous_batched.py b/tensorflow_compression/python/entropy_models/continuous_batched.py index fd17430..27f08e1 100644 --- a/tensorflow_compression/python/entropy_models/continuous_batched.py +++ b/tensorflow_compression/python/entropy_models/continuous_batched.py @@ -171,8 +171,9 @@ def __init__(self, use. If provided (not `None`), then `offset_heuristic` is ineffective. decode_sanity_check: Boolean. If `True`, an raises an error if the binary strings passed into `decompress` are not completely decoded. - laplace_tail_mass: Float. If non-zero, will augment the prior with a - `NoisyLaplace` mixture component for training stability. (experimental) + laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive, + will augment the prior with a `NoisyLaplace` mixture component for + training stability. (experimental) """ if (prior is None) == (prior_shape is None): raise ValueError("Either `prior` or `prior_shape` must be provided.") diff --git a/tensorflow_compression/python/entropy_models/continuous_batched_test.py b/tensorflow_compression/python/entropy_models/continuous_batched_test.py index 85ffd19..87ffc99 100644 --- a/tensorflow_compression/python/entropy_models/continuous_batched_test.py +++ b/tensorflow_compression/python/entropy_models/continuous_batched_test.py @@ -241,6 +241,37 @@ def test_small_bitcost_for_dirac_prior(self): # Quantization noise should be between -.5 and .5 self.assertAllClose(x, x_decoded, rtol=0., atol=.5) + def test_laplace_tail_mass(self): + noisy = uniform_noise.NoisyNormal(loc=0., scale=1.) + em = ContinuousBatchedEntropyModel(noisy, 1, laplace_tail_mass=0.0) + self.assertEqual(em.laplace_tail_mass, 0.0) + em = ContinuousBatchedEntropyModel(noisy, 1, + laplace_tail_mass=tf.constant(1e-3)) + self.assertEqual(em.laplace_tail_mass, tf.constant(1e-3)) + log_prob = em._log_prob(noisy, tf.constant(0.0)) + self.assertEqual(log_prob.dtype, tf.float32) + + def test_laplace_tail_mass_works_in_tf_function(self): + noisy = uniform_noise.NoisyNormal(loc=0., scale=1.) + samples = noisy.base.sample([100]) + + # Since tf.function traces each function twice, and only allows variable + # creation in the first call, we need to have a stateful object in which we + # create the entropy model only the first time the function is called, and + # store it for the second time. + + class EntropyModel: + + def log_prob(self, values): + if not hasattr(self, "em"): + self.em = ContinuousBatchedEntropyModel( + noisy, 1, laplace_tail_mass=tf.constant(1e-3)) + return self.em._log_prob(noisy, values) + + values_eager = EntropyModel().log_prob(samples) + values_function = tf.function(EntropyModel().log_prob)(samples) + self.assertAllEqual(values_eager, values_function) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_compression/python/entropy_models/continuous_indexed.py b/tensorflow_compression/python/entropy_models/continuous_indexed.py index d5796b5..e8f968f 100644 --- a/tensorflow_compression/python/entropy_models/continuous_indexed.py +++ b/tensorflow_compression/python/entropy_models/continuous_indexed.py @@ -189,8 +189,9 @@ def __init__(self, computations. Defaults to `tf.float32`. decode_sanity_check: Boolean. If `True`, an raises an error if the binary strings passed into `decompress` are not completely decoded. - laplace_tail_mass: Float. If non-zero, will augment the prior with a - `NoisyLaplace` mixture component for training stability. (experimental) + laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive, + will augment the prior with a `NoisyLaplace` mixture component for + training stability. (experimental) """ if not callable(prior_fn): raise TypeError("`prior_fn` must be a class or factory function.") @@ -496,8 +497,9 @@ def __init__(self, Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`. prior_dtype: `tf.dtypes.DType`. Data type of prior and probability computations. Defaults to `tf.float32`. - laplace_tail_mass: Float. If non-zero, will augment the prior with a - `NoisyLaplace` mixture component for training stability. (experimental) + laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive, + will augment the prior with a `NoisyLaplace` mixture component for + training stability. (experimental) """ num_scales = int(num_scales) super().__init__(