Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimize usage of get_extension inside our extensions #288

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pigeons"
uuid = "0eb8d820-af6a-4919-95ae-11206f830c31"
authors = ["Alexandre Bouchard-Côté <[email protected]>, Nikola Surjanovic <[email protected]>, Paul Tiede <[email protected]>, Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"]
version = "0.4.5"
version = "0.4.6"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
28 changes: 17 additions & 11 deletions ext/PigeonsEnzymeExt/PigeonsEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,32 @@ else
using ..LogDensityProblemsAD
end

# TODO: currently, the concrete versions of ADGradientWrapper are defined only
# in the extensions of LogDensityProblemsAD. Therefore, it is impossible to
# dispatch on them; see
# https://github.com/tpapp/LogDensityProblemsAD.jl/issues/32
# This is a HACK to extract that type
const EnzymeGradientLogDensity = if isdefined(Base, :get_extension)
Base.get_extension(LogDensityProblemsAD, :LogDensityProblemsADEnzymeExt).EnzymeGradientLogDensity
else
LogDensityProblemsAD.LogDensityProblemsADEnzymeExt.EnzymeGradientLogDensity
# A simpler version of the wrapper defined in LogDensityProblemsAD's extension
struct EnzymeWrapper{TLP} <: Pigeons.ADWrapper
log_potential::TLP
end

# special ADgradient constructor for Enzyme
function LogDensityProblemsAD.ADgradient(
kind::Val{:Enzyme},
log_potential,
buffers::Pigeons.Augmentation
)
d = LogDensityProblems.dimension(log_potential)
buffer = Pigeons.get_buffer(buffers, :gradient_buffer, d)
enclosed = EnzymeWrapper(log_potential)
Pigeons.BufferedAD(enclosed, buffer, nothing, nothing)
end

# adapted from LogDensityProblemsAD to use the Replica's buffer
function LogDensityProblems.logdensity_and_gradient(
b::Pigeons.BufferedAD{<:EnzymeGradientLogDensity},
b::Pigeons.BufferedAD{<:EnzymeWrapper},
x::AbstractVector
)
∂ℓ_∂x = fill!(b.buffer, zero(eltype(b.buffer))) # NB: Enzyme gives erroneous answer if buffer is not zeroed first
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, LogDensityProblems.logdensity, Enzyme.Active,
Enzyme.Const(b.enclosed.), Enzyme.Duplicated(x, ∂ℓ_∂x)
Enzyme.Const(b.enclosed.log_potential), Enzyme.Duplicated(x, ∂ℓ_∂x)
)
y, ∂ℓ_∂x
end
Expand Down
26 changes: 11 additions & 15 deletions ext/PigeonsForwardDiffExt/PigeonsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,10 @@ else
import ..ForwardDiff: DiffResults
end

# TODO: currently, the concrete versions of ADGradientWrapper are defined only
# in the extensions of LogDensityProblemsAD. Therefore, it is impossible to
# dispatch on them; see
# https://github.com/tpapp/LogDensityProblemsAD.jl/issues/32
# This is a HACK to extract that type
const ForwardDiffLogDensity = if isdefined(Base, :get_extension)
Base.get_extension(LogDensityProblemsAD, :LogDensityProblemsADForwardDiffExt).ForwardDiffLogDensity
else
LogDensityProblemsAD.LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity
# A simpler version of the wrapper defined in LogDensityProblemsAD's extension
struct ForwardDiffWrapper{TLP, TGC <: ForwardDiff.GradientConfig} <: Pigeons.ADWrapper
log_potential::TLP
gradient_config::TGC
end

# special ADgradient constructor for ForwardDiff
Expand All @@ -33,21 +28,22 @@ function LogDensityProblemsAD.ADgradient(
buffers::Pigeons.Augmentation
)
d = LogDensityProblems.dimension(log_potential)
buffer = Pigeons.get_buffer(buffers, :gradient_buffer, d)
enclosed = ADgradient(kind, log_potential; x = buffer)
buffer = Pigeons.get_buffer(buffers, :gradient_buffer, d)
lp_fix = Base.Fix1(LogDensityProblems.logdensity, log_potential)
gradient_config = ForwardDiff.GradientConfig(lp_fix, buffer, ForwardDiff.Chunk(d))
enclosed = ForwardDiffWrapper(log_potential, gradient_config)
diff_result = DiffResults.MutableDiffResult(zero(eltype(buffer)), (buffer, ))
Pigeons.BufferedAD(enclosed, diff_result, nothing, nothing)
end

# adapted from LogDensityProblemsAD to use the Replica's buffer
function LogDensityProblems.logdensity_and_gradient(
b::Pigeons.BufferedAD{<:ForwardDiffLogDensity},
b::Pigeons.BufferedAD{<:ForwardDiffWrapper},
x::AbstractVector
)
diff_result = b.buffer
ℓ_fix = Base.Fix1(LogDensityProblems.logdensity, b.enclosed.ℓ)
ForwardDiff.gradient!(diff_result, ℓ_fix, x, b.enclosed.gradient_config)

lp_fix = Base.Fix1(LogDensityProblems.logdensity, b.enclosed.log_potential)
ForwardDiff.gradient!(diff_result, lp_fix, x, b.enclosed.gradient_config)
return (DiffResults.value(diff_result), DiffResults.gradient(diff_result))
end

Expand Down
43 changes: 22 additions & 21 deletions ext/PigeonsReverseDiffExt/PigeonsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,23 @@ else
import ..ReverseDiff: DiffResults
end

