Skip to content

Commit

Permalink
Merge pull request #263 from ACEsuit/zbl
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner authored Sep 15, 2024
2 parents 23965a2 + 5ba196f commit 1a19e4b
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 37 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2"
EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc"
ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Expand Down
14 changes: 13 additions & 1 deletion src/ace1_compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const _kw_defaults = Dict(:elements => nothing,
:pair_envelope => (:r, 2),
#
:Eref => missing,
:ZBL => false,
#
:variable_cutoffs => false,
)
Expand Down Expand Up @@ -300,8 +301,12 @@ function _pair_basis(kwargs)
# here we use a similar convention, just need to convert to ace1-style
envelope = kwargs[:pair_envelope]
if envelope isa Tuple && envelope[1] == :r
if kwargs[:ZBL]
@warn("""It is not recommended to combine the ZBL reference potential
with a repulsive pair basis. Use `pair_envelope = (:x, 0, q)` instead.""")
end
envelope = (:r_ace1, envelope[2])
end
end

pair_basis = ace_learnable_Rnlrzz(; spec = pair_spec,
maxq = maxq,
Expand All @@ -322,6 +327,11 @@ end

function ace1_model(; kwargs...)

# change the default for the envelope if ZBL is used
if haskey(kwargs, :ZBL) && kwargs[:ZBL] && !haskey(kwargs, :envelope)
kwargs = (; pair_envelope = (:x, 0, 2), kwargs...)
end

kwargs = _clean_args(kwargs)

elements = _get_elements(kwargs)
Expand Down Expand Up @@ -370,9 +380,11 @@ function ace1_model(; kwargs...)
E0s = Dict([ key => val * u"eV" for (key, val) in Eref]...)
end


model = Models.ace_model(; elements=elements,
order = cor_order,
Ytype = :spherical,
ZBL = kwargs[:ZBL],
E0s = E0s,
rbasis = rbasis,
pair_basis = pairbasis,
Expand Down
38 changes: 9 additions & 29 deletions src/models/ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,16 @@ function _make_idx_AA_spec(AA_spec, A_spec)
return AA_spec_idx
end

function _make_Vref_E0s(rbasis, E0s::Nothing)
NZ = _get_nz(rbasis)
return _make_Vref_E0s(rbasis, [ _i2z(rbasis, i) => 0.0 for i = 1:NZ ] )
end

# E0s can be anything with (key, value) pairs
function _make_Vref_E0s(rbasis, E0s)
_convert_E0s(E0s::Union{Dict, NamedTuple}) = E0s
_convert_E0s(E0s::Union{AbstractVector, Tuple}) = Dict(E0s...)
_convert_E0s(E0s) = error("E0s must be nothing, a NamedTuple, Dict or list of pairs")

NZ = _get_nz(rbasis)
V0 = OneBody(_convert_E0s(E0s))
if length(V0.E0) != NZ
error("E0s must have the right number of elements")
end

return V0
end


function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector,
Vref,
level = TotalDegree(),
pair_basis = nothing,
E0s = nothing,
Vref = _make_Vref_E0s(rbasis, E0s), )
)

# storing E0s with unit
model_meta = Dict{String, Any}("E0s" => deepcopy(E0s))
# # storing E0s with unit
# model_meta = Dict{String, Any}("E0s" => deepcopy(E0s))
model_meta = Dict{String, Any}()

# generate the coupling coefficients
cgen = EquivariantModels.Rot3DCoeffs_real(0)
Expand Down Expand Up @@ -170,8 +151,8 @@ end
# since it is implicitly already encoded in AA_spec. We need a
# function `auto_level` that generates level automagically from AA_spec.
function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level,
pair_basis, E0s = nothing)
return _generate_ace_model(rbasis, Ytype, AA_spec, level, pair_basis, E0s)
pair_basis, Vref)
return _generate_ace_model(rbasis, Ytype, AA_spec, Vref, level, pair_basis)
end

