Skip to content

Commit

Permalink
🤖 Format .jl files
Browse files Browse the repository at this point in the history
  • Loading branch information
dpo committed Oct 9, 2024
1 parent e33a43d commit a14c06b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/lbfgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ function LBFGSData(
inverse ? Vector{T}(undef, 0) : [zeros(T, n) for _ = 1:mem],
1,
Vector{T}(undef, n),
Array{T}(undef, (n, 2*mem)),
Vector{T}(undef, 2*mem),
Vector{T}(undef, n)
Array{T}(undef, (n, 2 * mem)),
Vector{T}(undef, 2 * mem),
Vector{T}(undef, n),
)
end

Expand Down
25 changes: 14 additions & 11 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export check_ctranspose, check_hermitian, check_positive_definite, normest, solve_shifted_system!, ldiv!
export check_ctranspose,
check_hermitian, check_positive_definite, normest, solve_shifted_system!, ldiv!
import LinearAlgebra.ldiv!

"""
Expand Down Expand Up @@ -147,8 +148,7 @@ end
check_positive_definite(M::AbstractMatrix; kwargs...) =
check_positive_definite(LinearOperator(M); kwargs...)


"""
"""
solve_shifted_system!(x, B, b, σ)
Solve linear system (B + σI) x = b, where B is a forward L-BFGS operator and σ ≥ 0.
Expand Down Expand Up @@ -209,14 +209,13 @@ function solve_shifted_system!(
B::LBFGSOperator{T, I, F1, F2, F3},
b::AbstractVector{T},
σ::T,
) where {T, I, F1, F2, F3}

) where {T, I, F1, F2, F3}
if σ < 0
throw(ArgumentError("σ must be nonnegative"))
end
data = B.data
insert = data.insert

γ_inv = 1 / data.scaling_factor
x_0 = 1 / (γ_inv + σ)
@. x = x_0 * b
Expand All @@ -234,20 +233,20 @@ function solve_shifted_system!(
sign_t = 1
for t = 1:(i - 1)
c0 = dot(view(data.shifted_p, :, t), data.shifted_u)
c1= sign_t .*data.shifted_v[t]
c1 = sign_t .* data.shifted_v[t]
c2 = c1 * c0
view(data.shifted_p, :, i) .+= c2 .* view(data.shifted_p, :, t)
sign_t = -sign_t
end

data.shifted_v[i] = 1 / (1 - sign_i * dot(data.shifted_u, view(data.shifted_p, :, i)))
x .+= sign_i *data.shifted_v[i] * (view(data.shifted_p, :, i)' * b) .* view(data.shifted_p, :, i)
x .+=
sign_i * data.shifted_v[i] * (view(data.shifted_p, :, i)' * b) .* view(data.shifted_p, :, i)
sign_i = -sign_i
end
return x
end


"""
ldiv!(x, B, b)
Expand Down Expand Up @@ -279,8 +278,12 @@ ldiv!(x, B, b)
# The vector `x` now contains the solution
"""

function ldiv!(x::AbstractVector{T}, B::LBFGSOperator{T, I, F1, F2, F3}, b::AbstractVector{T}) where {T, I, F1, F2, F3}
function ldiv!(
x::AbstractVector{T},
B::LBFGSOperator{T, I, F1, F2, F3},
b::AbstractVector{T},
) where {T, I, F1, F2, F3}
# Call solve_shifted_system! with σ = 0
solve_shifted_system!(x, B, b, T(0.0))
return x
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ include("test_deprecated.jl")
include("test_normest.jl")
include("test_diag.jl")
include("test_chainrules.jl")
include("test_solve_shifted_system.jl")
include("test_solve_shifted_system.jl")
6 changes: 3 additions & 3 deletions test/test_solve_shifted_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ function setup_test_val(; M = 5, n = 100, scaling = false, σ = 0.1)
x = randn(n)
b = B * x + σ .* x # so we know the true answer is x

return B, H , b, σ, zeros(n), x
return B, H, b, σ, zeros(n), x
end

function test_solve_shifted_system()
@testset "solve_shifted_system! Default setup test" begin
# Setup Test Case 1: Default setup from setup_test_val
B,_, b, σ, x_sol, x_true = setup_test_val(n = 100, M = 5)
B, _, b, σ, x_sol, x_true = setup_test_val(n = 100, M = 5)

result = solve_shifted_system!(x_sol, B, b, σ)

Expand All @@ -40,7 +40,7 @@ function test_solve_shifted_system()
end
@testset "solve_shifted_system! Negative σ test" begin
# Setup Test Case 2: Negative σ
B,_, b, _, x_sol, _ = setup_test_val(n = 100, M = 5)
B, _, b, _, x_sol, _ = setup_test_val(n = 100, M = 5)
σ = -0.1

# Expect an ArgumentError to be thrown
Expand Down

0 comments on commit a14c06b

Please sign in to comment.