Skip to content

Commit

Permalink
Merge pull request #262 from ACEsuit/co/repulsion
Browse files Browse the repository at this point in the history
Translate repulsion restraint to 0.8
  • Loading branch information
cortner authored Sep 14, 2024
2 parents dd227fe + 03861b8 commit 23965a2
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 95 deletions.
3 changes: 1 addition & 2 deletions src/ACEpotentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@ include("ace1_compat.jl")
# Fitting
include("atoms_data.jl")
include("fit_model.jl")
include("repulsion_restraint.jl")

# Data
include("example_data.jl")

# Misc
# TODO: all of this just needs to be moved from JuLIP to AtomsBase
include("analysis/dataset_analysis.jl")
include("analysis/potential_analysis.jl")
include("descriptor.jl")


# TODO: to be completely rewritten
# include("io.jl")
# include("export.jl")

# Experimental
Expand Down
27 changes: 15 additions & 12 deletions src/atoms_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ function _getfuzzy(coll, key)
end

_issimilarkey(k1, k2) = lowercase(String(k1)) == lowercase(String(k2))
_issimilarkey(k1::Nothing, k2) = false
_issimilarkey(k1, k2::Nothing) = false
_issimilarkey(k1::Nothing, k2::Nothing) = false

function _find_similar_key(coll, key)
for k in keys(coll)
Expand All @@ -43,38 +46,38 @@ function _find_similar_key(coll, key)
return nothing
end

function _find_similar_key(sys::ExtXYZ.Atoms, key)
for k in keys(sys.system_data)
function _find_similar_key(sys::AbstractSystem, key)
for k in keys(sys)
if _issimilarkey(k, key)
return k
end
end
for k in keys(sys.atom_data)
for k in atomkeys(sys)
if _issimilarkey(k, key)
return k
end
end
return nothing
end

function _get_data_fuzzy(sys::ExtXYZ.Atoms, key)
function _get_data_fuzzy(sys::AbstractSystem, key)
k = _find_similar_key(sys, key)
if k == nothing
error("Couldn't find $key or similar in collection with keys $(keys(sys))")
end
if haskey(sys.system_data, k)
return sys.system_data[k]
if haskey(sys, k)
return sys[k]
end
return sys.atom_data[k]
return sys[:, k]
end

_has_similar_key(coll, key) = (_find_similar_key(coll, key) != nothing)

function _get_data(sys::ExtXYZ.Atoms, key)
if haskey(sys.system_data, key)
return sys.system_data[key]
elseif haskey(sys.atom_data, key)
return sys.atom_data[key]
function _get_data(sys::AbstractSystem, key)
if haskey(sys, key)
return sys[key]
elseif hasatomkey(sys, key)
return sys[:, key]
else
error("Couldn't find $key in System")
end
Expand Down
10 changes: 9 additions & 1 deletion src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,15 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
end

if repulsion_restraint
error("Repulsion restraint is currently not implemented")
restraint_data = _rep_dimer_data_atomsbase(
model;
weight = restraint_weight,
energy_key = Symbol(energy_key),
kwargs...
)
append!(data, restraint_data)
# return nothing
# error("Repulsion restraint is currently not implemented")
# if eltype(data) == AtomsData
# append!(data, _rep_dimer_data(model; weight = restraint_weight))
# else
Expand Down
File renamed without changes.
File renamed without changes.
108 changes: 29 additions & 79 deletions src/repulsion_restraint.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,46 @@
# ---------------- Implementation of the repulsion restraint

import ACEpotentials: ACEPotential
import ACEpotentials.Models: ACEModel

function _rep_dimer_data_atomsbase(
model;
weight=0.01,
energy_key=:energy,
group_key=:config_type,
kwargs...
)
zz = model.basis.BB[1].zlist.list
model::ACEPotential{<: ACEModel};
weight=0.01,
energy_key=:energy,
group_key=:config_type,
kwargs...
)
B_pair = model.model.pairbasis
zz = B_pair._i2z
restraints = []
B_pair = model.basis.BB[1]
if !isa(B_pair, ACE1.PolyPairBasis)
error("repulsion restraints only implemented for PolyPairBasis")
end