# NOTE : a nicer convenience constructor is also provided in `ace_heuristics.jl`
Expand Down Expand Up @@ -320,9 +301,8 @@ function evaluate(model::ACEModel,
val += dot(Apair, (@view ps.Wpair[:, i_z0]))
end
# -------------------
# TODO - Vref : assume it is a OneBody
@assert model.Vref isa OneBody
val += model.Vref.E0[Z0]
# Vref
val += eval_site(model.Vref, Rs, Zs, Z0)
# -------------------

end # @no_escape
Expand Down
43 changes: 41 additions & 2 deletions src/models/ace_heuristics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Random
import ACEpotentials: DefaultHypers
import EmpiricalPotentials


# -------------------------------------------------------
Expand Down Expand Up @@ -101,9 +102,45 @@ end



_convert_E0s(E0s::Union{Dict, NamedTuple}) = E0s
_convert_E0s(E0s::Union{AbstractVector, Tuple}) = Dict(E0s...)
_convert_E0s(E0s) = error("E0s must be nothing, a NamedTuple, Dict or list of pairs")

# E0s can be anything with (key, value) pairs
_make_Vref_E0s(elements, E0s) = OneBody(_convert_E0s(E0s))

function _make_Vref_E0s(elements, E0s::Nothing)
NZ = length(elements)
zz = _convert_zlist(elements)
return _make_Vref_E0s(elements, [ z => 0.0 for z in zz ] )
end


function _make_Vref(elements, E0s, ZBL, rcut = nothing)

if !(isnothing(E0s))
E0s = _convert_E0s(E0s)
if (sort([elements...]) != sort(collect(keys(E0s))))
error("E0s keys must be the same as the list of elements")
end
end

Vref_E0s = _make_Vref_E0s(elements, E0s)

if ZBL
Vref_zbl = EmpiricalPotentials.ZBL(rcut*u"")
return SitePotentialStack((Vref_E0s, Vref_zbl))
else
return Vref_E0s
end
end



function ace_model(; elements = nothing,
order = nothing,
Ytype = :solid,
ZBL = false,
E0s = nothing,
rin0cuts = :auto,
# radial basis
Expand Down Expand Up @@ -170,11 +207,13 @@ function ace_model(; elements = nothing,
end
end


AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec,
level = level, max_level = max_level)

model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis, E0s)
rcut = maximum([ x.rcut for x in rin0cuts ])
Vref = _make_Vref(elements, E0s, ZBL, rcut)

model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis, Vref)
model.meta["init_WB"] = String(init_WB)
model.meta["init_Wpair"] = String(init_Wpair)

Expand Down
1 change: 1 addition & 0 deletions src/models/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ function length_basis end
include("elements.jl")

include("onebody.jl")
include("stacked_pot.jl")

include("radial_envelopes.jl")

Expand Down
4 changes: 2 additions & 2 deletions src/models/onebody.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ energy_unit(V::OneBody) = V.energy_unit
length_unit(V::OneBody) = V.length_unit

cutoff_radius(::OneBody{T}) where {T} =
sqrt(eps(T))

sqrt(eps(T)) * u"Å"
eval_site(V::OneBody, Rs, Zs, zi::Integer) =
V.E0[zi]

Expand Down
34 changes: 34 additions & 0 deletions src/models/stacked_pot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

import AtomsCalculators
import AtomsCalculatorsUtilities

import AtomsCalculatorsUtilities.SitePotentials: SitePotential,
cutoff_radius, eval_site, eval_grad_site,
energy_unit, length_unit


struct SitePotentialStack{TP} <: SitePotential
pots::TP
end

function energy_unit(pot::SitePotentialStack)
return energy_unit(pot.pots[1])
end

function length_unit(pot::SitePotentialStack)
return length_unit(pot.pots[1])
end

function cutoff_radius(pot::SitePotentialStack)
return maximum(cutoff_radius, pot.pots)
end

function eval_site(pot::SitePotentialStack, Rs, Zs, z0)
return sum( eval_site(p, Rs, Zs, z0) for p in pot.pots )
end

function eval_grad_site(pot::SitePotentialStack, Rs, Zs, z0)
return eval_site(pot, Rs, Zs, z0),
sum( eval_grad_site(p, Rs, Zs, z0)[2] for p in pot.pots )
end

1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004"
AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1"
AtomsCalculatorsUtilities = "9855a07e-8816-4d1b-ac92-859c17475477"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2"
ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand Down
58 changes: 58 additions & 0 deletions test/models/test_Vref.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
##

using Test
using Polynomials4ML.Testing: print_tf, println_slim

using ACEpotentials
M = ACEpotentials.Models

using Random, StaticArrays, LinearAlgebra, Unitful, EmpiricalPotentials
using AtomsCalculatorsUtilities.SitePotentials: eval_site, eval_grad_site
using AtomsBase
##

elements = [:C, :O, :H]
E0s = Dict(:C => -rand(), :O => -rand(), :H => -rand())
rcut = 5.5u""
zbl = ZBL(rcut)
V0 = M.OneBody(E0s)

Vref1 = M._make_Vref(elements, E0s, false)
Vref2 = M._make_Vref(elements, nothing, true, ustrip(rcut))
Vref3 = M._make_Vref(elements, E0s, true, ustrip(rcut))

zC = atomic_number(ChemicalSpecies(:C))
zO = atomic_number(ChemicalSpecies(:O))
zH = atomic_number(ChemicalSpecies(:H))

