Skip to content

Commit

Permalink
Fix nested and single output IfElse logp
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Jun 5, 2023
1 parent a30e0d4 commit 5b68edc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,7 @@ def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then)
logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else)

return ifelse(if_var, logps_then, logps_else)
logps = ifelse(if_var, logps_then, logps_else)
if len(logps) == 1:
return logps[0]
return logps
22 changes: 21 additions & 1 deletion tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
as_index_constant,
)

from pymc.logprob.basic import factorized_joint_logprob
from pymc.logprob.basic import factorized_joint_logprob, logp
from pymc.logprob.mixture import MixtureRV, expand_indices
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.utils import dirac_delta
Expand Down Expand Up @@ -1112,3 +1112,23 @@ def test_joint_logprob_subtensor():
logp_vals = logp_vals_fn(A_idx_value, I_value)

np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)


def test_nested_ifelse():
idx = pt.scalar("idx", dtype=int)

dist0 = pt.random.normal(-5, 1)
dist1 = pt.random.normal(0, 1)
dist2 = pt.random.normal(5, 1)
mix = ifelse(pt.eq(idx, 0), dist0, ifelse(pt.eq(idx, 1), dist1, dist2))
mix.name = "mix"

value = mix.clone()
mix_logp = logp(mix, value)
assert mix_logp.name == "mix_logprob"
mix_logp_fn = pytensor.function([idx, value], mix_logp)

test_value = 0.25
np.testing.assert_almost_equal(mix_logp_fn(0, test_value), sp.norm.logpdf(test_value, -5, 1))
np.testing.assert_almost_equal(mix_logp_fn(1, test_value), sp.norm.logpdf(test_value, 0, 1))
np.testing.assert_almost_equal(mix_logp_fn(2, test_value), sp.norm.logpdf(test_value, 5, 1))

0 comments on commit 5b68edc

Please sign in to comment.