Skip to content

Commit

Permalink
Merge pull request #258 from ACEsuit/co/fixfast
Browse files Browse the repository at this point in the history
Bugfixes for fast evaluator
  • Loading branch information
cortner authored Sep 14, 2024
2 parents ed2c3a9 + 40f8677 commit dd227fe
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 134 deletions.
17 changes: 17 additions & 0 deletions docs/src/tutorials/basic_julia_workflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,20 @@ end

# Finally, we delete the model to clean up.
rm("TiAl_model.json")

# ### Fast Evaluator
#
# `ACEpotentials.jl` provides an experimental "fast evaluator". This tries to
# merge some of the operations in the full model resulting in a "slimmer" and
# usually faster evaluator. In some cases the performance gain can be multiple
# factors up to an order of magnitude. This is particularly important when
# using a parameter estimation solver that sparsifies. In that case, the
# performance gain can be significant.
#
# To construct the fast evaluator, simply use
# ```julia
# fpot = fast_evaluator(model)
# ```
# An optional keyword argument `aa_static = true` can be used to optimize the
# n-correlation layer for very small models (at most a few hundred parameters).
# For larger models this leads to a stack overflow.
39 changes: 19 additions & 20 deletions examples/zuobench/zuo_asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Distributed, Random, SparseArrays
addprocs(10, exeflags="--project=$(Base.active_project())")
@everywhere using ACEpotentials, PrettyTables
using ACEpotentials.Models: fast_evaluator

##

Expand Down Expand Up @@ -33,32 +34,30 @@ At, yt, Wt = ACEpotentials.assemble(train_data, model)
Av, yv, Wv = ACEpotentials.assemble(val_data, model)

@info("Compute ASP Path")
solver = ACEfit.ASP(; P = P, select = :final, tsvd = true,
actMax = 1000, traceFlag=true )
solver = ACEfit.ASP(; P = P, select = :final, tsvd = true, actMax = 1000 )
asp_result = ACEfit.solve(solver, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv)

##

@info("Pick solutions for 100, 300, 1000 parameters, compute errors")

@show length(asp_result["path"])
path = asp_result["path"]
nnzs = [ nnz(p.solution) for p in path ]
I1000 = length(nnzs)
I300 = findfirst(nnzs .>= 300)
I100 = findfirst(nnzs .>= 100)

model_1000 = deepcopy(model)
set_parameters!(model_1000, path[I1000].solution)
model_300 = deepcopy(model)
set_parameters!(model_300, path[I300].solution)
model_100 = deepcopy(model)
set_parameters!(model_100, path[I100].solution)

err_100 = ACEpotentials.linear_errors(test_data, model_100)
err_300 = ACEpotentials.linear_errors(test_data, model_300)
err_1000 = ACEpotentials.linear_errors(test_data, model_1000)

# select models from the model path
model_1000 = set_parameters!( deepcopy(model),
ACEfit.asp_select(asp_result, :final)[1])
model_300 = set_parameters!( deepcopy(model),
ACEfit.asp_select(asp_result, (:bysize, 300))[1])
model_100 = set_parameters!( deepcopy(model),
ACEfit.asp_select(asp_result, (:bysize, 100))[1])

# generate sparsified, faster evaluators
pot_1000 = fast_evaluator(model_1000; aa_static = false) # static can cause stack overflow
pot_300 = fast_evaluator(model_300; aa_static = true)
pot_100 = fast_evaluator(model_100; aa_static = true)

@info("Evaluate errors on the test set")
err_100 = ACEpotentials.linear_errors(test_data, pot_100)
err_300 = ACEpotentials.linear_errors(test_data, pot_300)
err_1000 = ACEpotentials.linear_errors(test_data, pot_1000)

##

Expand Down
8 changes: 4 additions & 4 deletions src/ACEpotentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ import ACEpotentials.ACE1compat: ace1_model
import ACEpotentials.Models: algebraic_smoothness_prior,
exp_smoothness_prior,
gaussian_smoothness_prior,
set_parameters!
set_parameters!,
fast_evaluator
import JSON

export ace1_model,
length_basis,
algebraic_smoothness_prior,
exp_smoothness_prior,
gaussian_smoothness_prior,
set_parameters!

set_parameters!,
fast_evaluator

include("json_interface.jl")



