From 2ed9fd0950b755c0fea891333e5765097a6de7db Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Thu, 26 Jan 2023 18:36:39 +0100 Subject: [PATCH 1/8] Fix issue #249 --- .../mv_normal_mean_covariance.jl | 2 +- src/distributions/mv_normal_mean_precision.jl | 2 +- .../mv_normal_weighted_mean_precision.jl | 2 +- src/distributions/normal.jl | 5 ++ src/helpers/algebra/common.jl | 26 +++++++ test/algebra/test_common.jl | 13 ++++ test/distributions/test_normal.jl | 75 ++++++++++++------- 7 files changed, 94 insertions(+), 31 deletions(-) diff --git a/src/distributions/mv_normal_mean_covariance.jl b/src/distributions/mv_normal_mean_covariance.jl index 7938bf160..35e3b2227 100644 --- a/src/distributions/mv_normal_mean_covariance.jl +++ b/src/distributions/mv_normal_mean_covariance.jl @@ -55,7 +55,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanCovariance, x::AbstractVect for i in 1:length(r) @inbounds r[i] = μ[i] - x[i] end - return dot(r, invcov(dist), r) # x' * A * x + return xT_A_y(r, invcov(dist), r) # x' * A * x end Base.eltype(::MvNormalMeanCovariance{T}) where {T} = T diff --git a/src/distributions/mv_normal_mean_precision.jl b/src/distributions/mv_normal_mean_precision.jl index 5fc2a16fa..28afad052 100644 --- a/src/distributions/mv_normal_mean_precision.jl +++ b/src/distributions/mv_normal_mean_precision.jl @@ -45,7 +45,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanPrecision, x::AbstractVecto for i in 1:length(r) @inbounds r[i] = μ[i] - x[i] end - return dot(r, invcov(dist), r) + return xT_A_y(r, invcov(dist), r) # x' * A * x end Base.eltype(::MvNormalMeanPrecision{T}) where {T} = T diff --git a/src/distributions/mv_normal_weighted_mean_precision.jl b/src/distributions/mv_normal_weighted_mean_precision.jl index 0d1daafc2..6f1863625 100644 --- a/src/distributions/mv_normal_weighted_mean_precision.jl +++ b/src/distributions/mv_normal_weighted_mean_precision.jl @@ -54,7 +54,7 @@ function Distributions.sqmahal!(r, dist::MvNormalWeightedMeanPrecision, x::Abstr for i in 1:length(r) @inbounds r[i] = μ[i] - x[i] end - return dot(r, invcov(dist), r) + return xT_A_y(r, invcov(dist), r) # x' * A * x end Base.eltype(::MvNormalWeightedMeanPrecision{T}) where {T} = T diff --git a/src/distributions/normal.jl b/src/distributions/normal.jl index 389b7f8f4..c8647632d 100644 --- a/src/distributions/normal.jl +++ b/src/distributions/normal.jl @@ -191,6 +191,11 @@ promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanVariance}) promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanPrecision}) = MvNormalMeanPrecision promote_variate_type(::Type{Multivariate}, ::Type{<:NormalWeightedMeanPrecision}) = MvNormalWeightedMeanPrecision +# Conversion to gaussian distributions from `Distributions.jl` + +Base.convert(::Type{ Normal }, dist::UnivariateNormalDistributionsFamily) = Normal(mean_std(dist)...) +Base.convert(::Type{ MvNormal }, dist::MultivariateNormalDistributionsFamily) = MvNormal(mean_cov(dist)...) + # Conversion to mean - variance parametrisation function Base.convert(::Type{NormalMeanVariance{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real} diff --git a/src/helpers/algebra/common.jl b/src/helpers/algebra/common.jl index cf0b93d1b..099069eb5 100644 --- a/src/helpers/algebra/common.jl +++ b/src/helpers/algebra/common.jl @@ -164,6 +164,32 @@ function v_a_vT(v1, a, v2) return result end +""" + xT_A_y(x, A, y) + +Computes `dot(x, A, y)`. The built-in Julia 3-arg `dot` is not compatible with the auto-differentiation packages, +such as `ForwardDiff`. We use our own implementation in some cases but ultimately fallback to the `dot`. +""" +xT_A_y(x, A, y) = dot(x, A, y) + +function xT_A_y(x::AbstractVector, A::AbstractMatrix, y::AbstractVector) + (axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch()) + T = typeof(dot(first(x), first(A), first(y))) + s = zero(T) + i₁ = first(eachindex(x)) + x₁ = first(x) + @inbounds for j in eachindex(y) + yj = y[j] + temp = zero(adjoint(A[i₁,j]) * x₁) + @simd for i in eachindex(x) + temp += adjoint(A[i,j]) * x[i] + end + s += dot(temp, yj) + end + return s +end + + """ mvbeta(x) diff --git a/test/algebra/test_common.jl b/test/algebra/test_common.jl index 98bd95669..cee614ca1 100644 --- a/test/algebra/test_common.jl +++ b/test/algebra/test_common.jl @@ -84,6 +84,19 @@ using LinearAlgebra @test ReactiveMP.mul_trace(a, b) ≈ a * b end end + + @testset "xT_A_y" begin + import ReactiveMP: xT_A_y + + rng = MersenneTwister(1234) + for size in 2:5, T1 in (Float32, Float64), T2 in (Float32, Float64), T3 in (Float32, Float64) + x = rand(T1, size) + A = rand(T2, size, size) + y = rand(T3, size) + @test dot(x, A, y) ≈ xT_A_y(x, A, y) + end + + end end end diff --git a/test/distributions/test_normal.jl b/test/distributions/test_normal.jl index f2ad7eeb2..b8a50fd61 100644 --- a/test/distributions/test_normal.jl +++ b/test/distributions/test_normal.jl @@ -5,28 +5,40 @@ using ReactiveMP using Random using LinearAlgebra using Distributions +using ForwardDiff import ReactiveMP: convert_eltype @testset "Normal" begin @testset "Univariate conversions" begin - check_basic_statistics = (left, right) -> begin + check_basic_statistics = (left, right; include_extended_methods = true) -> begin @test mean(left) ≈ mean(right) @test median(left) ≈ median(right) @test mode(left) ≈ mode(right) - @test weightedmean(left) ≈ weightedmean(right) @test var(left) ≈ var(right) @test std(left) ≈ std(right) - @test cov(left) ≈ cov(right) - @test invcov(left) ≈ invcov(right) - @test precision(left) ≈ precision(right) @test entropy(left) ≈ entropy(right) - @test pdf(left, 1.0) ≈ pdf(right, 1.0) - @test pdf(left, -1.0) ≈ pdf(right, -1.0) - @test pdf(left, 0.0) ≈ pdf(right, 0.0) - @test logpdf(left, 1.0) ≈ logpdf(right, 1.0) - @test logpdf(left, -1.0) ≈ logpdf(right, -1.0) - @test logpdf(left, 0.0) ≈ logpdf(right, 0.0) + + for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand()) + @test pdf(left, value) ≈ pdf(right, value) + @test logpdf(left, value) ≈ logpdf(right, value) + @test all(ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value])) + @test all(ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value])) + end + + # These methods are not defined for distributions from `Distributions.jl + if include_extended_methods + @test cov(left) ≈ cov(right) + @test invcov(left) ≈ invcov(right) + @test weightedmean(left) ≈ weightedmean(right) + @test precision(left) ≈ precision(right) + @test all(mean_cov(left) .≈ mean_cov(right)) + @test all(mean_invcov(left) .≈ mean_invcov(right)) + @test all(mean_precision(left) .≈ mean_precision(right)) + @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) + @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) + @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + end end types = ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float64}) @@ -36,6 +48,7 @@ import ReactiveMP: convert_eltype for type in types left = convert(type, rand(rng, Float64), rand(rng, Float64)) + check_basic_statistics(left, convert(Normal, left); include_extended_methods = false) for type in [types..., etypes...] right = convert(type, left) check_basic_statistics(left, right) @@ -56,31 +69,36 @@ import ReactiveMP: convert_eltype end @testset "Multivariate conversions" begin - check_basic_statistics = (left, right, dims) -> begin + check_basic_statistics = (left, right, dims; include_extended_methods = true) -> begin @test mean(left) ≈ mean(right) @test mode(left) ≈ mode(right) - @test weightedmean(left) ≈ weightedmean(right) @test var(left) ≈ var(right) @test cov(left) ≈ cov(right) - @test invcov(left) ≈ invcov(right) @test logdetcov(left) ≈ logdetcov(right) - @test precision(left) ≈ precision(right) @test length(left) === length(right) - @test ndims(left) === ndims(right) @test size(left) === size(right) @test entropy(left) ≈ entropy(right) - @test all(mean_cov(left) .≈ mean_cov(right)) - @test all(mean_invcov(left) .≈ mean_invcov(right)) - @test all(mean_precision(left) .≈ mean_precision(right)) - @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) - @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) - @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) - @test pdf(left, fill(1.0, dims)) ≈ pdf(right, fill(1.0, dims)) - @test pdf(left, fill(-1.0, dims)) ≈ pdf(right, fill(-1.0, dims)) - @test pdf(left, fill(0.0, dims)) ≈ pdf(right, fill(0.0, dims)) - @test logpdf(left, fill(1.0, dims)) ≈ logpdf(right, fill(1.0, dims)) - @test logpdf(left, fill(-1.0, dims)) ≈ logpdf(right, fill(-1.0, dims)) - @test logpdf(left, fill(0.0, dims)) ≈ logpdf(right, fill(0.0, dims)) + + for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims)) + @test pdf(left, value) ≈ pdf(right, value) + @test logpdf(left, value) ≈ logpdf(right, value) + @test all(isapprox.(ForwardDiff.gradient((x) -> logpdf(left, x), value), ForwardDiff.gradient((x) -> logpdf(right, x), value), atol = 1e-14)) + @test all(isapprox.(ForwardDiff.hessian((x) -> logpdf(left, x), value), ForwardDiff.hessian((x) -> logpdf(right, x), value), atol = 1e-14)) + end + + # These methods are not defined for distributions from `Distributions.jl + if include_extended_methods + @test ndims(left) === ndims(right) + @test invcov(left) ≈ invcov(right) + @test weightedmean(left) ≈ weightedmean(right) + @test precision(left) ≈ precision(right) + @test all(mean_cov(left) .≈ mean_cov(right)) + @test all(mean_invcov(left) .≈ mean_invcov(right)) + @test all(mean_precision(left) .≈ mean_precision(right)) + @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) + @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) + @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + end end types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64}) @@ -92,6 +110,7 @@ import ReactiveMP: convert_eltype for dim in dims for type in types left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(rand(rng, Float64, dim)))) + check_basic_statistics(left, convert(MvNormal, left), dim; include_extended_methods = false) for type in [types..., etypes...] right = convert(type, left) check_basic_statistics(left, right, dim) From 501046a0f1e6cc55ccfdc3e5eb31e7c3e0911410 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Thu, 26 Jan 2023 18:43:56 +0100 Subject: [PATCH 2/8] remove 3arg dot function --- src/distributions/normal.jl | 2 +- src/nodes/autoregressive.jl | 6 +++--- src/rules/dot_product/out.jl | 2 +- src/rules/multiplication/A.jl | 2 +- src/rules/multiplication/in.jl | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/distributions/normal.jl b/src/distributions/normal.jl index c8647632d..d3e0ed3d1 100644 --- a/src/distributions/normal.jl +++ b/src/distributions/normal.jl @@ -317,7 +317,7 @@ function Base.prod( n = length(left) v_inv, v_logdet = cholinv_logdet(v) m = m_left - m_right - return -(v_logdet + n * log2π) / 2 - dot(m, v_inv, m) / 2 + return -(v_logdet + n * log2π) / 2 - xT_A_y(m, v_inv, m) / 2 end ## Friendly functions diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl index ad59b2634..1694ba58b 100644 --- a/src/nodes/autoregressive.jl +++ b/src/nodes/autoregressive.jl @@ -50,8 +50,8 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici my1, Vy1 = first(myx), first(Vyx) Vy1x = ar_slice(getvform(meta), Vyx, 1, (order + 1):(2order)) - # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 - AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2 + # Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 + AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2 # correction if is_multivariate(meta) @@ -76,7 +76,7 @@ end my1, Vy1 = first(my), first(Vy) - AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx))) + AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx))) # correction if is_multivariate(meta) diff --git a/src/rules/dot_product/out.jl b/src/rules/dot_product/out.jl index aa2f78d8e..1b7ccf247 100644 --- a/src/rules/dot_product/out.jl +++ b/src/rules/dot_product/out.jl @@ -6,5 +6,5 @@ end @rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin A = mean(m_in1) in2_mean, in2_cov = mean_cov(m_in2) - return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A)) + return NormalMeanVariance(dot(A, in2_mean), xT_A_y(A, in2_cov, A)) end diff --git a/src/rules/multiplication/A.jl b/src/rules/multiplication/A.jl index 96181e752..5330aa45e 100644 --- a/src/rules/multiplication/A.jl +++ b/src/rules/multiplication/A.jl @@ -18,7 +18,7 @@ end @rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin A = mean(m_in) ξ_out, W_out = weightedmean_precision(m_out) - W = correction!(meta, dot(A, W_out, A)) + W = correction!(meta, xT_A_y(A, W_out, A)) return NormalWeightedMeanPrecision(dot(A, ξ_out), W) end diff --git a/src/rules/multiplication/in.jl b/src/rules/multiplication/in.jl index a03123a5e..040e222b1 100644 --- a/src/rules/multiplication/in.jl +++ b/src/rules/multiplication/in.jl @@ -18,7 +18,7 @@ end @rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin A = mean(m_A) ξ_out, W_out = weightedmean_precision(m_out) - W = correction!(meta, dot(A, W_out, A)) + W = correction!(meta, xT_A_y(A, W_out, A)) return NormalWeightedMeanPrecision(dot(A, ξ_out), W) end From c376fd44a94e47f0e6f3dce97f354dd44462f150 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Thu, 26 Jan 2023 18:47:42 +0100 Subject: [PATCH 3/8] style: make format --- src/distributions/normal.jl | 4 +- src/helpers/algebra/common.jl | 5 +- test/algebra/test_common.jl | 3 +- test/distributions/test_normal.jl | 116 +++++++++++++++--------------- 4 files changed, 64 insertions(+), 64 deletions(-) diff --git a/src/distributions/normal.jl b/src/distributions/normal.jl index d3e0ed3d1..66117b40a 100644 --- a/src/distributions/normal.jl +++ b/src/distributions/normal.jl @@ -193,8 +193,8 @@ promote_variate_type(::Type{Multivariate}, ::Type{<:NormalWeightedMeanPrecision} # Conversion to gaussian distributions from `Distributions.jl` -Base.convert(::Type{ Normal }, dist::UnivariateNormalDistributionsFamily) = Normal(mean_std(dist)...) -Base.convert(::Type{ MvNormal }, dist::MultivariateNormalDistributionsFamily) = MvNormal(mean_cov(dist)...) +Base.convert(::Type{Normal}, dist::UnivariateNormalDistributionsFamily) = Normal(mean_std(dist)...) +Base.convert(::Type{MvNormal}, dist::MultivariateNormalDistributionsFamily) = MvNormal(mean_cov(dist)...) # Conversion to mean - variance parametrisation diff --git a/src/helpers/algebra/common.jl b/src/helpers/algebra/common.jl index 099069eb5..d4e366566 100644 --- a/src/helpers/algebra/common.jl +++ b/src/helpers/algebra/common.jl @@ -180,16 +180,15 @@ function xT_A_y(x::AbstractVector, A::AbstractMatrix, y::AbstractVector) x₁ = first(x) @inbounds for j in eachindex(y) yj = y[j] - temp = zero(adjoint(A[i₁,j]) * x₁) + temp = zero(adjoint(A[i₁, j]) * x₁) @simd for i in eachindex(x) - temp += adjoint(A[i,j]) * x[i] + temp += adjoint(A[i, j]) * x[i] end s += dot(temp, yj) end return s end - """ mvbeta(x) diff --git a/test/algebra/test_common.jl b/test/algebra/test_common.jl index cee614ca1..925aab234 100644 --- a/test/algebra/test_common.jl +++ b/test/algebra/test_common.jl @@ -85,7 +85,7 @@ using LinearAlgebra end end - @testset "xT_A_y" begin + @testset "xT_A_y" begin import ReactiveMP: xT_A_y rng = MersenneTwister(1234) @@ -95,7 +95,6 @@ using LinearAlgebra y = rand(T3, size) @test dot(x, A, y) ≈ xT_A_y(x, A, y) end - end end diff --git a/test/distributions/test_normal.jl b/test/distributions/test_normal.jl index b8a50fd61..be789d83e 100644 --- a/test/distributions/test_normal.jl +++ b/test/distributions/test_normal.jl @@ -11,35 +11,36 @@ import ReactiveMP: convert_eltype @testset "Normal" begin @testset "Univariate conversions" begin - check_basic_statistics = (left, right; include_extended_methods = true) -> begin - @test mean(left) ≈ mean(right) - @test median(left) ≈ median(right) - @test mode(left) ≈ mode(right) - @test var(left) ≈ var(right) - @test std(left) ≈ std(right) - @test entropy(left) ≈ entropy(right) - - for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand()) - @test pdf(left, value) ≈ pdf(right, value) - @test logpdf(left, value) ≈ logpdf(right, value) - @test all(ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value])) - @test all(ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value])) - end - - # These methods are not defined for distributions from `Distributions.jl - if include_extended_methods - @test cov(left) ≈ cov(right) - @test invcov(left) ≈ invcov(right) - @test weightedmean(left) ≈ weightedmean(right) - @test precision(left) ≈ precision(right) - @test all(mean_cov(left) .≈ mean_cov(right)) - @test all(mean_invcov(left) .≈ mean_invcov(right)) - @test all(mean_precision(left) .≈ mean_precision(right)) - @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) - @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) - @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + check_basic_statistics = + (left, right; include_extended_methods = true) -> begin + @test mean(left) ≈ mean(right) + @test median(left) ≈ median(right) + @test mode(left) ≈ mode(right) + @test var(left) ≈ var(right) + @test std(left) ≈ std(right) + @test entropy(left) ≈ entropy(right) + + for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand()) + @test pdf(left, value) ≈ pdf(right, value) + @test logpdf(left, value) ≈ logpdf(right, value) + @test all(ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value])) + @test all(ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value])) + end + + # These methods are not defined for distributions from `Distributions.jl + if include_extended_methods + @test cov(left) ≈ cov(right) + @test invcov(left) ≈ invcov(right) + @test weightedmean(left) ≈ weightedmean(right) + @test precision(left) ≈ precision(right) + @test all(mean_cov(left) .≈ mean_cov(right)) + @test all(mean_invcov(left) .≈ mean_invcov(right)) + @test all(mean_precision(left) .≈ mean_precision(right)) + @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) + @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) + @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + end end - end types = ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float64}) etypes = ReactiveMP.union_types(UnivariateNormalDistributionsFamily) @@ -69,37 +70,38 @@ import ReactiveMP: convert_eltype end @testset "Multivariate conversions" begin - check_basic_statistics = (left, right, dims; include_extended_methods = true) -> begin - @test mean(left) ≈ mean(right) - @test mode(left) ≈ mode(right) - @test var(left) ≈ var(right) - @test cov(left) ≈ cov(right) - @test logdetcov(left) ≈ logdetcov(right) - @test length(left) === length(right) - @test size(left) === size(right) - @test entropy(left) ≈ entropy(right) - - for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims)) - @test pdf(left, value) ≈ pdf(right, value) - @test logpdf(left, value) ≈ logpdf(right, value) - @test all(isapprox.(ForwardDiff.gradient((x) -> logpdf(left, x), value), ForwardDiff.gradient((x) -> logpdf(right, x), value), atol = 1e-14)) - @test all(isapprox.(ForwardDiff.hessian((x) -> logpdf(left, x), value), ForwardDiff.hessian((x) -> logpdf(right, x), value), atol = 1e-14)) - end + check_basic_statistics = + (left, right, dims; include_extended_methods = true) -> begin + @test mean(left) ≈ mean(right) + @test mode(left) ≈ mode(right) + @test var(left) ≈ var(right) + @test cov(left) ≈ cov(right) + @test logdetcov(left) ≈ logdetcov(right) + @test length(left) === length(right) + @test size(left) === size(right) + @test entropy(left) ≈ entropy(right) + + for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims)) + @test pdf(left, value) ≈ pdf(right, value) + @test logpdf(left, value) ≈ logpdf(right, value) + @test all(isapprox.(ForwardDiff.gradient((x) -> logpdf(left, x), value), ForwardDiff.gradient((x) -> logpdf(right, x), value), atol = 1e-14)) + @test all(isapprox.(ForwardDiff.hessian((x) -> logpdf(left, x), value), ForwardDiff.hessian((x) -> logpdf(right, x), value), atol = 1e-14)) + end - # These methods are not defined for distributions from `Distributions.jl - if include_extended_methods - @test ndims(left) === ndims(right) - @test invcov(left) ≈ invcov(right) - @test weightedmean(left) ≈ weightedmean(right) - @test precision(left) ≈ precision(right) - @test all(mean_cov(left) .≈ mean_cov(right)) - @test all(mean_invcov(left) .≈ mean_invcov(right)) - @test all(mean_precision(left) .≈ mean_precision(right)) - @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) - @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) - @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + # These methods are not defined for distributions from `Distributions.jl + if include_extended_methods + @test ndims(left) === ndims(right) + @test invcov(left) ≈ invcov(right) + @test weightedmean(left) ≈ weightedmean(right) + @test precision(left) ≈ precision(right) + @test all(mean_cov(left) .≈ mean_cov(right)) + @test all(mean_invcov(left) .≈ mean_invcov(right)) + @test all(mean_precision(left) .≈ mean_precision(right)) + @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) + @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) + @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + end end - end types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64}) etypes = ReactiveMP.union_types(MultivariateNormalDistributionsFamily) From ad608cf4dde5340f9bbfb0b38fab205333aded94 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 27 Jan 2023 11:28:31 +0100 Subject: [PATCH 4/8] feat: use multiple dispatch for fix --- docs/src/extra/contributing.md | 8 +++++- src/ReactiveMP.jl | 1 + .../mv_normal_mean_covariance.jl | 2 +- src/distributions/mv_normal_mean_precision.jl | 2 +- .../mv_normal_weighted_mean_precision.jl | 2 +- src/distributions/normal.jl | 2 +- src/fixes.jl | 27 +++++++++++++++++++ src/helpers/algebra/common.jl | 25 ----------------- src/nodes/autoregressive.jl | 4 +-- src/rules/dot_product/out.jl | 2 +- src/rules/multiplication/A.jl | 2 +- src/rules/multiplication/in.jl | 2 +- test/algebra/test_common.jl | 12 --------- 13 files changed, 44 insertions(+), 47 deletions(-) create mode 100644 src/fixes.jl diff --git a/docs/src/extra/contributing.md b/docs/src/extra/contributing.md index 14f6e59d6..79f7707fa 100644 --- a/docs/src/extra/contributing.md +++ b/docs/src/extra/contributing.md @@ -69,6 +69,12 @@ In addition tests can be evaluated by running following command in the ReactiveM make test ``` +### Fixes to external libraries + +If a bug has been discovered in an external dependencies of the `ReactiveMP.jl` it is the best to open an issue +directly in the dependency's github repository. You use can use the `fixes.jl` file for hot-fixes before +a new release of the broken dependecy is available. + ### Makefile `ReactiveMP.jl` uses `Makefile` for most common operations: @@ -80,4 +86,4 @@ make test - `make docs`: Compile documentation - `make benchmark`: Run simple benchmark - `make lint`: Check codestyle -- `make format`: Check and fix codestyle \ No newline at end of file +- `make format`: Check and fix codestyle diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index abf6cff0e..2c708d7e3 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -7,6 +7,7 @@ using TinyHugeNumbers # Reexport `tiny` and `huge` from the `TinyHugeNumbers` export tiny, huge +include("fixes.jl") include("helpers/macrohelpers.jl") include("helpers/helpers.jl") diff --git a/src/distributions/mv_normal_mean_covariance.jl b/src/distributions/mv_normal_mean_covariance.jl index 35e3b2227..7938bf160 100644 --- a/src/distributions/mv_normal_mean_covariance.jl +++ b/src/distributions/mv_normal_mean_covariance.jl @@ -55,7 +55,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanCovariance, x::AbstractVect for i in 1:length(r) @inbounds r[i] = μ[i] - x[i] end - return xT_A_y(r, invcov(dist), r) # x' * A * x + return dot(r, invcov(dist), r) # x' * A * x end Base.eltype(::MvNormalMeanCovariance{T}) where {T} = T diff --git a/src/distributions/mv_normal_mean_precision.jl b/src/distributions/mv_normal_mean_precision.jl index 28afad052..889b00146 100644 --- a/src/distributions/mv_normal_mean_precision.jl +++ b/src/distributions/mv_normal_mean_precision.jl @@ -45,7 +45,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanPrecision, x::AbstractVecto for i in 1:length(r) @inbounds r[i] = μ[i] - x[i] end - return xT_A_y(r, invcov(dist), r) # x' * A * x + return dot(r, invcov(dist), r) # x' * A * x end Base.eltype(::MvNormalMeanPrecision{T}) where {T} = T diff --git a/src/distributions/mv_normal_weighted_mean_precision.jl b/src/distributions/mv_normal_weighted_mean_precision.jl index 6f1863625..86aa046ec 100644 --- a/src/distributions/mv_normal_weighted_mean_precision.jl +++ b/src/distributions/mv_normal_weighted_mean_precision.jl @@ -54,7 +54,7 @@ function Distributions.sqmahal!(r, dist::MvNormalWeightedMeanPrecision, x::Abstr for i in 1:length(r) @inbounds r[i] = μ[i] - x[i] end - return xT_A_y(r, invcov(dist), r) # x' * A * x + return dot(r, invcov(dist), r) # x' * A * x end Base.eltype(::MvNormalWeightedMeanPrecision{T}) where {T} = T diff --git a/src/distributions/normal.jl b/src/distributions/normal.jl index 66117b40a..9caeb15bf 100644 --- a/src/distributions/normal.jl +++ b/src/distributions/normal.jl @@ -317,7 +317,7 @@ function Base.prod( n = length(left) v_inv, v_logdet = cholinv_logdet(v) m = m_left - m_right - return -(v_logdet + n * log2π) / 2 - xT_A_y(m, v_inv, m) / 2 + return -(v_logdet + n * log2π) / 2 - dot(m, v_inv, m) / 2 end ## Friendly functions diff --git a/src/fixes.jl b/src/fixes.jl new file mode 100644 index 000000000..a0380c5f4 --- /dev/null +++ b/src/fixes.jl @@ -0,0 +1,27 @@ +# This file implements various hot-fixes for external dependencies +# This file can be empty, which is fine. It only means that all external dependecies released a new version +# that is now fixed + +# Fix for 3-argument `dot` product and `ForwardDiff.hessian`, see +# https://github.com/JuliaDiff/ForwardDiff.jl/issues/551 +# https://github.com/JuliaDiff/ForwardDiff.jl/pull/481 +# https://github.com/JuliaDiff/ForwardDiff.jl/issues/480 +import LinearAlgebra: dot +import ForwardDiff + +function dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector{D}) where {D <: ForwardDiff.Dual} + (axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch()) + T = typeof(dot(first(x), first(A), first(y))) + s = zero(T) + i₁ = first(eachindex(x)) + x₁ = first(x) + @inbounds for j in eachindex(y) + yj = y[j] + temp = zero(adjoint(A[i₁, j]) * x₁) + @simd for i in eachindex(x) + temp += adjoint(A[i, j]) * x[i] + end + s += dot(temp, yj) + end + return s +end diff --git a/src/helpers/algebra/common.jl b/src/helpers/algebra/common.jl index d4e366566..cf0b93d1b 100644 --- a/src/helpers/algebra/common.jl +++ b/src/helpers/algebra/common.jl @@ -164,31 +164,6 @@ function v_a_vT(v1, a, v2) return result end -""" - xT_A_y(x, A, y) - -Computes `dot(x, A, y)`. The built-in Julia 3-arg `dot` is not compatible with the auto-differentiation packages, -such as `ForwardDiff`. We use our own implementation in some cases but ultimately fallback to the `dot`. -""" -xT_A_y(x, A, y) = dot(x, A, y) - -function xT_A_y(x::AbstractVector, A::AbstractMatrix, y::AbstractVector) - (axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch()) - T = typeof(dot(first(x), first(A), first(y))) - s = zero(T) - i₁ = first(eachindex(x)) - x₁ = first(x) - @inbounds for j in eachindex(y) - yj = y[j] - temp = zero(adjoint(A[i₁, j]) * x₁) - @simd for i in eachindex(x) - temp += adjoint(A[i, j]) * x[i] - end - s += dot(temp, yj) - end - return s -end - """ mvbeta(x) diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl index 1694ba58b..07a620fe5 100644 --- a/src/nodes/autoregressive.jl +++ b/src/nodes/autoregressive.jl @@ -51,7 +51,7 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici Vy1x = ar_slice(getvform(meta), Vyx, 1, (order + 1):(2order)) # Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 - AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2 + AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2 # correction if is_multivariate(meta) @@ -76,7 +76,7 @@ end my1, Vy1 = first(my), first(Vy) - AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx))) + AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx))) # correction if is_multivariate(meta) diff --git a/src/rules/dot_product/out.jl b/src/rules/dot_product/out.jl index 1b7ccf247..aa2f78d8e 100644 --- a/src/rules/dot_product/out.jl +++ b/src/rules/dot_product/out.jl @@ -6,5 +6,5 @@ end @rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin A = mean(m_in1) in2_mean, in2_cov = mean_cov(m_in2) - return NormalMeanVariance(dot(A, in2_mean), xT_A_y(A, in2_cov, A)) + return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A)) end diff --git a/src/rules/multiplication/A.jl b/src/rules/multiplication/A.jl index 5330aa45e..96181e752 100644 --- a/src/rules/multiplication/A.jl +++ b/src/rules/multiplication/A.jl @@ -18,7 +18,7 @@ end @rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin A = mean(m_in) ξ_out, W_out = weightedmean_precision(m_out) - W = correction!(meta, xT_A_y(A, W_out, A)) + W = correction!(meta, dot(A, W_out, A)) return NormalWeightedMeanPrecision(dot(A, ξ_out), W) end diff --git a/src/rules/multiplication/in.jl b/src/rules/multiplication/in.jl index 040e222b1..a03123a5e 100644 --- a/src/rules/multiplication/in.jl +++ b/src/rules/multiplication/in.jl @@ -18,7 +18,7 @@ end @rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin A = mean(m_A) ξ_out, W_out = weightedmean_precision(m_out) - W = correction!(meta, xT_A_y(A, W_out, A)) + W = correction!(meta, dot(A, W_out, A)) return NormalWeightedMeanPrecision(dot(A, ξ_out), W) end diff --git a/test/algebra/test_common.jl b/test/algebra/test_common.jl index 925aab234..98bd95669 100644 --- a/test/algebra/test_common.jl +++ b/test/algebra/test_common.jl @@ -84,18 +84,6 @@ using LinearAlgebra @test ReactiveMP.mul_trace(a, b) ≈ a * b end end - - @testset "xT_A_y" begin - import ReactiveMP: xT_A_y - - rng = MersenneTwister(1234) - for size in 2:5, T1 in (Float32, Float64), T2 in (Float32, Float64), T3 in (Float32, Float64) - x = rand(T1, size) - A = rand(T2, size, size) - y = rand(T3, size) - @test dot(x, A, y) ≈ xT_A_y(x, A, y) - end - end end end From b276fbb71cc42be1d0148c5c822f537542c5125c Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 27 Jan 2023 13:59:10 +0100 Subject: [PATCH 5/8] test: add test for ReactiveMP ForwardDiffGrad wrapper --- test/approximations/test_grad.jl | 53 ++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 test/approximations/test_grad.jl diff --git a/test/approximations/test_grad.jl b/test/approximations/test_grad.jl new file mode 100644 index 000000000..cd80aeac6 --- /dev/null +++ b/test/approximations/test_grad.jl @@ -0,0 +1,53 @@ +module ForwardDiffGradTest + +using Test +using ReactiveMP +using Random +using LinearAlgebra +using Distributions +using ForwardDiff + +import ReactiveMP: convert_eltype + +@testset "ForwardDiffGrad" begin + grad = ForwardDiffGrad() + + check_basic_statistics = + (left, right, dims) -> begin + for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims)) + @test all(isapprox.(ReactiveMP.compute_gradient(grad, (x) -> logpdf(left, x), value), ReactiveMP.compute_gradient(grad, (x) -> logpdf(left, x), value), atol = 1e-14)) + @test all(isapprox.(ReactiveMP.compute_hessian(grad, (x) -> logpdf(left, x), value), ReactiveMP.compute_hessian(grad, (x) -> logpdf(left, x), value), atol = 1e-14)) + end + end + + types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64}) + etypes = ReactiveMP.union_types(MultivariateNormalDistributionsFamily) + + dims = (2, 3, 5) + rng = MersenneTwister(1234) + + for dim in dims + for type in types + left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(rand(rng, Float64, dim)))) + check_basic_statistics(left, convert(MvNormal, left), dim) + for type in [types..., etypes...] + right = convert(type, left) + check_basic_statistics(left, right, dim) + + p1 = prod(ProdPreserveTypeLeft(), left, right) + @test typeof(p1) <: typeof(left) + + p2 = prod(ProdPreserveTypeRight(), left, right) + @test typeof(p2) <: typeof(right) + + p3 = prod(ProdAnalytical(), left, right) + + check_basic_statistics(p1, p2, dim) + check_basic_statistics(p2, p3, dim) + check_basic_statistics(p1, p3, dim) + end + end + end +end + +end From c6d8bc665ea4dfaad2ff53deddb769a796b8a129 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 27 Jan 2023 14:05:00 +0100 Subject: [PATCH 6/8] test: get rid of the copy-paste tests from normal (ForwardDiffGradTest) --- test/approximations/test_grad.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/approximations/test_grad.jl b/test/approximations/test_grad.jl index cd80aeac6..b6bf25f8d 100644 --- a/test/approximations/test_grad.jl +++ b/test/approximations/test_grad.jl @@ -35,11 +35,7 @@ import ReactiveMP: convert_eltype check_basic_statistics(left, right, dim) p1 = prod(ProdPreserveTypeLeft(), left, right) - @test typeof(p1) <: typeof(left) - p2 = prod(ProdPreserveTypeRight(), left, right) - @test typeof(p2) <: typeof(right) - p3 = prod(ProdAnalytical(), left, right) check_basic_statistics(p1, p2, dim) From f37aa0e2966ba951a2e4f544c283ddabcffdd8fd Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 27 Jan 2023 14:29:10 +0100 Subject: [PATCH 7/8] test: simplify ForwardDiffGradTest --- test/approximations/test_grad.jl | 35 ++++---------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/test/approximations/test_grad.jl b/test/approximations/test_grad.jl index b6bf25f8d..1c79aaa0a 100644 --- a/test/approximations/test_grad.jl +++ b/test/approximations/test_grad.jl @@ -12,38 +12,11 @@ import ReactiveMP: convert_eltype @testset "ForwardDiffGrad" begin grad = ForwardDiffGrad() - check_basic_statistics = - (left, right, dims) -> begin - for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims)) - @test all(isapprox.(ReactiveMP.compute_gradient(grad, (x) -> logpdf(left, x), value), ReactiveMP.compute_gradient(grad, (x) -> logpdf(left, x), value), atol = 1e-14)) - @test all(isapprox.(ReactiveMP.compute_hessian(grad, (x) -> logpdf(left, x), value), ReactiveMP.compute_hessian(grad, (x) -> logpdf(left, x), value), atol = 1e-14)) - end - end - - types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64}) - etypes = ReactiveMP.union_types(MultivariateNormalDistributionsFamily) - - dims = (2, 3, 5) - rng = MersenneTwister(1234) - - for dim in dims - for type in types - left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(rand(rng, Float64, dim)))) - check_basic_statistics(left, convert(MvNormal, left), dim) - for type in [types..., etypes...] - right = convert(type, left) - check_basic_statistics(left, right, dim) - - p1 = prod(ProdPreserveTypeLeft(), left, right) - p2 = prod(ProdPreserveTypeRight(), left, right) - p3 = prod(ProdAnalytical(), left, right) - - check_basic_statistics(p1, p2, dim) - check_basic_statistics(p2, p3, dim) - check_basic_statistics(p1, p3, dim) - end - end + for i in 1:100 + @test ReactiveMP.compute_gradient(grad, (x) -> sum(x)^2, [i]) ≈ [2*i] + @test ReactiveMP.compute_hessian(grad, (x) -> sum(x)^2, [i]) ≈ [2;;] end + end end From bc5b02bd96b6d8e421e3a4f91808c40df1c1cfec Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 27 Jan 2023 14:31:59 +0100 Subject: [PATCH 8/8] style: :art: --- test/approximations/test_grad.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/approximations/test_grad.jl b/test/approximations/test_grad.jl index 1c79aaa0a..93dfd9873 100644 --- a/test/approximations/test_grad.jl +++ b/test/approximations/test_grad.jl @@ -13,10 +13,9 @@ import ReactiveMP: convert_eltype grad = ForwardDiffGrad() for i in 1:100 - @test ReactiveMP.compute_gradient(grad, (x) -> sum(x)^2, [i]) ≈ [2*i] + @test ReactiveMP.compute_gradient(grad, (x) -> sum(x)^2, [i]) ≈ [2 * i] @test ReactiveMP.compute_hessian(grad, (x) -> sum(x)^2, [i]) ≈ [2;;] end - end end