Rs0 = SVector{3, Float64}[]
Zs0 = Int[]
nZ = rand(3:5)
Rs1 = randn(SVector{3, Float64}, nZ)
Zs1 = rand([zC, zO, zH], nZ)

##

print_tf(@test (eval_site(zbl, Rs0, Zs0, zC) == 0.0 ))
print_tf(@test (eval_site(zbl, Rs0, Zs0, zO) == 0.0 ))
print_tf(@test (eval_site(zbl, Rs0, Zs0, zH) == 0.0 ))
print_tf(@test (eval_site(V0, Rs0, Zs0, zC) == E0s[:C] ))
print_tf(@test (eval_site(V0, Rs0, Zs0, zO) == E0s[:O] ))
print_tf(@test (eval_site(V0, Rs0, Zs0, zH) == E0s[:H] ))
println()

##

for (Rs, Zs) in [ (Rs0, Zs0), (Rs1, Zs1)], z in [zC, zO, zH]
print_tf(@test ( eval_site(Vref1, Rs, Zs, z) == eval_site(V0, Rs, Zs, z) ))
print_tf(@test ( eval_site(Vref2, Rs, Zs, z) == eval_site(zbl, Rs, Zs, z) ))
print_tf(@test ( eval_site(Vref3, Rs, Zs, z) == eval_site(V0, Rs, Zs, z) + eval_site(zbl, Rs, Zs, z) ))
print_tf(@test ( eval_grad_site(Vref1, Rs, Zs, z) == eval_grad_site(V0, Rs, Zs, z) ))
print_tf(@test ( eval_grad_site(Vref2, Rs, Zs, z) == eval_grad_site(zbl, Rs, Zs, z) ))
print_tf(@test ( eval_grad_site(Vref3, Rs, Zs, z)[2] == eval_grad_site(V0, Rs, Zs, z)[2] + eval_grad_site(zbl, Rs, Zs, z)[2] ))
end
println()

##
2 changes: 1 addition & 1 deletion test/models/test_ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ for ybasis in [:spherical, :solid]
@info("Test Rotation-Invariance of the Model")

for ntest = 1:50
local st1, Nat, Rs, Zs, Z0, val
local st1, Nat, Rs, Zs, Z0, val, Rs1, Zs1

Nat = rand(8:16)
Rs, Zs, Z0 = M.rand_atenv(model, Nat)
Expand Down
1 change: 1 addition & 0 deletions test/models/test_models.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

@testset "Vref" begin; include("test_Vref.jl"); end
@testset "Radial Envelopes" begin; include("test_radial_envelopes.jl"); end
@testset "Radial Transforms" begin; include("test_radial_transforms.jl"); end
@testset "Rnlrzz Basis" begin; include("test_Rnl.jl"); end
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_radialweights.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))

##

Expand Down
30 changes: 29 additions & 1 deletion test/test_silicon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

using ACEpotentials
using ACEbase.Testing: println_slim
using ExtXYZ, AtomsBase
using ExtXYZ, AtomsBase, Unitful, StaticArrays, AtomsCalculators
using AtomsCalculators: potential_energy
using Distributed
using LazyArtifacts
using Test
Expand Down Expand Up @@ -138,3 +139,30 @@ zSi = atomic_number(ChemicalSpecies(:Si))
Bpair = model.model.pairbasis(r, zSi, zSi, NamedTuple(), NamedTuple())
V2 = sum(model.ps.Wpair[:, 1] .* Bpair)
println_slim(@test V2 > 10_000)


##
# rerun a fit with ZBL reference

@info("Run a fit with ZBL reference potential")
@warn(" The tests here seem a bit weak, the ZBL implementation may be buggy.")

model = ACE1compat.ace1_model(; ZBL = true, pair_transform = (:agnesi, 1, 2),
params...)

acefit!(data, model;
data_keys...,
weights = weights,
repulsion_restraint = true,
solver = ACEfit.BLR())

r = 0.000001 + 0.000001 * rand()
zSi = atomic_number(ChemicalSpecies(:Si))
dimer = periodic_system(
[Atom(ChemicalSpecies(:Si), SA[0.0,0.0,0.0]u"Å", ),
Atom(ChemicalSpecies(:Si), SA[ r,0.0,0.0]u"Å", )],
( SA[ r+1, 0.0, 0.0 ]u"Å", SA[0.0,1.0,0.0]u"Å", SA[0.0,0.0,1.0]u"Å" ),
periodicity = (false, false, false) )
@show potential_energy(dimer, model)
println_slim(@test potential_energy(dimer, model) > 1e3u"eV")

0 comments on commit 1a19e4b

Please sign in to comment.