# TODO: currently, the concrete versions of ADGradientWrapper are defined only
# in the extensions of LogDensityProblemsAD. Therefore, it is impossible to
# dispatch on them; see
# https://github.com/tpapp/LogDensityProblemsAD.jl/issues/32
# This is a HACK to extract that type
const ReverseDiffLogDensity = if isdefined(Base, :get_extension)
Base.get_extension(LogDensityProblemsAD, :LogDensityProblemsADReverseDiffExt).ReverseDiffLogDensity
else
LogDensityProblemsAD.LogDensityProblemsADReverseDiffExt.ReverseDiffLogDensity
# A simpler version of the wrapper defined in LogDensityProblemsAD's extension
struct ReverseDiffWrapper{TLP, TCT} <: Pigeons.ADWrapper
log_potential::TLP
compiled_tape::TCT
end
function make_compiled_tape(log_potential, x)
lp_fix = Base.Fix1(LogDensityProblems.logdensity, log_potential)
tape = ReverseDiff.GradientTape(lp_fix, x)
return ReverseDiff.compile(tape)
end
compute_gradient!(rdw::ReverseDiffWrapper, diff_result, x) =
ReverseDiff.gradient!(
diff_result,
Base.Fix1(LogDensityProblems.logdensity, rdw.log_potential), x
)
compute_gradient!(rdw::ReverseDiffWrapper{<:Any,<:ReverseDiff.CompiledTape}, diff_result, x) =
ReverseDiff.gradient!(diff_result, rdw.compiled_tape, x)

# special ADgradient constructor for ReverseDiff
function LogDensityProblemsAD.ADgradient(
Expand All @@ -34,9 +41,8 @@ function LogDensityProblemsAD.ADgradient(
)
d = LogDensityProblems.dimension(log_potential)
buffer = Pigeons.get_buffer(buffers, :gradient_buffer, d)
compile_tape = Pigeons.get_tape_compilation_strategy()

if compile_tape
should_compile = Pigeons.get_tape_compilation_strategy()
if should_compile
@info """

Using ReverseDiff with tape compilation, which usually results in huge performance gains.
Expand All @@ -51,24 +57,19 @@ function LogDensityProblemsAD.ADgradient(
by calling `Pigeons.set_tape_compilation_strategy!(true)`.
""" maxlog=1
end

enclosed = ADgradient(kind, log_potential; x = buffer, compile=Val{compile_tape}())
compiled_tape = should_compile ? make_compiled_tape(log_potential, buffer) : nothing
enclosed = ReverseDiffWrapper(log_potential, compiled_tape)
diff_result = DiffResults.MutableDiffResult(zero(eltype(buffer)), (buffer, ))
Pigeons.BufferedAD(enclosed, diff_result, nothing, nothing)
end

# adapted from LogDensityProblemsAD to use the Replica's buffer
function LogDensityProblems.logdensity_and_gradient(
b::Pigeons.BufferedAD{<:ReverseDiffLogDensity},
b::Pigeons.BufferedAD{<:ReverseDiffWrapper},
x::AbstractVector
)
diff_result = b.buffer
compiled_tape = b.enclosed.compiledtape
if compiled_tape === nothing
ReverseDiff.gradient!(diff_result, Base.Fix1(LogDensityProblems.logdensity, b.enclosed.ℓ), x)
else
ReverseDiff.gradient!(diff_result, compiled_tape, x)
end
compute_gradient!(b.enclosed, diff_result, x)
return (DiffResults.value(diff_result), DiffResults.gradient(diff_result))
end

Expand Down
7 changes: 7 additions & 0 deletions src/explorers/BufferedAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,10 @@ Currently this is only used for [ReverseDiff](https://github.com/JuliaDiff/Rever
function set_tape_compilation_strategy!(compile::Bool)
COMPILE_TAPE[] = compile
end

# used in the AD extensions
abstract type ADWrapper end
LogDensityProblems.logdensity(adw::ADWrapper, x::AbstractVector) =
LogDensityProblems.logdensity(adw.log_potential, x)
LogDensityProblems.dimension(adw::ADWrapper) =
LogDensityProblems.dimension(adw.log_potential)
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ArgMacros = "dbc42088-9de8-42a0-8ec8-2cd114e1ea3e"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
#Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
2 changes: 0 additions & 2 deletions test/supporting/HetPrecisionNormalLogPotential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,3 @@ end

LogDensityProblems.logdensity(log_potential::HetPrecisionNormalLogPotential, x) = log_potential(x)
LogDensityProblems.dimension(log_potential::HetPrecisionNormalLogPotential) = length(log_potential.precisions)
LogDensityProblemsAD.ADgradient(kind::Symbol, log_potential::HetPrecisionNormalLogPotential, buffers::Pigeons.Augmentation) =
LogDensityProblemsAD.ADgradient(kind, log_potential)
2 changes: 1 addition & 1 deletion test/supporting/setup.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# use single statement to avoid multiple precompile stages
using Pigeons,
ArgMacros,
Bijectors,
#Bijectors,
BridgeStan,
DelimitedFiles,
Distributions,
Expand Down
4 changes: 2 additions & 2 deletions test/test_DistributionLogPotential.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Bijectors
# using Bijectors
using DelimitedFiles

include("supporting/mpi_test_utils.jl")
Expand Down Expand Up @@ -30,7 +30,7 @@ end
@test abs(Pigeons.global_barrier(pt) - 3.15) < 0.1
end

@static if !is_windows_in_CI()
@static if false#!is_windows_in_CI()
@testset "DLP: Stan interface + iid sampling" begin
# recreate Fig. 6 in Ballnus et al. (2017)

Expand Down
Loading