Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic implementaton of Transition node #439

Merged
merged 9 commits into from
Jan 22, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DiffResults = "1.1.0"
Distributions = "0.24, 0.25"
DomainIntegrals = "0.3.2, 0.4"
DomainSets = "0.5.2, 0.6, 0.7"
ExponentialFamily = "1.6.0"
ExponentialFamily = "1.7.0"
ExponentialFamilyProjection = "1.2"
FastCholesky = "1.3.0"
FastGaussQuadrature = "0.4, 0.5"
Expand Down
1 change: 1 addition & 0 deletions src/nodes/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include("predefined/gamma_shape_rate.jl")
include("predefined/beta.jl")
include("predefined/categorical.jl")
include("predefined/matrix_dirichlet.jl")
include("predefined/tensor_dirichlet.jl")
include("predefined/dirichlet.jl")
include("predefined/bernoulli.jl")
include("predefined/gcv.jl")
Expand Down
10 changes: 10 additions & 0 deletions src/nodes/predefined/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import SpecialFunctions: loggamma
import Base.Broadcast: BroadcastFunction

@node TensorDirichlet Stochastic [out, a]

@average_energy TensorDirichlet (q_out::TensorDirichlet, q_a::PointMass) = begin
m_a = mean(q_a)
logmean = mean(BroadcastFunction(log), q_out)
return sum(-loggamma.(sum(m_a, dims = 1)) .+ sum(loggamma.(m_a), dims = 1) .- sum((m_a .- 1.0) .* logmean, dims = 1))
end
53 changes: 51 additions & 2 deletions src/nodes/predefined/transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,30 @@

struct Transition end

@node Transition Stochastic [out, in, a]
ReactiveMP.sdtype(::Type{Transition}) = ReactiveMP.Stochastic()
ReactiveMP.is_predefined_node(::Type{Transition}) = ReactiveMP.PredefinedNodeFunctionalForm()

function ReactiveMP.prepare_interfaces_generic(fform::Type{Transition}, interfaces::AbstractVector)
return map(enumerate(interfaces)) do (index, (name, variable))
return ReactiveMP.NodeInterface(ReactiveMP.alias_interface(fform, index, name), variable)

Check warning on line 12 in src/nodes/predefined/transition.jl

View check run for this annotation

Codecov / codecov/patch

src/nodes/predefined/transition.jl#L10-L12

Added lines #L10 - L12 were not covered by tests
end
end

function ReactiveMP.alias_interface(::Type{Transition}, index, name)
if name === :out && index === 1
return :out
elseif name === :in && index === 2
return :in
elseif name === :in && index === 3
return :a
elseif name === :in && index >= 4
return Symbol(:T, index - 3)
end
end

function ReactiveMP.collect_factorisation(::Type{Transition}, t::Tuple)
return t
end

