diff --git a/src/rules/transition/in.jl b/src/rules/transition/in.jl index 88399f074..c76928adf 100644 --- a/src/rules/transition/in.jl +++ b/src/rules/transition/in.jl @@ -1,6 +1,6 @@ import Base.Broadcast: BroadcastFunction -@rule Transition(:in, Marginalisation) (m_out::Categorical, m_a::PointMass) = begin +@rule Transition(:in, Marginalisation) (m_out::Union{DiscreteNonParametric, PointMass}, m_a::PointMass) = begin @logscale log(sum(mean(A)' * probvec(m_out))) p = mean(m_a)' * probvec(m_out) normalize!(p, 1) @@ -12,12 +12,12 @@ end return Categorical(a ./ sum(a)) end -@rule Transition(:in, Marginalisation) (m_out::Categorical, q_a::MatrixDirichlet) = begin +@rule Transition(:in, Marginalisation) (m_out::Union{DiscreteNonParametric, PointMass}, q_a::MatrixDirichlet) = begin a = clamp.(exp.(mean(BroadcastFunction(log), q_a))' * probvec(m_out), tiny, Inf) return Categorical(a ./ sum(a)) end -@rule Transition(:in, Marginalisation) (m_out::Categorical, q_a::PointMass, meta::Any) = begin +@rule Transition(:in, Marginalisation) (m_out::Union{DiscreteNonParametric, PointMass}, q_a::PointMass, meta::Any) = begin return @call_rule Transition(:in, Marginalisation) (m_out = m_out, m_a = q_a, meta = meta) end diff --git a/src/rules/transition/out.jl b/src/rules/transition/out.jl index 6832474b0..326cc0feb 100644 --- a/src/rules/transition/out.jl +++ b/src/rules/transition/out.jl @@ -3,7 +3,7 @@ import Base.Broadcast: BroadcastFunction # Belief Propagation # # --------------------------------- # -@rule Transition(:out, Marginalisation) (m_in::Categorical, m_a::PointMass) = begin +@rule Transition(:out, Marginalisation) (m_in::Union{PointMass, DiscreteNonParametric}, m_a::PointMass) = begin @logscale 0 p = mean(m_a) * probvec(m_in) normalize!(p, 1) @@ -20,17 +20,17 @@ end # Variational # # --------------------------------- # -@rule Transition(:out, Marginalisation) (q_in::Categorical, q_a::Any) = begin +@rule Transition(:out, Marginalisation) (q_in::DiscreteNonParametric, q_a::Any) = begin a = clamp.(exp.(mean(BroadcastFunction(log), q_a) * probvec(q_in)), tiny, Inf) return Categorical(a ./ sum(a)) end -@rule Transition(:out, Marginalisation) (m_in::Categorical, q_a::ContinuousMatrixDistribution) = begin +@rule Transition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::ContinuousMatrixDistribution) = begin a = clamp.(exp.(mean(BroadcastFunction(log), q_a)) * probvec(m_in), tiny, Inf) return Categorical(a ./ sum(a)) end -@rule Transition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::PointMass, meta::Any) = begin +@rule Transition(:out, Marginalisation) (m_in::Union{PointMass, DiscreteNonParametric}, q_a::PointMass, meta::Any) = begin @logscale 0 return @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = q_a, meta = meta, addons = getaddons()) end diff --git a/test/rules/transition/out_tests.jl b/test/rules/transition/out_tests.jl index 18ac2e973..c0fd0adff 100644 --- a/test/rules/transition/out_tests.jl +++ b/test/rules/transition/out_tests.jl @@ -59,4 +59,11 @@ ) ] end + + @testset "Variational Bayes: (m_in::PointMass, q_a::PointMass)" begin + @test_rules [check_type_promotion = false] Transition(:out, Marginalisation) [ + (input = (m_in = PointMass([0, 1, 0]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), output = Categorical([0.1, 0.3, 0.6])), + (input = (m_in = PointMass([1, 0, 0]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), output = Categorical([0.1 / 0.9, 0.6 / 0.9, 0.2 / 0.9])) + ] + end end