diff --git a/Project.toml b/Project.toml index d1867ceb0..4d926cf93 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ DiffResults = "1.1.0" Distributions = "0.24, 0.25" DomainIntegrals = "0.3.2, 0.4" DomainSets = "0.5.2, 0.6, 0.7" -ExponentialFamily = "1.6.0" +ExponentialFamily = "1.7.0" ExponentialFamilyProjection = "1.2" FastCholesky = "1.3.0" FastGaussQuadrature = "0.4, 0.5" diff --git a/src/nodes/predefined.jl b/src/nodes/predefined.jl index 4a82cdb90..6e2d510f0 100644 --- a/src/nodes/predefined.jl +++ b/src/nodes/predefined.jl @@ -14,6 +14,7 @@ include("predefined/gamma_shape_rate.jl") include("predefined/beta.jl") include("predefined/categorical.jl") include("predefined/matrix_dirichlet.jl") +include("predefined/tensor_dirichlet.jl") include("predefined/dirichlet.jl") include("predefined/bernoulli.jl") include("predefined/gcv.jl") diff --git a/src/nodes/predefined/tensor_dirichlet.jl b/src/nodes/predefined/tensor_dirichlet.jl new file mode 100644 index 000000000..9bfd88167 --- /dev/null +++ b/src/nodes/predefined/tensor_dirichlet.jl @@ -0,0 +1,10 @@ +import SpecialFunctions: loggamma +import Base.Broadcast: BroadcastFunction + +@node TensorDirichlet Stochastic [out, a] + +@average_energy TensorDirichlet (q_out::TensorDirichlet, q_a::PointMass) = begin + m_a = mean(q_a) + logmean = mean(BroadcastFunction(log), q_out) + return sum(-loggamma.(sum(m_a, dims = 1)) .+ sum(loggamma.(m_a), dims = 1) .- sum((m_a .- 1.0) .* logmean, dims = 1)) +end diff --git a/src/nodes/predefined/transition.jl b/src/nodes/predefined/transition.jl index d30f6d7c5..c39af4969 100644 --- a/src/nodes/predefined/transition.jl +++ b/src/nodes/predefined/transition.jl @@ -4,7 +4,30 @@ import Base.Broadcast: BroadcastFunction struct Transition end -@node Transition Stochastic [out, in, a] +ReactiveMP.sdtype(::Type{Transition}) = ReactiveMP.Stochastic() +ReactiveMP.is_predefined_node(::Type{Transition}) = ReactiveMP.PredefinedNodeFunctionalForm() + +function ReactiveMP.prepare_interfaces_generic(fform::Type{Transition}, interfaces::AbstractVector) + return map(enumerate(interfaces)) do (index, (name, variable)) + return ReactiveMP.NodeInterface(ReactiveMP.alias_interface(fform, index, name), variable) + end +end + +function ReactiveMP.alias_interface(::Type{Transition}, index, name) + if name === :out && index === 1 + return :out + elseif name === :in && index === 2 + return :in + elseif name === :in && index === 3 + return :a + elseif name === :in && index >= 4 + return Symbol(:T, index - 3) + end +end + +function ReactiveMP.collect_factorisation(::Type{Transition}, t::Tuple) + return t +end @average_energy Transition (q_out::Any, q_in::Any, q_a::MatrixDirichlet) = begin return -probvec(q_out)' * mean(BroadcastFunction(log), q_a) * probvec(q_in) @@ -19,9 +42,35 @@ end # The reason is that we don't want to take log of zeros in the matrix `q_a` (if there are any) # The trick here is that if RHS matrix has zero inputs, than the corresponding entries of the `contingency_matrix` matrix # should also be zeros (see corresponding @marginalrule), so at the end `log(tiny) * 0` should not influence the result. - return -ReactiveMP.mul_trace(components(q_out_in)', mean(BroadcastFunction(clamplog), q_a)) + result = -ReactiveMP.mul_trace(components(q_out_in)', mean(BroadcastFunction(clamplog), q_a)) + return result end @average_energy Transition (q_out::Any, q_in::Any, q_a::PointMass) = begin return -probvec(q_out)' * mean(BroadcastFunction(clamplog), q_a) * probvec(q_in) end + +function score(::AverageEnergy, ::Type{<:Transition}, ::Val{mnames}, marginals::Tuple{<:Marginal{<:Contingency}, <:Marginal{<:TensorDirichlet}}, ::Nothing) where {mnames} + q_contingency, q_a = getdata.(marginals) + return -sum(mean(BroadcastFunction(log), q_a) .* components(q_contingency)) +end + +function __reduce_td_from_messages(messages, q_A, interface_index) + vmp = clamp.(exp.(mean(BroadcastFunction(log), q_A)), tiny, Inf) + probvecs = probvec.(messages) + for (i, vector) in enumerate(probvecs) + if i ≥ interface_index + actual_index = i + 1 + else + actual_index = i + end + v = view(vector, :) + localdims = ntuple(x -> x == actual_index::Int64 ? length(v) : 1, ndims(vmp)) + vmp .*= reshape(v, localdims) + end + dims = ntuple(x -> x ≥ interface_index ? x + 1 : x, ndims(vmp) - 1) + vmp = sum(vmp, dims = dims) + msg = reshape(vmp, :) + msg ./= sum(msg) + return Categorical(msg) +end diff --git a/src/rules/predefined.jl b/src/rules/predefined.jl index ee0090855..08a4a17f7 100644 --- a/src/rules/predefined.jl +++ b/src/rules/predefined.jl @@ -106,6 +106,7 @@ include("transition/marginals.jl") include("transition/out.jl") include("transition/in.jl") include("transition/a.jl") +include("transition/t.jl") include("continuous_transition/y.jl") include("continuous_transition/x.jl") @@ -184,3 +185,6 @@ include("delta/cvi/marginals.jl") include("half_normal/out.jl") include("binomial_polya/beta.jl") + +include("tensor_dirichlet/out.jl") +include("tensor_dirichlet/marginals.jl") diff --git a/src/rules/tensor_dirichlet/marginals.jl b/src/rules/tensor_dirichlet/marginals.jl new file mode 100644 index 000000000..41f1922bc --- /dev/null +++ b/src/rules/tensor_dirichlet/marginals.jl @@ -0,0 +1,4 @@ + +@marginalrule TensorDirichlet(:out_a) (m_out::TensorDirichlet, m_a::PointMass) = begin + return convert_paramfloattype((out = prod(ClosedProd(), TensorDirichlet(mean(m_a)), m_out), a = m_a)) +end diff --git a/src/rules/tensor_dirichlet/out.jl b/src/rules/tensor_dirichlet/out.jl new file mode 100644 index 000000000..0a810e314 --- /dev/null +++ b/src/rules/tensor_dirichlet/out.jl @@ -0,0 +1,4 @@ + +@rule TensorDirichlet(:out, Marginalisation) (m_a::PointMass,) = TensorDirichlet(mean(m_a)) + +@rule TensorDirichlet(:out, Marginalisation) (q_a::PointMass,) = TensorDirichlet(mean(q_a)) diff --git a/src/rules/transition/a.jl b/src/rules/transition/a.jl index 2a9d2d1be..d74b76c88 100644 --- a/src/rules/transition/a.jl +++ b/src/rules/transition/a.jl @@ -6,3 +6,16 @@ end @rule Transition(:a, Marginalisation) (q_out_in::Contingency,) = begin return MatrixDirichlet(components(q_out_in) .+ 1) end + +ReactiveMP.rule( + fform::Type{<:Transition}, + on::Val{:a}, + vconstraint::Marginalisation, + messages_names::Nothing, + messages::Nothing, + marginals_names::Val{m_names} where {m_names}, + marginals::Tuple, + meta::Any, + addons::Any, + ::Any +) = TensorDirichlet(components(getdata(first(marginals))) .+ 1), addons diff --git a/src/rules/transition/in.jl b/src/rules/transition/in.jl index c76928adf..c059ba144 100644 --- a/src/rules/transition/in.jl +++ b/src/rules/transition/in.jl @@ -26,3 +26,18 @@ end normalize!(p, 1) return Categorical(p) end + +function ReactiveMP.rule( + fform::Type{<:Transition}, + on::Val{:in}, + vconstraint::Marginalisation, + messages_names::Val{m_names}, + messages::Tuple, + marginals_names::Val{(:a,)}, + marginals::Tuple, + meta::Any, + addons::Any, + ::Any +) where {m_names} + return __reduce_td_from_messages(messages, first(marginals), 2), addons +end diff --git a/src/rules/transition/marginals.jl b/src/rules/transition/marginals.jl index 5f886629b..f4b4bd2b0 100644 --- a/src/rules/transition/marginals.jl +++ b/src/rules/transition/marginals.jl @@ -23,3 +23,16 @@ end m_in_2 = @call_rule Transition(:in, Marginalisation) (m_out = m_out, m_a = m_a, meta = meta) return convert_paramfloattype((out = m_out, in = prod(ClosedProd(), m_in_2, m_in), a = m_a)) end + +@marginalrule Transition(:out_in) (m_out::PointMass, m_in::Categorical, q_a::PointMass) = begin + m_in_2 = @call_rule Transition(:in, Marginalisation) (m_out = m_out, q_a = q_a) + return convert_paramfloattype((out = m_out, in = prod(ClosedProd(), m_in, m_in_2))) +end + +outer_product(vs) = prod.(Iterators.product(vs...)) + +function marginalrule( + ::Type{<:Transition}, ::Val{marginal_symbol}, ::Val{message_names}, messages::Tuple, ::Val{marginal_names}, marginals::Tuple, ::Any, ::Any +) where {marginal_symbol, message_names, marginal_names} + return Contingency(outer_product(probvec.(messages)) .* clamp.(exp.(mean(BroadcastFunction(log), first(marginals))), tiny, huge)) +end diff --git a/src/rules/transition/out.jl b/src/rules/transition/out.jl index 326cc0feb..6d9101f4d 100644 --- a/src/rules/transition/out.jl +++ b/src/rules/transition/out.jl @@ -34,3 +34,18 @@ end @logscale 0 return @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = q_a, meta = meta, addons = getaddons()) end + +function ReactiveMP.rule( + fform::Type{<:Transition}, + on::Val{:out}, + vconstraint::Marginalisation, + messages_names::Val{m_names}, + messages::Tuple, + marginals_names::Val{(:a,)}, + marginals::Tuple, + meta::Any, + addons::Any, + ::Any +) where {m_names} + return __reduce_td_from_messages(messages, first(marginals), 1), addons +end diff --git a/src/rules/transition/t.jl b/src/rules/transition/t.jl new file mode 100644 index 000000000..01e0734cd --- /dev/null +++ b/src/rules/transition/t.jl @@ -0,0 +1,17 @@ +import Base.Broadcast: BroadcastFunction + +function ReactiveMP.rule( + fform::Type{<:Transition}, + on::Val{S}, + vconstraint::Marginalisation, + messages_names::Val{m_names}, + messages::Tuple, + marginals_names::Val{(:a,)}, + marginals::Tuple, + meta::Any, + addons::Any, + ::Any +) where {S, m_names} + interface_index = parse(Int, String(S)[2:end]) + 2 + return __reduce_td_from_messages(messages, first(marginals), interface_index), addons +end diff --git a/test/nodes/predefined/tensor_dirichlet.jl b/test/nodes/predefined/tensor_dirichlet.jl new file mode 100644 index 000000000..85ee2aa7b --- /dev/null +++ b/test/nodes/predefined/tensor_dirichlet.jl @@ -0,0 +1,56 @@ + +@testitem "TensorDirichletNode" begin + using ReactiveMP, Random, BayesBase, ExponentialFamily, Distributions, StableRNGs + + @testset "AverageEnergy" begin + begin + rng = StableRNG(123456) + for i in 1:100 + α = rand(rng, 2, 2) + a = rand(rng, 2, 2) + q_out = TensorDirichlet(α) + q_a = PointMass(a) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing)) + avg_energy = score(AverageEnergy(), TensorDirichlet, Val{(:out, :a)}(), marginals, nothing) + + q_out = MatrixDirichlet(α) + q_a = PointMass(a) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing)) + avg_energy_matrix = score(AverageEnergy(), MatrixDirichlet, Val{(:out, :a)}(), marginals, nothing) + + @test avg_energy ≈ avg_energy_matrix + end + end + + begin + for rank in 3:5 + for dim in 2:5 + for i in 1:100 + dims = ntuple(d -> dim, rank) + α = rand(rng, dims...) + a = rand(rng, dims...) + + q_out = TensorDirichlet(α) + q_a = PointMass(a) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing)) + avg_energy = score(AverageEnergy(), TensorDirichlet, Val{(:out, :a)}(), marginals, nothing) + + q_out = Dirichlet.(eachslice(α, dims = ntuple(d -> d + 1, rank - 1))) + q_a = PointMass.(eachslice(a, dims = ntuple(d -> d + 1, rank - 1))) + + avg_energy_matrix = 0.0 + for (dir, a) in zip(q_out, q_a) + marginals = (Marginal(dir, false, false, nothing), Marginal(a, false, false, nothing)) + avg_energy_matrix += score(AverageEnergy(), Dirichlet, Val{(:out, :a)}(), marginals, nothing) + end + + @test avg_energy ≈ avg_energy_matrix + end + end + end + end + end +end diff --git a/test/nodes/predefined/transition_tests.jl b/test/nodes/predefined/transition_tests.jl new file mode 100644 index 000000000..c33bcb407 --- /dev/null +++ b/test/nodes/predefined/transition_tests.jl @@ -0,0 +1,86 @@ +@testitem "TransitionNode" begin + using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily + + @testset "Transition node properties" begin + @test ReactiveMP.sdtype(Transition) == Stochastic() + @test ReactiveMP.alias_interface(Transition, 1, :out) == :out + @test ReactiveMP.alias_interface(Transition, 2, :in) == :in + @test ReactiveMP.alias_interface(Transition, 3, :in) == :a + @test ReactiveMP.alias_interface(Transition, 4, :in) == :T1 + + @test ReactiveMP.collect_factorisation(Transition, ()) == () + end + @testset "AverageEnergy(q_out_in::Contingency, q_a::MatrixDirichlet)" begin end + + @testset "AverageEnergy(q_out_in::Contingency, q_a::PointMass)" begin + contingency_matrix = [0.2 0.3; 0.4 0.1] + a_matrix = [0.7 0.3; 0.2 0.8] + + q_out_in = Contingency(contingency_matrix) + q_a = PointMass(a_matrix) + + marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing)) + + # Expected value calculated by hand + expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) + + @test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected + + contingency_matrix = [0.2 0.3; 0.4 0.1] + a_matrix = [1.0 0.0; 0.0 1.0] + + q_out_in = Contingency(contingency_matrix) + q_a = PointMass(a_matrix) + + marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing)) + + expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) + + @test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected + + contingency_matrix = prod.(Iterators.product([0, 1, 0], [0.1, 0.4, 0.5])) + a_matrix = diageye(3) + + q_out_in = Contingency(contingency_matrix) + q_a = PointMass(a_matrix) + + marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing)) + + expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) + + @test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected + + contingency_matrix = [0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0] + q_out_in = Contingency(contingency_matrix) + q_a = PointMass(diageye(3)) + + marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing)) + + expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf))) + @test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected + end + + @testset "AverageEnergy(q_out::Any, q_in::Any, q_a::PointMass)" begin + q_out = Categorical([0.3, 0.7]) + q_in = Categorical([0.8, 0.2]) + q_a = PointMass([0.7 0.3; 0.2 0.8]) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_in, false, false, nothing), Marginal(q_a, false, false, nothing)) + + contingency = probvec(q_out) * probvec(q_in)' + expected = -sum(contingency .* log.(clamp.(mean(q_a), tiny, Inf))) + + @test score(AverageEnergy(), Transition, Val{(:out, :in, :a)}(), marginals, nothing) ≈ expected + + q_out = Categorical([0.0, 1.0]) + q_in = Categorical([0.0, 1.0]) + q_a = PointMass([1.0 0.0; 1.0 0.0]) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_in, false, false, nothing), Marginal(q_a, false, false, nothing)) + + contingency = probvec(q_out) * probvec(q_in)' + + expected = -sum(contingency .* log.(clamp.(mean(q_a), tiny, Inf))) + @test score(AverageEnergy(), Transition, Val{(:out, :in, :a)}(), marginals, nothing) ≈ expected + end +end diff --git a/test/rules/transition/a_tests.jl b/test/rules/transition/a_tests.jl index 970c94500..10969a428 100644 --- a/test/rules/transition/a_tests.jl +++ b/test/rules/transition/a_tests.jl @@ -18,4 +18,16 @@ input = (q_out_in = Contingency(diageye(3)),), output = MatrixDirichlet([1.333333333333333 1 1; 1 1.3333333333333 1; 1 1 1.33333333333333333]) )] end + + @testset "Variational Bayes: (q_out_in_t1::Contingency)" begin + @test_rules [check_type_promotion = false] Transition(:a, Marginalisation) [( + input = (q_out_in_t1 = Contingency(ones(3, 3, 3)),), output = TensorDirichlet(ones(3, 3, 3) .+ (1 / 27)) + )] + end + + @testset "Variational Bayes: (q_out_in_t1_t2::Contingency)" begin + @test_rules [check_type_promotion = false] Transition(:a, Marginalisation) [( + input = (q_out_in_t1_t2 = Contingency(ones(3, 3, 3, 3)),), output = TensorDirichlet(ones(3, 3, 3, 3) .+ (1 / 81)) + )] + end end diff --git a/test/rules/transition/in_tests.jl b/test/rules/transition/in_tests.jl index 026f848bb..78819f1e0 100644 --- a/test/rules/transition/in_tests.jl +++ b/test/rules/transition/in_tests.jl @@ -44,4 +44,134 @@ output = Categorical([0.23000000000000004, 0.43, 0.33999999999999997]) )] end + + @testset "Belief Propagation: (m_out::Categorical, q_a::TensorDirichlet, m_t1::Categorical)" begin + @test_rules [check_type_promotion = false] Transition(:in, Marginalisation) [ + ( + input = ( + m_out = Categorical([0.06363608348699812, 0.4635487496635592, 0.47281516684944275]), + q_a = TensorDirichlet([3.0 2.0 9.0; 9.0 10.0 9.0; 4.0 6.0 3.0;;; 8.0 8.0 4.0; 9.0 1.0 6.0; 6.0 3.0 9.0;;; 3.0 7.0 8.0; 6.0 4.0 5.0; 6.0 1.0 10.0]), + m_T1 = Categorical([0.5271858992772847, 0.07706246907875924, 0.3957516316439561]) + ), + output = Categorical([0.37646565409395055, 0.3158727786171196, 0.30766156728892985]) + ), + ( + input = ( + m_out = Categorical([0.5103858588726022, 0.42556134873724166, 0.06405279239015611]), + q_a = TensorDirichlet([9.0 7.0 4.0; 4.0 1.0 7.0; 2.0 2.0 10.0;;; 7.0 8.0 6.0; 6.0 7.0 2.0; 4.0 7.0 8.0;;; 4.0 3.0 2.0; 2.0 8.0 7.0; 10.0 6.0 7.0]), + m_T1 = Categorical([0.6160127621173446, 0.2777189566460366, 0.10626828123661897]) + ), + output = Categorical([0.3781041880107084, 0.36464328449743966, 0.2572525274918519]) + ), + ( + input = ( + m_out = Categorical([0.4453670227558059, 0.2630035661457053, 0.2916294110984888]), + q_a = TensorDirichlet([7.0 1.0 1.0; 9.0 9.0 1.0; 7.0 2.0 3.0;;; 6.0 9.0 10.0; 4.0 9.0 1.0; 10.0 9.0 10.0;;; 4.0 7.0 9.0; 8.0 4.0 10.0; 9.0 3.0 6.0]), + m_T1 = Categorical([0.15208943638244485, 0.6704322113566465, 0.17747835226090863]) + ), + output = Categorical([0.3259798540726559, 0.3289172250099544, 0.3451029209173897]) + ), + ( + input = ( + m_out = Categorical([0.15440187143581133, 0.8335492681493561, 0.012048860414832555]), + q_a = TensorDirichlet([3.0 4.0 4.0; 9.0 4.0 8.0; 8.0 1.0 8.0;;; 7.0 8.0 9.0; 9.0 4.0 1.0; 7.0 3.0 10.0;;; 7.0 1.0 1.0; 3.0 3.0 4.0; 2.0 3.0 2.0]), + m_T1 = Categorical([0.38463636429622916, 0.4014483333701451, 0.21391530233362574]) + ), + output = Categorical([0.36339604758350774, 0.35400347953287076, 0.28260047288362156]) + ), + ( + input = ( + m_out = Categorical([0.4161210892223872, 0.4941277161962706, 0.0897511945813421]), + q_a = TensorDirichlet([6.0 3.0 5.0; 3.0 6.0 9.0; 8.0 6.0 6.0;;; 9.0 8.0 10.0; 7.0 2.0 8.0; 2.0 10.0 9.0;;; 10.0 4.0 6.0; 6.0 4.0 2.0; 4.0 2.0 3.0]), + m_T1 = Categorical([0.2835337016406116, 0.26073332343890476, 0.45573297492048376]) + ), + output = Categorical([0.35054919761788883, 0.3174728738964184, 0.3319779284856927]) + ) + ] + end + + @testset "Belief Propagation: (m_out::Categorical, q_a::TensorDirichlet, m_t1::Categorical, m_t2::Categorical)" begin + @test_rules [check_type_promotion = false] Transition(:in, Marginalisation) [ + ( + input = ( + m_out = Categorical([0.08799332630703943, 0.29132551818215013, 0.6206811555108104]), + q_a = TensorDirichlet( + [ + 14.0 10.0 1.0; 6.0 8.0 1.0; 3.0 10.0 7.0;;; 14.0 3.0 3.0; 4.0 9.0 3.0; 5.0 5.0 14.0;;; 9.0 10.0 1.0; 3.0 7.0 4.0; 8.0 2.0 12.0;;;; + 13.0 9.0 1.0; 9.0 7.0 8.0; 5.0 1.0 11.0;;; 8.0 7.0 4.0; 1.0 14.0 7.0; 4.0 10.0 6.0;;; 15.0 5.0 5.0; 7.0 6.0 5.0; 7.0 3.0 10.0;;;; + 15.0 7.0 8.0; 5.0 10.0 9.0; 6.0 3.0 14.0;;; 7.0 1.0 6.0; 3.0 12.0 5.0; 3.0 5.0 9.0;;; 15.0 5.0 9.0; 9.0 12.0 10.0; 3.0 7.0 10.0;;;; + 13.0 4.0 8.0; 4.0 8.0 2.0; 5.0 10.0 11.0;;; 15.0 4.0 8.0; 9.0 15.0 2.0; 5.0 2.0 12.0;;; 14.0 1.0 8.0; 10.0 6.0 5.0; 5.0 6.0 13.0 + ] + ), + m_T1 = Categorical([0.11240998953463174, 0.5372891244719414, 0.35030088599342685]), + m_T2 = Categorical([0.09240952821382971, 0.3184996197469503, 0.3995918737894391, 0.18949897824978093]) + ), + output = Categorical([0.2516472953196383, 0.3420258883005958, 0.40632681637976575]) + ), + ( + input = ( + m_out = Categorical([0.41399930903334414, 0.2569572285438312, 0.32904346242282473]), + q_a = TensorDirichlet( + [ + 15.0 4.0 4.0; 6.0 11.0 5.0; 1.0 1.0 14.0;;; 6.0 9.0 8.0; 10.0 10.0 5.0; 6.0 2.0 13.0;;; 13.0 2.0 7.0; 3.0 9.0 6.0; 5.0 5.0 7.0;;;; + 7.0 2.0 9.0; 6.0 8.0 2.0; 6.0 4.0 12.0;;; 13.0 4.0 3.0; 1.0 10.0 2.0; 7.0 7.0 8.0;;; 9.0 9.0 4.0; 1.0 10.0 3.0; 4.0 4.0 12.0;;;; + 13.0 5.0 5.0; 1.0 7.0 1.0; 7.0 5.0 11.0;;; 9.0 10.0 6.0; 6.0 13.0 1.0; 9.0 2.0 8.0;;; 13.0 9.0 4.0; 6.0 12.0 2.0; 10.0 1.0 13.0;;;; + 9.0 6.0 8.0; 3.0 7.0 9.0; 8.0 4.0 11.0;;; 12.0 3.0 4.0; 2.0 11.0 9.0; 6.0 7.0 11.0;;; 13.0 6.0 10.0; 3.0 9.0 9.0; 10.0 10.0 6.0 + ] + ), + m_T1 = Categorical([0.2712971200834005, 0.3827161804909996, 0.3459866994255999]), + m_T2 = Categorical([0.2284278504147311, 0.1734009515489395, 0.1685400610240271, 0.4296311370123022]) + ), + output = Categorical([0.3580424644840306, 0.31051652805186924, 0.3314410074641002]) + ), + ( + input = ( + m_out = Categorical([0.28007415705382577, 0.362168131823555, 0.35775771112261917]), + q_a = TensorDirichlet( + [ + 12.0 9.0 1.0; 2.0 8.0 7.0; 10.0 1.0 13.0;;; 11.0 8.0 7.0; 10.0 14.0 7.0; 5.0 2.0 15.0;;; 8.0 4.0 4.0; 7.0 11.0 5.0; 4.0 10.0 6.0;;;; + 14.0 3.0 7.0; 8.0 15.0 2.0; 5.0 8.0 15.0;;; 7.0 7.0 4.0; 6.0 11.0 2.0; 10.0 9.0 12.0;;; 11.0 7.0 4.0; 2.0 7.0 4.0; 1.0 1.0 15.0;;;; + 13.0 7.0 5.0; 6.0 15.0 7.0; 9.0 5.0 14.0;;; 7.0 2.0 7.0; 1.0 7.0 2.0; 5.0 8.0 9.0;;; 13.0 5.0 2.0; 3.0 7.0 5.0; 9.0 10.0 14.0;;;; + 14.0 10.0 3.0; 8.0 15.0 9.0; 8.0 2.0 13.0;;; 14.0 10.0 5.0; 2.0 11.0 4.0; 6.0 3.0 11.0;;; 9.0 8.0 10.0; 9.0 6.0 8.0; 5.0 3.0 6.0 + ] + ), + m_T1 = Categorical([0.24450475682493267, 0.4963764838006123, 0.25911875937445494]), + m_T2 = Categorical([0.013289905256330505, 0.03916840092344259, 0.6158904501712571, 0.3316512436489697]) + ), + output = Categorical([0.31784517029280823, 0.34141019417661717, 0.3407446355305746]) + ), + ( + input = ( + m_out = Categorical([0.231721871481526, 0.43974647264393085, 0.3285316558745432]), + q_a = TensorDirichlet( + [ + 11.0 8.0 5.0; 5.0 7.0 10.0; 2.0 1.0 11.0;;; 15.0 8.0 4.0; 8.0 12.0 3.0; 6.0 6.0 14.0;;; 13.0 10.0 8.0; 1.0 8.0 7.0; 8.0 8.0 14.0;;;; + 8.0 4.0 8.0; 9.0 13.0 3.0; 8.0 2.0 7.0;;; 10.0 8.0 10.0; 5.0 11.0 8.0; 2.0 8.0 8.0;;; 10.0 3.0 2.0; 9.0 11.0 5.0; 10.0 4.0 15.0;;;; + 10.0 7.0 9.0; 1.0 14.0 10.0; 10.0 5.0 15.0;;; 7.0 9.0 7.0; 5.0 14.0 1.0; 9.0 9.0 6.0;;; 15.0 7.0 3.0; 7.0 8.0 4.0; 2.0 2.0 14.0;;;; + 14.0 7.0 2.0; 3.0 7.0 6.0; 4.0 9.0 15.0;;; 13.0 1.0 7.0; 3.0 12.0 4.0; 1.0 7.0 14.0;;; 7.0 7.0 3.0; 2.0 7.0 4.0; 8.0 9.0 10.0 + ] + ), + m_T1 = Categorical([0.16173494859799328, 0.5444108919070189, 0.2938541594949879]), + m_T2 = Categorical([0.3924160844603588, 0.34208347405766765, 0.0742389939993073, 0.19126144748266627]) + ), + output = Categorical([0.305396640756791, 0.36076561844428456, 0.3338377407989245]) + ), + ( + input = ( + m_out = Categorical([0.3438709572699468, 0.327896945058581, 0.3282320976714722]), + q_a = TensorDirichlet( + [ + 10.0 9.0 2.0; 6.0 8.0 10.0; 7.0 1.0 11.0;;; 8.0 5.0 2.0; 5.0 7.0 3.0; 8.0 1.0 8.0;;; 15.0 5.0 7.0; 4.0 13.0 6.0; 3.0 8.0 7.0;;;; + 10.0 9.0 4.0; 6.0 12.0 10.0; 6.0 6.0 12.0;;; 12.0 8.0 3.0; 3.0 15.0 3.0; 6.0 3.0 10.0;;; 6.0 5.0 8.0; 8.0 6.0 10.0; 8.0 5.0 11.0;;;; + 15.0 2.0 7.0; 9.0 14.0 3.0; 7.0 4.0 11.0;;; 15.0 5.0 8.0; 10.0 6.0 2.0; 1.0 8.0 12.0;;; 8.0 2.0 8.0; 3.0 12.0 6.0; 5.0 8.0 10.0;;;; + 14.0 2.0 2.0; 10.0 10.0 2.0; 1.0 4.0 9.0;;; 10.0 1.0 7.0; 1.0 6.0 3.0; 1.0 10.0 13.0;;; 6.0 8.0 4.0; 4.0 9.0 3.0; 3.0 2.0 9.0 + ] + ), + m_T1 = Categorical([0.08064616359815222, 0.3016652714857759, 0.6176885649160718]), + m_T2 = Categorical([0.11248216922354308, 0.026994470113754852, 0.4375353107069617, 0.4229880499557404]) + ), + output = Categorical([0.33437642195805395, 0.3321767052616158, 0.33344687278033025]) + ) + ] + end end diff --git a/test/rules/transition/marginals_tests.jl b/test/rules/transition/marginals_tests.jl new file mode 100644 index 000000000..841272917 --- /dev/null +++ b/test/rules/transition/marginals_tests.jl @@ -0,0 +1,79 @@ + +@testitem "marginalrules:Transition" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra + + import ReactiveMP: @test_marginalrules + + @testset "out_in: (m_out::Categorical, m_in::Categorical, q_a::MatrixDirichlet)" begin + @test_marginalrules [check_type_promotion = false] Transition(:out_in) [( + input = (m_out = Categorical([0.2, 0.5, 0.3]), m_in = Categorical([0.7, 0.1, 0.2]), q_a = MatrixDirichlet([3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0])), + output = Contingency( + [ + 0.1986102968597683 0.017209033482868178 0.034418066965736356 + 0.30115808595019306 0.07093224887848869 0.08604516741434089 + 0.18069485157011583 0.025813550224302262 0.08511869865418642 + ] + ) + )] + end + + @testset "out_in_t1: (m_out::Categorical, m_in::Categorical, m_t1::Categorical, q_a::TensorDirichlet)" begin + @test_marginalrules [check_type_promotion = false] Transition(:out_in_t1) [ + ( + input = ( + m_out = Categorical([0.2, 0.5, 0.3]), + m_in = Categorical([0.7, 0.1, 0.2]), + m_t1 = Categorical([0.01, 0.9, 0.09]), + q_a = TensorDirichlet([3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0;;; 3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0;;; 3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0]) + ), + output = Contingency( + [ + 0.001986102968597683 0.0001720903348286818 0.0003441806696573636; 0.0030115808595019304 0.0007093224887848868 0.0008604516741434088; 0.0018069485157011583 0.0002581355022430226 0.0008511869865418641;;; + 0.17874926717379147 0.015488130134581363 0.030976260269162725; 0.2710422773551738 0.06383902399063981 0.0774406506729068; 0.16262536641310427 0.023232195201872037 0.07660682878876778;;; + 0.017874926717379145 0.001548813013458136 0.003097626026916272; 0.027104227735517378 0.006383902399063981 0.007744065067290678; 0.016262536641310426 0.0023232195201872037 0.007660682878876776 + ] + ) + ), + ( + input = ( + m_out = Categorical([0, 1, 0]), + m_in = Categorical([0, 1, 0]), + m_t1 = Categorical([0, 0, 1]), + q_a = TensorDirichlet([3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0;;; 3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0;;; 3.0 2.0 2.0; 2.0 3.0 2.0; 2.0 2.0 3.0]) + ), + output = Contingency([ + 0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;; + 0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;; + 0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0 + ]) + ) + ] + end + + @testset "out_in_t1_t2: (m_out::Categorical, m_in::Categorical, m_t1::Categorical, m_t2::Categorical, q_a::TensorDirichlet)" begin + @test_marginalrules [check_type_promotion = false] Transition(:out_in_t1_t2) [( + input = ( + m_out = Categorical([0.2, 0.5, 0.3]), + m_in = Categorical([0.7, 0.1, 0.2]), + m_t1 = Categorical([0.01, 0.9, 0.09]), + m_t2 = Categorical([0.25, 0.01, 0.09, 0.65]), + q_a = TensorDirichlet( + [ + 3.4814561121678347 2.5351658244027844 2.0637422006197856; 2.2919979590901685 3.5854980740024467 2.057024456382512; 2.961498802369847 2.2641205050393607 3.2344282382034804;;; 3.748169016349685 2.1522033841434904 2.9468022556183513; 2.3868319648098764 3.3058305246781945 2.6555313055477683; 2.153603001551738 2.1909039151153378 3.218338677959591;;; 3.76759376279165 2.67577869934414 2.9092268547954774; 2.2886069210422426 3.7986205864251543 2.5056888207498655; 2.2768291735341766 2.2200857998842514 3.057692286732935;;;; + 3.2742052774819848 2.4495891683271696 2.910828066613324; 2.5896672232696503 3.469476528069095 2.3827465484202577; 2.2438457120549247 2.6260212280800417 3.295102517531366;;; 3.0224009062118107 2.9849350862212414 2.3587228609218913; 2.2727605127367476 3.5102889560557706 2.8243063953618273; 2.511594139684071 2.7428760858629246 3.910232940256628;;; 3.1338212505134146 2.303530993403952 2.7961593523809443; 2.9308944983075493 3.8464245809821307 2.666661277436305; 2.3613769215551983 2.0716049642237397 3.8418737601142907;;;; + 3.199501489096538 2.137518268035324 2.2408411282891665; 2.018092973870658 3.9235861067050033 2.7024020409406586; 2.6474910666238243 2.8381488044892156 3.707495905126347;;; 3.8476286872449337 2.3304782998090983 2.4655689133178282; 2.9433720367760063 3.5064531916942263 2.1306271953512823; 2.1371102642348183 2.867051597898455 3.6355716472274127;;; 3.871567868158627 2.398172641112016 2.4104536086573325; 2.862391328670923 3.796451684660049 2.019185201609586; 2.7614244667946997 2.8789865894397155 3.7383014331832483;;;; + 3.881225479462249 2.2954511016342427 2.7405225916860303; 2.6750638439694003 3.341534956301027 2.6703959167706186; 2.8461081506237194 2.7026253701192577 3.0188263005741622;;; 3.0488670662223054 2.0029971596452283 2.160458108078882; 2.5907129699378246 3.1715956194931714 2.9269873994895623; 2.756186325418192 2.3609090734349567 3.688167146789853;;; 3.906810674829052 2.808930999947986 2.5749185772777476; 2.3358640383658136 3.3937346187125024 2.6150136277384557; 2.2931807105355295 2.6214291139700827 3.6118464498460283 + ] + ) + ), + output = Contingency( + [ + 0.00012722338894496782 1.3022463411054439e-5 2.3158773040985717e-5; 0.00019265723394137817 4.9093360894474115e-5 5.7655750159444724e-5; 0.00015788411353765338 1.6983987794485747e-5 6.010648836077194e-5;;; 0.013178927757294632 0.0010542583344193288 0.0026629508290819644; 0.019279198227000676 0.004434759171177195 0.005875968179603495; 0.010170243238045414 0.0016174242002469851 0.004432132400365538;;; 0.001318205158406199 0.00012042898201068178 0.00027369494386151537; 0.0018193995993215945 0.0004542886984016782 0.0005713274331452671; 0.0010846223508451313 0.0001435049394841303 0.0004355069391263673;;;; + 5.1286316330728594e-6 4.894739000052876e-7 1.1997751386684061e-6; 9.695571155233414e-6 1.852908417921917e-6 2.3524534129870656e-6; 4.87320647236301e-6 7.993604640097737e-7 2.082874375208284e-6;;; 0.0004374150538359077 5.14976787531794e-5 7.872612278913694e-5; 0.0007733517629102456 0.00015564227997362512 0.0002451226065796317; 0.0005250959255472306 6.982301738650833e-5 0.0002149358268849241;;; 4.208793120980026e-5 4.2529473919388285e-6 9.455515421222158e-6; 9.720951505197967e-5 1.9560393175950377e-5 2.2325990735847952e-5; 4.4868373640979835e-5 5.579402390992366e-6 2.056044991429916e-5;;;; + 4.6405678398183326e-5 3.55905436914347e-6 7.785118026404092e-6; 6.598607617445748e-5 1.8395813506114248e-5 2.4506351808675896e-5; 5.554843967911595e-5 7.567494073295783e-6 2.1321688043455643e-5;;; 0.004518107504302185 0.00036554421801970194 0.0008316632808738873; 0.008270032858940277 0.0014899847467708608 0.001732538023087781; 0.003350992023479249 0.0007057614806289696 0.001977797041473378;;; 0.0004263541855616303 3.624499174801884e-5 8.155365325340587e-5; 0.000749567905431587 0.0001562183319223116 0.0001630999799901975; 0.00043079572234991777 6.787271668048047e-5 0.00020590461660051907;;;; + 0.000346689198530034 3.0121466426002132e-5 7.400594751234022e-5; 0.0005602929900117366 0.00011831603447926764 0.00017931562091896792; 0.0003621908181066448 5.520705962851366e-5 0.00012459361828002287;;; 0.026586609038591547 0.002540844420799181 0.004759775296735686; 0.054679569154396444 0.011161289980470982 0.017264242519374518; 0.03535978938336349 0.004691761062668513 0.013568616711277714;;; 0.0034822255538199177 0.00032680711645289906 0.0005899745083016601; 0.004730174550604292 0.001021187375290388 0.0015029285989940897; 0.002773625987891417 0.00045101159461917554 0.001320470818930216 + ] + ) + )] + end +end diff --git a/test/rules/transition/out_tests.jl b/test/rules/transition/out_tests.jl index c0fd0adff..39a8f03a3 100644 --- a/test/rules/transition/out_tests.jl +++ b/test/rules/transition/out_tests.jl @@ -66,4 +66,134 @@ (input = (m_in = PointMass([1, 0, 0]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), output = Categorical([0.1 / 0.9, 0.6 / 0.9, 0.2 / 0.9])) ] end + + @testset "Belief Propagation: (m_in::Categorical, q_a::TensorDirichlet, m_t1::Categorical)" begin + @test_rules [check_type_promotion = false] Transition(:out, Marginalisation) [ + ( + input = ( + m_in = Categorical([0.08799332630703943, 0.29132551818215013, 0.6206811555108104]), + q_a = TensorDirichlet([10.0 2.0 9.0; 1.0 6.0 1.0; 10.0 7.0 10.0;;; 2.0 4.0 8.0; 2.0 3.0 7.0; 5.0 3.0 8.0;;; 1.0 7.0 8.0; 2.0 10.0 10.0; 8.0 9.0 3.0]), + m_T1 = Categorical([0.11240998953463174, 0.5372891244719414, 0.35030088599342685]) + ), + output = Categorical([0.3434052652024092, 0.32330935346400147, 0.3332853813335894]) + ), + ( + input = ( + m_in = Categorical([0.41399930903334414, 0.2569572285438312, 0.32904346242282473]), + q_a = TensorDirichlet([9.0 6.0 8.0; 8.0 9.0 2.0; 2.0 5.0 8.0;;; 3.0 10.0 6.0; 4.0 6.0 1.0; 7.0 8.0 3.0;;; 8.0 8.0 7.0; 1.0 8.0 4.0; 6.0 4.0 2.0]), + m_T1 = Categorical([0.2712971200834005, 0.3827161804909996, 0.3459866994255999]) + ), + output = Categorical([0.4482838053046239, 0.24279036189260314, 0.3089258328027729]) + ), + ( + input = ( + m_in = Categorical([0.28007415705382577, 0.362168131823555, 0.35775771112261917]), + q_a = TensorDirichlet([5.0 1.0 10.0; 9.0 5.0 4.0; 7.0 2.0 6.0;;; 9.0 7.0 5.0; 4.0 3.0 3.0; 2.0 4.0 8.0;;; 8.0 9.0 9.0; 6.0 10.0 10.0; 8.0 9.0 6.0]), + m_T1 = Categorical([0.24450475682493267, 0.4963764838006123, 0.25911875937445494]) + ), + output = Categorical([0.3951613829932443, 0.2988663229112348, 0.3059722940955208]) + ), + ( + input = ( + m_in = Categorical([0.231721871481526, 0.43974647264393085, 0.3285316558745432]), + q_a = TensorDirichlet([6.0 9.0 8.0; 9.0 9.0 4.0; 1.0 6.0 8.0;;; 3.0 6.0 8.0; 3.0 8.0 5.0; 4.0 5.0 2.0;;; 10.0 1.0 8.0; 2.0 7.0 1.0; 7.0 7.0 3.0]), + m_T1 = Categorical([0.16173494859799328, 0.5444108919070189, 0.2938541594949879]) + ), + output = Categorical([0.3845876425194768, 0.3317654183958069, 0.2836469390847161]) + ), + ( + input = ( + m_in = Categorical([0.3050714581189855, 0.42249711097760284, 0.27243143090341165]), + q_a = TensorDirichlet([10.0 5.0 1.0; 10.0 8.0 4.0; 10.0 4.0 1.0;;; 9.0 9.0 4.0; 1.0 10.0 3.0; 7.0 4.0 1.0;;; 4.0 8.0 6.0; 4.0 9.0 4.0; 3.0 9.0 5.0]), + m_T1 = Categorical([0.08064616359815222, 0.3016652714857759, 0.6176885649160718]) + ), + output = Categorical([0.3824006540022958, 0.334754745371225, 0.28284460062647926]) + ) + ] + end + + @testset "Belief Propagation: (m_in::Categorical, q_a::TensorDirichlet, m_t1::Categorical, m_t2::Categorical)" begin + @test_rules [check_type_promotion = false] Transition(:out, Marginalisation) [ + ( + input = ( + m_in = Categorical([0.08799332630703943, 0.29132551818215013, 0.6206811555108104]), + q_a = TensorDirichlet( + [ + 14.0 10.0 1.0; 6.0 8.0 1.0; 3.0 10.0 7.0;;; 14.0 3.0 3.0; 4.0 9.0 3.0; 5.0 5.0 14.0;;; 9.0 10.0 1.0; 3.0 7.0 4.0; 8.0 2.0 12.0;;;; + 13.0 9.0 1.0; 9.0 7.0 8.0; 5.0 1.0 11.0;;; 8.0 7.0 4.0; 1.0 14.0 7.0; 4.0 10.0 6.0;;; 15.0 5.0 5.0; 7.0 6.0 5.0; 7.0 3.0 10.0;;;; + 15.0 7.0 8.0; 5.0 10.0 9.0; 6.0 3.0 14.0;;; 7.0 1.0 6.0; 3.0 12.0 5.0; 3.0 5.0 9.0;;; 15.0 5.0 9.0; 9.0 12.0 10.0; 3.0 7.0 10.0;;;; + 13.0 4.0 8.0; 4.0 8.0 2.0; 5.0 10.0 11.0;;; 15.0 4.0 8.0; 9.0 15.0 2.0; 5.0 2.0 12.0;;; 14.0 1.0 8.0; 10.0 6.0 5.0; 5.0 6.0 13.0 + ] + ), + m_T1 = Categorical([0.11240998953463174, 0.5372891244719414, 0.35030088599342685]), + m_T2 = Categorical([0.09240952821382971, 0.3184996197469503, 0.3995918737894391, 0.18949897824978093]) + ), + output = Categorical([0.27036967103523596, 0.34093292224483207, 0.3886974067199319]) + ), + ( + input = ( + m_in = Categorical([0.41399930903334414, 0.2569572285438312, 0.32904346242282473]), + q_a = TensorDirichlet( + [ + 15.0 4.0 4.0; 6.0 11.0 5.0; 1.0 1.0 14.0;;; 6.0 9.0 8.0; 10.0 10.0 5.0; 6.0 2.0 13.0;;; 13.0 2.0 7.0; 3.0 9.0 6.0; 5.0 5.0 7.0;;;; + 7.0 2.0 9.0; 6.0 8.0 2.0; 6.0 4.0 12.0;;; 13.0 4.0 3.0; 1.0 10.0 2.0; 7.0 7.0 8.0;;; 9.0 9.0 4.0; 1.0 10.0 3.0; 4.0 4.0 12.0;;;; + 13.0 5.0 5.0; 1.0 7.0 1.0; 7.0 5.0 11.0;;; 9.0 10.0 6.0; 6.0 13.0 1.0; 9.0 2.0 8.0;;; 13.0 9.0 4.0; 6.0 12.0 2.0; 10.0 1.0 13.0;;;; + 9.0 6.0 8.0; 3.0 7.0 9.0; 8.0 4.0 11.0;;; 12.0 3.0 4.0; 2.0 11.0 9.0; 6.0 7.0 11.0;;; 13.0 6.0 10.0; 3.0 9.0 9.0; 10.0 10.0 6.0 + ] + ), + m_T1 = Categorical([0.2712971200834005, 0.3827161804909996, 0.3459866994255999]), + m_T2 = Categorical([0.2284278504147311, 0.1734009515489395, 0.1685400610240271, 0.4296311370123022]) + ), + output = Categorical([0.37903960313665536, 0.27173800064227394, 0.34922239622107076]) + ), + ( + input = ( + m_in = Categorical([0.28007415705382577, 0.362168131823555, 0.35775771112261917]), + q_a = TensorDirichlet( + [ + 12.0 9.0 1.0; 2.0 8.0 7.0; 10.0 1.0 13.0;;; 11.0 8.0 7.0; 10.0 14.0 7.0; 5.0 2.0 15.0;;; 8.0 4.0 4.0; 7.0 11.0 5.0; 4.0 10.0 6.0;;;; + 14.0 3.0 7.0; 8.0 15.0 2.0; 5.0 8.0 15.0;;; 7.0 7.0 4.0; 6.0 11.0 2.0; 10.0 9.0 12.0;;; 11.0 7.0 4.0; 2.0 7.0 4.0; 1.0 1.0 15.0;;;; + 13.0 7.0 5.0; 6.0 15.0 7.0; 9.0 5.0 14.0;;; 7.0 2.0 7.0; 1.0 7.0 2.0; 5.0 8.0 9.0;;; 13.0 5.0 2.0; 3.0 7.0 5.0; 9.0 10.0 14.0;;;; + 14.0 10.0 3.0; 8.0 15.0 9.0; 8.0 2.0 13.0;;; 14.0 10.0 5.0; 2.0 11.0 4.0; 6.0 3.0 11.0;;; 9.0 8.0 10.0; 9.0 6.0 8.0; 5.0 3.0 6.0 + ] + ), + m_T1 = Categorical([0.24450475682493267, 0.4963764838006123, 0.25911875937445494]), + m_T2 = Categorical([0.013289905256330505, 0.03916840092344259, 0.6158904501712571, 0.3316512436489697]) + ), + output = Categorical([0.3350657512455311, 0.27378913306703423, 0.39114511568743465]) + ), + ( + input = ( + m_in = Categorical([0.231721871481526, 0.43974647264393085, 0.3285316558745432]), + q_a = TensorDirichlet( + [ + 11.0 8.0 5.0; 5.0 7.0 10.0; 2.0 1.0 11.0;;; 15.0 8.0 4.0; 8.0 12.0 3.0; 6.0 6.0 14.0;;; 13.0 10.0 8.0; 1.0 8.0 7.0; 8.0 8.0 14.0;;;; + 8.0 4.0 8.0; 9.0 13.0 3.0; 8.0 2.0 7.0;;; 10.0 8.0 10.0; 5.0 11.0 8.0; 2.0 8.0 8.0;;; 10.0 3.0 2.0; 9.0 11.0 5.0; 10.0 4.0 15.0;;;; + 10.0 7.0 9.0; 1.0 14.0 10.0; 10.0 5.0 15.0;;; 7.0 9.0 7.0; 5.0 14.0 1.0; 9.0 9.0 6.0;;; 15.0 7.0 3.0; 7.0 8.0 4.0; 2.0 2.0 14.0;;;; + 14.0 7.0 2.0; 3.0 7.0 6.0; 4.0 9.0 15.0;;; 13.0 1.0 7.0; 3.0 12.0 4.0; 1.0 7.0 14.0;;; 7.0 7.0 3.0; 2.0 7.0 4.0; 8.0 9.0 10.0 + ] + ), + m_T1 = Categorical([0.16173494859799328, 0.5444108919070189, 0.2938541594949879]), + m_T2 = Categorical([0.3924160844603588, 0.34208347405766765, 0.0742389939993073, 0.19126144748266627]) + ), + output = Categorical([0.3324688090125486, 0.32875902717085775, 0.3387721638165937]) + ), + ( + input = ( + m_in = Categorical([0.3050714581189855, 0.42249711097760284, 0.27243143090341165]), + q_a = TensorDirichlet( + [ + 10.0 9.0 2.0; 6.0 8.0 10.0; 7.0 1.0 11.0;;; 8.0 5.0 2.0; 5.0 7.0 3.0; 8.0 1.0 8.0;;; 15.0 5.0 7.0; 4.0 13.0 6.0; 3.0 8.0 7.0;;;; + 10.0 9.0 4.0; 6.0 12.0 10.0; 6.0 6.0 12.0;;; 12.0 8.0 3.0; 3.0 15.0 3.0; 6.0 3.0 10.0;;; 6.0 5.0 8.0; 8.0 6.0 10.0; 8.0 5.0 11.0;;;; + 15.0 2.0 7.0; 9.0 14.0 3.0; 7.0 4.0 11.0;;; 15.0 5.0 8.0; 10.0 6.0 2.0; 1.0 8.0 12.0;;; 8.0 2.0 8.0; 3.0 12.0 6.0; 5.0 8.0 10.0;;;; + 14.0 2.0 2.0; 10.0 10.0 2.0; 1.0 4.0 9.0;;; 10.0 1.0 7.0; 1.0 6.0 3.0; 1.0 10.0 13.0;;; 6.0 8.0 4.0; 4.0 9.0 3.0; 3.0 2.0 9.0 + ] + ), + m_T1 = Categorical([0.08064616359815222, 0.3016652714857759, 0.6176885649160718]), + m_T2 = Categorical([0.11248216922354308, 0.026994470113754852, 0.4375353107069617, 0.4229880499557404]) + ), + output = Categorical([0.3438709572699468, 0.327896945058581, 0.3282320976714722]) + ) + ] + end end diff --git a/test/rules/transition/t_tests.jl b/test/rules/transition/t_tests.jl new file mode 100644 index 000000000..0015a19aa --- /dev/null +++ b/test/rules/transition/t_tests.jl @@ -0,0 +1,220 @@ +@testitem "rules:Transition:in" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + + import ReactiveMP: @test_rules + + @testset "Belief Propagation: (m_out::Categorical, m_in::Categorical, q_a::TensorDirichlet)" begin + @test_rules [check_type_promotion = false] Transition(:t1, Marginalisation) [ + ( + input = ( + m_out = Categorical([0.0510559014089735, 0.05387178800920238, 0.8950723105818241]), + m_in = Categorical([0.054929868794045565, 0.04163049496789134, 0.9034396362380631]), + q_a = TensorDirichlet([1.0 6.0 32.0; 2.0 2.0 9.0; 5.0 5.0 6.0;;; 9.0 5.0 6.0; 4.0 10.0 6.0; 10.0 6.0 32.0;;; 6.0 1.0 8.0; 2.0 10.0 7.0; 1.0 3.0 8.0]) + ), + output = Categorical([0.16071994681128632, 0.5608867375006992, 0.27839331568801456]) + ), + ( + input = ( + m_out = Categorical([0.03485782178222314, 0.01245130553310446, 0.9526908726846723]), + m_in = Categorical([0.06468167584028221, 0.0649680496982824, 0.8703502744614353]), + q_a = TensorDirichlet([8.0 9.0 37.0; 5.0 9.0 8.0; 6.0 9.0 3.0;;; 3.0 4.0 7.0; 7.0 4.0 3.0; 8.0 6.0 34.0;;; 5.0 10.0 2.0; 3.0 2.0 3.0; 4.0 10.0 7.0]) + ), + output = Categorical([0.08216104573408281, 0.5240321267913837, 0.3938068274745334]) + ), + ( + input = ( + m_out = Categorical([0.013906945163896326, 0.014657675872327216, 0.9714353789637764]), + m_in = Categorical([0.06213023407554826, 0.07345989690934168, 0.8644098690151101]), + q_a = TensorDirichlet([10.0 1.0 36.0; 1.0 1.0 7.0; 2.0 9.0 4.0;;; 2.0 5.0 9.0; 1.0 9.0 1.0; 8.0 7.0 33.0;;; 6.0 2.0 1.0; 7.0 3.0 2.0; 1.0 7.0 4.0]) + ), + output = Categorical([0.10371781881719036, 0.5257747986340537, 0.37050738254875604]) + ), + ( + input = ( + m_out = Categorical([0.061610939826165154, 0.06994219293747578, 0.868446867236359]), + m_in = Categorical([0.038785774762193144, 0.017316838524837008, 0.9438973867129699]), + q_a = TensorDirichlet([9.0 1.0 37.0; 9.0 9.0 8.0; 10.0 2.0 7.0;;; 5.0 10.0 2.0; 9.0 9.0 8.0; 8.0 5.0 33.0;;; 5.0 9.0 1.0; 2.0 4.0 10.0; 2.0 1.0 5.0]) + ), + output = Categorical([0.1524663903301974, 0.5879220226683617, 0.2596115870014409]) + ), + ( + input = ( + m_out = Categorical([0.059854086369599824, 0.044893636911381464, 0.8952522767190187]), + m_in = Categorical([0.0056135853425655965, 0.03575690970940535, 0.958629504948029]), + q_a = TensorDirichlet([6.0 1.0 32.0; 8.0 7.0 1.0; 9.0 1.0 9.0;;; 10.0 10.0 4.0; 2.0 9.0 9.0; 3.0 2.0 33.0;;; 1.0 3.0 5.0; 7.0 8.0 7.0; 7.0 8.0 2.0]) + ), + output = Categorical([0.22345398380365966, 0.6247689354236513, 0.1517770807726892]) + ) + ] + end + + @testset "Belief Propagation: (m_out::Categorical, m_in::Categorical, q_a::TensorDirichlet, m_t2::Categorical)" begin + @test_rules [check_type_promotion = false] Transition(:T1, Marginalisation) [ + ( + input = ( + m_out = Categorical([0.08799332630703943, 0.29132551818215013, 0.6206811555108104]), + m_in = Categorical([0.11240998953463174, 0.5372891244719414, 0.35030088599342685]), + q_a = TensorDirichlet( + [ + 14.0 10.0 1.0; 6.0 8.0 1.0; 3.0 10.0 7.0;;; 14.0 3.0 3.0; 4.0 9.0 3.0; 5.0 5.0 14.0;;; 9.0 10.0 1.0; 3.0 7.0 4.0; 8.0 2.0 12.0;;;; + 13.0 9.0 1.0; 9.0 7.0 8.0; 5.0 1.0 11.0;;; 8.0 7.0 4.0; 1.0 14.0 7.0; 4.0 10.0 6.0;;; 15.0 5.0 5.0; 7.0 6.0 5.0; 7.0 3.0 10.0;;;; + 15.0 7.0 8.0; 5.0 10.0 9.0; 6.0 3.0 14.0;;; 7.0 1.0 6.0; 3.0 12.0 5.0; 3.0 5.0 9.0;;; 15.0 5.0 9.0; 9.0 12.0 10.0; 3.0 7.0 10.0;;;; + 13.0 4.0 8.0; 4.0 8.0 2.0; 5.0 10.0 11.0;;; 15.0 4.0 8.0; 9.0 15.0 2.0; 5.0 2.0 12.0;;; 14.0 1.0 8.0; 10.0 6.0 5.0; 5.0 6.0 13.0 + ] + ), + m_T2 = Categorical([0.09240952821382971, 0.3184996197469503, 0.3995918737894391, 0.18949897824978093]) + ), + output = Categorical([0.3204124841029019, 0.34442455887321344, 0.3351629570238847]) + ), + ( + input = ( + m_out = Categorical([0.41399930903334414, 0.2569572285438312, 0.32904346242282473]), + m_in = Categorical([0.2712971200834005, 0.3827161804909996, 0.3459866994255999]), + q_a = TensorDirichlet( + [ + 15.0 4.0 4.0; 6.0 11.0 5.0; 1.0 1.0 14.0;;; 6.0 9.0 8.0; 10.0 10.0 5.0; 6.0 2.0 13.0;;; 13.0 2.0 7.0; 3.0 9.0 6.0; 5.0 5.0 7.0;;;; + 7.0 2.0 9.0; 6.0 8.0 2.0; 6.0 4.0 12.0;;; 13.0 4.0 3.0; 1.0 10.0 2.0; 7.0 7.0 8.0;;; 9.0 9.0 4.0; 1.0 10.0 3.0; 4.0 4.0 12.0;;;; + 13.0 5.0 5.0; 1.0 7.0 1.0; 7.0 5.0 11.0;;; 9.0 10.0 6.0; 6.0 13.0 1.0; 9.0 2.0 8.0;;; 13.0 9.0 4.0; 6.0 12.0 2.0; 10.0 1.0 13.0;;;; + 9.0 6.0 8.0; 3.0 7.0 9.0; 8.0 4.0 11.0;;; 12.0 3.0 4.0; 2.0 11.0 9.0; 6.0 7.0 11.0;;; 13.0 6.0 10.0; 3.0 9.0 9.0; 10.0 10.0 6.0 + ] + ), + m_T2 = Categorical([0.2284278504147311, 0.1734009515489395, 0.1685400610240271, 0.4296311370123022]) + ), + output = Categorical([0.33252217570544323, 0.3303090112079051, 0.33716881308665164]) + ), + ( + input = ( + m_out = Categorical([0.28007415705382577, 0.362168131823555, 0.35775771112261917]), + m_in = Categorical([0.24450475682493267, 0.4963764838006123, 0.25911875937445494]), + q_a = TensorDirichlet( + [ + 12.0 9.0 1.0; 2.0 8.0 7.0; 10.0 1.0 13.0;;; 11.0 8.0 7.0; 10.0 14.0 7.0; 5.0 2.0 15.0;;; 8.0 4.0 4.0; 7.0 11.0 5.0; 4.0 10.0 6.0;;;; + 14.0 3.0 7.0; 8.0 15.0 2.0; 5.0 8.0 15.0;;; 7.0 7.0 4.0; 6.0 11.0 2.0; 10.0 9.0 12.0;;; 11.0 7.0 4.0; 2.0 7.0 4.0; 1.0 1.0 15.0;;;; + 13.0 7.0 5.0; 6.0 15.0 7.0; 9.0 5.0 14.0;;; 7.0 2.0 7.0; 1.0 7.0 2.0; 5.0 8.0 9.0;;; 13.0 5.0 2.0; 3.0 7.0 5.0; 9.0 10.0 14.0;;;; + 14.0 10.0 3.0; 8.0 15.0 9.0; 8.0 2.0 13.0;;; 14.0 10.0 5.0; 2.0 11.0 4.0; 6.0 3.0 11.0;;; 9.0 8.0 10.0; 9.0 6.0 8.0; 5.0 3.0 6.0 + ] + ), + m_T2 = Categorical([0.013289905256330505, 0.03916840092344259, 0.6158904501712571, 0.3316512436489697]) + ), + output = Categorical([0.33852639912537374, 0.3291567669590365, 0.3323168339155897]) + ), + ( + input = ( + m_out = Categorical([0.231721871481526, 0.43974647264393085, 0.3285316558745432]), + m_in = Categorical([0.16173494859799328, 0.5444108919070189, 0.2938541594949879]), + q_a = TensorDirichlet( + [ + 11.0 8.0 5.0; 5.0 7.0 10.0; 2.0 1.0 11.0;;; 15.0 8.0 4.0; 8.0 12.0 3.0; 6.0 6.0 14.0;;; 13.0 10.0 8.0; 1.0 8.0 7.0; 8.0 8.0 14.0;;;; + 8.0 4.0 8.0; 9.0 13.0 3.0; 8.0 2.0 7.0;;; 10.0 8.0 10.0; 5.0 11.0 8.0; 2.0 8.0 8.0;;; 10.0 3.0 2.0; 9.0 11.0 5.0; 10.0 4.0 15.0;;;; + 10.0 7.0 9.0; 1.0 14.0 10.0; 10.0 5.0 15.0;;; 7.0 9.0 7.0; 5.0 14.0 1.0; 9.0 9.0 6.0;;; 15.0 7.0 3.0; 7.0 8.0 4.0; 2.0 2.0 14.0;;;; + 14.0 7.0 2.0; 3.0 7.0 6.0; 4.0 9.0 15.0;;; 13.0 1.0 7.0; 3.0 12.0 4.0; 1.0 7.0 14.0;;; 7.0 7.0 3.0; 2.0 7.0 4.0; 8.0 9.0 10.0 + ] + ), + m_T2 = Categorical([0.3924160844603588, 0.34208347405766765, 0.0742389939993073, 0.19126144748266627]) + ), + output = Categorical([0.3336500717081419, 0.33405984098175184, 0.33229008731010634]) + ), + ( + input = ( + m_out = Categorical([0.3438709572699468, 0.327896945058581, 0.3282320976714722]), + m_in = Categorical([0.08064616359815222, 0.3016652714857759, 0.6176885649160718]), + q_a = TensorDirichlet( + [ + 10.0 9.0 2.0; 6.0 8.0 10.0; 7.0 1.0 11.0;;; 8.0 5.0 2.0; 5.0 7.0 3.0; 8.0 1.0 8.0;;; 15.0 5.0 7.0; 4.0 13.0 6.0; 3.0 8.0 7.0;;;; + 10.0 9.0 4.0; 6.0 12.0 10.0; 6.0 6.0 12.0;;; 12.0 8.0 3.0; 3.0 15.0 3.0; 6.0 3.0 10.0;;; 6.0 5.0 8.0; 8.0 6.0 10.0; 8.0 5.0 11.0;;;; + 15.0 2.0 7.0; 9.0 14.0 3.0; 7.0 4.0 11.0;;; 15.0 5.0 8.0; 10.0 6.0 2.0; 1.0 8.0 12.0;;; 8.0 2.0 8.0; 3.0 12.0 6.0; 5.0 8.0 10.0;;;; + 14.0 2.0 2.0; 10.0 10.0 2.0; 1.0 4.0 9.0;;; 10.0 1.0 7.0; 1.0 6.0 3.0; 1.0 10.0 13.0;;; 6.0 8.0 4.0; 4.0 9.0 3.0; 3.0 2.0 9.0 + ] + ), + m_T2 = Categorical([0.11248216922354308, 0.026994470113754852, 0.4375353107069617, 0.4229880499557404]) + ), + output = Categorical([0.3314741303273284, 0.33435785748140795, 0.3341680121912637]) + ) + ] + end + + @testset "Belief Propagation: (m_out::Categorical, m_in::Categorical, q_a::TensorDirichlet, m_t1::Categorical)" begin + @test_rules [check_type_promotion = false] Transition(:t2, Marginalisation) [ + ( + input = ( + m_out = Categorical([0.08799332630703943, 0.29132551818215013, 0.6206811555108104]), + m_in = Categorical([0.2516472953196383, 0.3420258883005958, 0.40632681637976575]), + q_a = TensorDirichlet( + [ + 14.0 10.0 1.0; 6.0 8.0 1.0; 3.0 10.0 7.0;;; 14.0 3.0 3.0; 4.0 9.0 3.0; 5.0 5.0 14.0;;; 9.0 10.0 1.0; 3.0 7.0 4.0; 8.0 2.0 12.0;;;; + 13.0 9.0 1.0; 9.0 7.0 8.0; 5.0 1.0 11.0;;; 8.0 7.0 4.0; 1.0 14.0 7.0; 4.0 10.0 6.0;;; 15.0 5.0 5.0; 7.0 6.0 5.0; 7.0 3.0 10.0;;;; + 15.0 7.0 8.0; 5.0 10.0 9.0; 6.0 3.0 14.0;;; 7.0 1.0 6.0; 3.0 12.0 5.0; 3.0 5.0 9.0;;; 15.0 5.0 9.0; 9.0 12.0 10.0; 3.0 7.0 10.0;;;; + 13.0 4.0 8.0; 4.0 8.0 2.0; 5.0 10.0 11.0;;; 15.0 4.0 8.0; 9.0 15.0 2.0; 5.0 2.0 12.0;;; 14.0 1.0 8.0; 10.0 6.0 5.0; 5.0 6.0 13.0 + ] + ), + m_T1 = Categorical([0.11240998953463174, 0.5372891244719414, 0.35030088599342685]) + ), + output = Categorical([0.2736744772456484, 0.23881821104809484, 0.2388250291499574, 0.2486822825562994]) + ), + ( + input = ( + m_out = Categorical([0.41399930903334414, 0.2569572285438312, 0.32904346242282473]), + m_in = Categorical([0.3580424644840306, 0.31051652805186924, 0.3314410074641002]), + q_a = TensorDirichlet( + [ + 15.0 4.0 4.0; 6.0 11.0 5.0; 1.0 1.0 14.0;;; 6.0 9.0 8.0; 10.0 10.0 5.0; 6.0 2.0 13.0;;; 13.0 2.0 7.0; 3.0 9.0 6.0; 5.0 5.0 7.0;;;; + 7.0 2.0 9.0; 6.0 8.0 2.0; 6.0 4.0 12.0;;; 13.0 4.0 3.0; 1.0 10.0 2.0; 7.0 7.0 8.0;;; 9.0 9.0 4.0; 1.0 10.0 3.0; 4.0 4.0 12.0;;;; + 13.0 5.0 5.0; 1.0 7.0 1.0; 7.0 5.0 11.0;;; 9.0 10.0 6.0; 6.0 13.0 1.0; 9.0 2.0 8.0;;; 13.0 9.0 4.0; 6.0 12.0 2.0; 10.0 1.0 13.0;;;; + 9.0 6.0 8.0; 3.0 7.0 9.0; 8.0 4.0 11.0;;; 12.0 3.0 4.0; 2.0 11.0 9.0; 6.0 7.0 11.0;;; 13.0 6.0 10.0; 3.0 9.0 9.0; 10.0 10.0 6.0 + ] + ), + m_T1 = Categorical([0.2712971200834005, 0.3827161804909996, 0.3459866994255999]) + ), + output = Categorical([0.2456412839560082, 0.25071835307616863, 0.25386159128651975, 0.24977877168130333]) + ), + ( + input = ( + m_out = Categorical([0.28007415705382577, 0.362168131823555, 0.35775771112261917]), + m_in = Categorical([0.31784517029280823, 0.34141019417661717, 0.3407446355305746]), + q_a = TensorDirichlet( + [ + 12.0 9.0 1.0; 2.0 8.0 7.0; 10.0 1.0 13.0;;; 11.0 8.0 7.0; 10.0 14.0 7.0; 5.0 2.0 15.0;;; 8.0 4.0 4.0; 7.0 11.0 5.0; 4.0 10.0 6.0;;;; + 14.0 3.0 7.0; 8.0 15.0 2.0; 5.0 8.0 15.0;;; 7.0 7.0 4.0; 6.0 11.0 2.0; 10.0 9.0 12.0;;; 11.0 7.0 4.0; 2.0 7.0 4.0; 1.0 1.0 15.0;;;; + 13.0 7.0 5.0; 6.0 15.0 7.0; 9.0 5.0 14.0;;; 7.0 2.0 7.0; 1.0 7.0 2.0; 5.0 8.0 9.0;;; 13.0 5.0 2.0; 3.0 7.0 5.0; 9.0 10.0 14.0;;;; + 14.0 10.0 3.0; 8.0 15.0 9.0; 8.0 2.0 13.0;;; 14.0 10.0 5.0; 2.0 11.0 4.0; 6.0 3.0 11.0;;; 9.0 8.0 10.0; 9.0 6.0 8.0; 5.0 3.0 6.0 + ] + ), + m_T1 = Categorical([0.24450475682493267, 0.4963764838006123, 0.25911875937445494]) + ), + output = Categorical([0.2521503739599087, 0.25077911660660274, 0.25010306073581035, 0.24696744869767814]) + ), + ( + input = ( + m_out = Categorical([0.231721871481526, 0.43974647264393085, 0.3285316558745432]), + m_in = Categorical([0.305396640756791, 0.36076561844428456, 0.3338377407989245]), + q_a = TensorDirichlet( + [ + 11.0 8.0 5.0; 5.0 7.0 10.0; 2.0 1.0 11.0;;; 15.0 8.0 4.0; 8.0 12.0 3.0; 6.0 6.0 14.0;;; 13.0 10.0 8.0; 1.0 8.0 7.0; 8.0 8.0 14.0;;;; + 8.0 4.0 8.0; 9.0 13.0 3.0; 8.0 2.0 7.0;;; 10.0 8.0 10.0; 5.0 11.0 8.0; 2.0 8.0 8.0;;; 10.0 3.0 2.0; 9.0 11.0 5.0; 10.0 4.0 15.0;;;; + 10.0 7.0 9.0; 1.0 14.0 10.0; 10.0 5.0 15.0;;; 7.0 9.0 7.0; 5.0 14.0 1.0; 9.0 9.0 6.0;;; 15.0 7.0 3.0; 7.0 8.0 4.0; 2.0 2.0 14.0;;;; + 14.0 7.0 2.0; 3.0 7.0 6.0; 4.0 9.0 15.0;;; 13.0 1.0 7.0; 3.0 12.0 4.0; 1.0 7.0 14.0;;; 7.0 7.0 3.0; 2.0 7.0 4.0; 8.0 9.0 10.0 + ] + ), + m_T1 = Categorical([0.16173494859799328, 0.5444108919070189, 0.2938541594949879]) + ), + output = Categorical([0.24769360939236879, 0.2570979400410685, 0.24636800327908906, 0.24884044728747362]) + ), + ( + input = ( + m_out = Categorical([0.3438709572699468, 0.327896945058581, 0.3282320976714722]), + m_in = Categorical([0.33437642195805395, 0.3321767052616158, 0.33344687278033025]), + q_a = TensorDirichlet( + [ + 10.0 9.0 2.0; 6.0 8.0 10.0; 7.0 1.0 11.0;;; 8.0 5.0 2.0; 5.0 7.0 3.0; 8.0 1.0 8.0;;; 15.0 5.0 7.0; 4.0 13.0 6.0; 3.0 8.0 7.0;;;; + 10.0 9.0 4.0; 6.0 12.0 10.0; 6.0 6.0 12.0;;; 12.0 8.0 3.0; 3.0 15.0 3.0; 6.0 3.0 10.0;;; 6.0 5.0 8.0; 8.0 6.0 10.0; 8.0 5.0 11.0;;;; + 15.0 2.0 7.0; 9.0 14.0 3.0; 7.0 4.0 11.0;;; 15.0 5.0 8.0; 10.0 6.0 2.0; 1.0 8.0 12.0;;; 8.0 2.0 8.0; 3.0 12.0 6.0; 5.0 8.0 10.0;;;; + 14.0 2.0 2.0; 10.0 10.0 2.0; 1.0 4.0 9.0;;; 10.0 1.0 7.0; 1.0 6.0 3.0; 1.0 10.0 13.0;;; 6.0 8.0 4.0; 4.0 9.0 3.0; 3.0 2.0 9.0 + ] + ), + m_T1 = Categorical([0.08064616359815222, 0.3016652714857759, 0.6176885649160718]) + ), + output = Categorical([0.2505213451093424, 0.25068976800157844, 0.25100559750160933, 0.2477832893874698]) + ) + ] + end +end