for i = 1:length(zz), j = i:length(zz)
z1, z2 = zz[i], zz[j]
s1, s2 = chemical_symbol.((z1, z2))
r0_est = 1.0 # could try to get this from the model meta-data
_rin = r0_est / 100 # can't take 0 since we'd end up with ∞ / ∞
Pr_ij = B_pair.J[i, j]
if !isa(Pr_ij, ACE1.OrthPolys.TransformedPolys)
error("repulsion restraints only implemented for TransformedPolys")
end
envfun = Pr_ij.envelope
if !isa(envfun, ACE1.OrthPolys.PolyEnvelope)
error("repulsion restraints only implemented for PolyEnvelope")
end
if !(envfun.p >= 0)
error("repulsion restraints only implemented for PolyEnvelope with p >= 0")
end
env_rin = ACE1.evaluate(envfun, _rin)
T_ij = model.model.pairbasis.transforms[i, j]
env_ij = model.model.pairbasis.envelopes[i, j]
env_rin = ACEpotentials.Models.evaluate(env_ij, _rin, T_ij(_rin))

a1 = Atom(zz[1].z, zeros(3)u"Å")
a2 = Atom(zz[2].z, [_rin, 0, 0]u"Å")
cell = [ [_rin+1, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]u"Å"
boundary_conditions = [DirichletZero(), DirichletZero(), DirichletZero()]
data = FlexibleSystem([a1, a2], cell, boundary_conditions)
a1 = Atom(z1, zeros(3)u"Å")
a2 = Atom(z2, [_rin, 0, 0]u"Å")
cell = tuple([SA[_rin+1, 0.0, 0.0], SA[0.0, 1.0, 0.0], SA[0.0, 0.0, 1.0]]u"Å" ...)
pbc = (false, false, false)
system = FlexibleSystem([a1, a2], cell, pbc)
system.data[energy_key] = env_rin
system.data[:config_type] = "restraint"

data = AtomsData(system;
energy_key = energy_key,
force_key = nothing,
virial_key = nothing,
weights = Dict("restraint" => Dict("E" => weight)),
v_ref = _get_Vref(model)
)

# add weight to the structure
kwargs =[
energy_key => env_rin,
group_key => "restraint",
:energy_weight => weight,
]
data = FlexibleSystem(data; kwargs...)

push!(restraints, data)
end

return restraints
end

function _rep_dimer_data(model;
weight = 0.01
)
zz = model.basis.BB[1].zlist.list
restraints = []
restraint_weights = Dict("restraint" => Dict("E" => weight, "F" => 0.0, "V" => 0.0))
B_pair = model.basis.BB[1]
if !isa(B_pair, ACE1.PolyPairBasis)
error("repulsion restraints only implemented for PolyPairBasis")
end

for i = 1:length(zz), j = i:length(zz)
z1, z2 = zz[i], zz[j]
s1, s2 = chemical_symbol.((z1, z2))
r0_est = 1.0 # could try to get this from the model meta-data
_rin = r0_est / 100 # can't take 0 since we'd end up with ∞ / ∞
Pr_ij = B_pair.J[i, j]
if !isa(Pr_ij, ACE1.OrthPolys.TransformedPolys)
error("repulsion restraints only implemented for TransformedPolys")
end
envfun = Pr_ij.envelope
if !isa(envfun, ACE1.OrthPolys.PolyEnvelope)
error("repulsion restraints only implemented for PolyEnvelope")
end
if !(envfun.p >= 0)
error("repulsion restraints only implemented for PolyEnvelope with p >= 0")
end
env_rin = ACE1.evaluate(envfun, _rin)
at = at_dimer(_rin, z1, z2)
set_data!(at, "REF_energy", env_rin)
set_data!(at, "config_type", "restraint")
# AtomsData(atoms::Atoms; energy_key, force_key, virial_key, weights, v_ref, weight_key)
dat = ACEpotentials.AtomsData(at, energy_key = "REF_energy",
force_key = "REF_forces",
virial_key = "REF_virial",
weights = restraint_weights,
v_ref = model.Vref)
push!(restraints, dat)
end

return restraints
end
20 changes: 19 additions & 1 deletion test/test_silicon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,22 @@ using AtomsBuilder
sys = rattle!(bulk(:Si, cubic=true) * 2, 0.1)
X = site_descriptors(sys, model)
X234 = site_descriptors(sys, model; domain = [2,3,4])
X234 == X[2:4]
println_slim( @test X234 == X[2:4] )


##
# rerun fit with repulsion restraint

@info("Run a fit with repulsion restraint")

acefit!(data, model;
data_keys...,
weights = weights,
solver = ACEfit.BLR(),
repulsion_restraint = true)

r = 0.001 + 0.01 * rand()
zSi = atomic_number(ChemicalSpecies(:Si))
Bpair = model.model.pairbasis(r, zSi, zSi, NamedTuple(), NamedTuple())
V2 = sum(model.ps.Wpair[:, 1] .* Bpair)
println_slim(@test V2 > 10_000)

0 comments on commit 23965a2

Please sign in to comment.