Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Lazy caches support get_tmp #102

Merged
merged 4 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ lbc = GeneralLazyBufferCache(function (p)
end)
```

then `lbc[p]` will be smart and reuse the caches. A full example looks like the following:
then `lbc[p]` (or, equivalently, `get_tmp(lbc, p)`) will be smart and reuse the caches. A full example looks like the following:

```julia
using Random, DifferentialEquations, LinearAlgebra, Optimization, OptimizationNLopt,
OptimizationOptimJL, PreallocationTools
OptimizationOptimJL, PreallocationTools

lbc = GeneralLazyBufferCache(function (p)
DifferentialEquations.init(ODEProblem(ode_fnc, y₀, (0.0, T), p), Tsit5(); saveat = t)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ then `lbc[p]` will be smart and reuse the caches. A full example looks like the

```julia
using Random, DifferentialEquations, LinearAlgebra, Optimization, OptimizationNLopt,
OptimizationOptimJL, PreallocationTools
OptimizationOptimJL, PreallocationTools

lbc = GeneralLazyBufferCache(function (p)
DifferentialEquations.init(ODEProblem(ode_fnc, y₀, (0.0, T), p), Tsit5(); saveat = t)
Expand Down
9 changes: 6 additions & 3 deletions src/PreallocationTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,16 @@ function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N}
typeof(similar(x, ntuple(Returns(1), N)))
end

# override the [] method
function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray}
function get_tmp(b::LazyBufferCache, u::T) where {T <: AbstractArray}
s = b.sizemap(size(u)) # required buffer size
get!(b.bufs, (T, s)) do
similar(u, s) # buffer to allocate if it was not found in b.bufs
end::similar_type(u, s) # declare type since b.bufs dictionary is untyped
end

# override the [] method
Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} = get_tmp(b, u)

# GeneralLazyBufferCache

"""
Expand All @@ -246,11 +248,12 @@ struct GeneralLazyBufferCache{F <: Function}
GeneralLazyBufferCache(f::F = identity) where {F <: Function} = new{F}(Dict(), f) # start with empty dict
end

function Base.getindex(b::GeneralLazyBufferCache, u::T) where {T}
function get_tmp(b::GeneralLazyBufferCache, u::T) where {T}
get!(b.bufs, T) do
b.f(u)
end
end
Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)

export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
export get_tmp
Expand Down
28 changes: 18 additions & 10 deletions test/core_dispatch.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra,
Test, PreallocationTools, ForwardDiff, LabelledArrays,
RecursiveArrayTools
Test, PreallocationTools, ForwardDiff, LabelledArrays,
RecursiveArrayTools

function test(u0, dual, chunk_size)
cache = PreallocationTools.DiffCache(u0, chunk_size)
Expand Down Expand Up @@ -53,8 +53,10 @@ results = test(u0, dual, chunk_size)

chunk_size = 5
u0_B = ones(5, 5)
dual_B = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size}, 2, 2)
dual_B = zeros(
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size},
2, 2)
cache_B = FixedSizeDiffCache(u0_B, chunk_size)
tmp_du_BA = get_tmp(cache_B, u0_B)
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
Expand Down Expand Up @@ -102,9 +104,11 @@ results = test(u0, dual, chunk_size)
#ArrayPartition tests
chunk_size = 2
u0 = ArrayPartition(ones(2, 2), ones(3, 3))
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
dual_a = zeros(
ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
chunk_size}, 2, 2)
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
dual_b = zeros(
ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
chunk_size}, 3, 3)
dual = ArrayPartition(dual_a, dual_b)
results = test(u0, dual, chunk_size)
Expand All @@ -128,10 +132,14 @@ results = test(u0, dual, chunk_size)
@test eltype(results[7]) == eltype(dual)

u0_AP = ArrayPartition(ones(2, 2), ones(3, 3))
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size}, 2, 2)
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size}, 3, 3)
dual_a = zeros(
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size},
2, 2)
dual_b = zeros(
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size},
3, 3)
dual_AP = ArrayPartition(dual_a, dual_b)
cache_AP = FixedSizeDiffCache(u0_AP, chunk_size)
tmp_du_APA = get_tmp(cache_AP, u0_AP)
Expand Down
13 changes: 8 additions & 5 deletions test/core_nesteddual.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra,
OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, Optimization,
OptimizationOptimJL
OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, Optimization,
OptimizationOptimJL

randmat = rand(5, 3)
sto = similar(randmat)
Expand All @@ -23,7 +23,8 @@ end
In setting up the DiffCache, we are setting chunk_size to [1, 1], because we differentiate
only with respect to τ. This initializes the cache with the minimum memory needed. =#
stod = DiffCache(sto, [1, 1])
df3 = ForwardDiff.derivative(τ -> ForwardDiff.derivative(ξ -> claytonsample!(stod, ξ, 0.0),
df3 = ForwardDiff.derivative(
τ -> ForwardDiff.derivative(ξ -> claytonsample!(stod, ξ, 0.0),
τ), 0.3)

#= taking the second derivative of claytonsample! with respect to τ with auto-detected chunk-size.
Expand All @@ -32,7 +33,8 @@ than what's needed (1+1), the auto-allocated cache is big enough to handle the n
if we don't specify the keyword argument levels = 2. This should in general not be relied on to work,
especially if more levels of nesting occur (see optimization example below). =#
stod = DiffCache(sto)
df4 = ForwardDiff.derivative(τ -> ForwardDiff.derivative(ξ -> claytonsample!(stod, ξ, 0.0),
df4 = ForwardDiff.derivative(
τ -> ForwardDiff.derivative(ξ -> claytonsample!(stod, ξ, 0.0),
τ), 0.3)

@test df3 ≈ df4
Expand All @@ -41,7 +43,8 @@ df4 = ForwardDiff.derivative(τ -> ForwardDiff.derivative(ξ -> claytonsample!(s
For the given size of sto, ForwardDiff's heuristic chooses chunk_size = 8 and with keyword arg levels = 2,
the created cache size is larger than what's needed (even more so than the last example). =#
stod = DiffCache(sto, levels = 2)
df5 = ForwardDiff.derivative(τ -> ForwardDiff.derivative(ξ -> claytonsample!(stod, ξ, 0.0),
df5 = ForwardDiff.derivative(
τ -> ForwardDiff.derivative(ξ -> claytonsample!(stod, ξ, 0.0),
τ), 0.3)

@test df3 ≈ df5
Expand Down
4 changes: 2 additions & 2 deletions test/core_odes.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra,
OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
RecursiveArrayTools
OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
RecursiveArrayTools

#Base array
function foo(du, u, (A, tmp), t)
Expand Down
6 changes: 4 additions & 2 deletions test/general_lbc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Random,
OrdinaryDiffEq, LinearAlgebra, Optimization, OptimizationOptimJL,
PreallocationTools
OrdinaryDiffEq, LinearAlgebra, Optimization, OptimizationOptimJL,
PreallocationTools

lbc = GeneralLazyBufferCache(function (p)
init(ODEProblem(ode_fnc, y₀,
Expand Down Expand Up @@ -40,6 +40,7 @@ x = rand(1000)
y = view(x, 1:900)
@inferred cache[y]
@test 0 == @allocated cache[y]
@test cache[y] === get_tmp(cache, y)

cache_17 = LazyBufferCache(Returns(17))
x = 1:10
Expand All @@ -52,3 +53,4 @@ cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000))
# @inferred cache[Float64]
cache[Float64] # generate the buffer
@test 0 == @allocated cache[Float64]
@test get_tmp(cache, Float64) === cache[Float64]
14 changes: 9 additions & 5 deletions test/gpu_all.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
using LinearAlgebra,
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff

# upstream
OrdinaryDiffEq.DiffEqBase.anyeltypedual(x::FixedSizeDiffCache, counter = 0) = Any

#Dispatch tests
chunk_size = 5
u0_CU = cu(ones(5, 5))
dual_CU = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
chunk_size}, 2, 2))
dual_CU = cu(zeros(
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
chunk_size},
2, 2))
dual_N = ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32, 5}(0)
cache_CU = DiffCache(u0_CU, chunk_size)
tmp_du_CUA = get_tmp(cache_CU, u0_CU)
Expand All @@ -32,8 +34,10 @@ tmp_dual_du_CUN = get_tmp(cache_CU, dual_N)

chunk_size = 5
u0_B = cu(ones(5, 5))
dual_B = cu(zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
chunk_size}, 2, 2))
dual_B = cu(zeros(
ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float32}, Float32,
chunk_size},
2, 2))
cache_B = FixedSizeDiffCache(u0_B, chunk_size)
tmp_du_BA = get_tmp(cache_B, u0_B)
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
Expand Down
Loading