@average_energy Transition (q_out::Any, q_in::Any, q_a::MatrixDirichlet) = begin
return -probvec(q_out)' * mean(BroadcastFunction(log), q_a) * probvec(q_in)
Expand All @@ -19,9 +42,35 @@
# The reason is that we don't want to take log of zeros in the matrix `q_a` (if there are any)
# The trick here is that if RHS matrix has zero inputs, than the corresponding entries of the `contingency_matrix` matrix
# should also be zeros (see corresponding @marginalrule), so at the end `log(tiny) * 0` should not influence the result.
return -ReactiveMP.mul_trace(components(q_out_in)', mean(BroadcastFunction(clamplog), q_a))
result = -ReactiveMP.mul_trace(components(q_out_in)', mean(BroadcastFunction(clamplog), q_a))
return result
end

@average_energy Transition (q_out::Any, q_in::Any, q_a::PointMass) = begin
return -probvec(q_out)' * mean(BroadcastFunction(clamplog), q_a) * probvec(q_in)
end

function score(::AverageEnergy, ::Type{<:Transition}, ::Val{mnames}, marginals::Tuple{<:Marginal{<:Contingency}, <:Marginal{<:TensorDirichlet}}, ::Nothing) where {mnames}
q_contingency, q_a = getdata.(marginals)
return -sum(mean(BroadcastFunction(log), q_a) .* components(q_contingency))

Check warning on line 55 in src/nodes/predefined/transition.jl

View check run for this annotation

Codecov / codecov/patch

src/nodes/predefined/transition.jl#L53-L55

Added lines #L53 - L55 were not covered by tests
end

function __reduce_td_from_messages(messages, q_A, interface_index)
vmp = clamp.(exp.(mean(BroadcastFunction(log), q_A)), tiny, Inf)
probvecs = probvec.(messages)
for (i, vector) in enumerate(probvecs)
if i ≥ interface_index
actual_index = i + 1
else
actual_index = i
end
v = view(vector, :)
localdims = ntuple(x -> x == actual_index::Int64 ? length(v) : 1, ndims(vmp))
vmp .*= reshape(v, localdims)
end
dims = ntuple(x -> x ≥ interface_index ? x + 1 : x, ndims(vmp) - 1)
vmp = sum(vmp, dims = dims)
msg = reshape(vmp, :)
msg ./= sum(msg)
return Categorical(msg)
end
4 changes: 4 additions & 0 deletions src/rules/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ include("transition/marginals.jl")
include("transition/out.jl")
include("transition/in.jl")
include("transition/a.jl")
include("transition/t.jl")

include("continuous_transition/y.jl")
include("continuous_transition/x.jl")
Expand Down Expand Up @@ -184,3 +185,6 @@ include("delta/cvi/marginals.jl")
include("half_normal/out.jl")

include("binomial_polya/beta.jl")

include("tensor_dirichlet/out.jl")
include("tensor_dirichlet/marginals.jl")
4 changes: 4 additions & 0 deletions src/rules/tensor_dirichlet/marginals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

@marginalrule TensorDirichlet(:out_a) (m_out::TensorDirichlet, m_a::PointMass) = begin
return convert_paramfloattype((out = prod(ClosedProd(), TensorDirichlet(mean(m_a)), m_out), a = m_a))
end
4 changes: 4 additions & 0 deletions src/rules/tensor_dirichlet/out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

@rule TensorDirichlet(:out, Marginalisation) (m_a::PointMass,) = TensorDirichlet(mean(m_a))

@rule TensorDirichlet(:out, Marginalisation) (q_a::PointMass,) = TensorDirichlet(mean(q_a))
13 changes: 13 additions & 0 deletions src/rules/transition/a.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,16 @@ end
@rule Transition(:a, Marginalisation) (q_out_in::Contingency,) = begin
return MatrixDirichlet(components(q_out_in) .+ 1)
end

ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{:a},
vconstraint::Marginalisation,
messages_names::Nothing,
messages::Nothing,
marginals_names::Val{m_names} where {m_names},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) = TensorDirichlet(components(getdata(first(marginals))) .+ 1), addons
15 changes: 15 additions & 0 deletions src/rules/transition/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ end
normalize!(p, 1)
return Categorical(p)
end

function ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{:in},
vconstraint::Marginalisation,
messages_names::Val{m_names},
messages::Tuple,
marginals_names::Val{(:a,)},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) where {m_names}
return __reduce_td_from_messages(messages, first(marginals), 2), addons
end
13 changes: 13 additions & 0 deletions src/rules/transition/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,16 @@ end
m_in_2 = @call_rule Transition(:in, Marginalisation) (m_out = m_out, m_a = m_a, meta = meta)
return convert_paramfloattype((out = m_out, in = prod(ClosedProd(), m_in_2, m_in), a = m_a))
end

@marginalrule Transition(:out_in) (m_out::PointMass, m_in::Categorical, q_a::PointMass) = begin
m_in_2 = @call_rule Transition(:in, Marginalisation) (m_out = m_out, q_a = q_a)
return convert_paramfloattype((out = m_out, in = prod(ClosedProd(), m_in, m_in_2)))
end

outer_product(vs) = prod.(Iterators.product(vs...))

function marginalrule(
::Type{<:Transition}, ::Val{marginal_symbol}, ::Val{message_names}, messages::Tuple, ::Val{marginal_names}, marginals::Tuple, ::Any, ::Any
) where {marginal_symbol, message_names, marginal_names}
return Contingency(outer_product(probvec.(messages)) .* clamp.(exp.(mean(BroadcastFunction(log), first(marginals))), tiny, huge))
end
15 changes: 15 additions & 0 deletions src/rules/transition/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,18 @@ end
@logscale 0
return @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = q_a, meta = meta, addons = getaddons())
end

function ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{:out},
vconstraint::Marginalisation,
messages_names::Val{m_names},
messages::Tuple,
marginals_names::Val{(:a,)},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) where {m_names}
return __reduce_td_from_messages(messages, first(marginals), 1), addons
end
17 changes: 17 additions & 0 deletions src/rules/transition/t.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import Base.Broadcast: BroadcastFunction

function ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{S},
vconstraint::Marginalisation,
messages_names::Val{m_names},
messages::Tuple,
marginals_names::Val{(:a,)},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) where {S, m_names}
interface_index = parse(Int, String(S)[2:end]) + 2
return __reduce_td_from_messages(messages, first(marginals), interface_index), addons
end
56 changes: 56 additions & 0 deletions test/nodes/predefined/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

@testitem "TensorDirichletNode" begin
using ReactiveMP, Random, BayesBase, ExponentialFamily, Distributions, StableRNGs

@testset "AverageEnergy" begin
begin
rng = StableRNG(123456)
for i in 1:100
α = rand(rng, 2, 2)
a = rand(rng, 2, 2)
q_out = TensorDirichlet(α)
q_a = PointMass(a)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing))
avg_energy = score(AverageEnergy(), TensorDirichlet, Val{(:out, :a)}(), marginals, nothing)

