From ec27cb5e29061f8a5ac7ae170b8997b337d09b3e Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 3 Nov 2022 16:13:56 +0100 Subject: [PATCH 1/2] Updated Gaussian Mixture update rules Updated Gaussian Mixture update and refined tests --- src/rules/normal_mixture/out.jl | 23 +++-------------------- test/rules/normal_mixture/test_out.jl | 26 +++++++++++++++++--------- 2 files changed, 20 insertions(+), 29 deletions(-) diff --git a/src/rules/normal_mixture/out.jl b/src/rules/normal_mixture/out.jl index a5c028ff5..2a64378c5 100644 --- a/src/rules/normal_mixture/out.jl +++ b/src/rules/normal_mixture/out.jl @@ -1,12 +1,7 @@ -@rule NormalMixture{N}(:out, Marginalisation) (q_switch::Any, q_m::ManyOf{N, UnivariateNormalDistributionsFamily}, q_p::ManyOf{N, GammaDistributionsFamily}) where {N} = begin - πs = probvec(q_switch) - return NormalMeanPrecision(sum(πs .* mean.(q_m)), sum(πs .* mean.(q_p))) -end -@rule NormalMixture{N}(:out, Marginalisation) (q_switch::Any, q_m::ManyOf{N, MultivariateNormalDistributionsFamily}, q_p::ManyOf{N, Wishart}) where {N} = begin +@rule NormalMixture{N}(:out, Marginalisation) (q_switch::Any, q_m::ManyOf{N, Any}, q_p::ManyOf{N, Any}) where {N} = begin πs = probvec(q_switch) - d = ndims(first(q_m)) # Better to preinitialize q_p_m = mean.(q_p) @@ -15,20 +10,8 @@ end W = mapreduce(x -> x[1] * x[2], +, zip(πs, q_p_m)) ξ = mapreduce(x -> x[1] * x[2] * x[3], +, zip(πs, q_p_m, q_m_m)) - return MvNormalWeightedMeanPrecision(ξ, W) -end + F = variate_form(ξ) -@rule NormalMixture{N}(:out, Marginalisation) (q_switch::Any, q_m::ManyOf{N, PointMass{T} where T <: Real}, q_p::ManyOf{N, PointMass{T} where T <: Real}) where {N} = begin - πs = probvec(q_switch) - return NormalMeanPrecision(sum(πs .* mean.(q_m)), sum(πs .* mean.(q_p))) + return convert(promote_variate_type(F, NormalWeightedMeanPrecision), ξ, W) end -@rule NormalMixture{N}(:out, Marginalisation) (q_switch::Any, q_m::ManyOf{N, PointMass{<:AbstractVector}}, q_p::ManyOf{N, PointMass{<:AbstractMatrix}}) where {N} = begin - πs = probvec(q_switch) - d = ndims(first(q_m)) - - w = mapreduce(x -> x[1] * mean(x[2]), +, zip(πs, q_p)) - xi = mapreduce(x -> x[1] * mean(x[2]) * mean(x[3]), +, zip(πs, q_p, q_m)) - - return MvNormalWeightedMeanPrecision(xi, w) -end diff --git a/test/rules/normal_mixture/test_out.jl b/test/rules/normal_mixture/test_out.jl index 1b4d1bca4..bdf67d7a9 100644 --- a/test/rules/normal_mixture/test_out.jl +++ b/test/rules/normal_mixture/test_out.jl @@ -12,36 +12,44 @@ import ReactiveMP: @test_rules @test_rules [with_float_conversions = true] NormalMixture{2}(:out, Marginalisation) [ ( input = (q_switch = Categorical([0.5, 0.5]), q_m = ManyOf(PointMass(1.0), PointMass(1.0)), q_p = ManyOf(PointMass(1.0), PointMass(1.0))), - output = NormalMeanPrecision(1.0, 1.0) + output = NormalWeightedMeanPrecision(1.0, 1.0) ), ( input = (q_switch = Categorical([1.0, 0.0]), q_m = ManyOf(PointMass(1.0), PointMass(2.0)), q_p = ManyOf(PointMass(2.0), PointMass(1.0))), - output = NormalMeanPrecision(1.0, 2.0) + output = NormalWeightedMeanPrecision(2.0, 2.0) ), ( input = (q_switch = Categorical([0.5, 0.5]), q_m = ManyOf(PointMass(1.0), PointMass(1.0)), q_p = ManyOf(PointMass(1.0), PointMass(1.0))), - output = NormalMeanPrecision(1.0, 1.0) + output = NormalWeightedMeanPrecision(1.0, 1.0) ), ( input = (q_switch = Categorical([1.0, 0.0]), q_m = ManyOf(PointMass(1.0), PointMass(2.0)), q_p = ManyOf(PointMass(2.0), PointMass(1.0))), - output = NormalMeanPrecision(1.0, 2.0) + output = NormalWeightedMeanPrecision(2.0, 2.0) ), ( input = (q_switch = Categorical([0.0, 1.0]), q_m = ManyOf(PointMass(2.0), PointMass(-3.0)), q_p = ManyOf(PointMass(4.0), PointMass(3.0))), - output = NormalMeanPrecision(-3.0, 3.0) + output = NormalWeightedMeanPrecision(-9.0, 3.0) ) ] end @testset "Variational : (m_μ::UnivariateNormalDistributionsFamily..., m_p::GammaDistributionsFamily...)" begin @test_rules [with_float_conversions = true] NormalMixture{2}(:out, Marginalisation) [ + ( + input = ( + q_switch = Bernoulli(0.2), + q_m = ManyOf(NormalMeanVariance(5.0, 2.0), NormalMeanVariance(10.0, 3.0)), + q_p = ManyOf(GammaShapeRate(1.0, 2.0), GammaShapeRate(2.0, 1.0)) + ), + output = NormalWeightedMeanPrecision(33 / 2, 17 / 10) + ), ( input = ( q_switch = Categorical([0.5, 0.5]), q_m = ManyOf(NormalMeanVariance(1.0, 2.0), NormalMeanPrecision(-2.0, 3.0)), q_p = ManyOf(GammaShapeRate(1.0, 1.0), GammaShapeScale(2.0, 0.1)) ), - output = NormalMeanPrecision(-1 / 2, 6 / 10) + output = NormalWeightedMeanPrecision(3 / 10, 6 / 10) ), ( input = ( @@ -49,7 +57,7 @@ import ReactiveMP: @test_rules q_m = ManyOf(NormalWeightedMeanPrecision(-1.0, 2.0), NormalMeanPrecision(2.0, 3.0)), q_p = ManyOf(GammaShapeScale(1.0, 1.0), GammaShapeRate(2.0, 0.1)) ), - output = NormalMeanPrecision(1 / 8, 5.75) + output = NormalWeightedMeanPrecision(77 / 8, 23 / 4) ), ( input = ( @@ -57,8 +65,8 @@ import ReactiveMP: @test_rules q_m = ManyOf(NormalMeanVariance(1.0, 2.0), NormalMeanPrecision(-2.0, 3.0)), q_p = ManyOf(GammaShapeRate(1.0, 1.0), GammaShapeScale(2.0, 0.1)) ), - output = NormalMeanPrecision(1, 1) - ) + output = NormalWeightedMeanPrecision(1, 1) + ), ] end From b76c16fb157ee47d97709358342da2506bcbd18f Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 3 Nov 2022 16:17:12 +0100 Subject: [PATCH 2/2] Style make format --- src/rules/normal_mixture/out.jl | 2 -- test/rules/normal_mixture/test_out.jl | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/rules/normal_mixture/out.jl b/src/rules/normal_mixture/out.jl index 2a64378c5..416f6172a 100644 --- a/src/rules/normal_mixture/out.jl +++ b/src/rules/normal_mixture/out.jl @@ -1,5 +1,4 @@ - @rule NormalMixture{N}(:out, Marginalisation) (q_switch::Any, q_m::ManyOf{N, Any}, q_p::ManyOf{N, Any}) where {N} = begin πs = probvec(q_switch) @@ -14,4 +13,3 @@ return convert(promote_variate_type(F, NormalWeightedMeanPrecision), ξ, W) end - diff --git a/test/rules/normal_mixture/test_out.jl b/test/rules/normal_mixture/test_out.jl index bdf67d7a9..ff6f39056 100644 --- a/test/rules/normal_mixture/test_out.jl +++ b/test/rules/normal_mixture/test_out.jl @@ -66,7 +66,7 @@ import ReactiveMP: @test_rules q_p = ManyOf(GammaShapeRate(1.0, 1.0), GammaShapeScale(2.0, 0.1)) ), output = NormalWeightedMeanPrecision(1, 1) - ), + ) ] end