Skip to content

Commit

Permalink
Merge pull request #116 from ErikQQY/qqy/sct
Browse files Browse the repository at this point in the history
Add extension for SparseConnectivityTracer
  • Loading branch information
ChrisRackauckas authored Feb 13, 2025
2 parents 9b1b2d7 + d494f72 commit b19ffd6
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
group:
- Core
version:
- '1'
- '1.10'
os:
- ubuntu-latest
- macos-latest
Expand Down
42 changes: 24 additions & 18 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PreallocationTools"
uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
authors = ["Chris Rackauckas <[email protected]>"]
version = "0.4.24"
version = "0.4.25"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,31 +10,36 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[weakdeps]
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"

[extensions]
PreallocationToolsReverseDiffExt = "ReverseDiff"
PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer"

[compat]
Adapt = "3.4, 4"
Aqua = "0.8"
ArrayInterface = "7.7"
ForwardDiff = "0.10.19"
LabelledArrays = "1.15"
LinearAlgebra = "1"
Optimization = "3.19"
OptimizationOptimJL = "0.1.5"
OrdinaryDiffEq = "6.65"
Pkg = "1"
Random = "1"
RecursiveArrayTools = "3.2"
ReverseDiff = "1"
Adapt = "4.1.1"
ADTypes = "1.13"
Aqua = "0.8.11"
ArrayInterface = "7.18.0"
ForwardDiff = "0.10.38"
LabelledArrays = "1.16.0"
LinearAlgebra = "1.10"
Optimization = "4.1.1"
OptimizationOptimJL = "0.4.1"
OrdinaryDiffEq = "6.91.0"
Pkg = "1.10"
Random = "1.10.8"
RecursiveArrayTools = "3.29.0"
ReverseDiff = "1.15.3"
SafeTestsets = "0.1"
SparseArrays = "1"
Symbolics = "5.12"
Test = "1"
SparseArrays = "1.10"
SparseConnectivityTracer = "0.6.12"
Symbolics = "6.29.0"
Test = "1.10"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -47,8 +52,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics"]
test = ["Aqua", "ADTypes", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
46 changes: 46 additions & 0 deletions ext/PreallocationToolsSparseConnectivityTracerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module PreallocationToolsSparseConnectivityTracerExt

using PreallocationTools
isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) :
(import ..SparseConnectivityTracer)

function PreallocationTools.get_tmp(
dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual}
if isbitstype(T)
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
PreallocationTools.enlargediffcache!(dc, nelem)
end
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
else
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
end
end

function PreallocationTools.get_tmp(
dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual}
if isbitstype(T)
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
PreallocationTools.enlargediffcache!(dc, nelem)
end
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
else
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
end
end

function PreallocationTools.get_tmp(
dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual}
if isbitstype(T)
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
PreallocationTools.enlargediffcache!(dc, nelem)
end
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
else
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
end
end

end
6 changes: 4 additions & 2 deletions test/core_nesteddual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ newtonsol = solve(optprob, Newton())
cache = DiffCache(zeros(ps, ps), [4, 4, 2])
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(ps, ps), (0.0, 1.0),
(coeffs, cache))
realsol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
realsol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = 2)),
saveat = 0.0:0.1:10.0, reltol = 1e-8)

function objfun(x, prob, realsol, cache)
prob = remake(prob, u0 = eltype(x).(prob.u0), p = (x, cache))
sol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
sol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = 2)),
saveat = 0.0:0.1:10.0, reltol = 1e-8)

ofv = 0.0
if any((s.retcode != ReturnCode.Success for s in sol))
Expand Down
10 changes: 5 additions & 5 deletions test/core_odes.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra,
OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
RecursiveArrayTools
RecursiveArrayTools, ADTypes

#Base array
function foo(du, u, (A, tmp), t)
Expand All @@ -10,17 +10,17 @@ function foo(du, u, (A, tmp), t)
nothing
end
#with defined chunk_size
chunk_size = 5
chunk_size = 9
u0 = ones(5, 5)
A = ones(5, 5)
cache = DiffCache(zeros(5, 5), chunk_size)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
sol = solve(prob, Rodas5P(autodiff = AutoForwardDiff(chunksize = chunk_size)))
@test sol.retcode == ReturnCode.Success

cache = FixedSizeDiffCache(zeros(5, 5), chunk_size)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
sol = solve(prob, Rodas5P(autodiff = AutoForwardDiff(chunksize = chunk_size)))
@test sol.retcode == ReturnCode.Success

#with auto-detected chunk_size
Expand Down Expand Up @@ -60,7 +60,7 @@ end
chunk_size = 4
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0),
(A, DiffCache(c, chunk_size)))
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
sol = solve(prob, Rodas5P(autodiff = AutoForwardDiff(chunksize = chunk_size)))
@test sol.retcode == ReturnCode.Success
#with auto-detected chunk_size
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, DiffCache(c)))
Expand Down
8 changes: 4 additions & 4 deletions test/gpu_all.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearAlgebra,
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff, ADTypes

# upstream
OrdinaryDiffEq.DiffEqBase.anyeltypedual(x::FixedSizeDiffCache, counter = 0) = Any
Expand Down Expand Up @@ -56,17 +56,17 @@ function foo(du, u, (A, tmp), t)
nothing
end
#with specified chunk_size
chunk_size = 10
chunk_size = 9
u0 = cu(rand(10, 10)) #example kept small for test purposes.
A = cu(-randn(10, 10))
cache = DiffCache(cu(zeros(10, 10)), chunk_size)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0f0, 1.0f0), (A, cache))
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
sol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = chunk_size)))
@test sol.retcode == ReturnCode.Success

cache = FixedSizeDiffCache(cu(zeros(10, 10)), chunk_size)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0f0, 1.0f0), (A, cache))
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
sol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = chunk_size)))
@test sol.retcode == ReturnCode.Success

#with auto-detected chunk_size
Expand Down

0 comments on commit b19ffd6

Please sign in to comment.