Skip to content

Commit

Permalink
Eliminate useless allocation for diagonal quasi-Newton operators (#337)
Browse files Browse the repository at this point in the history
Co-authored-by: Dominique <[email protected]>
  • Loading branch information
MohamedLaghdafHABIBOULLAH and dpo authored Oct 10, 2024
1 parent 28fd5ae commit ab0de4e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
42 changes: 20 additions & 22 deletions src/DiagonalHessianApproximation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,22 @@ end
# y = ∇f(x_{k+1}) - ∇f(x_k)
function push!(
B::DiagonalPSB{T, I, V, F},
s0::V,
y0::V,
s::V,
y::V,
) where {T <: Real, I <: Integer, V <: AbstractVector{T}, F}
s0Norm = norm(s0, 2)
if s0Norm == 0
sNorm = norm(s, 2)
if sNorm == 0
error("Cannot update DiagonalQN operator with s=0")
end
# sᵀBs = sᵀy can be scaled by ||s||² without changing the update
s = (si / s0Norm for si s0)
s2 = (si^2 for si s)
y = (yi / s0Norm for yi y0)
trA2 = dot(s2, s2)
sT_y = dot(s, y)
sT_B_s = dot(s2, B.d)
sNorm2 = sNorm^2
trA2 = dot(s2, s2) / sNorm2^2
sT_y = dot(s, y) / sNorm2
sT_B_s = dot(s2, B.d) / sNorm2
q = sT_y - sT_B_s
q /= trA2
B.d .+= q .* s .^ 2
B.d .+= q / sNorm2 .* s .^ 2
return B
end

Expand Down Expand Up @@ -126,25 +125,24 @@ end
# y = ∇f(x_{k+1}) - ∇f(x_k)
function push!(
B::DiagonalAndrei{T, I, V, F},
s0::V,
y0::V,
s::V,
y::V,
) where {T <: Real, I <: Integer, V <: AbstractVector{T}, F}
s0Norm = norm(s0, 2)
if s0Norm == 0
sNorm = norm(s, 2)
if sNorm == 0
error("Cannot update DiagonalQN operator with s=0")
end
# sᵀBs = sᵀy can be scaled by ||s||² without changing the update
s = (si / s0Norm for si s0)
s2 = (si^2 for si s)
y = (yi / s0Norm for yi y0)
trA2 = dot(s2, s2)
sT_y = dot(s, y)
sT_B_s = dot(s2, B.d)
sNorm2 = sNorm^2
trA2 = dot(s2, s2) / sNorm2^2
sT_y = dot(s, y) / sNorm2
sT_B_s = dot(s2, B.d) / sNorm2
q = sT_y - sT_B_s
sT_s = dot(s, s)
sT_s = dot(s, s) / sNorm2
q += sT_s
q /= trA2
B.d .+= q .* s .^ 2 .- 1
B.d .+= q / sNorm2 .* s .^ 2 .- 1
return B
end

Expand Down Expand Up @@ -199,7 +197,7 @@ function push!(
s::V,
y::V,
) where {T <: Real, I <: Integer, F, V <: AbstractVector{T}}
if all(s .== 0)
if all(x -> x == 0, s)
error("Cannot divide by zero and s .= 0")
end
B.d[1] = dot(s, y) / dot(s, s)
Expand Down
40 changes: 40 additions & 0 deletions test/test_diag.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
"""
@wrappedallocs(expr)
Given an expression, this macro wraps that expression inside a new function
which will evaluate that expression and measure the amount of memory allocated
by the expression. Wrapping the expression in a new function allows for more
accurate memory allocation detection when using global variables (e.g. when
at the REPL).
This code is based on that of https://github.com/JuliaAlgebra/TypedPolynomials.jl/blob/master/test/runtests.jl
For example, `@wrappedallocs(x + y)` produces:
```julia
function g(x1, x2)
@allocated x1 + x2
end
g(x, y)
```
You can use this macro in a unit test to verify that a function does not
allocate:
```
@test @wrappedallocs(x + y) == 0
```
"""
macro wrappedallocs(expr)
argnames = [gensym() for a in expr.args]
quote
function g($(argnames...))
@allocated $(Expr(expr.head, argnames...))
end
$(Expr(:call, :g, [esc(a) for a in expr.args]...))
end
end

# Points
x0 = [-1.0, 1.0, -1.0]
x1 = x0 + [1.0, 0.0, 1.0]
Expand Down Expand Up @@ -74,12 +111,15 @@ end
u = similar(v)
mul!(u, A, v)
@test (@allocated mul!(u, A, v)) == 0
@test (@wrappedallocs push!(A, u, v)) == 0
B = DiagonalPSB(d)
mul!(u, B, v)
@test (@allocated mul!(u, B, v)) == 0
@test (@wrappedallocs push!(B, u, v)) == 0
C = SpectralGradient(rand(), 5)
mul!(u, C, v)
@test (@allocated mul!(u, C, v)) == 0
@test (@wrappedallocs push!(C, u, v)) == 0
end

@testset "reset" begin
Expand Down

0 comments on commit ab0de4e

Please sign in to comment.