q_out = MatrixDirichlet(α)
q_a = PointMass(a)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing))
avg_energy_matrix = score(AverageEnergy(), MatrixDirichlet, Val{(:out, :a)}(), marginals, nothing)

@test avg_energy ≈ avg_energy_matrix
end
end

begin
for rank in 3:5
for dim in 2:5
for i in 1:100
dims = ntuple(d -> dim, rank)
α = rand(rng, dims...)
a = rand(rng, dims...)

q_out = TensorDirichlet(α)
q_a = PointMass(a)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing))
avg_energy = score(AverageEnergy(), TensorDirichlet, Val{(:out, :a)}(), marginals, nothing)

q_out = Dirichlet.(eachslice(α, dims = ntuple(d -> d + 1, rank - 1)))
q_a = PointMass.(eachslice(a, dims = ntuple(d -> d + 1, rank - 1)))

avg_energy_matrix = 0.0
for (dir, a) in zip(q_out, q_a)
marginals = (Marginal(dir, false, false, nothing), Marginal(a, false, false, nothing))
avg_energy_matrix += score(AverageEnergy(), Dirichlet, Val{(:out, :a)}(), marginals, nothing)
end

@test avg_energy ≈ avg_energy_matrix
end
end
end
end
end
end
86 changes: 86 additions & 0 deletions test/nodes/predefined/transition_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
@testitem "TransitionNode" begin
using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily

@testset "Transition node properties" begin
@test ReactiveMP.sdtype(Transition) == Stochastic()
@test ReactiveMP.alias_interface(Transition, 1, :out) == :out
@test ReactiveMP.alias_interface(Transition, 2, :in) == :in
@test ReactiveMP.alias_interface(Transition, 3, :in) == :a
@test ReactiveMP.alias_interface(Transition, 4, :in) == :T1

@test ReactiveMP.collect_factorisation(Transition, ()) == ()
end
@testset "AverageEnergy(q_out_in::Contingency, q_a::MatrixDirichlet)" begin end

@testset "AverageEnergy(q_out_in::Contingency, q_a::PointMass)" begin
contingency_matrix = [0.2 0.3; 0.4 0.1]
a_matrix = [0.7 0.3; 0.2 0.8]

q_out_in = Contingency(contingency_matrix)
q_a = PointMass(a_matrix)

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

# Expected value calculated by hand
expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected

contingency_matrix = [0.2 0.3; 0.4 0.1]
a_matrix = [1.0 0.0; 0.0 1.0]

q_out_in = Contingency(contingency_matrix)
q_a = PointMass(a_matrix)

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected

contingency_matrix = prod.(Iterators.product([0, 1, 0], [0.1, 0.4, 0.5]))
a_matrix = diageye(3)

q_out_in = Contingency(contingency_matrix)
q_a = PointMass(a_matrix)

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected

contingency_matrix = [0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0]
q_out_in = Contingency(contingency_matrix)
q_a = PointMass(diageye(3))

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))
@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) ≈ expected
end

@testset "AverageEnergy(q_out::Any, q_in::Any, q_a::PointMass)" begin
q_out = Categorical([0.3, 0.7])
q_in = Categorical([0.8, 0.2])
q_a = PointMass([0.7 0.3; 0.2 0.8])

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_in, false, false, nothing), Marginal(q_a, false, false, nothing))

contingency = probvec(q_out) * probvec(q_in)'
expected = -sum(contingency .* log.(clamp.(mean(q_a), tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out, :in, :a)}(), marginals, nothing) ≈ expected

q_out = Categorical([0.0, 1.0])
q_in = Categorical([0.0, 1.0])
q_a = PointMass([1.0 0.0; 1.0 0.0])

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_in, false, false, nothing), Marginal(q_a, false, false, nothing))

contingency = probvec(q_out) * probvec(q_in)'

expected = -sum(contingency .* log.(clamp.(mean(q_a), tiny, Inf)))
@test score(AverageEnergy(), Transition, Val{(:out, :in, :a)}(), marginals, nothing) ≈ expected
end
end
12 changes: 12 additions & 0 deletions test/rules/transition/a_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@
input = (q_out_in = Contingency(diageye(3)),), output = MatrixDirichlet([1.333333333333333 1 1; 1 1.3333333333333 1; 1 1 1.33333333333333333])
)]
end

@testset "Variational Bayes: (q_out_in_t1::Contingency)" begin
@test_rules [check_type_promotion = false] Transition(:a, Marginalisation) [(
input = (q_out_in_t1 = Contingency(ones(3, 3, 3)),), output = TensorDirichlet(ones(3, 3, 3) .+ (1 / 27))
)]
end

@testset "Variational Bayes: (q_out_in_t1_t2::Contingency)" begin
@test_rules [check_type_promotion = false] Transition(:a, Marginalisation) [(
input = (q_out_in_t1_t2 = Contingency(ones(3, 3, 3, 3)),), output = TensorDirichlet(ones(3, 3, 3, 3) .+ (1 / 81))
)]
end
end
Loading
Loading