end
4 changes: 2 additions & 2 deletions src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ function linear_errors(raw_data::AbstractArray{<: AbstractSystem}, model;
virial_key = "virial",
weights = default_weights(),
verbose = true,
return_efv = false
return_efv = false,
)
data = [ AtomsData(at; energy_key = energy_key, force_key=force_key,
virial_key = virial_key, weights = weights,
v_ref = _get_Vref(model))
v_ref = nothing)
for at in raw_data ]
return linear_errors(data, model; verbose=verbose, return_efv = return_efv)
end
Expand Down
26 changes: 24 additions & 2 deletions src/models/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ function FastACEInner(model::ACEPotential{<: ACEModel}, iz;
return FastACEinner(rbasis, ybasis, a_basis, aadot)
end

"""
fast_evaluator(model; aa_static = :auto)
Constructs an experimental "fast evaluator" for a fitted model, which merges
some operations resulting in a "slimmer" and usually faster evaluator.
In some cases the performance gain can be significant, especially when the
fitted parameters are sparse.
To construct the fast evaluator,
```julia
fpot = fast_evaluator(model)
```
An optional keyword argument `aa_static = true` can be used to enforce
optimizing the n-correlation layer for very small models (at most a few
hundred parameters). For larger models this results in a stack overflow.
"""
function fast_evaluator(model::ACEPotential{<: ACEModel};
aa_static = :auto)
if aa_static == :auto
Expand Down Expand Up @@ -224,8 +240,14 @@ function (aadot::AADot)(A)
end

function eval_and_grad!(∇φ_A, aadot::AADot, A)
φ = aadot(A)
P4ML.pullback!(∇φ_A, aadot.cc, aadot.aabasis, A)
@no_escape begin
AA = @alloc(eltype(A), length(aadot.aabasis))
P4ML.evaluate!(AA, aadot.aabasis, A)
φ = dot(aadot.cc, AA)
P4ML.unsafe_pullback!(∇φ_A, aadot.cc, aadot.aabasis, AA)
nothing
end

return φ
end

Expand Down
30 changes: 16 additions & 14 deletions test/test_fast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ model.ps.WB[Ism] .= 0.0

# construct the fast evaluator (is it actually fast??)
fpot = M.fast_evaluator(model)
fpot_d = M.fast_evaluator(model, aa_static=false)

##

Expand All @@ -47,9 +48,10 @@ for ntest = 1:20
E2 = M.eval_site(model, Rs, Zs, z0)
v1, ∇v1 = M.eval_grad_site(fpot, Rs, Zs, z0)
v2, ∇v2 = M.eval_grad_site(model, Rs, Zs, z0)
v3, ∇v3 = M.eval_grad_site(fpot_d, Rs, Zs, z0)

print_tf(@test E1 E2 v1 v2)
print_tf(@test all(∇v1 .≈ ∇v2))
print_tf(@test E1 E2 v1 v2 v3)
print_tf(@test all(∇v1 .≈ ∇v2 .≈ ∇v3))
end
println()

Expand All @@ -69,14 +71,15 @@ tolerance = 1e-10
rattle = 0.1

for ntest = 1:20
local at
local at, efv1, efv2, efv3
at = bulk(:Si, cubic=true) * 2
rattle!(at, rattle)
efv1 = energy_forces_virial(at, model)
efv2 = energy_forces_virial(at, fpot)
efv3 = energy_forces_virial(at, fpot_d)
print_tf(@test ustrip(abs(efv1.energy - efv2.energy)) < tolerance)
print_tf(@test all(efv1.forces .≈ efv2.forces))
print_tf(@test all(efv1.virial .≈ efv2.virial))
print_tf(@test all(efv1.forces .≈ efv2.forces .≈ efv3.forces))
print_tf(@test all(efv1.virial .≈ efv2.virial .≈ efv3.virial))
end
println()

Expand All @@ -101,6 +104,7 @@ model.ps.Wpair[:, :] = Diagonal((1:len2).^(-2)) * model.ps.Wpair[:, :]

@info("convert to UF_ACE format")
fpot = M.fast_evaluator(model)
fpot_d = M.fast_evaluator(model, aa_static=false)

##

Expand All @@ -112,9 +116,10 @@ for ntest = 1:20
E2 = M.eval_site(model, Rs, Zs, z0)
v1, ∇v1 = M.eval_grad_site(fpot, Rs, Zs, z0)
v2, ∇v2 = M.eval_grad_site(model, Rs, Zs, z0)
v3, ∇v3 = M.eval_grad_site(fpot_d, Rs, Zs, z0)

print_tf(@test E1 E2 v1 v2)
print_tf(@test all(∇v1 .≈ ∇v2))
print_tf(@test E1 E2 v1 v2 v3)
print_tf(@test all(∇v1 .≈ ∇v2 .≈ ∇v3))

end
println()
Expand All @@ -127,18 +132,15 @@ tolerance = 1e-12
rattle = 0.01

for ntest = 1:20
local sys
local sys, efv1, efv2, efv3
sys = rattle!(bulk(:Al, cubic=true) * 2, 0.1)
randz!(sys, [:Ti => 0.5, :Al => 0.5])

efv1 = energy_forces_virial(sys, model)
efv2 = energy_forces_virial(sys, fpot)
efv3 = energy_forces_virial(sys, fpot_d)
print_tf(@test ustrip(abs(efv1.energy - efv2.energy)) < tolerance)
print_tf(@test all(efv1.forces .≈ efv2.forces))
print_tf(@test all(efv1.virial .≈ efv2.virial))

# E1 = potential_energy(sys, model)
# E2 = potential_energy(sys, fpot)
# print_tf(@test ustrip(abs(E1 - E2)) < tolerance)
print_tf(@test all(efv1.forces .≈ efv2.forces .≈ efv3.forces))
print_tf(@test all(efv1.virial .≈ efv2.virial .≈ efv3.virial))
end
println()
1 change: 1 addition & 0 deletions test/test_io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ACEpotentials.save_model(model, fname; model_spec = model_spec)
model1, meta = ACEpotentials.load_model(fname)

for ntest = 1:10
local sys
sys = rattle!(bulk(:Al, cubic=true) * 2, 0.1)
sys = randz!(sys, [:Ti => 0.5, :Al => 0.5])
print_tf( @test potential_energy(sys, model) potential_energy(sys, model1) )
Expand Down
92 changes: 0 additions & 92 deletions test/test_uface.jl

This file was deleted.

0 comments on commit dd227fe

Please sign in to comment.