Skip to content

Commit

Permalink
solve_shifted_system! method for LBFGS solving step in recursive way (#…
Browse files Browse the repository at this point in the history
…338)

Co-authored-by: Dominique <[email protected]>
  • Loading branch information
farhadrclass and dpo authored Oct 8, 2024
1 parent 4c92b94 commit e33a43d
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Function | Description
`size` | Return the size of a linear operator
`symmetric` | Determine whether the operator is symmetric
`normest` | Estimate the 2-norm
`solve_shifted_system!` | Solves linear system $(B + \sigma I) x = b$, where $B$ is a forward L-BFGS operator and $\sigma \geq 0$.


## Other Operations on Operators
Expand Down
6 changes: 6 additions & 0 deletions src/lbfgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ mutable struct LBFGSData{T, I <: Integer}
b::Vector{Vector{T}}
insert::I
Ax::Vector{T}
shifted_p::Matrix{T} # Temporary matrix used in the computation solve_shifted_system!
shifted_v::Vector{T}
shifted_u::Vector{T}
end

function LBFGSData(
Expand Down Expand Up @@ -43,6 +46,9 @@ function LBFGSData(
inverse ? Vector{T}(undef, 0) : [zeros(T, n) for _ = 1:mem],
1,
Vector{T}(undef, n),
Array{T}(undef, (n, 2*mem)),
Vector{T}(undef, 2*mem),
Vector{T}(undef, n)
)
end

Expand Down
141 changes: 140 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export check_ctranspose, check_hermitian, check_positive_definite, normest
export check_ctranspose, check_hermitian, check_positive_definite, normest, solve_shifted_system!, ldiv!
import LinearAlgebra.ldiv!

"""
normest(S) estimates the matrix 2-norm of S.
Expand Down Expand Up @@ -145,3 +146,141 @@ end

check_positive_definite(M::AbstractMatrix; kwargs...) =
check_positive_definite(LinearOperator(M); kwargs...)


"""
solve_shifted_system!(x, B, b, σ)
Solve linear system (B + σI) x = b, where B is a forward L-BFGS operator and σ ≥ 0.
### Parameters
- `x::AbstractVector{T}`: preallocated vector of length n that is used to store the solution x.
- `B::LBFGSOperator`: forward L-BFGS operator that models a matrix of size n x n.
- `b::AbstractVector{T}`: right-hand side vector of length n.
- `σ::T`: nonnegative shift.
### Returns
- `x::AbstractVector{T}`: solution vector `x` of length n.
### Method
The method uses a two-loop recursion-like approach with modifications to handle the shift `σ`.
### Example
```julia
using Random
# Problem setup
n = 100 # size of the problem
mem = 10 # L-BFGS memory size
scaling = true # enable scaling
# Create an L-BFGS operator
B = LBFGSOperator(n, mem = mem, scaling = scaling)
# Add random {s, y} pairs to the L-BFGS operator
for _ = 1:10
s = rand(n)
y = rand(n)
push!(B, s, y) # Add the {s, y} pair to B
end
# Prepare vectors for the system
x = zeros(n) # Preallocated solution vector
b = rand(n) # Right-hand side vector
σ = 0.1 # Small shift value
# Solve the shifted system
result = solve_shifted_system!(x, B, b, σ)
# Check that the solution is close enough (residual test)
@assert norm(B * x + σ * x - b) / norm(b) < 1e-8
```
### References
Erway, J. B., Jain, V., & Marcia, R. F. Shifted L-BFGS Systems. Optimization Methods and Software, 29(5), pp. 992-1004, 2014.
"""
function solve_shifted_system!(
x::AbstractVector{T},
B::LBFGSOperator{T, I, F1, F2, F3},
b::AbstractVector{T},
σ::T,
) where {T, I, F1, F2, F3}

if σ < 0
throw(ArgumentError("σ must be nonnegative"))
end
data = B.data
insert = data.insert

γ_inv = 1 / data.scaling_factor
x_0 = 1 / (γ_inv + σ)
@. x = x_0 * b

max_i = 2 * data.mem
sign_i = 1

for i = 1:max_i
j = (i + 1) ÷ 2
k = mod(insert + j - 1, data.mem) + 1
data.shifted_u .= ((sign_i == -1) ? data.b[k] : data.a[k])

@. data.shifted_p[:, i] = x_0 * data.shifted_u

sign_t = 1
for t = 1:(i - 1)
c0 = dot(view(data.shifted_p, :, t), data.shifted_u)
c1= sign_t .*data.shifted_v[t]
c2 = c1 * c0
view(data.shifted_p, :, i) .+= c2 .* view(data.shifted_p, :, t)
sign_t = -sign_t
end

data.shifted_v[i] = 1 / (1 - sign_i * dot(data.shifted_u, view(data.shifted_p, :, i)))
x .+= sign_i *data.shifted_v[i] * (view(data.shifted_p, :, i)' * b) .* view(data.shifted_p, :, i)
sign_i = -sign_i
end
return x
end


"""
ldiv!(x, B, b)
Solves the linear system Bx = b.
### Arguments:
- `x::AbstractVector{T}`: preallocated vector of length n that is used to store the solution x.
- `B::LBFGSOperator`: forward L-BFGS operator that models a matrix of size n x n.
- `b::AbstractVector{T}`: right-hand side vector of length n.
### Returns:
- `x::AbstractVector{T}`: The modified solution vector containing the solution to the linear system.
### Examples:
```julia
# Create an L-BFGS operator
B = LBFGSOperator(10)
# Generate random vectors
x = rand(10)
b = rand(10)
# Solve the linear system
ldiv!(x, B, b)
# The vector `x` now contains the solution
"""

function ldiv!(x::AbstractVector{T}, B::LBFGSOperator{T, I, F1, F2, F3}, b::AbstractVector{T}) where {T, I, F1, F2, F3}
# Call solve_shifted_system! with σ = 0
solve_shifted_system!(x, B, b, T(0.0))
return x
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ include("test_deprecated.jl")
include("test_normest.jl")
include("test_diag.jl")
include("test_chainrules.jl")
include("test_solve_shifted_system.jl")
64 changes: 64 additions & 0 deletions test/test_solve_shifted_system.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using Test
using LinearOperators
using LinearAlgebra

function setup_test_val(; M = 5, n = 100, scaling = false, σ = 0.1)
B = LBFGSOperator(n, mem = M, scaling = scaling)
H = InverseLBFGSOperator(n, mem = M, scaling = false)

for _ = 1:10
s = rand(n)
y = rand(n)
push!(B, s, y)
push!(H, s, y)
end

x = randn(n)
b = B * x + σ .* x # so we know the true answer is x

return B, H , b, σ, zeros(n), x
end

function test_solve_shifted_system()
@testset "solve_shifted_system! Default setup test" begin
# Setup Test Case 1: Default setup from setup_test_val
B,_, b, σ, x_sol, x_true = setup_test_val(n = 100, M = 5)

result = solve_shifted_system!(x_sol, B, b, σ)

# Test 1: Check if result is a vector of the same size as z
@test length(result) == length(b)

# Test 2: Verify that x_sol (result) is modified in place
@test result === x_sol

# Test 3: Check if the function produces finite values
@test all(isfinite, result)

# Test 4: Check if x_sol is close to the known solution x
@test isapprox(x_sol, x_true, atol = 1e-6, rtol = 1e-6)
end
@testset "solve_shifted_system! Negative σ test" begin
# Setup Test Case 2: Negative σ
B,_, b, _, x_sol, _ = setup_test_val(n = 100, M = 5)
σ = -0.1

# Expect an ArgumentError to be thrown
@test_throws ArgumentError solve_shifted_system!(x_sol, B, b, σ)
end

@testset "ldiv! test" begin
# Setup Test Case 1: Default setup from setup_test_val
B, H, b, _, x_sol, x_true = setup_test_val(n = 100, M = 5, σ = 0.0)

# Solve the system using solve_shifted_system!
result = ldiv!(x_sol, B, b)

# Check consistency with operator-vector product using H
x_H = H * b
@test isapprox(x_sol, x_H, atol = 1e-6, rtol = 1e-6)
@test isapprox(x_sol, x_true, atol = 1e-6, rtol = 1e-6)
end
end

test_solve_shifted_system()

0 comments on commit e33a43d

Please sign in to comment.