Skip to content

Commit

Permalink
Swap Optimization.jl and OptimizationOptim.jl for only Optim.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
aris-mav authored Jan 6, 2025
2 parents 352b798 + 15e3304 commit 78f0f96
Show file tree
Hide file tree
Showing 10 changed files with 367 additions and 196 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
name = "NMRInversions"
uuid = "55c20db2-0166-4687-95c3-62a9c7afb29b"
authors = ["Aristarchos Mavridis <[email protected]>"]
version = "0.9.2"
version = "0.9.3"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NativeFileDialog = "e1fe445b-aa65-4df4-81c1-2041507f0fd4"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolygonOps = "647866c9-e3ac-4575-94e7-e3d426903924"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -29,8 +28,7 @@ DelimitedFiles = "1"
GLMakie = "0.10"
JuMP = "1"
NativeFileDialog = "0.2"
Optimization = "3, 4"
OptimizationOptimJL = "0.3, 0.4"
Optim = "1.10.0"
PolygonOps = "0.1"
QuadraticModels = "0.9"
RipQP = "0.6"
Expand Down
16 changes: 16 additions & 0 deletions docs/src/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ invert(::Type{<:pulse_sequence2D}, ::AbstractVector, ::AbstractVector, ::Abstrac
invert(::input2D)
```

# Finding alpha
Here we provide two options for finding the optimal value for alpha,
namely Generalized Cross Validation (GCV) or L-curve.
Generally gcv seems slightly more reliable in NMR, but it's far from
perfect, so it's good to have alternatives and cross-check.
The following methods can be used as inputs for the `alpha` argument in the
`invert` function:

```@docs
gcv()
gcv(start; kwargs...)
gcv(lower, upper ; kwargs...)
lcurve(lowest_value, highest_value, number_of_steps)
lcurve(start; kwargs...)
lcurve(lower, upper ; kwargs...)
```

# Exponential fit functions

Expand Down
4 changes: 1 addition & 3 deletions src/NMRInversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ using LinearAlgebra
using SparseArrays
using NativeFileDialog
using PolygonOps
import Optimization, OptimizationOptimJL
using Optim

"""
to do list:
- add gcv for reci method
- differentiate between Mitchell GCV and optimization GCV
- introduce faf, flip angle fraction, to the kernel functions. 1 would be a perfect pulse, 0 would be no pulse.
- add precompilation
Expand Down
14 changes: 7 additions & 7 deletions src/exp_fits.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

export mexp
"""
mexp(seq, u, x)
Expand Down Expand Up @@ -61,7 +60,7 @@ Arguments:
- `y` : acquisition y parameter (magnetization).
Optional arguments:
- `solver` : OptimizationOptimJL solver, defeault choice is BFGS().
- `solver` : Optim solver, defeault choice is IPNewton().
- `normalize` : Normalize the data before fitting? (default is true).
- `L` : An integer specifying which norm of the residuals you want to minimize (default is 2).
Expand Down Expand Up @@ -91,7 +90,7 @@ function expfit(
seq::Type{<:NMRInversions.pulse_sequence1D},
x::Vector,
y::Vector;
solver=OptimizationOptimJL.BFGS(),
solver= IPNewton(),
normalize::Bool=true,
L::Int = 2
)
Expand All @@ -116,10 +115,11 @@ function expfit(

end

# Solve the optimization
optf = Optimization.OptimizationFunction(mexp_loss, Optimization.AutoForwardDiff())
prob = Optimization.OptimizationProblem(optf, u0, (x, y, seq, L), lb=zeros(length(u0)), ub=Inf .* ones(length(u0)))
u = OptimizationOptimJL.solve(prob, solver, maxiters=5000, maxtime=100)
u = optimize(
u -> mexp_loss(u, (x, y, seq, L)),
zeros(length(u0)), Inf .* ones(length(u0)), u0,
solver
).minimizer

# Determine what's the x-axis of the seq (time or bfactor)
seq == NMRInversions.PFG ? x_ax = "b" : x_ax = "t"
Expand Down
145 changes: 115 additions & 30 deletions src/finding_alpha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,61 @@ function gcv_score(α, r, s, x; next_alpha=true)
end
end

function gcv_cost::Real,
svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}})

#=display("Testing α = $(round(α,sigdigits=3))")=#
f, r = solve_regularization(svds.K, svds.g, α, solver)
return gcv_score(α, r, svds.S, (svds.V' * f), next_alpha = false)
end

"""
Compute the curvature of the L-curve at a given point.
(Hansen 2010 page 92-93)
- `f` : solution vector
- `r` : residuals
- `α` : smoothing term
- `A` : Augmented kernel matrix (`K` and `αI` stacked vertically)
- `b` : Augmented residuals (`r` and `0` stacked vertically)
"""
function l_curvature(f, r, α, A, b)

ξ = f'f
ρ = r'r
λ = α

z = NMRInversions.solve_ls(A, b)

∂ξ∂λ = (4 / λ) * f'z

= 2 ** ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) / ((α * ξ^2 + ρ^2)^(3 / 2))

return

end

function l_cost(K, g, α, solver)

f, r = NMRInversions.solve_regularization(K, g, α, solver)

A = sparse([K; (α) * LinearAlgebra.I ])
b = sparse([r; zeros(size(A, 1) - size(r, 1))])

return l_curvature(f, r, α, A, b)

end


"""
Solve repeatedly until the GCV score stops decreasing.
Solve repeatedly until the GCV score stops decreasing, following Mitchell 2012 paper.
Select the solution with minimum gcv score and return it, along with the residuals.
"""
function solve_gcv(svds::svd_kernel_struct, solver::Union{regularization_solver, Type{<:regularization_solver}})
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::gcv_mitchell)

= svds.S
= length(s̃)
Expand Down Expand Up @@ -82,29 +131,68 @@ end


"""
Compute the curvature of the L-curve at a given point.
(Hansen 2010 page 92-93)
Find alpha via univariate optimization.
"""
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::find_alpha_univariate
)

- `f` : solution vector
- `r` : residuals
- `α` : smoothing term
- `A` : Augmented kernel matrix (`K` and `αI` stacked vertically)
- `b` : Augmented residuals (`r` and `0` stacked vertically)
local f
if mode.search_method == :gcv
f = x -> gcv_cost(x, svds, solver)

elseif mode.search_method == :lcurve
f = x -> l_cost(svds.K, svds.g, x, solver)

end

sol = optimize(
f,
mode.lower, mode.upper,
mode.algorithm,
abs_tol = mode.abs_tol
)

α = sol.minimizer
display("Converged at α =$(round(α,sigdigits=3)), after $(sol.f_calls) calls.")

f, r = NMRInversions.solve_regularization(svds.K, svds.g, α, solver)

return f, r, α

end

"""
function l_curvature(f, r, α, A, b)
Find alpha using Fminbox optimization.
"""
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::find_alpha_box
)

ξ = f'f
ρ = r'r
λ = α
local f
if mode.search_method == :gcv
f = x -> gcv_cost(first(x), svds, solver)

z = NMRInversions.solve_ls(A, b)
elseif mode.search_method == :lcurve

∂ξ∂λ = (4 / λ) * f'z
f = x -> l_cost(svds.K, svds.g, first(x), solver)
end

= 2 ** ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) / ((α * ξ^2 + ρ^2)^(3 / 2))
sol = optimize(
f,
[0], [Inf], [mode.start],
Fminbox(mode.algorithm),
mode.opts
)

return
α = sol.minimizer[1]

display("Converged at α =$(round(α,sigdigits=3)), after $(sol.f_calls) calls.")
f, r = NMRInversions.solve_regularization(svds.K, svds.g, α, solver)

return f, r, α

end

Expand All @@ -113,30 +201,27 @@ end
Test `n` alpha values between `lower` and `upper` and select the one
which is at the heel of the L curve, accoding to Hansen 2010.
"""
function solve_l_curve(K, g, solver, lower, upper, n)
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::lcurve_range
)

alphas = exp10.(range(log10(mode.lowest_value),
log10(mode.highest_value),
mode.number_of_steps))

alphas = exp10.(range(log10(lower), log10(upper), n))
curvatures = zeros(length(alphas))

for (i, α) in enumerate(alphas)
display("Testing α = $(round(α,sigdigits=3))")

f, r = NMRInversions.solve_regularization(K, g, α, solver)

A = sparse([K; (α) * LinearAlgebra.I ])
b = sparse([r; zeros(size(A, 1) - size(r, 1))])

c = l_curvature(f, r, α, A, b)

curvatures[i] = c
curvatures[i] = l_cost(svds.K, svds.g, α, solver)

end

α = alphas[argmin(curvatures)]
display("The optimal α is $(round(α,sigdigits=3))")

f, r = NMRInversions.solve_regularization(K, g, α, solver)

f, r = NMRInversions.solve_regularization(svds.K, svds.g, α, solver)
return f, r, α

end
43 changes: 15 additions & 28 deletions src/inversion_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ logarithmically spaced values
(default is (-5, 1, 128) for relaxation and (-10, -7, 128) for diffusion).
Alternatiively, a vector of values can be used directly, if more freedom is needed
(e.g. `lims=exp10.(range(-4, 1, 128))`).
- `alpha` determines the smoothing term. Use a real number for a fixed alpha. No selection will lead to automatically determining alpha through the defeault method, which is `gcv`.
- `alpha` determines the smoothing term. Use a real number for a fixed alpha.
No selection will lead to automatically determining alpha through the
defeault method, which is `gcv()`.
- `solver` is the algorithm used to do the inversion math. Default is `brd`.
- `normalize` will normalize `y` to 1 at the max value of `y`. Default is `true`.
"""
function invert(seq::Type{<:pulse_sequence1D}, x::AbstractArray, y::Vector;
lims::Union{Tuple{Real, Real, Int}, AbstractVector, Type{<:pulse_sequence1D}}=seq,
alpha::Union{Real, smoothing_optimizer, Type{<:smoothing_optimizer}}=gcv,
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd,
normalize::Bool = true)
lims::Union{Tuple{Real, Real, Int}, AbstractVector, Type{<:pulse_sequence1D}}=seq,
alpha::Union{Real, alpha_optimizer}=gcv(),
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd(),
normalize::Bool = true
)

if normalize
y = y ./ y[argmax(real(y))]
y = y ./ y[argmax(abs.(real(y)))]
end

if isa(lims, Tuple)
Expand All @@ -51,33 +54,24 @@ function invert(seq::Type{<:pulse_sequence1D}, x::AbstractArray, y::Vector;
end

ker_struct = create_kernel(seq, x, X, y)
α = 1.0 #placeholder, will be replaced below
α = 0.0 #placeholder, will be replaced below

if isa(alpha, Real)

α = alpha

f, r = solve_regularization(ker_struct.K, ker_struct.g, α, solver)

elseif alpha == gcv

f, r, α = solve_gcv(ker_struct, solver)
else

elseif isa(alpha, lcurve)
f, r, α = solve_l_curve(ker_struct.K, ker_struct.g, solver,
alpha.lowest_value, alpha.highest_value, alpha.number_of_steps)
f, r, α = find_alpha(ker_struct, solver, alpha)

else
error("alpha must be a real number or a smoothing_optimizer type.")

end

x_fit = exp10.(range(log10(1e-8), log10(1.1 * x[end]), 512))
y_fit = create_kernel(seq, x_fit, X) * f

isreal(y) ? SNR = NaN : SNR = calc_snr(y)


if seq == PFG
X .= X ./ 1e9
end
Expand Down Expand Up @@ -131,8 +125,8 @@ function invert(
seq::Type{<:pulse_sequence2D}, x_direct::AbstractVector, x_indirect::AbstractVector, Data::AbstractMatrix;
lims1::Union{Tuple{Real, Real, Int}, AbstractVector}=(-5, 1, 100),
lims2::Union{Tuple{Real, Real, Int}, AbstractVector}=(-5, 1, 100),
alpha::Union{Real, smoothing_optimizer, Type{<:smoothing_optimizer}}=gcv,
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd,
alpha::Union{Real, alpha_optimizer} = gcv(),
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd(),
normalize::Bool=true)

if normalize
Expand Down Expand Up @@ -160,15 +154,8 @@ function invert(
α = alpha
f, r = solve_regularization(ker_struct.K, ker_struct.g, α, solver)

elseif alpha == gcv
f, r, α = solve_gcv(ker_struct, solver)

elseif isa(alpha, lcurve)
f, r, α = solve_l_curve(ker_struct.K, ker_struct.g, solver,
alpha.lowest_value, alpha.highest_value, alpha.number_of_steps)

else
error("alpha must be a real number or a smoothing_optimizer type.")
f, r, α = find_alpha(ker_struct, solver, alpha)

end

Expand Down
Loading

0 comments on commit 78f0f96

Please sign in to comment.