From fd7cd5fa6287dd43f39149c73d1e88d0780ffe8e Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Wed, 8 Jan 2020 09:21:21 -0800 Subject: [PATCH] Enable JointDistribution tests, even though they don't currently pass (CI is disabled). The change to rewrite.py also enables importing/using the JD classes in np/jax tests like cholesky_lkj w/ the logdet fix. PiperOrigin-RevId: 288711875 --- tensorflow_probability/python/distributions/BUILD | 12 +++++++++--- .../python/experimental/substrates/meta/rewrite.py | 4 +--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index af6cfc7c1a..d3052fd533 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -2037,9 +2037,11 @@ multi_substrate_py_test( ], ) -py_test( +multi_substrate_py_test( name = "joint_distribution_coroutine_test", srcs = ["joint_distribution_coroutine_test.py"], + jax_tags = ["notap"], + numpy_tags = ["notap"], deps = [ # numpy dep, # tensorflow dep, @@ -2048,9 +2050,11 @@ py_test( ], ) -py_test( +multi_substrate_py_test( name = "joint_distribution_named_test", srcs = ["joint_distribution_named_test.py"], + jax_tags = ["notap"], + numpy_tags = ["notap"], deps = [ # absl/testing:parameterized dep, # tensorflow dep, @@ -2059,9 +2063,11 @@ py_test( ], ) -py_test( +multi_substrate_py_test( name = "joint_distribution_sequential_test", srcs = ["joint_distribution_sequential_test.py"], + jax_tags = ["notap"], + numpy_tags = ["notap"], deps = [ # absl/testing:parameterized dep, # numpy dep, diff --git a/tensorflow_probability/python/experimental/substrates/meta/rewrite.py b/tensorflow_probability/python/experimental/substrates/meta/rewrite.py index 8166581fb8..f57a10293f 100644 --- a/tensorflow_probability/python/experimental/substrates/meta/rewrite.py +++ b/tensorflow_probability/python/experimental/substrates/meta/rewrite.py @@ -47,9 +47,7 @@ 'bijectors': ('masked_autoregressive', 'scale_matvec_lu', 'real_nvp'), 'distributions': - ('joint_distribution', 'joint_distribution_coroutine', - 'joint_distribution_named', 'joint_distribution_sequential', - 'internal.moving_stats'), + ('internal.moving_stats',), 'math': ('ode', 'diag_jacobian', 'interpolation', 'minimize', 'root_search', 'sparse'),