Skip to content

Commit

Permalink
Enable part of tfp.stats for numpy and Jax (sample_stats).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 288674522
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Jan 8, 2020
1 parent f9d45f5 commit fc84041
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from tensorflow_probability.python.stats._jax.leave_one_out import log_loomean_exp
from tensorflow_probability.python.stats._jax.leave_one_out import log_loosum_exp
from tensorflow_probability.python.stats._jax.leave_one_out import log_soomean_exp
from tensorflow_probability.python.stats._jax.sample_stats import * # pylint: disable=wildcard-import
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
('nuts', 'sample_annealed_importance', 'sample_halton_sequence',
'slice_sampler_kernel'),
'stats':
('calibration', 'quantiles', 'ranking', 'sample_stats')
('calibration', 'quantiles', 'ranking')
}
LIBS = ('bijectors', 'distributions', 'math', 'mcmc', 'stats', 'util')
INTERNALS = ('assert_util', 'distribution_util', 'dtype_util',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow_probability.python.internal import _numpy as internal
from tensorflow_probability.python.math import _numpy as math
from tensorflow_probability.python.mcmc import _numpy as mcmc
from tensorflow_probability.python.stats import _numpy as stats
from tensorflow_probability.python.util import _numpy as util

from tensorflow_probability.python.internal.backend import numpy as tf2numpy
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ py_library(
name = "stats",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow_probability/python/stats:stats.jax"],
deps = ["//tensorflow_probability/python/stats:stats.numpy"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from tensorflow_probability.python.stats._numpy.moving_stats import assign_log_moving_mean_exp
from tensorflow_probability.python.stats._numpy.moving_stats import assign_moving_mean_variance
from tensorflow_probability.python.stats._numpy.moving_stats import moving_mean_variance_zero_debias
from tensorflow_probability.python.stats._numpy.sample_stats import * # pylint: disable=wildcard-import
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@ def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-buil
tf.range,
lambda start, limit=None, delta=1, dtype=None, name='range': np.arange( # pylint: disable=g-long-lambda
start, limit, delta).astype(utils.numpy_dtype(
dtype or np.array(start).dtype)))
dtype or utils.common_dtype([start], np.int32))))

rank = utils.copy_docstring(
tf.rank,
lambda input, name=None: np.array(input).ndim) # pylint: disable=redefined-builtin,g-long-lambda
lambda input, name=None: np.int32(np.array(input).ndim)) # pylint: disable=redefined-builtin,g-long-lambda

reshape = utils.copy_docstring(
tf.reshape,
Expand Down
7 changes: 4 additions & 3 deletions tensorflow_probability/python/stats/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
load(
"//tensorflow_probability/python:build_defs.bzl",
"multi_substrate_py_library",
"multi_substrate_py_test",
)

package(
Expand All @@ -39,7 +40,6 @@ multi_substrate_py_library(
":calibration",
":quantiles",
":ranking",
":sample_stats",
],
deps = [
":calibration",
Expand Down Expand Up @@ -124,7 +124,7 @@ py_test(
],
)

py_library(
multi_substrate_py_library(
name = "sample_stats",
srcs = ["sample_stats.py"],
srcs_version = "PY2AND3",
Expand All @@ -140,10 +140,11 @@ py_library(
],
)

py_test(
multi_substrate_py_test(
name = "sample_stats_test",
size = "small",
srcs = ["sample_stats_test.py"],
numpy_tags = ["notap"],
deps = [
# numpy dep,
# tensorflow dep,
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,13 @@ def covariance(x,

# If we get lucky and axis is statically defined, we can do some checks.
if _is_list_like(event_axis) and _is_list_like(sample_axis):
event_axis = tuple(map(int, event_axis))
sample_axis = tuple(map(int, sample_axis))
if set(event_axis).intersection(sample_axis):
raise ValueError(
'sample_axis ({}) and event_axis ({}) overlapped'.format(
sample_axis, event_axis))
if (np.diff(sorted(event_axis)) > 1).any():
if (np.diff(np.array(sorted(event_axis))) > 1).any():
raise ValueError(
'event_axis must be contiguous. Found: {}'.format(event_axis))
batch_axis = list(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,15 +544,15 @@ def test_independent_uniform_samples(self):
class LogAverageProbsTest(test_util.TestCase):

def test_mathematical_correctness_bernoulli(self):
logits = tf.random.normal([10, 3, 4], seed=42)
logits = tf.random.normal([10, 3, 4], seed=test_util.test_seed())
# The "expected" calculation is numerically naive.
probs = tf.math.sigmoid(logits)
expected = tf.math.log(tf.reduce_mean(probs, axis=0))
actual = tfp.stats.log_average_probs(logits, validate_args=True)
self.assertAllClose(*self.evaluate([expected, actual]), rtol=1e-5, atol=0.)

def test_mathematical_correctness_categorical(self):
logits = tf.random.normal([10, 3, 4], seed=43)
logits = tf.random.normal([10, 3, 4], seed=test_util.test_seed())
# The "expected" calculation is numerically naive.
probs = tf.math.softmax(logits, axis=-1)
expected = tf.math.log(tf.reduce_mean(probs, axis=0))
Expand All @@ -561,7 +561,7 @@ def test_mathematical_correctness_categorical(self):
self.assertAllClose(*self.evaluate([expected, actual]), rtol=1e-5, atol=0.)

def test_bad_axis_static(self):
logits = tf.random.normal([10, 3, 4], seed=44)
logits = tf.random.normal([10, 3, 4], seed=test_util.test_seed())
with self.assertRaisesRegexp(ValueError, r'.*must be distinct.'):
tfp.stats.log_average_probs(
logits,
Expand Down

0 comments on commit fc84041

Please sign in to comment.