From 6480aa82e8a861de909c31cef79d4a97de82fd70 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 2 May 2024 22:30:36 -0700 Subject: [PATCH 001/112] first steps towards new ACEpotentials kernels : envelopes --- Project.toml | 4 +- src/ACEpotentials.jl | 1 + src/models/models.jl | 7 +++ src/models/radial_envelopes.jl | 69 ++++++++++++++++++++++++++++ test/models/test_radial_envelopes.jl | 30 ++++++++++++ 5 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 src/models/models.jl create mode 100644 src/models/radial_envelopes.jl create mode 100644 test/models/test_radial_envelopes.jl diff --git a/Project.toml b/Project.toml index 12ae62c2..09577712 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,8 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" @@ -30,9 +32,9 @@ JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" StaticArrays = "1" +UltraFastACE = "0.0.2" YAML = "0.4" julia = "1.9" -UltraFastACE = "0.0.2" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/ACEpotentials.jl b/src/ACEpotentials.jl index f1cdf5c4..991328e9 100644 --- a/src/ACEpotentials.jl +++ b/src/ACEpotentials.jl @@ -19,6 +19,7 @@ include("analysis/potential_analysis.jl") include("analysis/dataset_analysis.jl") include("experimental.jl") +include("models/models.jl") include("outdated/fit.jl") include("outdated/data.jl") diff --git a/src/models/models.jl b/src/models/models.jl new file mode 100644 index 00000000..cb2522b1 --- /dev/null +++ b/src/models/models.jl @@ -0,0 +1,7 @@ + +module Models + +include("radial_envelopes.jl") + + +end \ No newline at end of file diff --git a/src/models/radial_envelopes.jl b/src/models/radial_envelopes.jl new file mode 100644 index 00000000..1a3435a9 --- /dev/null +++ b/src/models/radial_envelopes.jl @@ -0,0 +1,69 @@ + +abstract type AbstractEnvelope end + +struct PolyEnvelope1sR{T} + rcut::T + p::Int + # ------- + meta::Dict{String, Any} +end + + +PolyEnvelope1sR(rcut, p) = + PolyEnvelope1sR(rcut, p, Dict{String, Any}()) + +function evaluate(env::PolyEnvelope1sR, r::T) where T + if r >= env.rcut + return zero(T) + end + p = env.p + # return r^(-p) - env.rcut^(-p) - p*(env.rcut^(-p-1))*(r - env.rcut) + return ( (r/env.rcut)^(-p) - 1.0) * (1 - r / env.rcut) +end + +evaluate_d(env::PolyEnvelope1sR, r) = + ForwardDiff.derivative(x -> evaluate(env, x), r) + +# ---------------------------- + +struct PolyEnvelope2sX{T} + x1::T + x2::T + p1::Int + p2::Int + s::T + # ------- + meta::Dict{String, Any} +end + +function PolyEnvelope2sX(x1, x2, p1, p2) + if x1 == x2 + error("x1 and x2 must be different!") + end + if x1 > x2 + @warn("swapping x1, x2 to ensure x1 < x2") + x1, x2 = x2, x1 + p1, p2 = p2, p1 + end + s = 1 / (abs(x2 - x1)/2)^(p1+p2) + PolyEnvelope2sX(x1, x2, p1, p2, s, Dict{String, Any}()) +end + + +function evaluate(env::PolyEnvelope2sX, x::T) where T + x1, x2 = env.x1, env.x2 + p1, p2 = env.p1, env.p2 + s = env.s + + if !(x1 < x < x2) + return zero(T) + end + + return s * (x-x1)^p1 * (x2-x)^p2 +end + + +evaluate_d(env::PolyEnvelope2sX, x::T) where T = + ForwardDiff.derivative(x -> evaluate(env, x), x) + + diff --git a/test/models/test_radial_envelopes.jl b/test/models/test_radial_envelopes.jl new file mode 100644 index 00000000..7abc07aa --- /dev/null +++ b/test/models/test_radial_envelopes.jl @@ -0,0 +1,30 @@ + + +using Pkg; Pkg.activate("."); +using TestEnv; TestEnv.activate(); + +using ACEpotentials + +# there are no real tests for envelopes yet. The only thing we have is +# a plot of the envelopes to inspect manually. + +## + +#= +using Plots + +rcut = 2.0 +envpair = ACEpotentials.Models.PolyEnvelope1sR(rcut, 1) +rr = range(0.0001, rcut+0.5, length=200) +y2 = ACEpotentials.Models.evaluate.(Ref(envpair), rr) + +envmb = ACEpotentials.Models.PolyEnvelope2sX(0.0, 1.0, 2, 2) +ymb = ACEpotentials.Models.evaluate.(Ref(envmb), rr) + +plot(rr, y2, label="pair envelope", lw=2, legend=:topleft, ylims = (-1.0, 3.0)) +plot!(rr, ymb, label="mb envelope", lw=2, ) +=# + +## + + From 9ed960aec7db758d5c7f96f19c8d70ecb205bb62 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 3 May 2024 11:01:03 -0700 Subject: [PATCH 002/112] initial transfer and a bit of cleanup of agnesi transform --- Project.toml | 1 + src/models/models.jl | 1 + src/models/radial_transforms.jl | 187 ++++++++++++++++++++++++++ test/models/test_radial_envelopes.jl | 4 +- test/models/test_radial_transforms.jl | 49 +++++++ 5 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 src/models/radial_transforms.jl create mode 100644 test/models/test_radial_transforms.jl diff --git a/Project.toml b/Project.toml index 09577712..e88efd7a 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JuLIP = "945c410c-986d-556a-acb1-167a618e0462" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" diff --git a/src/models/models.jl b/src/models/models.jl index cb2522b1..c293f598 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -3,5 +3,6 @@ module Models include("radial_envelopes.jl") +include("radial_transforms.jl") end \ No newline at end of file diff --git a/src/models/radial_transforms.jl b/src/models/radial_transforms.jl new file mode 100644 index 00000000..ae20554f --- /dev/null +++ b/src/models/radial_transforms.jl @@ -0,0 +1,187 @@ + +import ForwardDiff + +struct GeneralizedAgnesiTransform{T} + p::Int + q::Int + a::T + rin::T + r0::T +end + +(t::GeneralizedAgnesiTransform)(r) = evaluate(t, r) + +write_dict(T::GeneralizedAgnesiTransform) = + Dict("__id__" => "ACEpotentials_GeneralizedAgnesiTransform", + "r0" => T.r0, "p" => T.p, "q" => T.q, "a" => T.a, "rin" => T.rin) + +GeneralizedAgnesiTransform(D::Dict) = + GeneralizedAgnesiTransform(D["r0"], D["p"], D["q"], + D["a"], D["rin"]) + +read_dict(::Val{:ACEpotentials_GeneralizedAgnesiTransform}, D::Dict) = + GeneralizedAgnesiTransform(D) + +function evaluate(t::GeneralizedAgnesiTransform{T}, r::Number) where {T} + if r <= t.rin + return one(promote_type(T, typeof(r))) + end + a, r0, q, p, rin = t.a, t.r0, t.q, t.p, t.rin + s = (r-t.rin)/(t.r0-t.rin) + return 1 / (1 + a * s^q / (1 + s^(q-p))) +end + +evaluate_d(t::GeneralizedAgnesiTransform, r::Number) = + ForwardDiff.derivative(r -> transform(t, r), r) + + + +# --------------------------------------------------------------------------- + + +# tested a wide range of methods. Brent seems robust + fastest +using Roots: find_zero, ITP, Brent + +struct NormalizedTransform{T, TT} + trans::TT + yin::T + ycut::T + rin::T + rcut::T +end + +function NormalizedTransform(trans, rin::Number, rcut::Number) + yin = trans(rin) + ycut = trans(rcut) + return NormalizedTransform(trans, yin, ycut, rin, rcut) +end + + +(t::NormalizedTransform)(r) = evaluate(t, r) + +function evaluate(t::NormalizedTransform, r::Number) + y = t.trans(r) + return min(max(zero(y), (y - t.yin) / (t.ycut - t.yin)), one(y)) +end + +# this is the old version from ACE1.jl; a neat idea. We could return to it. +# but it could be better to integrate this into the inner transform. +# return 1 - (y - t.y1) / (t.y0 - t.y1) * (1 - (r/t.rcut)^4) + +evaluate_d(t::NormalizedTransform, r::Number) = + ForwardDiff.derivative(r -> evaluate(t, r), r) + + +function inv_transform(t::NormalizedTransform{T}, x::Number) where {T} + T1 = promote_type(T, typeof(x)) + if x <= 0 + return convert(T1, t.rin) + elseif x >= 1 + return convert(T1, t.rcut) + end + + g = r -> transform(t, r) - x + r = find_zero(g, (t.rin, t.rcut), Brent()) + @assert t.rin <= r <= t.rcut + @assert abs(g(r)) < 1e-12 + return r +end + +write_dict(T::NormalizedTransform) = + Dict("__id__" => "ACEpotentials_NormalizedTransform", + "trans" => write_dict(T.trans), + "yin" => T.yin, "ycut" => T.ycut, + "rin" => T.rin, "rcut" => T.rcut ) + +read_dict(::Val{:ACEpotentials_NormalizedTransform}, D::Dict) = + NormalizedTransform(read_dict(D["trans"]), + D["yin"], D["ycut"], D["rin"], D["rcut"]) + +# --------------------------------------------------------------------------- + +function test_normalized_transform(t; nx = 1000) + x = range(0.0, 1.0, length = nx) + r = [inv_transform(t, xi) for xi in x] + @assert r[1] == 0.0 + @assert r[end] == t.rcut + x2 = [t(ri) for ri in r] + if !all(abs.(x .- x2) .< 1e-10) + error("Inverse transform failed!") + end + @assert all(r[2:end] - r[1:end-1] .>= 0) + + r = range(0.0, t.rcut, length = nx) + x = [t(ri) for ri in r] + @assert x[1] == t.yin + @assert x[end] == t.ycut + r2 = [inv_transform(t, xi) for xi in x] + if !all(abs.(r .- r2) .< 1e-10) + error("Inverse transform failed!") + end + @assert all(x[2:end] .- x[1:end-1] .>= 0) + + return true +end + +# test transform from ACE1 to be merged with the above. +# function test_transform(T, rrange, ntests = 100) + +# rmin, rmax = extrema(rrange) +# rr = rmin .+ rand(100) * (rmax-rmin) +# xx = [ transform(T, r) for r in rr ] +# # check syntactic sugar +# xx1 = [ T(r) for r in rr ] +# print_tf(@test xx1 == xx) +# # check inversion +# rr1 = inv_transform.(Ref(T), xx) +# print_tf(@test rr1 ≈ rr) +# # check gradient +# dx = transform_d.(Ref(T), rr) +# adx = ForwardDiff.derivative.(Ref(r -> transform(T, r)), rr) +# print_tf(@test dx ≈ adx) + +# # TODO: check that the transform doesn't allocate +# @allocated begin +# x = 0.0; +# for r in rr +# x += transform(T, r) +# end +# end +# end + + + + +@doc raw""" +`function agnesi_transform:` constructs a generalized agnesi transform. +``` +trans = agnesi_transform(r0, p, q) +``` +with `q >= p`. This generates an `AnalyticTransform` object that implements +```math + x(r) = \frac{1}{1 + a (r/r_0)^q / (1 + (r/r0)^(q-p))} +``` +with default `a` chosen such that $|x'(r)|$ is maximised at $r = r_0$. But `a` may also be specified directly as a keyword argument. + +The transform satisfies +```math + x(r) \sim \frac{1}{1 + a (r/r_0)^p} \quad \text{as} \quad r \to 0 + \quad \text{and} + \quad + x(r) \sim \frac{1}{1 + a (r/r_0)^p} \quad \text{as} r \to \infty. +``` + +As default parameters we recommend `p = 2, q = 4` and the defaults for `a`. +""" +function agnesi_transform(r0, rcut, p, q; + rin = zero(r0), + a = (-2 * q + p * (-2 + 4 * q)) / (p + p^2 + q + q^2) ) + @assert p > 0 + @assert q > 0 + @assert q >= p + @assert a > 0 + @assert 0 < r0 < rcut + return NormalizedTransform( + GeneralizedAgnesiTransform(p, q, a, rin, r0), + rin, rcut ) +end diff --git a/test/models/test_radial_envelopes.jl b/test/models/test_radial_envelopes.jl index 7abc07aa..07705cf3 100644 --- a/test/models/test_radial_envelopes.jl +++ b/test/models/test_radial_envelopes.jl @@ -1,7 +1,7 @@ -using Pkg; Pkg.activate("."); -using TestEnv; TestEnv.activate(); +# using Pkg; Pkg.activate("."); +# using TestEnv; TestEnv.activate(); using ACEpotentials diff --git a/test/models/test_radial_transforms.jl b/test/models/test_radial_transforms.jl new file mode 100644 index 00000000..9b2291a8 --- /dev/null +++ b/test/models/test_radial_transforms.jl @@ -0,0 +1,49 @@ + +using Pkg; Pkg.activate("."); +using TestEnv; TestEnv.activate(); + +using ACEpotentials + +# there are no real tests for envelopes yet. The only thing we have is +# a plot of the envelopes to inspect manually. + +## +# this code block should normally be just commented out and is just intended +# for a quick visual inspection of the transforms. + +#= +using Plots +rcut = 6.5 +r0 = 2.3 + +trans_2_2 = ACEpotentials.Models.agnesi_transform(r0, rcut, 2, 2) +trans_2_4 = ACEpotentials.Models.agnesi_transform(r0, rcut, 2, 4) +trans_1_3 = ACEpotentials.Models.agnesi_transform(r0, rcut, 1, 3) + +rr = range(-0.5, rcut+0.5, length=200) + +y_2_2 = ACEpotentials.Models.evaluate.(Ref(trans_2_2), rr) +y_2_4 = ACEpotentials.Models.evaluate.(Ref(trans_2_4), rr) +y_1_3 = ACEpotentials.Models.evaluate.(Ref(trans_1_3), rr) +dy_2_2 = ACEpotentials.Models.evaluate_d.(Ref(trans_2_2), rr) +dy_2_4 = ACEpotentials.Models.evaluate_d.(Ref(trans_2_4), rr) +dy_1_3 = ACEpotentials.Models.evaluate_d.(Ref(trans_1_3), rr) + +plt1 = plot(rr, y_2_2, label="Agnesi(2,2)", lw=2, legend=:topleft, ylims = (-0.2, 1.2)) +plot!(rr, y_2_4, label="Agnesi(2,4)", lw=2) +plot!(rr, y_1_3, label="Agnesi(1,3)", lw=2) +vline!([0.0, r0, rcut], ls=:dash, lw=2, label="rin, r0, rcut") + +plt2 = plot(rr, dy_2_2, label="∇Agnesi(2,2)", lw=2, legend=:topright, ylims = (-0.05, 0.35)) +plot!(rr, dy_2_4, label="∇Agnesi(2,4)", lw=2) +plot!(rr, dy_1_3, label="∇Agnesi(1,3)", lw=2) +vline!([0.0, r0, rcut], ls=:dash, lw=2, label="rin, r0, rcut") + +plot(plt1, plt2, layout=(1,2), size = (600, 800)) +=# + +## + + + + From 582af44a8b23bc12058e0288a341d721f0e77adb Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 3 May 2024 15:53:17 -0700 Subject: [PATCH 003/112] first draft of transforms finished --- src/models/radial_transforms.jl | 62 ++++++++++++++++++++++----- test/models/test_radial_transforms.jl | 11 ++++- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/models/radial_transforms.jl b/src/models/radial_transforms.jl index ae20554f..ef751acd 100644 --- a/src/models/radial_transforms.jl +++ b/src/models/radial_transforms.jl @@ -80,7 +80,7 @@ function inv_transform(t::NormalizedTransform{T}, x::Number) where {T} return convert(T1, t.rcut) end - g = r -> transform(t, r) - x + g = r -> evaluate(t, r) - x r = find_zero(g, (t.rin, t.rcut), Brent()) @assert t.rin <= r <= t.rcut @assert abs(g(r)) < 1e-12 @@ -100,27 +100,69 @@ read_dict(::Val{:ACEpotentials_NormalizedTransform}, D::Dict) = # --------------------------------------------------------------------------- function test_normalized_transform(t; nx = 1000) + fails = 0 + x = range(0.0, 1.0, length = nx) r = [inv_transform(t, xi) for xi in x] - @assert r[1] == 0.0 - @assert r[end] == t.rcut + if !(r[1] ≈ t.rin) + fails += 1 + @error("t⁻¹(0) ≈ rin fails") + end + if !(r[end] ≈ t.rcut) + fails += 1 + @error("t⁻¹(1) ≈ rcut fails") + end x2 = [t(ri) for ri in r] if !all(abs.(x .- x2) .< 1e-10) - error("Inverse transform failed!") + fails += 1 + @error("t⁻¹ ∘ t ≈ id failed!") + end + if !(all(r[2:end] - r[1:end-1] .>= - eps())) + fails += 1 + @error("t⁻¹ is not monotonically increasing") end - @assert all(r[2:end] - r[1:end-1] .>= 0) r = range(0.0, t.rcut, length = nx) x = [t(ri) for ri in r] - @assert x[1] == t.yin - @assert x[end] == t.ycut + if !(x[1] ≈ 0.0) + fails += 1 + @error("t(rin) ≈ yin fails") + end + if !(x[end] ≈ 1.0) + fails += 1 + @error("t(rcut) ≈ ycut fails") + end r2 = [inv_transform(t, xi) for xi in x] if !all(abs.(r .- r2) .< 1e-10) - error("Inverse transform failed!") + fails += 1 + @error("t ∘ t⁻¹ ≈ id failed!") + end + if !(all(x[2:end] - x[1:end-1] .>= - eps())) + fails += 1 + @error("t is not monotonically increasing") + end + + rr = r[2:end-1] + dx = evaluate_d.(Ref(t), rr) + adx = ForwardDiff.derivative.(Ref(r -> t(r)), rr) + if !all(abs.(dx - adx) .< 1e-10) + fails += 1 + @error("transform gradient test failed") + end + + if fails > 0 + @info("$fails transform tests fails") + end + + # TODO: check that the transform doesn't allocate + @allocated begin + x1 = 0.0; + for r1 in rr + x1 += evaluate(t, r1) + end end - @assert all(x[2:end] .- x[1:end-1] .>= 0) - return true + return (fails == 0) end # test transform from ACE1 to be merged with the above. diff --git a/test/models/test_radial_transforms.jl b/test/models/test_radial_transforms.jl index 9b2291a8..a5e8a4e5 100644 --- a/test/models/test_radial_transforms.jl +++ b/test/models/test_radial_transforms.jl @@ -2,7 +2,7 @@ using Pkg; Pkg.activate("."); using TestEnv; TestEnv.activate(); -using ACEpotentials +using ACEpotentials, Test # there are no real tests for envelopes yet. The only thing we have is # a plot of the envelopes to inspect manually. @@ -45,5 +45,14 @@ plot(plt1, plt2, layout=(1,2), size = (600, 800)) ## +rcut = 6.5 +r0 = 2.3 + +trans_2_2 = ACEpotentials.Models.agnesi_transform(r0, rcut, 2, 2) +trans_2_4 = ACEpotentials.Models.agnesi_transform(r0, rcut, 2, 4) +trans_1_3 = ACEpotentials.Models.agnesi_transform(r0, rcut, 1, 3) +for trans in [trans_2_2, trans_2_4, trans_1_3] + @test ACEpotentials.Models.test_normalized_transform(trans_2_2) +end From 7c6886d643e35e78a69300815a41d1e051373771 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 4 May 2024 22:48:15 -0700 Subject: [PATCH 004/112] draft learnable Rnlrzz basis --- Project.toml | 3 + src/models/Rnl_basis.jl | 133 ++++++++++++++++++++++++++ src/models/elements.jl | 54 +++++++++++ src/models/models.jl | 4 + src/models/radial_transforms.jl | 54 ++++------- test/Project.toml | 2 + test/models/test_learnable_Rnl.jl | 43 +++++++++ test/models/test_radial_transforms.jl | 6 +- 8 files changed, 262 insertions(+), 37 deletions(-) create mode 100644 src/models/Rnl_basis.jl create mode 100644 src/models/elements.jl create mode 100644 test/models/test_learnable_Rnl.jl diff --git a/Project.toml b/Project.toml index e88efd7a..c7b052b5 100644 --- a/Project.toml +++ b/Project.toml @@ -13,9 +13,12 @@ Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JuLIP = "945c410c-986d-556a-acb1-167a618e0462" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl new file mode 100644 index 00000000..882af7fa --- /dev/null +++ b/src/models/Rnl_basis.jl @@ -0,0 +1,133 @@ + +import LuxCore: AbstractExplicitLayer, + initialparameters, + initialstates +using StaticArrays: SMatrix +using Random: AbstractRNG + +abstract type AbstractRnlzzBasis <: AbstractExplicitLayer end + +# NOTEs: +# each smatrix in the types below indexes (i, j) +# where i is the center, j is neighbour + + +struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW} <: AbstractRnlzzBasis + _i2z::NTuple{NZ, Int} + polys::TPOLY + transforms::SMatrix{NZ, NZ, TT} + envelopes::SMatrix{NZ, NZ, TENV} + # rcut::SMatrix{NZ, NZ, T} # matrix of (rin, rout) + weights::SMatrix{NZ, NZ, TW} # learnable weights, nothing when using Lux + #-------------- + # meta should contain spec, rin0cuts + meta::Dict{String, Any} +end + + +# struct SplineRnlrzzBasis{NZ, SPL, ENV} <: AbstractRnlzzBasis +# _i2z::NTuple{NZ, Int} # iz -> z mapping +# splines::SMatrix{NZ, NZ, SPL} # matrix of splined radial bases +# envelopes::SMatrix{NZ, NZ, ENV} # matrix of radial envelopes +# rincut::SMatrix{NZ, NZ, Tuple{T, T}} # matrix of (rin, rout) + +# #-------------- +# # meta should contain spec +# meta::Dict{String, Any} +# end + +# a few getter functions for convenient access to those fields of matrices +_rincut_zz(obj, zi, zj) = obj.rincut[_z2i(obj, zi), _z2i(obj, zj)] +_envelope_zz(obj, zi, zj) = obj.envelopes[_z2i(obj, zi), _z2i(obj, zj)] +_spline_zz(obj, zi, zj) = obj.splines[_z2i(obj, zi), _z2i(obj, zj)] +_transform_zz(obj, zi, zj) = obj.transforms[_z2i(obj, zi), _z2i(obj, zj)] +_poly_zz(obj, zi, zj) = obj.poly[_z2i(obj, zi), _z2i(obj, zj)] + + +# ------------------------------------------------------------ +# CONSTRUCTORS AND UTILITIES +# ------------------------------------------------------------ + +# these _auto_trans are very poor and need to take care of a lot more +# cases, e.g. we may want to pass in the objects as a Matrix rather than +# SMatrix ... + +_auto_trans(t, NZ) = (t isa SMatrix) ? t : SMatrix{NZ, NZ}(fill(t, (NZ, NZ))) + +_auto_envel(env, NZ) = (env isa SMatrix) ? env : SMatrix{NZ, NZ}(fill(env, (NZ, NZ))) + +_auto_rincut(rincut, NZ) = (rincut isa SMatrix) ? rincut : SMatrix{NZ, NZ}(fill(rincut, (NZ, NZ))) + +_auto_weights(weights, NZ) = (weights isa SMatrix) ? weights : SMatrix{NZ, NZ}(fill(weights, (NZ, NZ))) + + +function LearnableRnlrzzBasis( + zlist, polys, transforms, envelopes, + rin0cuts, + spec::Vector{<: NamedTuple}; + weights=nothing, + meta=Dict{String, Any}()) + meta["rin0cuts"] = rin0cuts + meta["spec"] = spec + LearnableRnlrzzBasis(_convert_zlist(zlist), polys, + _auto_trans(transforms, length(zlist)), + _auto_envel(envelopes, length(zlist)), + # _auto_rincut(rincut, length(zlist)), + _auto_weights(weights, length(zlist)), + meta) +end + +Base.length(basis::LearnableRnlrzzBasis) = length(basis.meta["spec"]) + +function initialparameters(rng::AbstractRNG, + basis::LearnableRnlrzzBasis) + NZ = _get_nz(basis) + len_nl = length(basis) + len_q = length(basis.polys) + + function _W() + W = randn(rng, len_nl, len_q) + W = W ./ sqrt.(sum(W.^2, dims = 2)) + end + + return (W = [ _W() for i = 1:NZ, j = 1:NZ ], ) +end + +function initialstates(rng::AbstractRNG, + basis::LearnableRnlrzzBasis) + return NamedTuple() +end + + +# function learnable_Rnlrzz_basis(zlist; +# polys = :auto, +# transforms = :auto, +# envelopes = :auto, +# rincut = :auto, +# weight = :auto) + +# end + +function splinify(basis::LearnableRnlrzzBasis) + +end + +# ------------------------------------------------------------ +# EVALUATION INTERFACE +# ------------------------------------------------------------ + +import Polynomials4ML + +(l::LearnableRnlrzzBasis)(args...) = evaluate(l, args...) + +function evaluate(basis::LearnableRnlrzzBasis, r, Zi, Zj, ps, st) + iz = _z2i(basis, Zi) + jz = _z2i(basis, Zj) + Wij = ps.W[iz, jz] + trans_ij = basis.transforms[iz, jz] + x = trans_ij(r) + P = Polynomials4ML.evaluate(basis.polys, x) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + return Wij * (P .* e), st +end \ No newline at end of file diff --git a/src/models/elements.jl b/src/models/elements.jl new file mode 100644 index 00000000..aec87dca --- /dev/null +++ b/src/models/elements.jl @@ -0,0 +1,54 @@ + +using JuLIP: AtomicNumber +using StaticArrays: SMatrix +import ACE1x + + +_i2z(obj, i::Integer) = obj._i2z[i] + +_get_nz(obj) = length(obj._i2z) + +function _z2i(obj, Z) + for i_Z = 1:length(obj._i2z) + if obj._i2z[i_Z] == Z + return i_Z + end + end + error("_z2i : Z = $Z not found in obj._i2z") + return -1 # never reached +end + +# convert AtomicNumber -> Int is already defined in JuLIP +# we also want Symbol -> Int, but this would be terrible type piracy! +# so intead we make it a case distinction inside the _convert_zlist. +# not elegant but works for now. + +function _convert_zlist(zlist) + if eltype(zlist) == Symbol + return ntuple(i -> convert(Int, AtomicNumber(zlist[i])), length(zlist)) + end + return ntuple(i -> convert(Int, zlist[i]), length(zlist)) +end + +function _default_rin0cuts(zlist; rinfactor = 0.0, rcutfactor = 2.5) + function rin0cut(zi, zj) + r0 = ACE1x.get_r0(zi, zj) + return (rin = r0 * rinfactor, r0 = r0, rcut = r0 * rcutfactor) + end + NZ = length(zlist) + return SMatrix{NZ, NZ}([ rin0cut(zi, zj) for zi in zlist, zj in zlist ]) +end + + +# a one-hot embedding for the z variable. +# function embed_z(ace, Rs, Zs) +# TF = eltype(eltype(Rs)) +# Ez = acquire!(ace.pool, :Ez, (length(Zs), length(ace.rbasis)), TF) +# fill!(Ez, 0) +# for (j, z) in enumerate(Zs) +# iz = _z2i(ace.rbasis, z) +# Ez[j, iz] = 1 +# end +# return Ez +# end + diff --git a/src/models/models.jl b/src/models/models.jl index c293f598..de017067 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -1,8 +1,12 @@ module Models +include("elements.jl") + include("radial_envelopes.jl") include("radial_transforms.jl") +include("Rnl_basis.jl") + end \ No newline at end of file diff --git a/src/models/radial_transforms.jl b/src/models/radial_transforms.jl index ef751acd..930f27b8 100644 --- a/src/models/radial_transforms.jl +++ b/src/models/radial_transforms.jl @@ -42,6 +42,9 @@ evaluate_d(t::GeneralizedAgnesiTransform, r::Number) = # tested a wide range of methods. Brent seems robust + fastest using Roots: find_zero, ITP, Brent +""" +Maps the transform `trans` to the standardized interval [-1, 1] +""" struct NormalizedTransform{T, TT} trans::TT yin::T @@ -60,8 +63,9 @@ end (t::NormalizedTransform)(r) = evaluate(t, r) function evaluate(t::NormalizedTransform, r::Number) - y = t.trans(r) - return min(max(zero(y), (y - t.yin) / (t.ycut - t.yin)), one(y)) + y = t.trans(r) + 𝟙 = one(typeof(y)) + return min(max(-𝟙, -𝟙 + 2 * (y - t.yin) / (t.ycut - t.yin)), 𝟙) end # this is the old version from ACE1.jl; a neat idea. We could return to it. @@ -74,9 +78,10 @@ evaluate_d(t::NormalizedTransform, r::Number) = function inv_transform(t::NormalizedTransform{T}, x::Number) where {T} T1 = promote_type(T, typeof(x)) - if x <= 0 + 𝟙 = one(T1) + if x <= -𝟙 return convert(T1, t.rin) - elseif x >= 1 + elseif x >= 𝟙 return convert(T1, t.rcut) end @@ -102,11 +107,11 @@ read_dict(::Val{:ACEpotentials_NormalizedTransform}, D::Dict) = function test_normalized_transform(t; nx = 1000) fails = 0 - x = range(0.0, 1.0, length = nx) + x = range(-1.0, 1.0, length = nx) r = [inv_transform(t, xi) for xi in x] if !(r[1] ≈ t.rin) fails += 1 - @error("t⁻¹(0) ≈ rin fails") + @error("t⁻¹(-1) ≈ rin fails") end if !(r[end] ≈ t.rcut) fails += 1 @@ -124,7 +129,7 @@ function test_normalized_transform(t; nx = 1000) r = range(0.0, t.rcut, length = nx) x = [t(ri) for ri in r] - if !(x[1] ≈ 0.0) + if !(x[1] ≈ -1.0) fails += 1 @error("t(rin) ≈ yin fails") end @@ -160,37 +165,11 @@ function test_normalized_transform(t; nx = 1000) for r1 in rr x1 += evaluate(t, r1) end - end + end return (fails == 0) end -# test transform from ACE1 to be merged with the above. -# function test_transform(T, rrange, ntests = 100) - -# rmin, rmax = extrema(rrange) -# rr = rmin .+ rand(100) * (rmax-rmin) -# xx = [ transform(T, r) for r in rr ] -# # check syntactic sugar -# xx1 = [ T(r) for r in rr ] -# print_tf(@test xx1 == xx) -# # check inversion -# rr1 = inv_transform.(Ref(T), xx) -# print_tf(@test rr1 ≈ rr) -# # check gradient -# dx = transform_d.(Ref(T), rr) -# adx = ForwardDiff.derivative.(Ref(r -> transform(T, r)), rr) -# print_tf(@test dx ≈ adx) - -# # TODO: check that the transform doesn't allocate -# @allocated begin -# x = 0.0; -# for r in rr -# x += transform(T, r) -# end -# end -# end - @@ -227,3 +206,10 @@ function agnesi_transform(r0, rcut, p, q; GeneralizedAgnesiTransform(p, q, a, rin, r0), rin, rcut ) end + +function agnesi_transform(rin0cut::NamedTuple, p, q) + rin = rin0cut.rin + r0 = rin0cut.r0 + rcut = rin0cut.rcut + return agnesi_transform(r0, rcut, p, q, rin = rin) +end diff --git a/test/Project.toml b/test/Project.toml index 429ffa7c..c13e4e29 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,5 +4,7 @@ ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" JuLIP = "945c410c-986d-556a-acb1-167a618e0462" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl new file mode 100644 index 00000000..844652fc --- /dev/null +++ b/test/models/test_learnable_Rnl.jl @@ -0,0 +1,43 @@ + + + +using Pkg; Pkg.activate("."); +using TestEnv; TestEnv.activate(); + +using ACEpotentials +import Polynomials4ML +P4ML = Polynomials4ML +M = ACEpotentials.Models + +using Random, LuxCore +rng = Random.MersenneTwister(1234) + +## + +# LearnableRnlrzzBasis( +# zlist, polys, transforms, envelopes, rincut, spec::Vector{T_NL_TUPLE}; +# weights=nothing, meta=Dict{String, Any}()) = +# LeanrableRnlBasis(_convert_zlist(zlist), polys, +# _auto_trans(transforms, length(zlist)), +# _auto_envel(envelopes, length(zlist)), +# _auto_rincut(rincut, length(zlist)), +# _auto_weights(weights, length(zlist)), +# meta) + +Dtot = 5 +lmax = 3 +elements = (:Si, :O) +zlist = M._convert_zlist(elements) +rin0cuts = M._default_rin0cuts(elements) +transforms = M.agnesi_transform.(rin0cuts, 2, 2) +polys = P4ML.legendre_basis(Dtot+1) +envelopes = M.PolyEnvelope2sX(-1.0, 1.0, 2, 2) +spec = [ (n = n, l = l) for n = 1:(Dtot+1), l = 0:lmax if (n-1 + l) <= Dtot ] + +basis = M.LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, rin0cuts, spec) +ps, st = LuxCore.setup(rng, basis) + +r = 3.0 +Zi = zlist[1] +Zj = zlist[2] +Rnl, st1 = basis(r, Zi, Zj, ps, st) diff --git a/test/models/test_radial_transforms.jl b/test/models/test_radial_transforms.jl index a5e8a4e5..e61626e7 100644 --- a/test/models/test_radial_transforms.jl +++ b/test/models/test_radial_transforms.jl @@ -29,17 +29,17 @@ dy_2_2 = ACEpotentials.Models.evaluate_d.(Ref(trans_2_2), rr) dy_2_4 = ACEpotentials.Models.evaluate_d.(Ref(trans_2_4), rr) dy_1_3 = ACEpotentials.Models.evaluate_d.(Ref(trans_1_3), rr) -plt1 = plot(rr, y_2_2, label="Agnesi(2,2)", lw=2, legend=:topleft, ylims = (-0.2, 1.2)) +plt1 = plot(rr, y_2_2, label="Agnesi(2,2)", lw=2, legend=:topleft, ylims = (-1.2, 1.2)) plot!(rr, y_2_4, label="Agnesi(2,4)", lw=2) plot!(rr, y_1_3, label="Agnesi(1,3)", lw=2) vline!([0.0, r0, rcut], ls=:dash, lw=2, label="rin, r0, rcut") -plt2 = plot(rr, dy_2_2, label="∇Agnesi(2,2)", lw=2, legend=:topright, ylims = (-0.05, 0.35)) +plt2 = plot(rr, dy_2_2, label="∇Agnesi(2,2)", lw=2, legend=:topright, ylims = (-0.05, 0.8)) plot!(rr, dy_2_4, label="∇Agnesi(2,4)", lw=2) plot!(rr, dy_1_3, label="∇Agnesi(1,3)", lw=2) vline!([0.0, r0, rcut], ls=:dash, lw=2, label="rin, r0, rcut") -plot(plt1, plt2, layout=(1,2), size = (600, 800)) +plot(plt1, plt2, layout=(2,1), size = (600, 800)) =# ## From 1b9fa978a3b0910c13ae557d311eb644d54ba5ec Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 5 May 2024 17:14:41 -0700 Subject: [PATCH 005/112] simple Rnlrzz constructor --- src/models/Rnl_basis.jl | 60 ++++++++++++++----------------- src/models/ace_heuristics.jl | 33 +++++++++++++++++ src/models/elements.jl | 19 ++++++++++ src/models/models.jl | 1 + test/models/test_learnable_Rnl.jl | 29 ++++----------- 5 files changed, 85 insertions(+), 57 deletions(-) create mode 100644 src/models/ace_heuristics.jl diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 882af7fa..66c78450 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -11,16 +11,20 @@ abstract type AbstractRnlzzBasis <: AbstractExplicitLayer end # each smatrix in the types below indexes (i, j) # where i is the center, j is neighbour +const NT_RIN0CUTS{T} = NamedTuple{(:rin, :r0, :rcut), Tuple{T, T, T}} +const NT_NL_SPEC = NamedTuple{(:n, :l), Tuple{Int, Int}} -struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW} <: AbstractRnlzzBasis +struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractRnlzzBasis _i2z::NTuple{NZ, Int} polys::TPOLY transforms::SMatrix{NZ, NZ, TT} envelopes::SMatrix{NZ, NZ, TENV} - # rcut::SMatrix{NZ, NZ, T} # matrix of (rin, rout) - weights::SMatrix{NZ, NZ, TW} # learnable weights, nothing when using Lux - #-------------- - # meta should contain spec, rin0cuts + # -------------- + weights::SMatrix{NZ, NZ, TW} # learnable weights, `nothing` when using Lux + rin0cuts::SMatrix{NZ, NZ, NT_RIN0CUTS{T}} # matrix of (rin, rout, rcut) + spec::Vector{NT_NL_SPEC} + # -------------- + # meta meta::Dict{String, Any} end @@ -36,48 +40,44 @@ end # meta::Dict{String, Any} # end + # a few getter functions for convenient access to those fields of matrices -_rincut_zz(obj, zi, zj) = obj.rincut[_z2i(obj, zi), _z2i(obj, zj)] +_rincut_zz(obj, zi, zj) = obj.rin0cut[_z2i(obj, zi), _z2i(obj, zj)] _envelope_zz(obj, zi, zj) = obj.envelopes[_z2i(obj, zi), _z2i(obj, zj)] _spline_zz(obj, zi, zj) = obj.splines[_z2i(obj, zi), _z2i(obj, zj)] _transform_zz(obj, zi, zj) = obj.transforms[_z2i(obj, zi), _z2i(obj, zj)] -_poly_zz(obj, zi, zj) = obj.poly[_z2i(obj, zi), _z2i(obj, zj)] +# _polys_zz(obj, zi, zj) = obj.polys[_z2i(obj, zi), _z2i(obj, zj)] # ------------------------------------------------------------ # CONSTRUCTORS AND UTILITIES # ------------------------------------------------------------ -# these _auto_trans are very poor and need to take care of a lot more +# these _auto_... are very poor and need to take care of a lot more # cases, e.g. we may want to pass in the objects as a Matrix rather than # SMatrix ... -_auto_trans(t, NZ) = (t isa SMatrix) ? t : SMatrix{NZ, NZ}(fill(t, (NZ, NZ))) - -_auto_envel(env, NZ) = (env isa SMatrix) ? env : SMatrix{NZ, NZ}(fill(env, (NZ, NZ))) - -_auto_rincut(rincut, NZ) = (rincut isa SMatrix) ? rincut : SMatrix{NZ, NZ}(fill(rincut, (NZ, NZ))) -_auto_weights(weights, NZ) = (weights isa SMatrix) ? weights : SMatrix{NZ, NZ}(fill(weights, (NZ, NZ))) function LearnableRnlrzzBasis( - zlist, polys, transforms, envelopes, - rin0cuts, - spec::Vector{<: NamedTuple}; + zlist, polys, transforms, envelopes, rin0cuts, + spec::AbstractVector{NT_NL_SPEC}; weights=nothing, meta=Dict{String, Any}()) - meta["rin0cuts"] = rin0cuts - meta["spec"] = spec - LearnableRnlrzzBasis(_convert_zlist(zlist), polys, - _auto_trans(transforms, length(zlist)), - _auto_envel(envelopes, length(zlist)), - # _auto_rincut(rincut, length(zlist)), - _auto_weights(weights, length(zlist)), - meta) + NZ = length(zlist) + LearnableRnlrzzBasis(_convert_zlist(zlist), + polys, + _make_smatrix(transforms, NZ), + _make_smatrix(envelopes, NZ), + # -------------- + _make_smatrix(weights, NZ), + _make_smatrix(rin0cuts, NZ), + collect(spec), + meta) end -Base.length(basis::LearnableRnlrzzBasis) = length(basis.meta["spec"]) +Base.length(basis::LearnableRnlrzzBasis) = length(basis.spec) function initialparameters(rng::AbstractRNG, basis::LearnableRnlrzzBasis) @@ -99,14 +99,6 @@ function initialstates(rng::AbstractRNG, end -# function learnable_Rnlrzz_basis(zlist; -# polys = :auto, -# transforms = :auto, -# envelopes = :auto, -# rincut = :auto, -# weight = :auto) - -# end function splinify(basis::LearnableRnlrzzBasis) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl new file mode 100644 index 00000000..44a2631f --- /dev/null +++ b/src/models/ace_heuristics.jl @@ -0,0 +1,33 @@ + + +function ace_learnable_Rnlrzz(; + Dtot = nothing, + lmax = nothing, + elements = nothing, + spec = nothing, + rin0cuts = _default_rin0cuts(elements), + transforms = agnesi_transform.(rin0cuts, 2, 2), + polys = Polynomials4ML.legendre_basis(Dtot+1), + envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) + ) + if elements == nothing + error("elements must be specified!") + end + if (spec == nothing) && (Dtot == nothing || lmax == nothing) + error("Must specify either `spec` or `Dtot` and `lmax`") + end + + zlist =_convert_zlist(elements) + + if spec == nothing + spec = [ (n = n, l = l) for n = 1:(Dtot+1), l = 0:lmax + if (n-1 + l) <= Dtot ] + end + + maxn = maximum([ s.n for s in spec ]) + if maxn > length(polys) + error("maxn > length of polynomial basis") + end + + return LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, rin0cuts, spec) +end diff --git a/src/models/elements.jl b/src/models/elements.jl index aec87dca..99b2aa44 100644 --- a/src/models/elements.jl +++ b/src/models/elements.jl @@ -39,6 +39,25 @@ function _default_rin0cuts(zlist; rinfactor = 0.0, rcutfactor = 2.5) return SMatrix{NZ, NZ}([ rin0cut(zi, zj) for zi in zlist, zj in zlist ]) end +""" +Takes an object and converts it to an `SMatrix{NZ, NZ}` via the following rules: +- if `obj` is already an `SMatrix{NZ, NZ}` then it just return `obj` +- if `obj` is an `AbstractMatrix` and `size(obj) == (NZ, NZ)` then it + converts it to an `SMatrix{NZ, NZ}` with the same entries. +- otherwise it generates an `SMatrix{NZ, NZ}` filled with the value `obj`. +""" +function _make_smatrix(obj, NZ) + if obj isa SMatrix{NZ, NZ} + return obj + end + if obj isa AbstractMatrix && size(obj) == (NZ, NZ) + return SMatrix{NZ, NZ}(obj) + end + if obj isa AbstractArray && size(obj) != (NZ, NZ) + error("`_make_smatrix` : if the input `obj` is an `AbstractArray` then it must be of size `(NZ, NZ)`") + end + return SMatrix{NZ, NZ}(fill(obj, (NZ, NZ))) +end # a one-hot embedding for the z variable. # function embed_z(ace, Rs, Zs) diff --git a/src/models/models.jl b/src/models/models.jl index de017067..c92c7e6f 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -9,4 +9,5 @@ include("radial_transforms.jl") include("Rnl_basis.jl") +include("ace_heuristics.jl") end \ No newline at end of file diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 844652fc..8b5fc0c0 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -2,11 +2,9 @@ using Pkg; Pkg.activate("."); -using TestEnv; TestEnv.activate(); +# using TestEnv; TestEnv.activate(); using ACEpotentials -import Polynomials4ML -P4ML = Polynomials4ML M = ACEpotentials.Models using Random, LuxCore @@ -14,30 +12,15 @@ rng = Random.MersenneTwister(1234) ## -# LearnableRnlrzzBasis( -# zlist, polys, transforms, envelopes, rincut, spec::Vector{T_NL_TUPLE}; -# weights=nothing, meta=Dict{String, Any}()) = -# LeanrableRnlBasis(_convert_zlist(zlist), polys, -# _auto_trans(transforms, length(zlist)), -# _auto_envel(envelopes, length(zlist)), -# _auto_rincut(rincut, length(zlist)), -# _auto_weights(weights, length(zlist)), -# meta) - Dtot = 5 lmax = 3 elements = (:Si, :O) -zlist = M._convert_zlist(elements) -rin0cuts = M._default_rin0cuts(elements) -transforms = M.agnesi_transform.(rin0cuts, 2, 2) -polys = P4ML.legendre_basis(Dtot+1) -envelopes = M.PolyEnvelope2sX(-1.0, 1.0, 2, 2) -spec = [ (n = n, l = l) for n = 1:(Dtot+1), l = 0:lmax if (n-1 + l) <= Dtot ] - -basis = M.LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, rin0cuts, spec) +basis = M.ace_learnable_Rnlrzz(Dtot = Dtot, lmax = lmax, elements = elements) + ps, st = LuxCore.setup(rng, basis) r = 3.0 -Zi = zlist[1] -Zj = zlist[2] +Zi = basis._i2z[1] +Zj = basis._i2z[2] Rnl, st1 = basis(r, Zi, Zj, ps, st) + From cc00109ec011a28c42ea27ab0bc7d06683317f03 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 5 May 2024 23:40:57 -0700 Subject: [PATCH 006/112] several steps towards generating an ACE model --- Project.toml | 4 + src/models/ace.jl | 72 ++++++++++++++ test/models/test_ace.jl | 157 ++++++++++++++++++++++++++++++ test/models/test_learnable_Rnl.jl | 2 +- 4 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 src/models/ace.jl create mode 100644 test/models/test_ace.jl diff --git a/Project.toml b/Project.toml index c7b052b5..90fa8afb 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ ACE1 = "e3f9bc04-086e-409a-ba78-e9769fe067bb" ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" +EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc" ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -20,8 +21,10 @@ Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +RepLieGroups = "f07d36f2-91c4-427a-b67b-965fe5ebe1d2" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" @@ -35,6 +38,7 @@ ExtXYZ = "0.1.14" JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" +SpheriCart = "0.0.3" StaticArrays = "1" UltraFastACE = "0.0.2" YAML = "0.4" diff --git a/src/models/ace.jl b/src/models/ace.jl new file mode 100644 index 00000000..24fbb39d --- /dev/null +++ b/src/models/ace.jl @@ -0,0 +1,72 @@ + +using LuxCore: AbstractExplicitLayer, + AbstractExplicitContainerLayer, + initialparameters, + initialstates + +using SparseArrays: SparseMatrixCSC + +using SpheriCart: SolidHarmonics, SphericalHarmonics + +# ------------------------------------------------------------ +# ACE MODEL SPECIFICATION + + +struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer + _i2z::NTuple{NZ, Int} + rbasis::TRAD + ybasis::TY + abasis::TA + aabasis::TAA + A2Bmap::SparseMatrixCSC{T, Int} + # -------------- + # we can add a FS embedding here + # -------------- + bparams::NTuple{NZ, Vector{T}} + aaparams::NTuple{NZ, Vector{T}} + # -------------- + meta::Dict{String, Any} +end + +# ------------------------------------------------------------ +# CONSTRUCTORS AND UTILITIES + +const NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} + +function _A_from_AA_spec(AA_spec) + A_spec = NT_NLM[] + for bb in AA_spec + append!(A_spec, bb) + end + return unique(sort(A_spec)) +end + +function make_Y_basis(Ytype, lmax) + if Ytype == :solid + return SolidHarmonics(lmax) + elseif Ytype == :spherical + return SphericalHarmonics(lmax) + end + + error("unknown `Ytype` = $Ytype - I don't know how to generate a spherical basis from this.") +end + +function sort_and_filter_AA_spec(AA_spec) + + return unique(sort(AA_spec)) +end + + +function generate_A2B_map(AA_spec) + function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Rot3DCoeffs_long{L,T}}, spec) where {L,T} + +end + +function ace_model(rbasis, Ytype, AA_spec) + + A_spec = _A_from_AA_spec(AA_spec) + lmax = maximum([b.l for b in a_spec]) + ybasis = make_Y_basis(Ytype, lmax) + +end + diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl new file mode 100644 index 00000000..71b3e9ee --- /dev/null +++ b/test/models/test_ace.jl @@ -0,0 +1,157 @@ + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); + +using ACEpotentials +M = ACEpotentials.Models + +using Polynomials4ML +P4ML = Polynomials4ML + +using Random, LuxCore +rng = Random.MersenneTwister(1234) + +function _inv_list(l) + d = Dict() + for (i, x) in enumerate(l) + d[x] = i + end + return d +end + +struct TotalDegree + wL::Float64 +end + +TotalDegree() = TotalDegree(1.5) + +(l::TotalDegree)(b::NamedTuple) = b.n + b.l +(l::TotalDegree)(bb::Vector{<: NamedTuple}) = sum(l(b) for b in bb) + + +function make_AA_spec(; order = nothing, + r_spec = nothing, + max_level = nothing, + level = nothing, ) + # compute the r levels + r_level = [ level(b) for b in r_spec ] + + # generate the A basis spec from the radial basis spec. + NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} + A_spec = NT_NLM[] + A_spec_level = eltype(r_level)[] + for br in r_spec + for m = -br.l:br.l + b = (n = br.n, l = br.l, m = m) + push!(A_spec, b) + push!(A_spec_level, level(b)) + end + end + p = sortperm(A_spec_level) + A_spec = A_spec[p] + A_spec_level = A_spec_level[p] + inv_A_spec = _inv_list(A_spec) + + # generate the AA basis spec from the A basis spec + tup2b = vv -> [ A_spec[v] for v in vv[vv .> 0] ] + admissible = bb -> (length(bb) == 0) || level(bb) <= max_level + filter_ = EquivariantModels.RPE_filter_real(0) + + AA_spec = EquivariantModels.gensparse(; + NU = order, tup2b = tup2b, filter = filter_, + admissible = admissible, + minvv = fill(0, order), + maxvv = fill(length(A_spec), order), + ordered = true) + + AA_spec = [ vv[vv .> 0] for vv in AA_spec if !(isempty(vv[vv .> 0])) ] + + # map back to nlm + AA_spec_nlm = Vector{NT_NLM}[] + if length(AA_spec[1]) == 0 + push!(AA_spec_nlm, NT_NLM[]) + idx0 = 2 + else + idx0 = 1 + end + for vv in AA_spec + push!(AA_spec_nlm, [ A_spec[v] for v in vv ]) + end + + return AA_spec_nlm +end + +function make_A_spec(AA_spec, level) + NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} + A_spec = NT_NLM[] + for bb in AA_spec + append!(A_spec, bb) + end + A_spec_level = [ level(b) for b in A_spec ] + p = sortperm(A_spec_level) + A_spec = A_spec[p] + return A_spec +end + +## + +elements = (:Si, :O) +level = TotalDegree() +max_level = 8 +lmax = 4 + +rbasis = M.ace_learnable_Rnlrzz(Dtot = Dtot, lmax = lmax, elements = elements) +r_spec = rbasis.spec + +AA_spec = make_specs(order = 3, r_spec = r_spec, + level = level, max_level = max_level) + +## + +import RepLieGroups +import EquivariantModels +import SpheriCart + +cgen = EquivariantModels.Rot3DCoeffs_real(0) +AA2BB_map = EquivariantModels._rpi_A2B_matrix(cgen, AA_spec) + +keep_AA_idx = findall(sum(abs, AA2BB_map; dims = 1)[:] .> 0) + +AA_spec = AA_spec[keep_AA_idx] +AA2BB_map = AA2BB_map[:, keep_AA_idx] + +A_spec = make_A_spec(AA_spec, level) + +maxl = maximum([ b.l for b in A_spec ]) + +ybasis = SpheriCart.SolidHarmonics(maxl) + +## +# now we need to take the human-readable specs and convert them into +# the layer-readable specs + +r_spec = rbasis.spec + +# this should go into sphericart or P4ML +NT_LM = NamedTuple{(:l, :m), Tuple{Int, Int}} +y_spec = NT_LM[] +for i = 1:SpheriCart.sizeY(maxl) + l, m = SpheriCart.idx2lm(i) + push!(y_spec, (l = l, m = m)) +end + +# get the idx version of A_spec +inv_r_spec = _inv_list(r_spec) +inv_y_spec = _inv_list(y_spec) +A_spec_idx = [ (inv_r_spec[(n=b.n, l=b.l)], inv_y_spec[(l=b.l, m=b.m)]) + for b in A_spec ] +a_basis = P4ML.PooledSparseProduct(A_spec_idx) +a_basis.meta["A_spec"] = A_spec + +inv_A_spec = _inv_list(A_spec) +AA_spec_idx = [ [ inv_A_spec[b] for b in bb ] for bb in AA_spec ] +sort!.(AA_spec_idx) +aa_basis = P4ML.SparseSymmProdDAG(AA_spec_idx) +aa_basis.meta["AA_spec"] = AA_spec + +length(aa_basis) \ No newline at end of file diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 8b5fc0c0..f18ff363 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -1,7 +1,7 @@ -using Pkg; Pkg.activate("."); +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using ACEpotentials From dec0643f9d4fd4d657d6f1e714e00d93fc5ec55e Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 6 May 2024 11:00:22 -0700 Subject: [PATCH 007/112] cleanup ace model generation --- Project.toml | 1 + src/models/Rnl_basis.jl | 2 +- src/models/ace.jl | 127 +++++++++++++++++++++++++----- src/models/ace_heuristics.jl | 97 ++++++++++++++++++++--- src/models/models.jl | 7 +- src/models/utils.jl | 79 +++++++++++++++++++ test/models/test_ace.jl | 147 ++--------------------------------- test/models/test_models.jl | 0 8 files changed, 289 insertions(+), 171 deletions(-) create mode 100644 src/models/utils.jl create mode 100644 test/models/test_models.jl diff --git a/Project.toml b/Project.toml index 90fa8afb..371a59b1 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RepLieGroups = "f07d36f2-91c4-427a-b67b-965fe5ebe1d2" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 66c78450..61549160 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -90,7 +90,7 @@ function initialparameters(rng::AbstractRNG, W = W ./ sqrt.(sum(W.^2, dims = 2)) end - return (W = [ _W() for i = 1:NZ, j = 1:NZ ], ) + return (Wnlq = [ _W() for i = 1:NZ for j = 1:NZ ], ) end function initialstates(rng::AbstractRNG, diff --git a/src/models/ace.jl b/src/models/ace.jl index 24fbb39d..fc7631fb 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -4,15 +4,20 @@ using LuxCore: AbstractExplicitLayer, initialparameters, initialstates +using Random: AbstractRNG using SparseArrays: SparseMatrixCSC +import SpheriCart using SpheriCart: SolidHarmonics, SphericalHarmonics +import RepLieGroups +import EquivariantModels +import Polynomials4ML # ------------------------------------------------------------ # ACE MODEL SPECIFICATION -struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer +struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer{(:rbasis,)} _i2z::NTuple{NZ, Int} rbasis::TRAD ybasis::TY @@ -20,7 +25,7 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer aabasis::TAA A2Bmap::SparseMatrixCSC{T, Int} # -------------- - # we can add a FS embedding here + # we can add a nonlinear embedding here # -------------- bparams::NTuple{NZ, Vector{T}} aaparams::NTuple{NZ, Vector{T}} @@ -33,15 +38,7 @@ end const NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} -function _A_from_AA_spec(AA_spec) - A_spec = NT_NLM[] - for bb in AA_spec - append!(A_spec, bb) - end - return unique(sort(A_spec)) -end - -function make_Y_basis(Ytype, lmax) +function _make_Y_basis(Ytype, lmax) if Ytype == :solid return SolidHarmonics(lmax) elseif Ytype == :spherical @@ -51,22 +48,112 @@ function make_Y_basis(Ytype, lmax) error("unknown `Ytype` = $Ytype - I don't know how to generate a spherical basis from this.") end -function sort_and_filter_AA_spec(AA_spec) +# can we ignore the level function here? +function _make_A_spec(AA_spec, level) + NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} + A_spec = NT_NLM[] + for bb in AA_spec + append!(A_spec, bb) + end + A_spec_level = [ level(b) for b in A_spec ] + p = sortperm(A_spec_level) + A_spec = A_spec[p] + return A_spec +end + +# this should go into sphericart or P4ML +function _make_Y_spec(maxl::Integer) + NT_LM = NamedTuple{(:l, :m), Tuple{Int, Int}} + y_spec = NT_LM[] + for i = 1:SpheriCart.sizeY(maxl) + l, m = SpheriCart.idx2lm(i) + push!(y_spec, (l = l, m = m)) + end + return y_spec +end + + +function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, + level = TotalDegree()) + # generate the coupling coefficients + cgen = EquivariantModels.Rot3DCoeffs_real(0) + AA2BB_map = EquivariantModels._rpi_A2B_matrix(cgen, AA_spec) + + # find which AA basis functions are actually used and discard the rest + keep_AA_idx = findall(sum(abs, AA2BB_map; dims = 1)[:] .> 0) + AA_spec = AA_spec[keep_AA_idx] + AA2BB_map = AA2BB_map[:, keep_AA_idx] + + # generate the corresponding A basis spec + A_spec = _make_A_spec(AA_spec, level) + + # from the A basis we can generate the Y basis since we now know the + # maximum l value (though we probably already knew that from r_spec) + maxl = maximum([ b.l for b in A_spec ]) + ybasis = _make_Y_basis(Ytype, maxl) + + # now we need to take the human-readable specs and convert them into + # the layer-readable specs + r_spec = rbasis.spec + y_spec = _make_Y_spec(maxl) + + # get the idx version of A_spec + inv_r_spec = _inv_list(r_spec) + inv_y_spec = _inv_list(y_spec) + A_spec_idx = [ (inv_r_spec[(n=b.n, l=b.l)], inv_y_spec[(l=b.l, m=b.m)]) + for b in A_spec ] + # from this we can now generate the A basis layer + a_basis = Polynomials4ML.PooledSparseProduct(A_spec_idx) + a_basis.meta["A_spec"] = A_spec #(also store the human-readable spec) + + # get the idx version of AA_spec + inv_A_spec = _inv_list(A_spec) + AA_spec_idx = [ [ inv_A_spec[b] for b in bb ] for bb in AA_spec ] + sort!.(AA_spec_idx) + # from this we can now generate the AA basis layer + aa_basis = Polynomials4ML.SparseSymmProdDAG(AA_spec_idx) + aa_basis.meta["AA_spec"] = AA_spec # (also store the human-readable spec) - return unique(sort(AA_spec)) + NZ = _get_nz(rbasis) + n_B_params, n_AA_params = size(AA2BB_map) + return ACEModel(rbasis._i2z, rbasis, ybasis, a_basis, aa_basis, AA2BB_map, + ntuple(_ -> zeros(n_B_params), NZ), + ntuple(_ -> zeros(n_AA_params), NZ), + Dict{String, Any}() ) end +# TODO: it is not entirely clear that the `level` is really needed here +# since it is implicitly already encoded in AA_spec +function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level) + return _generate_ace_model(rbasis, Ytype, AA_spec, level) +end + +# NOTE : a nice convenience constructure is also provided in `ace_heuristics.jl` -function generate_A2B_map(AA_spec) - function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Rot3DCoeffs_long{L,T}}, spec) where {L,T} +# ------------------------------------------------------------ +# Lux stuff + +function initialparameters(rng::AbstractRNG, + model::ACEModel) + NZ = _get_nz(model) + n_B_params, n_AA_params = size(model.A2Bmap) + # only the B params are parameters, the AA params are uniquely defined + # via the B params. + return (WB = [ zeros(n_B_params) for _=1:NZ ], + rbasis = initialparameters(rng, model.rbasis), ) +end +function initialstates(rng::AbstractRNG, + model::ACEModel) + return ( rbasis = initialstates(rng, model.rbasis), ) end -function ace_model(rbasis, Ytype, AA_spec) +(l::ACEModel)(args...) = evaluate(l, args...) - A_spec = _A_from_AA_spec(AA_spec) - lmax = maximum([b.l for b in a_spec]) - ybasis = make_Y_basis(Ytype, lmax) -end +# ------------------------------------------------------------ +# Model Evaluation +# this should possibly be moved to a separate file once it +# gets more complicated. + diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 44a2631f..e2eec39d 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -1,33 +1,110 @@ +# -------------------------------------------------- +# different notions of "level" / total degree. +abstract type AbstractLevel end +struct TotalDegree <: AbstractLevel + wn::Float64 + wl::Float64 +end + +TotalDegree() = TotalDegree(1.0, 0.66) + +(l::TotalDegree)(b::NamedTuple) = b.n/l.wn + b.l/l.wl +(l::TotalDegree)(bb::AbstractVector{<: NamedTuple}) = sum(l(b) for b in bb) + + +struct EuclideanDegree <: AbstractLevel + wn::Float64 + wl::Float64 +end + +EuclideanDegree() = EuclideanDegree(1.0, 0.66) + +(l::EuclideanDegree)(b::NamedTuple) = sqrt( (b.n/l.wn)^2 + (b.l/l.wl)^2 ) +(l::EuclideanDegree)(bb::AbstractVector{<: NamedTuple}) = sqrt( sum(l(b)^2 for b in bb) ) + + +# ------------------------------------------------------- +# construction of Rnlrzz bases with lots of defaults +# +# TODO: offer simple option to initialize ACE1-like or trACE-like +# function ace_learnable_Rnlrzz(; - Dtot = nothing, - lmax = nothing, + max_level = nothing, + level = nothing, + maxl = nothing, + maxn = nothing, elements = nothing, spec = nothing, rin0cuts = _default_rin0cuts(elements), transforms = agnesi_transform.(rin0cuts, 2, 2), - polys = Polynomials4ML.legendre_basis(Dtot+1), + polys = :legendre, envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) ) if elements == nothing error("elements must be specified!") end - if (spec == nothing) && (Dtot == nothing || lmax == nothing) - error("Must specify either `spec` or `Dtot` and `lmax`") + if (spec == nothing) && (level == nothing || max_level == nothing) + error("Must specify either `spec` or `level, max_level`!") end zlist =_convert_zlist(elements) if spec == nothing - spec = [ (n = n, l = l) for n = 1:(Dtot+1), l = 0:lmax - if (n-1 + l) <= Dtot ] + spec = [ (n = n, l = l) for n = 1:maxn, l = 0:maxl + if level((n = n, l = l)) <= max_level ] + end + + # now the actual maxn is the maximum n in the spec + actual_maxn = maximum([ s.n for s in spec ]) + + if polys isa Symbol + if polys == :legendre + polys = Polynomials4ML.legendre_basis(actual_maxn) + else + error("unknown polynomial type : $polys") + end end - maxn = maximum([ s.n for s in spec ]) - if maxn > length(polys) - error("maxn > length of polynomial basis") + if actual_maxn > length(polys) + error("actual_maxn > length of polynomial basis") end return LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, rin0cuts, spec) end + + + +function ace_model(; elements = nothing, + order = nothing, + Ytype = :solid, + # radial basis + rbasis = nothing, + rbasis_type = :learnable, + maxl = 30, # maxl, max are fairly high defaults + maxn = 50, # that we will likely never reach + # basis size parameters + level = nothing, + max_level = nothing, + + ) + # construct an rbasis if needed + if isnothing(rbasis) + if rbasis_type == :learnable + rbasis = ace_learnable_Rnlrzz(; max_level = max_level, level = level, + maxl = maxl, maxn = maxn, + elements = elements) + else + error("unknown rbasis_type = $rbasis_type") + end + end + + AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, + level = level, max_level = max_level) + + return ace_model(rbasis, Ytype, AA_spec, level) +end + + +# ------------------------------------------------------- diff --git a/src/models/models.jl b/src/models/models.jl index c92c7e6f..e436851a 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -1,6 +1,8 @@ module Models +include("utils.jl") + include("elements.jl") include("radial_envelopes.jl") @@ -9,5 +11,8 @@ include("radial_transforms.jl") include("Rnl_basis.jl") -include("ace_heuristics.jl") +include("ace_heuristics.jl") + +include("ace.jl") + end \ No newline at end of file diff --git a/src/models/utils.jl b/src/models/utils.jl new file mode 100644 index 00000000..785bf6f7 --- /dev/null +++ b/src/models/utils.jl @@ -0,0 +1,79 @@ +import EquivariantModels + +function _inv_list(l) + d = Dict() + for (i, x) in enumerate(l) + d[x] = i + end + return d +end + + + +# TODO : the `sparse_AA_spec` should be replaced with a `sparse_ace_spec` +# which generates only a (n, l) spec. From that, we can then generate +# the corresponding (n, l, ) AA spec. This would be much more readable. + +""" +This is one of the most important functions to generate an ACE model with +sparse AA basis. It generates the AA basis specification as a list (`Vector`) +of vectors of `@NamedTuple{n::Int, l::Int, m::Int}`. + +### Parameters + +* `order` : maximum correlation order +* `r_spec` : radial basis specification in the format `Vector{@NamedTuple{a::Int64, b::Int64}}` +* `max_level` : maximum level of the basis, either a single scalar, or an iterable (one for each order) +* `level` : a function that computes the level of a basis element; see e.g. `TotalDegree` and `EuclideanDegree` +""" +function sparse_AA_spec(; order = nothing, + r_spec = nothing, + max_level = nothing, + level = nothing, ) + # compute the r levels + r_level = [ level(b) for b in r_spec ] + + # generate the A basis spec from the radial basis spec. + NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} + A_spec = NT_NLM[] + A_spec_level = eltype(r_level)[] + for br in r_spec + for m = -br.l:br.l + b = (n = br.n, l = br.l, m = m) + push!(A_spec, b) + push!(A_spec_level, level(b)) + end + end + p = sortperm(A_spec_level) + A_spec = A_spec[p] + A_spec_level = A_spec_level[p] + inv_A_spec = _inv_list(A_spec) + + # generate the AA basis spec from the A basis spec + tup2b = vv -> [ A_spec[v] for v in vv[vv .> 0] ] + admissible = bb -> (length(bb) == 0) || level(bb) <= max_level + filter_ = EquivariantModels.RPE_filter_real(0) + + AA_spec = EquivariantModels.gensparse(; + NU = order, tup2b = tup2b, filter = filter_, + admissible = admissible, + minvv = fill(0, order), + maxvv = fill(length(A_spec), order), + ordered = true) + + AA_spec = [ vv[vv .> 0] for vv in AA_spec if !(isempty(vv[vv .> 0])) ] + + # map back to nlm + AA_spec_nlm = Vector{NT_NLM}[] + if length(AA_spec[1]) == 0 + push!(AA_spec_nlm, NT_NLM[]) + idx0 = 2 + else + idx0 = 1 + end + for vv in AA_spec + push!(AA_spec_nlm, [ A_spec[v] for v in vv ]) + end + + return AA_spec_nlm +end diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 71b3e9ee..5ab479bf 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -5,153 +5,22 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) using ACEpotentials M = ACEpotentials.Models -using Polynomials4ML -P4ML = Polynomials4ML - using Random, LuxCore rng = Random.MersenneTwister(1234) -function _inv_list(l) - d = Dict() - for (i, x) in enumerate(l) - d[x] = i - end - return d -end - -struct TotalDegree - wL::Float64 -end - -TotalDegree() = TotalDegree(1.5) - -(l::TotalDegree)(b::NamedTuple) = b.n + b.l -(l::TotalDegree)(bb::Vector{<: NamedTuple}) = sum(l(b) for b in bb) - - -function make_AA_spec(; order = nothing, - r_spec = nothing, - max_level = nothing, - level = nothing, ) - # compute the r levels - r_level = [ level(b) for b in r_spec ] - - # generate the A basis spec from the radial basis spec. - NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} - A_spec = NT_NLM[] - A_spec_level = eltype(r_level)[] - for br in r_spec - for m = -br.l:br.l - b = (n = br.n, l = br.l, m = m) - push!(A_spec, b) - push!(A_spec_level, level(b)) - end - end - p = sortperm(A_spec_level) - A_spec = A_spec[p] - A_spec_level = A_spec_level[p] - inv_A_spec = _inv_list(A_spec) - - # generate the AA basis spec from the A basis spec - tup2b = vv -> [ A_spec[v] for v in vv[vv .> 0] ] - admissible = bb -> (length(bb) == 0) || level(bb) <= max_level - filter_ = EquivariantModels.RPE_filter_real(0) - - AA_spec = EquivariantModels.gensparse(; - NU = order, tup2b = tup2b, filter = filter_, - admissible = admissible, - minvv = fill(0, order), - maxvv = fill(length(A_spec), order), - ordered = true) - - AA_spec = [ vv[vv .> 0] for vv in AA_spec if !(isempty(vv[vv .> 0])) ] - - # map back to nlm - AA_spec_nlm = Vector{NT_NLM}[] - if length(AA_spec[1]) == 0 - push!(AA_spec_nlm, NT_NLM[]) - idx0 = 2 - else - idx0 = 1 - end - for vv in AA_spec - push!(AA_spec_nlm, [ A_spec[v] for v in vv ]) - end - - return AA_spec_nlm -end - -function make_A_spec(AA_spec, level) - NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} - A_spec = NT_NLM[] - for bb in AA_spec - append!(A_spec, bb) - end - A_spec_level = [ level(b) for b in A_spec ] - p = sortperm(A_spec_level) - A_spec = A_spec[p] - return A_spec -end - ## elements = (:Si, :O) -level = TotalDegree() -max_level = 8 -lmax = 4 +level = M.TotalDegree() +max_level = 12 +order = 3 -rbasis = M.ace_learnable_Rnlrzz(Dtot = Dtot, lmax = lmax, elements = elements) -r_spec = rbasis.spec +model = M.ace_model(; elements = elements, order = order, Ytype = :solid, + level = level, max_level = max_level, maxl = 4) -AA_spec = make_specs(order = 3, r_spec = r_spec, - level = level, max_level = max_level) +ps, st = LuxCore.setup(rng, model) -## - -import RepLieGroups -import EquivariantModels -import SpheriCart - -cgen = EquivariantModels.Rot3DCoeffs_real(0) -AA2BB_map = EquivariantModels._rpi_A2B_matrix(cgen, AA_spec) - -keep_AA_idx = findall(sum(abs, AA2BB_map; dims = 1)[:] .> 0) - -AA_spec = AA_spec[keep_AA_idx] -AA2BB_map = AA2BB_map[:, keep_AA_idx] - -A_spec = make_A_spec(AA_spec, level) - -maxl = maximum([ b.l for b in A_spec ]) +# TODO: the number of parameters seems off. -ybasis = SpheriCart.SolidHarmonics(maxl) - -## -# now we need to take the human-readable specs and convert them into -# the layer-readable specs - -r_spec = rbasis.spec - -# this should go into sphericart or P4ML -NT_LM = NamedTuple{(:l, :m), Tuple{Int, Int}} -y_spec = NT_LM[] -for i = 1:SpheriCart.sizeY(maxl) - l, m = SpheriCart.idx2lm(i) - push!(y_spec, (l = l, m = m)) -end - -# get the idx version of A_spec -inv_r_spec = _inv_list(r_spec) -inv_y_spec = _inv_list(y_spec) -A_spec_idx = [ (inv_r_spec[(n=b.n, l=b.l)], inv_y_spec[(l=b.l, m=b.m)]) - for b in A_spec ] -a_basis = P4ML.PooledSparseProduct(A_spec_idx) -a_basis.meta["A_spec"] = A_spec - -inv_A_spec = _inv_list(A_spec) -AA_spec_idx = [ [ inv_A_spec[b] for b in bb ] for bb in AA_spec ] -sort!.(AA_spec_idx) -aa_basis = P4ML.SparseSymmProdDAG(AA_spec_idx) -aa_basis.meta["AA_spec"] = AA_spec +## -length(aa_basis) \ No newline at end of file diff --git a/test/models/test_models.jl b/test/models/test_models.jl new file mode 100644 index 00000000..e69de29b From 7f319837f25bc45803aa6c19b4dc0cf2104ff5be Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 6 May 2024 14:39:07 -0700 Subject: [PATCH 008/112] draft model evaluation + symmetry tests --- Project.toml | 2 ++ src/models/Rnl_basis.jl | 37 +++++++++++++++++--- src/models/ace.jl | 65 +++++++++++++++++++++++++++++++++--- src/models/ace_heuristics.jl | 9 +++-- src/models/utils.jl | 36 ++++++++++++++++++++ test/models/test_ace.jl | 27 +++++++++++++-- 6 files changed, 163 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 371a59b1..c2e5a114 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.6.6" [deps] ACE1 = "e3f9bc04-086e-409a-ba78-e9769fe067bb" ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" +ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc" @@ -14,6 +15,7 @@ Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JuLIP = "945c410c-986d-556a-acb1-167a618e0462" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 61549160..68260d33 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -90,7 +90,7 @@ function initialparameters(rng::AbstractRNG, W = W ./ sqrt.(sum(W.^2, dims = 2)) end - return (Wnlq = [ _W() for i = 1:NZ for j = 1:NZ ], ) + return (Wnlq = [ _W() for i = 1:NZ, j = 1:NZ ], ) end function initialstates(rng::AbstractRNG, @@ -112,14 +112,43 @@ import Polynomials4ML (l::LearnableRnlrzzBasis)(args...) = evaluate(l, args...) -function evaluate(basis::LearnableRnlrzzBasis, r, Zi, Zj, ps, st) +function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) iz = _z2i(basis, Zi) jz = _z2i(basis, Zj) - Wij = ps.W[iz, jz] + Wij = ps.Wnlq[iz, jz] + trans_ij = basis.transforms[iz, jz] + x = trans_ij(r) + P = Polynomials4ML.evaluate(basis.polys, x) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + Rnl[:] .= Wij * (P .* e) + return Rnl, st +end + +function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) + iz = _z2i(basis, Zi) + jz = _z2i(basis, Zj) + Wij = ps.Wnlq[iz, jz] trans_ij = basis.transforms[iz, jz] x = trans_ij(r) P = Polynomials4ML.evaluate(basis.polys, x) env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, x) return Wij * (P .* e), st -end \ No newline at end of file +end + + +function evaluate_batched(basis::LearnableRnlrzzBasis, + rs::AbstractVector{<: Real}, zi, zjs, ps, st) + @assert length(rs) == length(zjs) + # evaluate the first one to get the types and size + Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) + # allocate storage + Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + # then evaluate the rest in-place + for j = 1:length(rs) + evaluate!((@view Rnl[j, :]), basis, rs[j], zi, zjs[j], ps, st) + end + return Rnl, st +end + diff --git a/src/models/ace.jl b/src/models/ace.jl index fc7631fb..ddb44ce3 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -4,8 +4,12 @@ using LuxCore: AbstractExplicitLayer, initialparameters, initialstates +using Lux: glorot_normal + using Random: AbstractRNG using SparseArrays: SparseMatrixCSC +using StaticArrays: SVector +using LinearAlgebra: norm, dot import SpheriCart using SpheriCart: SolidHarmonics, SphericalHarmonics @@ -123,12 +127,15 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, end # TODO: it is not entirely clear that the `level` is really needed here -# since it is implicitly already encoded in AA_spec +# since it is implicitly already encoded in AA_spec. We need a +# function `auto_level` that generates level automagically from AA_spec. function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level) return _generate_ace_model(rbasis, Ytype, AA_spec, level) end -# NOTE : a nice convenience constructure is also provided in `ace_heuristics.jl` +# NOTE : a nicer convenience constructor is also provided in `ace_heuristics.jl` +# this is where we should move all defaults, heuristics and other things +# that make life good. # ------------------------------------------------------------ # Lux stuff @@ -137,9 +144,20 @@ function initialparameters(rng::AbstractRNG, model::ACEModel) NZ = _get_nz(model) n_B_params, n_AA_params = size(model.A2Bmap) + # only the B params are parameters, the AA params are uniquely defined # via the B params. - return (WB = [ zeros(n_B_params) for _=1:NZ ], + + # there are different ways to initialize parameters + if model.meta["init_WB"] == "zeros" + winit = zeros + elseif model.meta["init_WB"] == "glorot_normal" + winit = glorot_normal + else + error("unknown `init_WB` = $(model.meta["init_WB"])") + end + + return (WB = [ winit(Float64, n_B_params) for _=1:NZ ], rbasis = initialparameters(rng, model.rbasis), ) end @@ -156,4 +174,43 @@ end # this should possibly be moved to a separate file once it # gets more complicated. - +# these _getlmax and _length should be moved into SpheriCart +_getlmax(ybasis::SolidHarmonics{L}) where {L} = L +_length(ybasis::SolidHarmonics) = SpheriCart.sizeY(_getlmax(ybasis)) + +function evaluate(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + # get the radii + rs = [ norm(r) for r in Rs ] # use Bumper + + # evaluate the radial basis + # use Bumper to pre-allocate + Rnl, _st = evaluate_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) + + # evaluate the Y basis + Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here + SpheriCart.compute!(Ylm, model.ybasis, Rs) + + # evaluate the A basis + TA = promote_type(T, eltype(Rnl)) + A = zeros(T, length(model.abasis)) + Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) + + # evaluate the AA basis + _AA = zeros(T, length(model.aabasis)) # use Bumper here + Polynomials4ML.evaluate!(_AA, model.aabasis, A) + # project to the actual AA basis + proj = model.aabasis.projection + AA = _AA[proj] # use Bumper here, or view; needs experimentation. + + # evaluate the coupling coefficients + B = model.A2Bmap * AA + + # contract with params + i_z0 = _z2i(model.rbasis, Z0) + val = dot(B, ps.WB[i_z0]) + + return val, st +end \ No newline at end of file diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index e2eec39d..d97e4408 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -1,3 +1,5 @@ + + # -------------------------------------------------- # different notions of "level" / total degree. @@ -87,7 +89,7 @@ function ace_model(; elements = nothing, # basis size parameters level = nothing, max_level = nothing, - + init_WB = :zeros, ) # construct an rbasis if needed if isnothing(rbasis) @@ -103,7 +105,10 @@ function ace_model(; elements = nothing, AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, level = level, max_level = max_level) - return ace_model(rbasis, Ytype, AA_spec, level) + model = ace_model(rbasis, Ytype, AA_spec, level) + model.meta["init_WB"] = String(init_WB) + + return model end diff --git a/src/models/utils.jl b/src/models/utils.jl index 785bf6f7..213c20d9 100644 --- a/src/models/utils.jl +++ b/src/models/utils.jl @@ -77,3 +77,39 @@ function sparse_AA_spec(; order = nothing, return AA_spec_nlm end + + + + +import ACE1 + +function rand_atenv(model, Nat) + z0 = rand(model._i2z) + zs = rand(model._i2z, Nat) + + rs = Float64[] + for zj in zs + iz0 = _z2i(model, z0) + izj = _z2i(model, zj) + x = 2 * rand() - 1 + t_ij = model.rbasis.transforms[iz0, izj] + r_ij = inv_transform(t_ij, x) + push!(rs, r_ij) + end + Rs = [ r * ACE1.Random.rand_sphere() for r in rs ] + return Rs, zs, z0 +end + + +using StaticArrays: @SMatrix +using LinearAlgebra: qr + +function rand_rot() + A = @SMatrix randn(3, 3) + Q, _ = qr(A) + return Q +end + +rand_iso() = rand([-1,1]) * rand_rot() + + diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 5ab479bf..3f6dc4a6 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -2,6 +2,9 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); +using Test, ACEbase +using ACEbase.Testing: print_tf + using ACEpotentials M = ACEpotentials.Models @@ -12,15 +15,33 @@ rng = Random.MersenneTwister(1234) elements = (:Si, :O) level = M.TotalDegree() -max_level = 12 +max_level = 15 order = 3 model = M.ace_model(; elements = elements, order = order, Ytype = :solid, - level = level, max_level = max_level, maxl = 4) + level = level, max_level = max_level, maxl = 8, + init_WB = :glorot_normal) ps, st = LuxCore.setup(rng, model) -# TODO: the number of parameters seems off. +# TODO: the number of parameters is completely off, so something is +# likely wrong here. + ## +@info("Test Rotation-Invariance of the Model") + +for ntest = 1:50 + Nat = rand(8:16) + Rs, Zs, Z0 = M.rand_atenv(model, Nat) + val, st = M.evaluate(model, Rs, Zs, Z0, ps, st) + + p = shuffle(1:Nat) + Rs1 = Ref(M.rand_iso()) .* Rs[p] + Zs1 = Zs[p] + val1, st = M.evaluate(model, Rs1, Zs1, Z0, ps, st) + + print_tf(@test abs(val - val1) < 1e-10) +end +println() From b9b67d99c668f2f2d36b51fbb2bed6300018dcf2 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 6 May 2024 15:30:59 -0700 Subject: [PATCH 009/112] r gradient of Rnlzz basis --- src/models/Rnl_basis.jl | 17 +++++++++++++++++ test/models/test_ace.jl | 11 +++++++++++ test/models/test_learnable_Rnl.jl | 15 ++++++++++++--- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 68260d33..59ee58ef 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -152,3 +152,20 @@ function evaluate_batched(basis::LearnableRnlrzzBasis, return Rnl, st end + +# ----- gradients +# because the typical scenario is that we have few r, then moderately +# many q and then many (n, l), this seems to be best done in Forward-mode. + +import ForwardDiff +using ForwardDiff: Dual + +function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} + d_r = Dual{T}(r, one(T)) + d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) + Rnl = ForwardDiff.value.(d_Rnl) + Rnl_d = ForwardDiff.extract_derivative(T, d_Rnl) + return Rnl, Rnl_d, st +end + + diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 3f6dc4a6..b702bb42 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -45,3 +45,14 @@ for ntest = 1:50 print_tf(@test abs(val - val1) < 1e-10) end println() + +## + +# # first test shows the performance is not at all awful even without any +# # optimizations and reductions in memory allocations. +# using BenchmarkTools +# Rs, Zs, z0 = M.rand_atenv(model, 16) +# @btime M.evaluate($model, $Rs, $Zs, $Z0, $ps, $st) + +## + diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index f18ff363..691217df 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -12,10 +12,12 @@ rng = Random.MersenneTwister(1234) ## -Dtot = 5 -lmax = 3 +max_level = 8 +level = M.TotalDegree() +maxl = 3; maxn = max_level; elements = (:Si, :O) -basis = M.ace_learnable_Rnlrzz(Dtot = Dtot, lmax = lmax, elements = elements) +basis = M.ace_learnable_Rnlrzz(; level=level, max_level=max_level, + maxl = maxl, maxn = maxn, elements = elements) ps, st = LuxCore.setup(rng, basis) @@ -23,4 +25,11 @@ r = 3.0 Zi = basis._i2z[1] Zj = basis._i2z[2] Rnl, st1 = basis(r, Zi, Zj, ps, st) +Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) + + +## + +@btime ($basis)(r, Zi, Zj, $ps, $st) +@btime M.evaluate_ed($basis, r, Zi, Zj, $ps, $st) From 7957bb7f6f6e319eb01c8f7e87d0bc435f7b0118 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 6 May 2024 19:57:38 -0700 Subject: [PATCH 010/112] intermediate backup --- src/models/Rnl_basis.jl | 20 ++++++ src/models/ace.jl | 105 ++++++++++++++++++++++++++++++ test/models/test_learnable_Rnl.jl | 3 + 3 files changed, 128 insertions(+) diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 59ee58ef..3fe7267e 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -169,3 +169,23 @@ function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T end +function evaluate_ed_batched(basis::LearnableRnlrzzBasis, + rs::AbstractVector{T}, Zi, Zs, ps, st + ) where {T <: Real} + + @assert length(rs) == length(Zs) + Rnl1, st = evaluate(basis, rs[1], Zi, Zs[1], ps, st) + Rnl = zeros(T, length(rs), length(Rnl1)) + Rnl_d = zeros(T, length(rs), length(Rnl1)) + + for j = 1:length(rs) + d_r = Dual{T}(rs[j], one(T)) + d_Rnl, st = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here + map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) + map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) + end + + return Rnl, Rnl_d, st +end + + diff --git a/src/models/ace.jl b/src/models/ace.jl index ddb44ce3..a7602a71 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -212,5 +212,110 @@ function evaluate(model::ACEModel, i_z0 = _z2i(model.rbasis, Z0) val = dot(B, ps.WB[i_z0]) + return val, st +end + + + +function evaluate(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + # get the radii + rs = [ norm(r) for r in Rs ] # use Bumper + + # evaluate the radial basis + # use Bumper to pre-allocate + Rnl, _st = evaluate_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) + + # evaluate the Y basis + Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here + SpheriCart.compute!(Ylm, model.ybasis, Rs) + + # evaluate the A basis + TA = promote_type(T, eltype(Rnl)) + A = zeros(T, length(model.abasis)) + Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) + + # evaluate the AA basis + _AA = zeros(T, length(model.aabasis)) # use Bumper here + Polynomials4ML.evaluate!(_AA, model.aabasis, A) + # project to the actual AA basis + proj = model.aabasis.projection + AA = _AA[proj] # use Bumper here, or view; needs experimentation. + + # evaluate the coupling coefficients + B = model.A2Bmap * AA + + # contract with params + i_z0 = _z2i(model.rbasis, Z0) + val = dot(B, ps.WB[i_z0]) + + return val, st +end + + +function evaluate_ed(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + + # ---------- EMBEDDINGS ------------ + # (these are done in forward mode, so not part of the fwd, bwd passes) + + # get the radii + rs = [ norm(r) for r in Rs ] # use Bumper + + # evaluate the radial basis + # use Bumper to pre-allocate + Rnl, dRnl, _st = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) + # evaluate the Y basis + Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here + dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) + SpheriCart.compute_with_grad!(Ylm, dYlm, model.ybasis, Rs) + + # ---------- FORWARD PASS ------------ + + # evaluate the A basis + TA = promote_type(T, eltype(Rnl)) + A = zeros(T, length(model.abasis)) + Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) + + # evaluate the AA basis + _AA = zeros(T, length(model.aabasis)) # use Bumper here + Polynomials4ML.evaluate!(_AA, model.aabasis, A) + # project to the actual AA basis + proj = model.aabasis.projection + AA = _AA[proj] # use Bumper here, or view; needs experimentation. + + # evaluate the coupling coefficients + B = model.A2Bmap * AA + + # contract with params + i_z0 = _z2i(model.rbasis, Z0) + Ei = dot(B, ps.WB[i_z0]) + + # ---------- BACKWARD PASS ------------ + + # ∂Ei / ∂B = WB[i_z0] + ∂B = ps.WB[i_z0] + + # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA + # = (WB[i_z0]) * A2Bmap + ∂AA = model.A2Bmap' * ∂B + + # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A + # = pullback(aabasis, ∂AA) + + # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) + + + # ---------- ASSEMBLE DERIVATIVES ------------ + # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm + # as follows: + # ∂Ei / ∂𝐫ⱼ = ∑_nl ∂Ei / ∂Rnl[j] * ∂Rnl[j] / ∂𝐫ⱼ + # + ∑_lm ∂Ei / ∂Ylm[j] * ∂Ylm[j] / ∂𝐫ⱼ + + return val, st end \ No newline at end of file diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 691217df..4d5f16bb 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -33,3 +33,6 @@ Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) @btime ($basis)(r, Zi, Zj, $ps, $st) @btime M.evaluate_ed($basis, r, Zi, Zj, $ps, $st) +rs = [r, r, r] +Zs = [Zj, Zj, Zj] +Rnl, Rnl_d, st1 = M.evaluate_ed_batched(basis, rs, Zi, Zs, ps, st) \ No newline at end of file From cd83fb998d403b7335ec07b592e7182f8f5a3919 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 6 May 2024 21:47:56 -0700 Subject: [PATCH 011/112] site energy gradient done and tested --- src/models/ace.jl | 40 +++++++++++++++++++------------ test/models/test_ace.jl | 27 ++++++++++++++++----- test/models/test_learnable_Rnl.jl | 16 +++++++++++-- 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index a7602a71..6b1fc75d 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -263,16 +263,16 @@ function evaluate_ed(model::ACEModel, # (these are done in forward mode, so not part of the fwd, bwd passes) # get the radii - rs = [ norm(r) for r in Rs ] # use Bumper + rs = [ norm(r) for r in Rs ] # TODO: use Bumper # evaluate the radial basis - # use Bumper to pre-allocate + # TODO: use Bumper to pre-allocate Rnl, dRnl, _st = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here + Ylm = zeros(T, length(Rs), _length(model.ybasis)) # TODO: use Bumper dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) - SpheriCart.compute_with_grad!(Ylm, dYlm, model.ybasis, Rs) + SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) # ---------- FORWARD PASS ------------ @@ -282,16 +282,18 @@ function evaluate_ed(model::ACEModel, Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) # evaluate the AA basis - _AA = zeros(T, length(model.aabasis)) # use Bumper here + _AA = zeros(T, length(model.aabasis)) # TODO: use Bumper here Polynomials4ML.evaluate!(_AA, model.aabasis, A) # project to the actual AA basis proj = model.aabasis.projection - AA = _AA[proj] # use Bumper here, or view; needs experimentation. + AA = _AA[proj] # TODO: use Bumper here, or view; needs experimentation. - # evaluate the coupling coefficients + # evaluate the coupling coefficients + # TODO: use Bumper and do it in-place B = model.A2Bmap * AA # contract with params + # (here we can insert another nonlinearity instead of the simple dot) i_z0 = _z2i(model.rbasis, Z0) Ei = dot(B, ps.WB[i_z0]) @@ -300,22 +302,30 @@ function evaluate_ed(model::ACEModel, # ∂Ei / ∂B = WB[i_z0] ∂B = ps.WB[i_z0] - # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA - # = (WB[i_z0]) * A2Bmap - ∂AA = model.A2Bmap' * ∂B + # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap + ∂AA = model.A2Bmap' * ∂B # TODO: make this in-place + _∂AA = zeros(T, length(_AA)) + _∂AA[proj] = ∂AA - # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A - # = pullback(aabasis, ∂AA) + # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) + ∂A = zeros(T, length(model.abasis)) + Polynomials4ML.pullback_arg!(∂A, _∂AA, model.aabasis, _AA) # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) - + ∂Rnl = zeros(T, size(Rnl)) + ∂Ylm = zeros(T, size(Ylm)) + Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, model.abasis, (Rnl, Ylm)) # ---------- ASSEMBLE DERIVATIVES ------------ # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm # as follows: # ∂Ei / ∂𝐫ⱼ = ∑_nl ∂Ei / ∂Rnl[j] * ∂Rnl[j] / ∂𝐫ⱼ # + ∑_lm ∂Ei / ∂Ylm[j] * ∂Ylm[j] / ∂𝐫ⱼ + ∇Ei = zeros(SVector{3, T}, length(Rs)) + for j = 1:length(Rs) + ∇Ei[j] = dot(∂Rnl[j, :], dRnl[j, :]) * (Rs[j] / rs[j]) + + sum(∂Ylm[j, :] .* dYlm[j, :]) + end - - return val, st + return Ei, ∇Ei, st end \ No newline at end of file diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index b702bb42..4a0cd3e5 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -8,7 +8,7 @@ using ACEbase.Testing: print_tf using ACEpotentials M = ACEpotentials.Models -using Random, LuxCore +using Random, LuxCore, StaticArrays rng = Random.MersenneTwister(1234) ## @@ -48,11 +48,26 @@ println() ## -# # first test shows the performance is not at all awful even without any -# # optimizations and reductions in memory allocations. -# using BenchmarkTools -# Rs, Zs, z0 = M.rand_atenv(model, 16) -# @btime M.evaluate($model, $Rs, $Zs, $Z0, $ps, $st) +Rs, Zs, z0 = M.rand_atenv(model, 16) +Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) + +Ei1, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + +Ei ≈ Ei1 + +Us = randn(SVector{3, Float64}, length(Rs)) +F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] +dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) +ACEbase.Testing.fdtest(F, dF, 0.0) + +## + +# first test shows the performance is not at all awful even without any +# optimizations and reductions in memory allocations. +using BenchmarkTools +Rs, Zs, z0 = M.rand_atenv(model, 16) +@btime M.evaluate($model, $Rs, $Zs, $z0, $ps, $st) +@btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) ## diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 4d5f16bb..4a937858 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -27,11 +27,23 @@ Zj = basis._i2z[2] Rnl, st1 = basis(r, Zi, Zj, ps, st) Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) +@info("Test derivatives of Rnlrzz basis") + +for ntest = 1:20 + r = 2.0 + rand() + Zi = rand(basis._i2z) + Zj = rand(basis._i2z) + U = randn(eltype(Rnl), length(Rnl)) + F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)[1]) + dF(t) = dot(U, M.evaluate_ed(basis, r + t, Zi, Zj, ps, st)[2]) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) +end +println() ## -@btime ($basis)(r, Zi, Zj, $ps, $st) -@btime M.evaluate_ed($basis, r, Zi, Zj, $ps, $st) +# @btime ($basis)(r, Zi, Zj, $ps, $st) +# @btime M.evaluate_ed($basis, r, Zi, Zj, $ps, $st) rs = [r, r, r] Zs = [Zj, Zj, Zj] From 982ead4e6f4e7bb5391e969d410e7de183e042fe Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 8 May 2024 14:47:03 -0700 Subject: [PATCH 012/112] first working version of grad_params --- Project.toml | 4 ++ src/models/Rnl_basis.jl | 88 +++++++++++++++++++++++- src/models/ace.jl | 110 +++++++++++++++++++++--------- test/models/test_ace.jl | 44 +++++++++--- test/models/test_learnable_Rnl.jl | 6 +- 5 files changed, 205 insertions(+), 47 deletions(-) diff --git a/Project.toml b/Project.toml index c2e5a114..1f6bb74e 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc" ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -17,6 +18,7 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" @@ -24,6 +26,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RepLieGroups = "f07d36f2-91c4-427a-b67b-965fe5ebe1d2" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -31,6 +34,7 @@ SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ACE1 = "0.12" diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 3fe7267e..3a864cbf 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -138,21 +138,47 @@ function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) end +# function evaluate_batched(basis::LearnableRnlrzzBasis, +# rs::AbstractVector{<: Real}, zi, zjs, ps, st) +# @assert length(rs) == length(zjs) +# # evaluate the first one to get the types and size +# Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) +# # allocate storage +# Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) +# # then evaluate the rest in-place +# for j = 1:length(rs) +# evaluate!((@view Rnl[j, :]), basis, rs[j], zi, zjs[j], ps, st) +# end +# return Rnl, st +# end + function evaluate_batched(basis::LearnableRnlrzzBasis, - rs::AbstractVector{<: Real}, zi, zjs, ps, st) + rs, zi, zjs, ps, st) + @assert length(rs) == length(zjs) # evaluate the first one to get the types and size Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) - # allocate storage + # ... and then allocate storage Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + # then evaluate the rest in-place for j = 1:length(rs) - evaluate!((@view Rnl[j, :]), basis, rs[j], zi, zjs[j], ps, st) + iz = _z2i(basis, zi) + jz = _z2i(basis, zjs[j]) + trans_ij = basis.transforms[iz, jz] + x = trans_ij(rs[j]) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + P = Polynomials4ML.evaluate(basis.polys, x) .* e + Rnl[j, :] = ps.Wnlq[iz, jz] * P end + return Rnl, st end + + # ----- gradients # because the typical scenario is that we have few r, then moderately # many q and then many (n, l), this seems to be best done in Forward-mode. @@ -189,3 +215,59 @@ function evaluate_ed_batched(basis::LearnableRnlrzzBasis, end + + +# -------- RRULES + +import ChainRulesCore: rrule, NotImplemented, NoTangent + +# NB : iz = īz = _z2i(z0) throughout +# +# Rnl[j, nl] = Wnlq[iz, jz] * Pq * e +# ∂_Wn̄l̄q̄[īz,j̄z] { ∑_jnl Δ[j,nl] * Rnl[j, nl] } +# = ∑_jnl Δ[j,nl] * Pq * e * δ_q̄q * δ_l̄l * δ_n̄n * δ_{īz,iz} * δ_{j̄z,jz} +# = ∑_{jz = j̄z} Δ[j̄z, n̄l̄] * P_q̄ * e +# +function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + @assert length(rs) == length(zjs) + # evaluate the first one to get the types and size + Rnl_1, _ = evaluate(basis, rs[1], zi, zjs[1], ps, st) + # ... and then allocate storage + Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + + # output storage for the gradients + ∂Wnlq = deepcopy(ps.Wnlq) + fill!.(∂Wnlq, 0) + + # then evaluate the rest in-place + for j = 1:length(rs) + iz = _z2i(basis, zi) + jz = _z2i(basis, zjs[j]) + trans_ij = basis.transforms[iz, jz] + x = trans_ij(rs[j]) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + P = Polynomials4ML.evaluate(basis.polys, x) .* e + # TODO: the P shouuld be stored inside a closure in the + # forward pass and then resused. + + # TODO: ... and obviously this part here needs to be moved + # to a SIMD loop. + ∂Wnlq[iz, jz][:, :] .+= Δ[j, :] * P' + end + + return (Wnql = ∂Wnlq,) +end + + +function rrule(::typeof(evaluate_batched), + basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) + + return (Rnl, st), + Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), + pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), + NoTangent()) +end \ No newline at end of file diff --git a/src/models/ace.jl b/src/models/ace.jl index 6b1fc75d..bff4f6e7 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -174,6 +174,8 @@ end # this should possibly be moved to a separate file once it # gets more complicated. +import Zygote + # these _getlmax and _length should be moved into SpheriCart _getlmax(ybasis::SolidHarmonics{L}) where {L} = L _length(ybasis::SolidHarmonics) = SpheriCart.sizeY(_getlmax(ybasis)) @@ -217,20 +219,26 @@ end -function evaluate(model::ACEModel, - Rs::AbstractVector{SVector{3, T}}, Zs, Z0, - ps, st) where {T} +function evaluate_ed(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + + # ---------- EMBEDDINGS ------------ + # (these are done in forward mode, so not part of the fwd, bwd passes) + # get the radii - rs = [ norm(r) for r in Rs ] # use Bumper + rs = [ norm(r) for r in Rs ] # TODO: use Bumper # evaluate the radial basis - # use Bumper to pre-allocate - Rnl, _st = evaluate_batched(model.rbasis, rs, Z0, Zs, - ps.rbasis, st.rbasis) - + # TODO: use Bumper to pre-allocate + Rnl, dRnl, _st = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here - SpheriCart.compute!(Ylm, model.ybasis, Rs) + Ylm = zeros(T, length(Rs), _length(model.ybasis)) # TODO: use Bumper + dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) + SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) + + # ---------- FORWARD PASS ------------ # evaluate the A basis TA = promote_type(T, eltype(Rnl)) @@ -238,24 +246,56 @@ function evaluate(model::ACEModel, Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) # evaluate the AA basis - _AA = zeros(T, length(model.aabasis)) # use Bumper here + _AA = zeros(T, length(model.aabasis)) # TODO: use Bumper here Polynomials4ML.evaluate!(_AA, model.aabasis, A) # project to the actual AA basis proj = model.aabasis.projection - AA = _AA[proj] # use Bumper here, or view; needs experimentation. + AA = _AA[proj] # TODO: use Bumper here, or view; needs experimentation. - # evaluate the coupling coefficients + # evaluate the coupling coefficients + # TODO: use Bumper and do it in-place B = model.A2Bmap * AA # contract with params + # (here we can insert another nonlinearity instead of the simple dot) i_z0 = _z2i(model.rbasis, Z0) - val = dot(B, ps.WB[i_z0]) - - return val, st + Ei = dot(B, ps.WB[i_z0]) + + # ---------- BACKWARD PASS ------------ + + # ∂Ei / ∂B = WB[i_z0] + ∂B = ps.WB[i_z0] + + # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap + ∂AA = model.A2Bmap' * ∂B # TODO: make this in-place + _∂AA = zeros(T, length(_AA)) + _∂AA[proj] = ∂AA + + # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) + ∂A = zeros(T, length(model.abasis)) + Polynomials4ML.pullback_arg!(∂A, _∂AA, model.aabasis, _AA) + + # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) + ∂Rnl = zeros(T, size(Rnl)) + ∂Ylm = zeros(T, size(Ylm)) + Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, model.abasis, (Rnl, Ylm)) + + # ---------- ASSEMBLE DERIVATIVES ------------ + # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm + # as follows: + # ∂Ei / ∂𝐫ⱼ = ∑_nl ∂Ei / ∂Rnl[j] * ∂Rnl[j] / ∂𝐫ⱼ + # + ∑_lm ∂Ei / ∂Ylm[j] * ∂Ylm[j] / ∂𝐫ⱼ + ∇Ei = zeros(SVector{3, T}, length(Rs)) + for j = 1:length(Rs) + ∇Ei[j] = dot(∂Rnl[j, :], dRnl[j, :]) * (Rs[j] / rs[j]) + + sum(∂Ylm[j, :] .* dYlm[j, :]) + end + + return Ei, ∇Ei, st end -function evaluate_ed(model::ACEModel, +function grad_params(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} @@ -267,8 +307,8 @@ function evaluate_ed(model::ACEModel, # evaluate the radial basis # TODO: use Bumper to pre-allocate - Rnl, dRnl, _st = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, - ps.rbasis, st.rbasis) + (Rnl, _st), pb_Rnl = rrule(evaluate_batched, model.rbasis, + rs, Z0, Zs, ps.rbasis, st.rbasis) # evaluate the Y basis Ylm = zeros(T, length(Rs), _length(model.ybasis)) # TODO: use Bumper dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) @@ -299,7 +339,9 @@ function evaluate_ed(model::ACEModel, # ---------- BACKWARD PASS ------------ - # ∂Ei / ∂B = WB[i_z0] + # we need ∂WB = ∂Ei/∂WB -> this goes into the gradient + # but we also need ∂B = ∂Ei / ∂B = WB[i_z0] to backpropagate + ∂WB_i = B ∂B = ps.WB[i_z0] # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap @@ -313,19 +355,21 @@ function evaluate_ed(model::ACEModel, # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) ∂Rnl = zeros(T, size(Rnl)) - ∂Ylm = zeros(T, size(Ylm)) + ∂Ylm = zeros(T, size(Ylm)) # we could make this a black hole since we don't need it. Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, model.abasis, (Rnl, Ylm)) # ---------- ASSEMBLE DERIVATIVES ------------ - # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm - # as follows: - # ∂Ei / ∂𝐫ⱼ = ∑_nl ∂Ei / ∂Rnl[j] * ∂Rnl[j] / ∂𝐫ⱼ - # + ∑_lm ∂Ei / ∂Ylm[j] * ∂Ylm[j] / ∂𝐫ⱼ - ∇Ei = zeros(SVector{3, T}, length(Rs)) - for j = 1:length(Rs) - ∇Ei[j] = dot(∂Rnl[j, :], dRnl[j, :]) * (Rs[j] / rs[j]) + - sum(∂Ylm[j, :] .* dYlm[j, :]) - end - - return Ei, ∇Ei, st -end \ No newline at end of file + # the first grad_param is ∂WB, which we already have but it needs to be + # written into a vector of vectors + ∂WB = [ zeros(eltype(∂WB_i), size(∂WB_i)) for _=1:_get_nz(model) ] + ∂WB[i_z0] = ∂WB_i + + # the second one is the gradient with respect to Rnl params + # + # ∂Ei / ∂Wn̄l̄q̄ + # = ∑_nlj ∂Ei / ∂Rnl[j] * ∂Rnl[j] / ∂Wn̄l̄q̄ + # = pullback(∂Rnl, rbasis, args...) + _, _, _, _, ∂Wqnl, _ = pb_Rnl(∂Rnl) # this should be a named tuple already. + + return Ei, (WB = ∂WB, rbasis = ∂Wqnl), st +end diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 4a0cd3e5..424301d8 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -3,12 +3,14 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using Test, ACEbase -using ACEbase.Testing: print_tf +using ACEbase.Testing: print_tf, println_slim using ACEpotentials M = ACEpotentials.Models -using Random, LuxCore, StaticArrays +using Optimisers + +using Random, LuxCore, StaticArrays, LinearAlgebra rng = Random.MersenneTwister(1234) ## @@ -48,26 +50,50 @@ println() ## +@info("Test derivatives w.r.t. positions") Rs, Zs, z0 = M.rand_atenv(model, 16) Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) - Ei1, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) +println_slim(@test Ei ≈ Ei1) + +for ntest = 1:20 + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Us = randn(SVector{3, Float64}, Nat) + F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] + dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) +end +println() + +## -Ei ≈ Ei1 +@info("Test derivatives w.r.t. parameters") +Nat = 15 +Rs, Zs, z0 = M.rand_atenv(model, Nat) +Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) +Ei1, ∇Ei, st = M.grad_params(model, Rs, Zs, z0, ps, st) +println_slim(@test Ei ≈ Ei1) -Us = randn(SVector{3, Float64}, length(Rs)) -F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] -dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) -ACEbase.Testing.fdtest(F, dF, 0.0) +for ntest = 1:20 + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + pvec, _restruct = destructure(ps) + uvec = randn(length(pvec)) / sqrt(length(pvec)) + F(t) = M.evaluate(model, Rs, Zs, z0, _restruct(pvec + t * uvec), st)[1] + dF0 = dot( destructure( M.grad_params(model, Rs, Zs, z0, ps, st)[2] )[1], uvec ) + print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose = false)) +end ## + # first test shows the performance is not at all awful even without any # optimizations and reductions in memory allocations. using BenchmarkTools Rs, Zs, z0 = M.rand_atenv(model, 16) @btime M.evaluate($model, $Rs, $Zs, $z0, $ps, $st) @btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) +@btime M.grad_params($model, $Rs, $Zs, $z0, $ps, $st) ## - diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 4a937858..974c5f3c 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -7,7 +7,8 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) using ACEpotentials M = ACEpotentials.Models -using Random, LuxCore +using Random, LuxCore, Test, ACEbase, LinearAlgebra +using ACEbase.Testing: print_tf rng = Random.MersenneTwister(1234) ## @@ -47,4 +48,5 @@ println() rs = [r, r, r] Zs = [Zj, Zj, Zj] -Rnl, Rnl_d, st1 = M.evaluate_ed_batched(basis, rs, Zi, Zs, ps, st) \ No newline at end of file +Rnl, Rnl_d, st1 = M.evaluate_ed_batched(basis, rs, Zi, Zs, ps, st) + From bb83dc96f33b6eadb0e5a87e96f987f287aecd82 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 8 May 2024 15:50:28 -0700 Subject: [PATCH 013/112] reverse over reverse done --- src/models/Rnl_basis.jl | 6 ++++-- src/models/ace.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/models/test_ace.jl | 25 +++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 3a864cbf..ed04d971 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -237,8 +237,10 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) # output storage for the gradients - ∂Wnlq = deepcopy(ps.Wnlq) - fill!.(∂Wnlq, 0) + T_∂Wnlq = promote_type(eltype(Δ), eltype(rs)) + NZ = _get_nz(basis) + ∂Wnlq = [ zeros(T_∂Wnlq, size(ps.Wnlq[i,j])) + for i = 1:NZ, j = 1:NZ ] # then evaluate the rest in-place for j = 1:length(rs) diff --git a/src/models/ace.jl b/src/models/ace.jl index bff4f6e7..854eedce 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -373,3 +373,41 @@ function grad_params(model::ACEModel, return Ei, (WB = ∂WB, rbasis = ∂Wqnl), st end + + +using Optimisers: destructure +using ForwardDiff: Dual, value, extract_derivative + +function pullback_2_mixed(Δ, Δd, model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} + # this is implemented as a directional derivative + # following a wonderful discussion on Discourse with + # Steven G. Johnson and Avik Pal + # + # we want the pullback for the pair (Ei, ∇Ei) computed via evaluate_ed + # i.e. for Δ, Δd the output sensitivities we want to compute + # ∂_w { Δ * Ei + Δd * ∇Ei } + # where ∂_w = gradient with respect to parameters. We first rewrite this as + # Δ * ∂_w Ei + d/dt ∂_w Ei( x + t Δd) |_{t = 0} + # Which we can compute from + # ∂_w Ei( x + Dual(0, 1) * Δd ) + # Beautiful. + + Rs_d = Rs + Dual{T}(0, 1) * Δd + Ei_d, ∂Ei_d, st = grad_params(model, Rs_d, Zs, Z0, ps, st) + + # To make our life easy we can hack the named-tuples a bit + # TODO: this can and probably should be made more efficient + ∂Ei_d_vec, _restruct = destructure(∂Ei_d) + # extract the gradient w.r.t. parameters + ∂Ei = value.(∂Ei_d_vec) + # extract the directional derivative w.r.t. positions + ∂∇Ei_Δd = extract_derivative.(T, ∂Ei_d_vec) + + # combine to produce the output + return _restruct(Δ * ∂Ei + ∂∇Ei_Δd) +end + + + + \ No newline at end of file diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 424301d8..88e70797 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -96,4 +96,29 @@ Rs, Zs, z0 = M.rand_atenv(model, 16) @btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) @btime M.grad_params($model, $Rs, $Zs, $z0, $ps, $st) + ## + +@info("Test second mixed derivatives reverse-over-reverse") +for ntest = 1:20 + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Us = randn(SVector{3, Float64}, Nat) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei, ∂Ei, _ = M.grad_params(model, Rs, Zs, z0, ps, st) + + # test partial derivative w.r.t. the Ei component + ∂2_Ei = M.pullback_2_mixed(1.0, 0*Us, model, Rs, Zs, z0, ps, st) + print_tf(@test destructure(∂2_Ei)[1] ≈ destructure(∂Ei)[1]) + + # test partial derivative w.r.t. the ∇Ei component + ∂2_∇Ei = M.pullback_2_mixed(0.0, Us, model, Rs, Zs, z0, ps, st) + ∂2_∇Ei_vec = destructure(∂2_∇Ei)[1] + + ps_vec, _restruct = destructure(ps) + vs_vec = randn(length(ps_vec)) / sqrt(length(ps_vec)) + F(t) = dot(Us, M.evaluate_ed(model, Rs, Zs, z0, _restruct(ps_vec + t * vs_vec), st)[2]) + dF0 = dot(∂2_∇Ei_vec, vs_vec) + print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false)) +end + From 39dc24083b788d10e7b79034464b493faffbe2d2 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 8 May 2024 15:53:21 -0700 Subject: [PATCH 014/112] benchmark cleanup --- test/models/test_ace.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 88e70797..188498fd 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -85,18 +85,6 @@ for ntest = 1:20 print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose = false)) end -## - - -# first test shows the performance is not at all awful even without any -# optimizations and reductions in memory allocations. -using BenchmarkTools -Rs, Zs, z0 = M.rand_atenv(model, 16) -@btime M.evaluate($model, $Rs, $Zs, $z0, $ps, $st) -@btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) -@btime M.grad_params($model, $Rs, $Zs, $z0, $ps, $st) - - ## @info("Test second mixed derivatives reverse-over-reverse") @@ -122,3 +110,17 @@ for ntest = 1:20 print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false)) end + +## + + +# # first test shows the performance is not at all awful even without any +# # optimizations and reductions in memory allocations. +# using BenchmarkTools +# Nat = 15 +# Rs, Zs, z0 = M.rand_atenv(model, Nat) +# Us = randn(SVector{3, Float64}, Nat) +# print(" evaluate : "); @btime M.evaluate($model, $Rs, $Zs, $z0, $ps, $st) +# print("evaluate_ed : "); @btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) +# print("grad_params : "); @btime M.grad_params($model, $Rs, $Zs, $z0, $ps, $st) +# print(" reverse^2 : "); @btime M.pullback_2_mixed(rand(), $Us, $model, $Rs, $Zs, $z0, $ps, $st) From d811cc7ec89f68fba46eab6abd4ad380236c6706 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 9 May 2024 22:09:54 -0700 Subject: [PATCH 015/112] a first bad basis implementation --- src/models/ace.jl | 64 ++++++++++++++++++++++++++++++++++++++++- test/models/test_ace.jl | 50 +++++++++++++++++++++++++------- 2 files changed, 103 insertions(+), 11 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index 854eedce..9823a0cb 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -408,6 +408,68 @@ function pullback_2_mixed(Δ, Δd, model::ACEModel, return _restruct(Δ * ∂Ei + ∂∇Ei_Δd) end +# ------------------------------------------------------------ +# ACE basis evaluation + + +function get_basis_inds(model::ACEModel, Z) + len_Bi = size(model.A2Bmap, 1) + i_z = _z2i(model.rbasis, Z) + return (i_z - 1) * len_Bi .+ (1:len_Bi) +end + +function evaluate_basis(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + # get the radii + rs = [ norm(r) for r in Rs ] # use Bumper + # evaluate the radial basis + # use Bumper to pre-allocate + Rnl, _st = evaluate_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) + + # evaluate the Y basis + Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here + SpheriCart.compute!(Ylm, model.ybasis, Rs) + + # evaluate the A basis + TA = promote_type(T, eltype(Rnl)) + A = zeros(T, length(model.abasis)) + Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) + + # evaluate the AA basis + _AA = zeros(T, length(model.aabasis)) # use Bumper here + Polynomials4ML.evaluate!(_AA, model.aabasis, A) + # project to the actual AA basis + proj = model.aabasis.projection + AA = _AA[proj] # use Bumper here, or view; needs experimentation. + + # evaluate the coupling coefficients + # TODO: use Bumper and do it in-place + Bi = model.A2Bmap * AA + B = zeros(eltype(Bi), length(Bi) * _get_nz(model)) + B[get_basis_inds(model, Z0)] .= Bi - \ No newline at end of file + return B, st +end + + +function evaluate_basis_ed(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + + B, st = evaluate_basis(model, Rs, Zs, Z0, ps, st) + + _vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs) + _svecs(Rsvec::AbstractVector{T}) where {T} = reinterpret(SVector{3, T}, Rsvec) + + dB_vec = ForwardDiff.jacobian( + _Rs -> evaluate_basis(model, _svecs(_Rs), Zs, Z0, ps, st)[1], + _vec(Rs)) + dB1 = reinterpret(SVector{3, T}, collect(dB_vec')[:]) + dB = collect( permutedims( reshape(dB1, length(Rs), length(B)), + (2, 1) ) ) + + return B, dB, st +end \ No newline at end of file diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 188498fd..9e85d521 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -113,14 +113,44 @@ end ## +@info("Test basis implementation") + +for ntest = 1:30 + Nat = 15 + Rs, Zs, z0 = M.rand_atenv(model, Nat) + i_z0 = M._z2i(model, z0) + Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) + B, st = M.evaluate_basis(model, Rs, Zs, z0, ps, st) + θ = vcat(ps.WB...) + print_tf(@test Ei ≈ dot(B, θ)) + + Ei, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + B, ∇B, st = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) + θ = vcat(ps.WB...) + print_tf(@test Ei ≈ dot(B, θ)) + print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) +end + + + +## + +@info("Basic performance benchmarks") +# first test shows the performance is not at all awful even without any +# optimizations and reductions in memory allocations. +using BenchmarkTools +Nat = 15 +Rs, Zs, z0 = M.rand_atenv(model, Nat) +Us = randn(SVector{3, Float64}, Nat) + +@info("Evaluation and adjoints") +print(" evaluate : "); @btime M.evaluate($model, $Rs, $Zs, $z0, $ps, $st) +print("evaluate_ed : "); @btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) +print("grad_params : "); @btime M.grad_params($model, $Rs, $Zs, $z0, $ps, $st) +print(" reverse^2 : "); @btime M.pullback_2_mixed(rand(), $Us, $model, $Rs, $Zs, $z0, $ps, $st) + +@info("Basis evaluation ") +@info(" NB: this is currently implemented using ForwardDiff and likely inefficient") +print("evaluate_basis : "); @btime M.evaluate_basis($model, $Rs, $Zs, $z0, $ps, $st) +print("evaluate_basis_ed : "); @btime M.evaluate_basis_ed($model, $Rs, $Zs, $z0, $ps, $st) -# # first test shows the performance is not at all awful even without any -# # optimizations and reductions in memory allocations. -# using BenchmarkTools -# Nat = 15 -# Rs, Zs, z0 = M.rand_atenv(model, Nat) -# Us = randn(SVector{3, Float64}, Nat) -# print(" evaluate : "); @btime M.evaluate($model, $Rs, $Zs, $z0, $ps, $st) -# print("evaluate_ed : "); @btime M.evaluate_ed($model, $Rs, $Zs, $z0, $ps, $st) -# print("grad_params : "); @btime M.grad_params($model, $Rs, $Zs, $z0, $ps, $st) -# print(" reverse^2 : "); @btime M.pullback_2_mixed(rand(), $Us, $model, $Rs, $Zs, $z0, $ps, $st) From b9d0afde95ef644632208ccf50b8b3b5e4ec6fd6 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 9 May 2024 22:33:46 -0700 Subject: [PATCH 016/112] jacobians of forces --- src/models/ace.jl | 32 +++++++++++++++++++++++++------- test/models/test_ace.jl | 20 +++++++++++++++++--- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index 9823a0cb..1c5ed884 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -454,6 +454,8 @@ function evaluate_basis(model::ACEModel, return B, st end +__vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs) +__svecs(Rsvec::AbstractVector{T}) where {T} = reinterpret(SVector{3, T}, Rsvec) function evaluate_basis_ed(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, @@ -461,15 +463,31 @@ function evaluate_basis_ed(model::ACEModel, B, st = evaluate_basis(model, Rs, Zs, Z0, ps, st) - _vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs) - _svecs(Rsvec::AbstractVector{T}) where {T} = reinterpret(SVector{3, T}, Rsvec) - dB_vec = ForwardDiff.jacobian( - _Rs -> evaluate_basis(model, _svecs(_Rs), Zs, Z0, ps, st)[1], - _vec(Rs)) - dB1 = reinterpret(SVector{3, T}, collect(dB_vec')[:]) + _Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st)[1], + __vec(Rs)) + dB1 = __svecs(collect(dB_vec')[:]) dB = collect( permutedims( reshape(dB1, length(Rs), length(B)), (2, 1) ) ) return B, dB, st -end \ No newline at end of file +end + + + +function jacobian_grad_params(model::ACEModel, + Rs::AbstractVector{SVector{3, T}}, Zs, Z0, + ps, st) where {T} + + Ei, ∂Ei, st = grad_params(model, Rs, Zs, Z0, ps, st) + ∂∂Ei_vec = ForwardDiff.jacobian( _Rs -> ( + destructure( grad_params(model, __svecs(_Rs), Zs, Z0, ps, st)[2] )[1] + ), + __vec(Rs)) + ∂Ei_vec = destructure(∂Ei)[1] + ∂∂Ei = collect( permutedims( + reshape( __svecs((∂∂Ei_vec')[:]), length(Rs), length(∂Ei_vec) ), + (2, 1) ) ) + return Ei, ∂Ei_vec, ∂∂Ei, st +end + diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 9e85d521..6e48564a 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -8,7 +8,7 @@ using ACEbase.Testing: print_tf, println_slim using ACEpotentials M = ACEpotentials.Models -using Optimisers +using Optimisers, ForwardDiff using Random, LuxCore, StaticArrays, LinearAlgebra rng = Random.MersenneTwister(1234) @@ -131,6 +131,19 @@ for ntest = 1:30 print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) end +## + +@info("Test the full mixed jacobian") + +for ntest = 1:30 + Nat = 15 + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat) + F(t) = destructure( M.grad_params(model, Rs + t * Us, Zs, z0, ps, st)[2] )[1] + dF0 = ForwardDiff.derivative(F, 0.0) + ∂∂Ei = M.jacobian_grad_params(model, Rs, Zs, z0, ps, st)[3] + print_tf(@test dF0 ≈ transpose.(∂∂Ei) * Us) +end ## @@ -151,6 +164,7 @@ print(" reverse^2 : "); @btime M.pullback_2_mixed(rand(), $Us, $model, $Rs, $Zs @info("Basis evaluation ") @info(" NB: this is currently implemented using ForwardDiff and likely inefficient") -print("evaluate_basis : "); @btime M.evaluate_basis($model, $Rs, $Zs, $z0, $ps, $st) -print("evaluate_basis_ed : "); @btime M.evaluate_basis_ed($model, $Rs, $Zs, $z0, $ps, $st) +print(" evaluate_basis : "); @btime M.evaluate_basis($model, $Rs, $Zs, $z0, $ps, $st) +print(" evaluate_basis_ed : "); @btime M.evaluate_basis_ed($model, $Rs, $Zs, $z0, $ps, $st) +print("jacobian_grad_params : "); @btime M.jacobian_grad_params($model, $Rs, $Zs, $z0, $ps, $st) From b1fd8240bdae33da19c29bcff26098bb572f44bf Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 10 May 2024 12:19:40 -0700 Subject: [PATCH 017/112] towards splinifying the basis --- Project.toml | 1 + src/models/Rnl_basis.jl | 320 +++++++++--------------------- src/models/Rnl_learnable.jl | 205 +++++++++++++++++++ src/models/Rnl_splines.jl | 193 ++++++++++++++++++ src/models/elements.jl | 1 + src/models/models.jl | 4 +- test/models/test_learnable_Rnl.jl | 8 + 7 files changed, 500 insertions(+), 232 deletions(-) create mode 100644 src/models/Rnl_learnable.jl create mode 100644 src/models/Rnl_splines.jl diff --git a/Project.toml b/Project.toml index 1f6bb74e..ddcc1675 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index ed04d971..2a2ad1b2 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -2,19 +2,32 @@ import LuxCore: AbstractExplicitLayer, initialparameters, initialstates -using StaticArrays: SMatrix +using StaticArrays: SMatrix, SVector using Random: AbstractRNG -abstract type AbstractRnlzzBasis <: AbstractExplicitLayer end +import OffsetArrays, Interpolations +using Interpolations: cubic_spline_interpolation, BSpline, Cubic, Line, OnGrid + # NOTEs: -# each smatrix in the types below indexes (i, j) +# each smatrix in the Rnl types indexes (i, j) # where i is the center, j is neighbour const NT_RIN0CUTS{T} = NamedTuple{(:rin, :r0, :rcut), Tuple{T, T, T}} const NT_NL_SPEC = NamedTuple{(:n, :l), Tuple{Int, Int}} - -struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractRnlzzBasis +const SPL_OF_SVEC{DIM, T} = + Interpolations.Extrapolation{SVector{DIM, T}, 1, + Interpolations.ScaledInterpolation{SVector{DIM, T}, 1, + Interpolations.BSplineInterpolation{SVector{DIM, T}, 1, + OffsetArrays.OffsetVector{SVector{DIM, T}, Vector{SVector{DIM, T}}}, + BSpline{Cubic{Line{OnGrid}}}, Tuple{Base.OneTo{Int}}}, + BSpline{Cubic{Line{OnGrid}}}, + Tuple{StepRangeLen{T, Base.TwicePrecision{T}, Base.TwicePrecision{T}, Int}}}, + BSpline{Cubic{Line{OnGrid}}}, Interpolations.Throw{Nothing} + } + + +struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractExplicitLayer _i2z::NTuple{NZ, Int} polys::TPOLY transforms::SMatrix{NZ, NZ, TT} @@ -28,248 +41,93 @@ struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractRnlzzBasis meta::Dict{String, Any} end +function set_params(basis::LearnableRnlrzzBasis, ps) + return LearnableRnlrzzBasis(basis._i2z, + basis.polys, + basis.transforms, + basis.envelopes, + # --------------- + _make_smatrix(ps.Wnlq, _get_nz(basis)), + basis.rin0cuts, + basis.spec, + # --------------- + basis.meta) +end -# struct SplineRnlrzzBasis{NZ, SPL, ENV} <: AbstractRnlzzBasis -# _i2z::NTuple{NZ, Int} # iz -> z mapping -# splines::SMatrix{NZ, NZ, SPL} # matrix of splined radial bases -# envelopes::SMatrix{NZ, NZ, ENV} # matrix of radial envelopes -# rincut::SMatrix{NZ, NZ, Tuple{T, T}} # matrix of (rin, rout) -# #-------------- -# # meta should contain spec -# meta::Dict{String, Any} -# end +struct SplineRnlrzzBasis{NZ, TT, TENV, LEN, T} <: AbstractExplicitLayer + _i2z::NTuple{NZ, Int} + transforms::SMatrix{NZ, NZ, TT} + envelopes::SMatrix{NZ, NZ, TENV} + splines::SMatrix{NZ, NZ, SPL_OF_SVEC{LEN, T}} + # -------------- + rin0cuts::SMatrix{NZ, NZ, NT_RIN0CUTS{T}} # matrix of (rin, rout, rcut) + spec::Vector{NT_NL_SPEC} + # -------------- + meta::Dict{String, Any} +end + # a few getter functions for convenient access to those fields of matrices -_rincut_zz(obj, zi, zj) = obj.rin0cut[_z2i(obj, zi), _z2i(obj, zj)] +_rincut_zz(obj, zi, zj) = obj.rin0cuts[_z2i(obj, zi), _z2i(obj, zj)] +_rin0cuts_zz(obj, zi, zj) = obj.rin0cuts[_z2i(obj, zi), _z2i(obj, zj)] +_rcut_zz(obj, zi, zj) = obj.rin0cuts[_z2i(obj, zi), _z2i(obj, zj)].rcut +_rin_zz(obj, zi, zj) = obj.rin0cuts[_z2i(obj, zi), _z2i(obj, zj)].rin +_r0_zz(obj, zi, zj) = obj.rin0cuts[_z2i(obj, zi), _z2i(obj, zj)].r0 _envelope_zz(obj, zi, zj) = obj.envelopes[_z2i(obj, zi), _z2i(obj, zj)] _spline_zz(obj, zi, zj) = obj.splines[_z2i(obj, zi), _z2i(obj, zj)] _transform_zz(obj, zi, zj) = obj.transforms[_z2i(obj, zi), _z2i(obj, zj)] -# _polys_zz(obj, zi, zj) = obj.polys[_z2i(obj, zi), _z2i(obj, zj)] - - -# ------------------------------------------------------------ -# CONSTRUCTORS AND UTILITIES -# ------------------------------------------------------------ - -# these _auto_... are very poor and need to take care of a lot more -# cases, e.g. we may want to pass in the objects as a Matrix rather than -# SMatrix ... - - - - -function LearnableRnlrzzBasis( - zlist, polys, transforms, envelopes, rin0cuts, - spec::AbstractVector{NT_NL_SPEC}; - weights=nothing, - meta=Dict{String, Any}()) - NZ = length(zlist) - LearnableRnlrzzBasis(_convert_zlist(zlist), - polys, - _make_smatrix(transforms, NZ), - _make_smatrix(envelopes, NZ), - # -------------- - _make_smatrix(weights, NZ), - _make_smatrix(rin0cuts, NZ), - collect(spec), - meta) -end - -Base.length(basis::LearnableRnlrzzBasis) = length(basis.spec) - -function initialparameters(rng::AbstractRNG, - basis::LearnableRnlrzzBasis) - NZ = _get_nz(basis) - len_nl = length(basis) - len_q = length(basis.polys) - - function _W() - W = randn(rng, len_nl, len_q) - W = W ./ sqrt.(sum(W.^2, dims = 2)) - end - - return (Wnlq = [ _W() for i = 1:NZ, j = 1:NZ ], ) -end - -function initialstates(rng::AbstractRNG, - basis::LearnableRnlrzzBasis) - return NamedTuple() -end - - - -function splinify(basis::LearnableRnlrzzBasis) - -end -# ------------------------------------------------------------ -# EVALUATION INTERFACE -# ------------------------------------------------------------ +_get_T(basis::LearnableRnlrzzBasis) = typeof(basis.rin0cuts[1,1].rin) -import Polynomials4ML +function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) -(l::LearnableRnlrzzBasis)(args...) = evaluate(l, args...) + # transform : r ∈ [rin, rcut] -> x + # and then Rnl = Wnl_q * Pq(x) * env(x) gives the basis. + # The problem with this is that we cannot evaluate the envelope from just + # r coordinates. We therefore keep the transform inside the splinified + # basis and only splinify the last operation, x -> Rnl(x) + # this also has the potential advantage that few spline points are needed, + # and that we get access to the same meta-information about the model building. + # + # in the following we assume all transforms map [rin, rcut] -> [-1, 1] -function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) - iz = _z2i(basis, Zi) - jz = _z2i(basis, Zj) - Wij = ps.Wnlq[iz, jz] - trans_ij = basis.transforms[iz, jz] - x = trans_ij(r) - P = Polynomials4ML.evaluate(basis.polys, x) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - Rnl[:] .= Wij * (P .* e) - return Rnl, st -end - -function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) - iz = _z2i(basis, Zi) - jz = _z2i(basis, Zj) - Wij = ps.Wnlq[iz, jz] - trans_ij = basis.transforms[iz, jz] - x = trans_ij(r) - P = Polynomials4ML.evaluate(basis.polys, x) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - return Wij * (P .* e), st -end - - -# function evaluate_batched(basis::LearnableRnlrzzBasis, -# rs::AbstractVector{<: Real}, zi, zjs, ps, st) -# @assert length(rs) == length(zjs) -# # evaluate the first one to get the types and size -# Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) -# # allocate storage -# Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) -# # then evaluate the rest in-place -# for j = 1:length(rs) -# evaluate!((@view Rnl[j, :]), basis, rs[j], zi, zjs[j], ps, st) -# end -# return Rnl, st -# end - -function evaluate_batched(basis::LearnableRnlrzzBasis, - rs, zi, zjs, ps, st) - - @assert length(rs) == length(zjs) - # evaluate the first one to get the types and size - Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) - # ... and then allocate storage - Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) - - # then evaluate the rest in-place - for j = 1:length(rs) - iz = _z2i(basis, zi) - jz = _z2i(basis, zjs[j]) - trans_ij = basis.transforms[iz, jz] - x = trans_ij(rs[j]) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - P = Polynomials4ML.evaluate(basis.polys, x) .* e - Rnl[j, :] = ps.Wnlq[iz, jz] * P - end - - return Rnl, st -end - - - - -# ----- gradients -# because the typical scenario is that we have few r, then moderately -# many q and then many (n, l), this seems to be best done in Forward-mode. - -import ForwardDiff -using ForwardDiff: Dual - -function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} - d_r = Dual{T}(r, one(T)) - d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) - Rnl = ForwardDiff.value.(d_Rnl) - Rnl_d = ForwardDiff.extract_derivative(T, d_Rnl) - return Rnl, Rnl_d, st -end - - -function evaluate_ed_batched(basis::LearnableRnlrzzBasis, - rs::AbstractVector{T}, Zi, Zs, ps, st - ) where {T <: Real} - - @assert length(rs) == length(Zs) - Rnl1, st = evaluate(basis, rs[1], Zi, Zs[1], ps, st) - Rnl = zeros(T, length(rs), length(Rnl1)) - Rnl_d = zeros(T, length(rs), length(Rnl1)) - - for j = 1:length(rs) - d_r = Dual{T}(rs[j], one(T)) - d_Rnl, st = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here - map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) - map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) - end - - return Rnl, Rnl_d, st -end - - - - -# -------- RRULES - -import ChainRulesCore: rrule, NotImplemented, NoTangent - -# NB : iz = īz = _z2i(z0) throughout -# -# Rnl[j, nl] = Wnlq[iz, jz] * Pq * e -# ∂_Wn̄l̄q̄[īz,j̄z] { ∑_jnl Δ[j,nl] * Rnl[j, nl] } -# = ∑_jnl Δ[j,nl] * Pq * e * δ_q̄q * δ_l̄l * δ_n̄n * δ_{īz,iz} * δ_{j̄z,jz} -# = ∑_{jz = j̄z} Δ[j̄z, n̄l̄] * P_q̄ * e -# -function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, - rs, zi, zjs, ps, st) - @assert length(rs) == length(zjs) - # evaluate the first one to get the types and size - Rnl_1, _ = evaluate(basis, rs[1], zi, zjs[1], ps, st) - # ... and then allocate storage - Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) - - # output storage for the gradients - T_∂Wnlq = promote_type(eltype(Δ), eltype(rs)) NZ = _get_nz(basis) - ∂Wnlq = [ zeros(T_∂Wnlq, size(ps.Wnlq[i,j])) - for i = 1:NZ, j = 1:NZ ] - - # then evaluate the rest in-place - for j = 1:length(rs) - iz = _z2i(basis, zi) - jz = _z2i(basis, zjs[j]) - trans_ij = basis.transforms[iz, jz] - x = trans_ij(rs[j]) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - P = Polynomials4ML.evaluate(basis.polys, x) .* e - # TODO: the P shouuld be stored inside a closure in the - # forward pass and then resused. - - # TODO: ... and obviously this part here needs to be moved - # to a SIMD loop. - ∂Wnlq[iz, jz][:, :] .+= Δ[j, :] * P' + T = _get_T(basis) + LEN = size(basis.weights[1, 1], 1) + _splines = Matrix{SPL_OF_SVEC{LEN, T}}(undef, (NZ, NZ)) + x_nodes = range(-1.0, 1.0, length = nnodes) + polys = basis.polys + + for iz0 = 1:NZ, iz1 = 1:NZ + rin0cut = basis.rin0cuts[iz0, iz1] + rin, rcut = rin0cut.rin, rin0cut.rcut + + Tij = basis.transforms[iz0, iz1] + Wnlq_ij = basis.weights[iz0, iz1] + Rnl = [ SVector{LEN}( Wnlq_ij * Polynomials4ML.evaluate(polys, x) ) + for x in x_nodes ] + + # now we need to spline the Rnl + splines_ij = cubic_spline_interpolation(x_nodes, Rnl) + _splines[iz0, iz1] = splines_ij end - return (Wnql = ∂Wnlq,) -end + splines = SMatrix{NZ, NZ, SPL_OF_SVEC{LEN, T}}(_splines) + + spl_basis = SplineRnlrzzBasis(basis._i2z, + basis.transforms, + basis.envelopes, + splines, + basis.rin0cuts, + basis.spec, + basis.meta) + spl_basis.meta["info"] = "constructed from LearnableRnlrzzBasis via `splinify`" -function rrule(::typeof(evaluate_batched), - basis::LearnableRnlrzzBasis, - rs, zi, zjs, ps, st) - Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) + # we should probably store more meta-data from which the splines can be + # easily reconstructed. - return (Rnl, st), - Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), - pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), - NoTangent()) + return spl_basis end \ No newline at end of file diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl new file mode 100644 index 00000000..8229785f --- /dev/null +++ b/src/models/Rnl_learnable.jl @@ -0,0 +1,205 @@ + + + +# ------------------------------------------------------------ +# CONSTRUCTORS AND UTILITIES +# ------------------------------------------------------------ + +function LearnableRnlrzzBasis( + zlist, polys, transforms, envelopes, rin0cuts, + spec::AbstractVector{NT_NL_SPEC}; + weights=nothing, + meta=Dict{String, Any}()) + NZ = length(zlist) + LearnableRnlrzzBasis(_convert_zlist(zlist), + polys, + _make_smatrix(transforms, NZ), + _make_smatrix(envelopes, NZ), + # -------------- + _make_smatrix(weights, NZ), + _make_smatrix(rin0cuts, NZ), + collect(spec), + meta) +end + +Base.length(basis::LearnableRnlrzzBasis) = length(basis.spec) + +function initialparameters(rng::AbstractRNG, + basis::LearnableRnlrzzBasis) + NZ = _get_nz(basis) + len_nl = length(basis) + len_q = length(basis.polys) + + function _W() + W = randn(rng, len_nl, len_q) + W = W ./ sqrt.(sum(W.^2, dims = 2)) + end + + return (Wnlq = [ _W() for i = 1:NZ, j = 1:NZ ], ) +end + +function initialstates(rng::AbstractRNG, + basis::LearnableRnlrzzBasis) + return NamedTuple() +end + + + + +# ------------------------------------------------------------ +# EVALUATION INTERFACE +# ------------------------------------------------------------ + +import Polynomials4ML + +(l::LearnableRnlrzzBasis)(args...) = evaluate(l, args...) + +function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) + iz = _z2i(basis, Zi) + jz = _z2i(basis, Zj) + Wij = ps.Wnlq[iz, jz] + trans_ij = basis.transforms[iz, jz] + x = trans_ij(r) + P = Polynomials4ML.evaluate(basis.polys, x) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + Rnl[:] .= Wij * (P .* e) + return Rnl, st +end + +function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) + iz = _z2i(basis, Zi) + jz = _z2i(basis, Zj) + Wij = ps.Wnlq[iz, jz] + trans_ij = basis.transforms[iz, jz] + x = trans_ij(r) + P = Polynomials4ML.evaluate(basis.polys, x) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + return Wij * (P .* e), st +end + + +function evaluate_batched(basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + + @assert length(rs) == length(zjs) + # evaluate the first one to get the types and size + Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) + # ... and then allocate storage + Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + + # then evaluate the rest in-place + for j = 1:length(rs) + iz = _z2i(basis, zi) + jz = _z2i(basis, zjs[j]) + trans_ij = basis.transforms[iz, jz] + x = trans_ij(rs[j]) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + P = Polynomials4ML.evaluate(basis.polys, x) .* e + Rnl[j, :] = ps.Wnlq[iz, jz] * P + end + + return Rnl, st +end + + + + +# ----- gradients +# because the typical scenario is that we have few r, then moderately +# many q and then many (n, l), this seems to be best done in Forward-mode. +# in initial tests it seems the performance is very near optimal +# so there is little sense trying to do something manual. + +import ForwardDiff +using ForwardDiff: Dual + +function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} + d_r = Dual{T}(r, one(T)) + d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) + Rnl = ForwardDiff.value.(d_Rnl) + Rnl_d = ForwardDiff.extract_derivative(T, d_Rnl) + return Rnl, Rnl_d, st +end + + +function evaluate_ed_batched(basis::LearnableRnlrzzBasis, + rs::AbstractVector{T}, Zi, Zs, ps, st + ) where {T <: Real} + + @assert length(rs) == length(Zs) + Rnl1, st = evaluate(basis, rs[1], Zi, Zs[1], ps, st) + Rnl = zeros(T, length(rs), length(Rnl1)) + Rnl_d = zeros(T, length(rs), length(Rnl1)) + + for j = 1:length(rs) + d_r = Dual{T}(rs[j], one(T)) + d_Rnl, st = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here + map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) + map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) + end + + return Rnl, Rnl_d, st +end + + + + +# -------- RRULES + +import ChainRulesCore: rrule, NotImplemented, NoTangent + +# NB : iz = īz = _z2i(z0) throughout +# +# Rnl[j, nl] = Wnlq[iz, jz] * Pq * e +# ∂_Wn̄l̄q̄[īz,j̄z] { ∑_jnl Δ[j,nl] * Rnl[j, nl] } +# = ∑_jnl Δ[j,nl] * Pq * e * δ_q̄q * δ_l̄l * δ_n̄n * δ_{īz,iz} * δ_{j̄z,jz} +# = ∑_{jz = j̄z} Δ[j̄z, n̄l̄] * P_q̄ * e +# +function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + @assert length(rs) == length(zjs) + # evaluate the first one to get the types and size + Rnl_1, _ = evaluate(basis, rs[1], zi, zjs[1], ps, st) + # ... and then allocate storage + Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + + # output storage for the gradients + T_∂Wnlq = promote_type(eltype(Δ), eltype(rs)) + NZ = _get_nz(basis) + ∂Wnlq = [ zeros(T_∂Wnlq, size(ps.Wnlq[i,j])) + for i = 1:NZ, j = 1:NZ ] + + # then evaluate the rest in-place + for j = 1:length(rs) + iz = _z2i(basis, zi) + jz = _z2i(basis, zjs[j]) + trans_ij = basis.transforms[iz, jz] + x = trans_ij(rs[j]) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + P = Polynomials4ML.evaluate(basis.polys, x) .* e + # TODO: the P shouuld be stored inside a closure in the + # forward pass and then resused. + + # TODO: ... and obviously this part here needs to be moved + # to a SIMD loop. + ∂Wnlq[iz, jz][:, :] .+= Δ[j, :] * P' + end + + return (Wnql = ∂Wnlq,) +end + + +function rrule(::typeof(evaluate_batched), + basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) + + return (Rnl, st), + Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), + pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), + NoTangent()) +end \ No newline at end of file diff --git a/src/models/Rnl_splines.jl b/src/models/Rnl_splines.jl new file mode 100644 index 00000000..5867921c --- /dev/null +++ b/src/models/Rnl_splines.jl @@ -0,0 +1,193 @@ + + +# ------------------------------------------------------------ +# CONSTRUCTORS AND UTILITIES +# ------------------------------------------------------------ + + +Base.length(basis::SplineRnlrzzBasis) = length(basis.spec) + +function initialparameters(rng::AbstractRNG, + basis::SplineRnlrzzBasis) + return NamedTuple() +end + +function initialstates(rng::AbstractRNG, + basis::SplineRnlrzzBasis) + return NamedTuple() +end + + +# ------------------------------------------------------------ +# EVALUATION INTERFACE +# ------------------------------------------------------------ + +#= +(l::SplineRnlrzzBasis)(args...) = evaluate(l, args...) + +function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) + iz = _z2i(basis, Zi) + jz = _z2i(basis, Zj) + spl_ij = basis.splines[iz, jz] + env_ij = basis.envelopes[iz, jz] + Rnl[:] .= evaluate(spl_ij, r) .* evaluate(env_ij, r) + + trans_ij = basis.transforms[iz, jz] + x = trans_ij(r) + P = Polynomials4ML.evaluate(basis.polys, x) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + Rnl[:] .= Wij * (P .* e) + return Rnl, st +end + +function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) + iz = _z2i(basis, Zi) + jz = _z2i(basis, Zj) + Wij = ps.Wnlq[iz, jz] + trans_ij = basis.transforms[iz, jz] + x = trans_ij(r) + P = Polynomials4ML.evaluate(basis.polys, x) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + return Wij * (P .* e), st +end + + +# function evaluate_batched(basis::LearnableRnlrzzBasis, +# rs::AbstractVector{<: Real}, zi, zjs, ps, st) +# @assert length(rs) == length(zjs) +# # evaluate the first one to get the types and size +# Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) +# # allocate storage +# Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) +# # then evaluate the rest in-place +# for j = 1:length(rs) +# evaluate!((@view Rnl[j, :]), basis, rs[j], zi, zjs[j], ps, st) +# end +# return Rnl, st +# end + +function evaluate_batched(basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + + @assert length(rs) == length(zjs) + # evaluate the first one to get the types and size + Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) + # ... and then allocate storage + Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + + # then evaluate the rest in-place + for j = 1:length(rs) + iz = _z2i(basis, zi) + jz = _z2i(basis, zjs[j]) + trans_ij = basis.transforms[iz, jz] + x = trans_ij(rs[j]) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + P = Polynomials4ML.evaluate(basis.polys, x) .* e + Rnl[j, :] = ps.Wnlq[iz, jz] * P + end + + return Rnl, st +end + + + + +# ----- gradients +# because the typical scenario is that we have few r, then moderately +# many q and then many (n, l), this seems to be best done in Forward-mode. + +import ForwardDiff +using ForwardDiff: Dual + +function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} + d_r = Dual{T}(r, one(T)) + d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) + Rnl = ForwardDiff.value.(d_Rnl) + Rnl_d = ForwardDiff.extract_derivative(T, d_Rnl) + return Rnl, Rnl_d, st +end + + +function evaluate_ed_batched(basis::LearnableRnlrzzBasis, + rs::AbstractVector{T}, Zi, Zs, ps, st + ) where {T <: Real} + + @assert length(rs) == length(Zs) + Rnl1, st = evaluate(basis, rs[1], Zi, Zs[1], ps, st) + Rnl = zeros(T, length(rs), length(Rnl1)) + Rnl_d = zeros(T, length(rs), length(Rnl1)) + + for j = 1:length(rs) + d_r = Dual{T}(rs[j], one(T)) + d_Rnl, st = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here + map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) + map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) + end + + return Rnl, Rnl_d, st +end + + + + +# -------- RRULES + +import ChainRulesCore: rrule, NotImplemented, NoTangent + +# NB : iz = īz = _z2i(z0) throughout +# +# Rnl[j, nl] = Wnlq[iz, jz] * Pq * e +# ∂_Wn̄l̄q̄[īz,j̄z] { ∑_jnl Δ[j,nl] * Rnl[j, nl] } +# = ∑_jnl Δ[j,nl] * Pq * e * δ_q̄q * δ_l̄l * δ_n̄n * δ_{īz,iz} * δ_{j̄z,jz} +# = ∑_{jz = j̄z} Δ[j̄z, n̄l̄] * P_q̄ * e +# +function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + @assert length(rs) == length(zjs) + # evaluate the first one to get the types and size + Rnl_1, _ = evaluate(basis, rs[1], zi, zjs[1], ps, st) + # ... and then allocate storage + Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + + # output storage for the gradients + T_∂Wnlq = promote_type(eltype(Δ), eltype(rs)) + NZ = _get_nz(basis) + ∂Wnlq = [ zeros(T_∂Wnlq, size(ps.Wnlq[i,j])) + for i = 1:NZ, j = 1:NZ ] + + # then evaluate the rest in-place + for j = 1:length(rs) + iz = _z2i(basis, zi) + jz = _z2i(basis, zjs[j]) + trans_ij = basis.transforms[iz, jz] + x = trans_ij(rs[j]) + env_ij = basis.envelopes[iz, jz] + e = evaluate(env_ij, x) + P = Polynomials4ML.evaluate(basis.polys, x) .* e + # TODO: the P shouuld be stored inside a closure in the + # forward pass and then resused. + + # TODO: ... and obviously this part here needs to be moved + # to a SIMD loop. + ∂Wnlq[iz, jz][:, :] .+= Δ[j, :] * P' + end + + return (Wnql = ∂Wnlq,) +end + + +function rrule(::typeof(evaluate_batched), + basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) + + return (Rnl, st), + Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), + pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), + NoTangent()) +end + +=# \ No newline at end of file diff --git a/src/models/elements.jl b/src/models/elements.jl index 99b2aa44..0ec4354b 100644 --- a/src/models/elements.jl +++ b/src/models/elements.jl @@ -59,6 +59,7 @@ function _make_smatrix(obj, NZ) return SMatrix{NZ, NZ}(fill(obj, (NZ, NZ))) end + # a one-hot embedding for the z variable. # function embed_z(ace, Rs, Zs) # TF = eltype(eltype(Rs)) diff --git a/src/models/models.jl b/src/models/models.jl index e436851a..8f0b4ea5 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -10,9 +10,11 @@ include("radial_envelopes.jl") include("radial_transforms.jl") include("Rnl_basis.jl") +include("Rnl_learnable.jl") +include("Rnl_splines.jl") -include("ace_heuristics.jl") +include("ace_heuristics.jl") include("ace.jl") end \ No newline at end of file diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 974c5f3c..34ce3f76 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -50,3 +50,11 @@ rs = [r, r, r] Zs = [Zj, Zj, Zj] Rnl, Rnl_d, st1 = M.evaluate_ed_batched(basis, rs, Zi, Zs, ps, st) +# more tests needed to check the correctness of the batched version +# can we implement some tests that check consistent with ACE1 in special cases? + +## + +basisp = M.set_params(basis, ps) +splb = M.splinify(basisp) + From 8c2918cabe29e2aeeba5b842a3ba620b12ca7163 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 10 May 2024 14:48:02 -0700 Subject: [PATCH 018/112] full implementation and test of splines --- src/models/Rnl_splines.jl | 135 +++++++----------------------- src/models/utils.jl | 14 ++-- test/models/test_Rnl.jl | 132 +++++++++++++++++++++++++++++ test/models/test_learnable_Rnl.jl | 42 +++++++--- 4 files changed, 201 insertions(+), 122 deletions(-) create mode 100644 test/models/test_Rnl.jl diff --git a/src/models/Rnl_splines.jl b/src/models/Rnl_splines.jl index 5867921c..12592897 100644 --- a/src/models/Rnl_splines.jl +++ b/src/models/Rnl_splines.jl @@ -22,53 +22,30 @@ end # EVALUATION INTERFACE # ------------------------------------------------------------ -#= (l::SplineRnlrzzBasis)(args...) = evaluate(l, args...) -function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) - iz = _z2i(basis, Zi) - jz = _z2i(basis, Zj) - spl_ij = basis.splines[iz, jz] - env_ij = basis.envelopes[iz, jz] - Rnl[:] .= evaluate(spl_ij, r) .* evaluate(env_ij, r) - trans_ij = basis.transforms[iz, jz] - x = trans_ij(r) - P = Polynomials4ML.evaluate(basis.polys, x) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - Rnl[:] .= Wij * (P .* e) - return Rnl, st -end +# function evaluate!(Rnl, basis::SplineRnlrzzBasis, r::Real, Zi, Zj, ps, st) +# Rnl[:] .= evaluate(basis, r, Zi, Zj, ps, st) +# return Rnl, st +# end -function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) + +function evaluate(basis::SplineRnlrzzBasis, r::Real, Zi, Zj, ps, st) iz = _z2i(basis, Zi) jz = _z2i(basis, Zj) - Wij = ps.Wnlq[iz, jz] - trans_ij = basis.transforms[iz, jz] - x = trans_ij(r) - P = Polynomials4ML.evaluate(basis.polys, x) + T_ij = basis.transforms[iz, jz] env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - return Wij * (P .* e), st -end + spl_ij = basis.splines[iz, jz] + x_ij = T_ij(r) + e_ij = evaluate(env_ij, x_ij) + + return spl_ij(x_ij) * e_ij, st +end -# function evaluate_batched(basis::LearnableRnlrzzBasis, -# rs::AbstractVector{<: Real}, zi, zjs, ps, st) -# @assert length(rs) == length(zjs) -# # evaluate the first one to get the types and size -# Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) -# # allocate storage -# Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) -# # then evaluate the rest in-place -# for j = 1:length(rs) -# evaluate!((@view Rnl[j, :]), basis, rs[j], zi, zjs[j], ps, st) -# end -# return Rnl, st -# end -function evaluate_batched(basis::LearnableRnlrzzBasis, +function evaluate_batched(basis::SplineRnlrzzBasis, rs, zi, zjs, ps, st) @assert length(rs) == length(zjs) @@ -76,33 +53,26 @@ function evaluate_batched(basis::LearnableRnlrzzBasis, Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) # ... and then allocate storage Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + Rnl[1, :] .= Rnl_1 # then evaluate the rest in-place - for j = 1:length(rs) - iz = _z2i(basis, zi) - jz = _z2i(basis, zjs[j]) - trans_ij = basis.transforms[iz, jz] - x = trans_ij(rs[j]) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - P = Polynomials4ML.evaluate(basis.polys, x) .* e - Rnl[j, :] = ps.Wnlq[iz, jz] * P + for j = 2:length(rs) + Rnl[j, :], st = evaluate(basis, rs[j], zi, zjs[j], ps, st) end return Rnl, st end - - # ----- gradients # because the typical scenario is that we have few r, then moderately # many q and then many (n, l), this seems to be best done in Forward-mode. + import ForwardDiff using ForwardDiff: Dual -function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} +function evaluate_ed(basis::SplineRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} d_r = Dual{T}(r, one(T)) d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) Rnl = ForwardDiff.value.(d_Rnl) @@ -111,20 +81,22 @@ function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T end -function evaluate_ed_batched(basis::LearnableRnlrzzBasis, + +function evaluate_ed_batched(basis::SplineRnlrzzBasis, rs::AbstractVector{T}, Zi, Zs, ps, st ) where {T <: Real} @assert length(rs) == length(Zs) - Rnl1, st = evaluate(basis, rs[1], Zi, Zs[1], ps, st) + Rnl1, ∇Rnl1, st = evaluate_ed(basis, rs[1], Zi, Zs[1], ps, st) Rnl = zeros(T, length(rs), length(Rnl1)) Rnl_d = zeros(T, length(rs), length(Rnl1)) + Rnl[1, :] .= Rnl1 + Rnl_d[1, :] .= ∇Rnl1 for j = 1:length(rs) - d_r = Dual{T}(rs[j], one(T)) - d_Rnl, st = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here - map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) - map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) + Rnl_j, ∇Rnl_j, st = evaluate_ed(basis, rs[j], Zi, Zs[j], ps, st) + Rnl[j, :] = Rnl_j + Rnl_d[j, :] = ∇Rnl_j end return Rnl, Rnl_d, st @@ -133,61 +105,12 @@ end -# -------- RRULES - -import ChainRulesCore: rrule, NotImplemented, NoTangent - -# NB : iz = īz = _z2i(z0) throughout -# -# Rnl[j, nl] = Wnlq[iz, jz] * Pq * e -# ∂_Wn̄l̄q̄[īz,j̄z] { ∑_jnl Δ[j,nl] * Rnl[j, nl] } -# = ∑_jnl Δ[j,nl] * Pq * e * δ_q̄q * δ_l̄l * δ_n̄n * δ_{īz,iz} * δ_{j̄z,jz} -# = ∑_{jz = j̄z} Δ[j̄z, n̄l̄] * P_q̄ * e -# -function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, - rs, zi, zjs, ps, st) - @assert length(rs) == length(zjs) - # evaluate the first one to get the types and size - Rnl_1, _ = evaluate(basis, rs[1], zi, zjs[1], ps, st) - # ... and then allocate storage - Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) - - # output storage for the gradients - T_∂Wnlq = promote_type(eltype(Δ), eltype(rs)) - NZ = _get_nz(basis) - ∂Wnlq = [ zeros(T_∂Wnlq, size(ps.Wnlq[i,j])) - for i = 1:NZ, j = 1:NZ ] - - # then evaluate the rest in-place - for j = 1:length(rs) - iz = _z2i(basis, zi) - jz = _z2i(basis, zjs[j]) - trans_ij = basis.transforms[iz, jz] - x = trans_ij(rs[j]) - env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) - P = Polynomials4ML.evaluate(basis.polys, x) .* e - # TODO: the P shouuld be stored inside a closure in the - # forward pass and then resused. - - # TODO: ... and obviously this part here needs to be moved - # to a SIMD loop. - ∂Wnlq[iz, jz][:, :] .+= Δ[j, :] * P' - end - - return (Wnql = ∂Wnlq,) -end - - function rrule(::typeof(evaluate_batched), - basis::LearnableRnlrzzBasis, + basis::SplineRnlrzzBasis, rs, zi, zjs, ps, st) Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) return (Rnl, st), Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), - pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), - NoTangent()) + NamedTuple(), NoTangent()) end - -=# \ No newline at end of file diff --git a/src/models/utils.jl b/src/models/utils.jl index 213c20d9..de8d12da 100644 --- a/src/models/utils.jl +++ b/src/models/utils.jl @@ -83,16 +83,18 @@ end import ACE1 -function rand_atenv(model, Nat) - z0 = rand(model._i2z) - zs = rand(model._i2z, Nat) +rand_atenv(model::ACEModel, Nat) = rand_atenv(model.rbasis, Nat) + +function rand_atenv(rbasis::Union{LearnableRnlrzzBasis, SplineRnlrzzBasis}, Nat) + z0 = rand(rbasis._i2z) + zs = rand(rbasis._i2z, Nat) rs = Float64[] for zj in zs - iz0 = _z2i(model, z0) - izj = _z2i(model, zj) + iz0 = _z2i(rbasis, z0) + izj = _z2i(rbasis, zj) x = 2 * rand() - 1 - t_ij = model.rbasis.transforms[iz0, izj] + t_ij = rbasis.transforms[iz0, izj] r_ij = inv_transform(t_ij, x) push!(rs, r_ij) end diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl new file mode 100644 index 00000000..08a0af62 --- /dev/null +++ b/test/models/test_Rnl.jl @@ -0,0 +1,132 @@ + + + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); + +using ACEpotentials +M = ACEpotentials.Models + +using Random, LuxCore, Test, ACEbase, LinearAlgebra +using ACEbase.Testing: print_tf +rng = Random.MersenneTwister(1234) + +## + +max_level = 8 +level = M.TotalDegree() +maxl = 3; maxn = max_level; +elements = (:Si, :O) +basis = M.ace_learnable_Rnlrzz(; level=level, max_level=max_level, + maxl = maxl, maxn = maxn, elements = elements) + +ps, st = LuxCore.setup(rng, basis) + +r = 3.0 +Zi = basis._i2z[1] +Zj = basis._i2z[2] +Rnl, st1 = basis(r, Zi, Zj, ps, st) +Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) + +@info("Test derivatives of LearnableRnlrzzBasis") + +for ntest = 1:20 + r = 2.0 + rand() + Zi = rand(basis._i2z) + Zj = rand(basis._i2z) + U = randn(eltype(Rnl), length(Rnl)) + F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)[1]) + dF(t) = dot(U, M.evaluate_ed(basis, r + t, Zi, Zj, ps, st)[2]) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) +end +println() + +## + +@info("LearnableRnlrzz : Consistency of single and batched evaluation") + +for ntest = 1:20 + Nat = rand(8:16) + Rs, Zs, Z0 = M.rand_atenv(basis, Nat) + rs = norm.(Rs) + + Rnl = [ M.evaluate(basis, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] + Rnl_b, st = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) + print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) + + Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) + ∇Rnl = [ M.evaluate_ed(basis, r, Z0, z, ps, st)[2] + for (r, z) in zip(rs, Zs) ] + + print_tf(@test Rnl_b ≈ Rnl_b2) + print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) +end + +## + +basis_p = M.set_params(basis, ps) + + +@info("Testing SplineRnlrzzBasis consistency via splinify") + +for ntest = 1:30 + Nat = 1 + Rs, Zs, Zi = M.rand_atenv(basis, Nat) + r = norm(Rs[1]) + Zj = Zs[1] + + Rnl, _ = basis(r, Zi, Zj, ps, st) + + for (nnodes, tol) in [(30, 1e-3), (100, 1e-5), (1000, 1e-8)] + + basis_spl = M.splinify(basis_p; nnodes = nnodes) + ps_spl, st_spl = LuxCore.setup(rng, basis_spl) + Rnl_spl, _ = basis_spl(r, Zi, Zj, ps_spl, st_spl) + rel_err = (Rnl - Rnl_spl) ./ (1 .+ abs.(Rnl)) + # use 1-norm here to not stress about small outliers + print_tf(@test norm(rel_err, 1)/length(Rnl) < tol) + # @show norm(rel_err, 1) / length(Rnl) + end +end + +## + +@info("Test derivatives of SplineRnlrzzBasis") + +basis_p = M.set_params(basis, ps) +basis_spl = M.splinify(basis_p; nnodes = 100) + +for ntest = 1:20 + Rs, Zs, Zi = M.rand_atenv(basis_spl, 1) + r = norm(Rs[1]); Zj = Zs[1] + Rnl = basis_spl(r, Zi, Zj, ps, st)[1] + U = randn(eltype(Rnl), length(Rnl)) + F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)[1]) + dF(t) = dot(U, M.evaluate_ed(basis, r + t, Zi, Zj, ps, st)[2]) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) +end +println() + +## + +@info("SplineRnlrzz : Consistency of single and batched evaluation") + +basis_p = M.set_params(basis, ps) +basis_spl = M.splinify(basis_p; nnodes = 100) + +for ntest = 1:20 + Nat = rand(8:16) + Rs, Zs, Z0 = M.rand_atenv(basis_spl, Nat) + rs = norm.(Rs) + + Rnl = [ M.evaluate(basis_spl, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] + Rnl_b, st = M.evaluate_batched(basis_spl, rs, Z0, Zs, ps, st) + print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) + + Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis_spl, rs, Z0, Zs, ps, st) + ∇Rnl = [ M.evaluate_ed(basis_spl, r, Z0, z, ps, st)[2] + for (r, z) in zip(rs, Zs) ] + + print_tf(@test Rnl_b ≈ Rnl_b2) + print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) +end diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 34ce3f76..9477a60e 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -43,18 +43,40 @@ println() ## -# @btime ($basis)(r, Zi, Zj, $ps, $st) -# @btime M.evaluate_ed($basis, r, Zi, Zj, $ps, $st) +@info("LearnableRnlrzz : Consistency of single and batched evaluation") -rs = [r, r, r] -Zs = [Zj, Zj, Zj] -Rnl, Rnl_d, st1 = M.evaluate_ed_batched(basis, rs, Zi, Zs, ps, st) - -# more tests needed to check the correctness of the batched version -# can we implement some tests that check consistent with ACE1 in special cases? +for ntest = 1:20 + Nat = rand(8:16) + Rs, Zs, Z0 = M.rand_atenv(basis, Nat) + rs = norm.(Rs) + + Rnl = [ M.evaluate(basis, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] + Rnl_b, st = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) + print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) + + Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) + ∇Rnl = [ M.evaluate_ed(basis, r, Z0, z, ps, st)[2] + for (r, z) in zip(rs, Zs) ] + + print_tf(@test Rnl_b ≈ Rnl_b2) + print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) +end ## -basisp = M.set_params(basis, ps) -splb = M.splinify(basisp) +basis_p = M.set_params(basis, ps) + +basis_spl = M.splinify(basis_p; nnodes = 30) +ps_spl, st_spl = LuxCore.setup(rng, basis_spl) + +Rnl, _ = basis(r, Zi, Zj, ps, st) +Rnl_spl, _ = basis_spl(r, Zi, Zj, ps_spl, st_spl) + +norm(Rnl - Rnl_spl, Inf) + +Rnl, ∇Rnl, _ = M.evaluate_ed(basis, r, Zi, Zj, ps, st) +Rnl_spl, ∇Rnl_spl, _ = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) + +norm(Rnl - Rnl_spl, Inf) +norm(∇Rnl - ∇Rnl_spl, Inf) From eed5ffd3cdde8630a3868b93f65a2fdd3fb199fc Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 11 May 2024 23:40:44 -0700 Subject: [PATCH 019/112] draft calculator --- Project.toml | 9 ++++ src/models/ace.jl | 5 ++ src/models/calculators.jl | 99 ++++++++++++++++++++++++++++++++++ src/models/models.jl | 7 ++- test/models/test_calculator.jl | 35 ++++++++++++ 5 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 src/models/calculators.jl create mode 100644 test/models/test_calculator.jl diff --git a/Project.toml b/Project.toml index ddcc1675..245671fb 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,15 @@ ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" +AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc" ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478" +Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JuLIP = "945c410c-986d-556a-acb1-167a618e0462" @@ -18,6 +24,8 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +NeighbourLists = "2fcf5ba9-9ed4-57cf-b73f-ff513e316b9c" +ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -34,6 +42,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/models/ace.jl b/src/models/ace.jl index 1c5ed884..64d24667 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -40,6 +40,11 @@ end # ------------------------------------------------------------ # CONSTRUCTORS AND UTILITIES +# this is terrible : I'm assuming here that there is a unique +# output type, which is of course not the case. It is needed temporarily +# to make things work with AtomsCalculators and EmpiricalPotentials +fl_type(::ACEModel{NZ, TRAD, TY, TA, TAA, T}) where {NZ, TRAD, TY, TA, TAA, T} = T + const NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} function _make_Y_basis(Ytype, lmax) diff --git a/src/models/calculators.jl b/src/models/calculators.jl new file mode 100644 index 00000000..2d135331 --- /dev/null +++ b/src/models/calculators.jl @@ -0,0 +1,99 @@ + +import EmpiricalPotentials +import EmpiricalPotentials: SitePotential, + cutoff_radius, + eval_site, + eval_grad_site, + site_virial, + PairList, + get_neighbours, + atomic_number + +import AtomsCalculators +import AtomsCalculators: energy_forces_virial + +using Folds, ChunkSplitters, Unitful, NeighbourLists + +using ComponentArrays: ComponentArray + +using ObjectPools: release! + +struct ACEPotential{MOD} <: SitePotential + model::MOD +end + +# TODO: allow user to specify what units the model is working with + +energy_unit(::ACEPotential) = 1.0u"eV" +distance_unit(::ACEPotential) = 1.0u"Å" +force_unit(V) = energy_unit(V) / distance_unit(V) +Base.zero(V::ACEPotential) = zero(energy_unit(V)) + + +# --------------------------------------------------------------- +# EmpiricalPotentials / SitePotential based implementation +# +# this currently doesn't know how to handle ps and st +# it assumes implicitly without checking that the model is +# storing its own parameters. + +cutoff_radius(V::ACEPotential{<: ACEModel}) = + maximum(x.rcut for x in V.model.rbasis.rin0cuts) * distance_unit(V) + +eval_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) = + evaluate(V.model, Rs, Zs, z0) * energy_unit(V) + +eval_grad_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) = + evaluate_ed(V.model, Rs, Zs, z0) * force_unit(V) + + +# --------------------------------------------------------------- +# manual implementation allowing parameters +# but basically copied from the EmpiricalPotentials implementation + +import JuLIP +import AtomsBase + +AtomsBase.atomic_number(at::JuLIP.Atoms, iat::Integer) = at.Z[iat] + +function energy_forces_virial( + at, V::ACEPotential{<: ACEModel}, ps, st; + domain = 1:length(at), + executor = ThreadedEx(), + ntasks = Threads.nthreads(), + nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + kwargs... + ) + + T = fl_type(V.model) # this is ACE specific + init_e() = zero(T) * energy_unit(V) + init_f() = zeros(SVector{3, T}, length(at)) * force_unit(V) + init_v() = zero(SMatrix{3, 3, T}) * energy_unit(V) + + # TODO: each task needs its own state if that is where + # the temporary arrays will be stored? + # but if we use bumper then there is no issue + + + E_F_V = Folds.sum(collect(chunks(domain, ntasks)), + executor; + init = [init_e(), init_f(), init_v()], + ) do (sub_domain, _) + + energy = init_e() + forces = init_f() + virial = init_v() + + for i in sub_domain + Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) + v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, ps, st) + energy += v * energy_unit(V) + forces[Js] -= (dv * force_unit(V)) + forces[i] += sum(dv) * force_unit(V) + virial += JuLIP.Potentials.site_virial(dv, Rs) * energy_unit(V) + release!(Js); release!(Rs); release!(Zs) + end + [energy, forces, virial] + end + return (energy = E_F_V[1], forces = E_F_V[2], virial = E_F_V[3]) +end diff --git a/src/models/models.jl b/src/models/models.jl index 8f0b4ea5..3dc7883b 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -1,8 +1,6 @@ module Models -include("utils.jl") - include("elements.jl") include("radial_envelopes.jl") @@ -17,4 +15,9 @@ include("Rnl_splines.jl") include("ace_heuristics.jl") include("ace.jl") +include("calculators.jl") + +include("utils.jl") + + end \ No newline at end of file diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl new file mode 100644 index 00000000..3ca019dc --- /dev/null +++ b/test/models/test_calculator.jl @@ -0,0 +1,35 @@ + + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); + +using Test, ACEbase +using ACEbase.Testing: print_tf, println_slim + +using ACEpotentials +M = ACEpotentials.Models + +using Optimisers, ForwardDiff + +using Random, LuxCore, StaticArrays, LinearAlgebra +rng = Random.MersenneTwister(1234) + +## + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 15 +order = 3 + +model = M.ace_model(; elements = elements, order = order, Ytype = :solid, + level = level, max_level = max_level, maxl = 8, + init_WB = :glorot_normal) + +ps, st = LuxCore.setup(rng, model) + +calc = M.ACEPotential(model) + +## + +at = bulk(:Si, cubic=true) * 2 +evf = M.energy_forces_virial(at, calc, ps, st) \ No newline at end of file From d0f70d620c4e0f908ec53a37b0ec6cd4436ad1da Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 12 May 2024 17:06:31 -0700 Subject: [PATCH 020/112] pullback through EFV --- src/models/calculators.jl | 123 ++++++++++++++++++++++++++++++++- test/models/test_calculator.jl | 31 ++++++++- 2 files changed, 150 insertions(+), 4 deletions(-) diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 2d135331..21a03ff3 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -12,10 +12,13 @@ import EmpiricalPotentials: SitePotential, import AtomsCalculators import AtomsCalculators: energy_forces_virial -using Folds, ChunkSplitters, Unitful, NeighbourLists +using Folds, ChunkSplitters, Unitful, NeighbourLists, + Optimisers, LuxCore, ChainRulesCore using ComponentArrays: ComponentArray +import ChainRulesCore: rrule, NoTangent, ZeroTangent + using ObjectPools: release! struct ACEPotential{MOD} <: SitePotential @@ -53,6 +56,9 @@ eval_grad_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) = import JuLIP import AtomsBase +using Unitful: ustrip +_ustrip(x) = ustrip(x) +_ustrip(x::ZeroTangent) = x AtomsBase.atomic_number(at::JuLIP.Atoms, iat::Integer) = at.Z[iat] @@ -73,8 +79,7 @@ function energy_forces_virial( # TODO: each task needs its own state if that is where # the temporary arrays will be stored? # but if we use bumper then there is no issue - - + E_F_V = Folds.sum(collect(chunks(domain, ntasks)), executor; init = [init_e(), init_f(), init_v()], @@ -97,3 +102,115 @@ function energy_forces_virial( end return (energy = E_F_V[1], forces = E_F_V[2], virial = E_F_V[3]) end + + +# this implements the pullback of the energy_forces_virial function +# w.r.t. to the parameters only!! +# we should implement similar pullback helpers for forces and remove them +# from the function below to be re-used broadly. + +# function site_virial(dV::AbstractVector{SVector{3, T1}}, +# Rs::AbstractVector{SVector{3, T2}}) where {T1, T2} +# T = promote_type(T1, T2) +# return sum( dVj * rj' for (dVj, rj) in zip(dV, Rs), +# init = zero(SMatrix{3, 3, T}) ) +# end + +# function pullback_sitevirial_dV(Δ, Rs) +# # Δ : virial = ∑_j dVj' * Δ * rj +# # ∂_dVj (Δ : virial) = Δ * rj +# return [ Δ * rj for rj in Rs ] +# end + +function pullback_EFV(Δefv, + at, V::ACEPotential{<: ACEModel}, ps, st; + domain = 1:length(at), + executor = ThreadedEx(), + ntasks = Threads.nthreads(), + nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + kwargs... + ) + + T = fl_type(V.model) + ps_vec, _restruct = destructure(ps) + TP = promote_type(eltype(ps_vec), T) + + # We resolve the pullback through the summation-over-sites manually, e.g., + # E = ∑_i E_i + # ∂ (Δe * E) = ∑_i ∂( Δe * E_i ) + + # TODO : There is a lot of ustrip hacking which implicitly + # assumes that the loss is dimensionless and that the + # gradient w.r.t. parameters therefore must also be dimensionless + + g_vec = Folds.sum(collect(chunks(domain, ntasks)), + executor; + init = zeros(TP, length(ps_vec)), + ) do (sub_domain, _) + + g_loc = zeros(TP, length(ps_vec)) + for i in sub_domain + Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) + + Δei = _ustrip(Δefv.energy) + + # them adjoint for dV needs combination of the virial and forces pullback + Δdi = [ _ustrip.(Δefv.virial * rj) for rj in Rs ] + for α = 1:length(Js) + # F[Js[α]] -= dV[α], F[i] += dV[α] + # ∂_dvj { Δf[Js[α]] * F[Js[α]] } -> + Δdi[α] -= _ustrip.( Δefv.forces[Js[α]] ) + Δdi[α] += _ustrip.( Δefv.forces[i] ) + end + + # now we can apply the pullback through evaluate_ed + # (maybe this needs to be renamed, it sounds a bit cryptic) + if eltype(Δdi) == ZeroTangent + g_nt = grad_params(V.model, Rs, Zs, z0, ps, st)[2] + mult = Δei + else + g_nt = pullback_2_mixed(Δei, Δdi, V.model, Rs, Zs, z0, ps, st) + mult = one(TP) + end + + release!(Js); release!(Rs); release!(Zs) + + # convert it back to a vector so we can accumulate it in the sum. + # this is quite bad - in the call to pullback_2_mixed we just + # converted it from a vector to a named tuple. We need to look into + # using something like ComponentArrays.jl to avoid this. + g_loc += destructure(g_nt)[1] * mult + end + g_loc + end + + return _restruct(g_vec) +end + + +function rrule(::typeof(energy_forces_virial), + at, V::ACEPotential{<: ACEModel}, ps, st; + domain = 1:length(at), + executor = ThreadedEx(), + ntasks = Threads.nthreads(), + nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + kwargs... + ) + + # TODO : analyze this code flow carefully and see if we can + # re-use any of the computations done in the EFV evaluation + EFV = energy_forces_virial(at, V, ps, st; + domain = domain, + executor = executor, + ntasks = ntasks, + nlist = nlist, + kwargs...) + + return EFV, Δefv -> ( NoTangent(), NoTangent(), NoTangent(), + pullback_EFV(Δefv, at, V, ps, st; + domain = domain, + executor = executor, + ntasks = ntasks, + nlist = nlist, + kwargs...), NoTangent() ) +end diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 3ca019dc..15077733 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -32,4 +32,33 @@ calc = M.ACEPotential(model) ## at = bulk(:Si, cubic=true) * 2 -evf = M.energy_forces_virial(at, calc, ps, st) \ No newline at end of file +evf = M.energy_forces_virial(at, calc, ps, st) + + +## +# testing the AD through a loss function + + +at = rattle!(bulk(:Si, cubic=true), 0.1) + +using Unitful +using Unitful: ustrip + +wE = 1.0 / u"eV" +wV = 1.0 / u"eV" +wF = 0.33 / u"eV/Å" + +function loss(at, calc, ps, st) + efv = M.energy_forces_virial(at, calc, ps, st) + _norm_sq(f) = sum(abs2, f) + return ( wE^2 * efv.energy^2 / length(at) + + wV^2 * sum(abs2, efv.virial) / length(at) + + wF^2 * sum(_norm_sq, efv.forces) ) +end + +## + +using Zygote +Zygote.refresh() + +Zygote.gradient(ps -> loss(at, calc, ps, st), ps)[1] From e8f8c675997531a2933f36b220f63bfd0dca0620 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 12 May 2024 23:07:41 -0700 Subject: [PATCH 021/112] fix force assembly bug --- src/models/calculators.jl | 44 ++++++++++++++++++++++--- test/models/test_calculator.jl | 59 +++++++++++++++++++++++++++++++--- 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 21a03ff3..5c0a8b80 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -23,8 +23,12 @@ using ObjectPools: release! struct ACEPotential{MOD} <: SitePotential model::MOD + ps + st end +ACEPotential(model) = ACEPotential(model, nothing, nothing) + # TODO: allow user to specify what units the model is working with energy_unit(::ACEPotential) = 1.0u"eV" @@ -44,10 +48,12 @@ cutoff_radius(V::ACEPotential{<: ACEModel}) = maximum(x.rcut for x in V.model.rbasis.rin0cuts) * distance_unit(V) eval_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) = - evaluate(V.model, Rs, Zs, z0) * energy_unit(V) + evaluate(V.model, Rs, Zs, z0, V.ps, V.st)[1] -eval_grad_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) = - evaluate_ed(V.model, Rs, Zs, z0) * force_unit(V) +function eval_grad_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) + v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, V.ps, V.st) + return v, dv +end # --------------------------------------------------------------- @@ -62,6 +68,32 @@ _ustrip(x::ZeroTangent) = x AtomsBase.atomic_number(at::JuLIP.Atoms, iat::Integer) = at.Z[iat] +function energy_forces_virial_serial( + at, V::ACEPotential{<: ACEModel}, ps, st; + domain = 1:length(at), + nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + ) + + T = fl_type(V.model) # this is ACE specific + energy = zero(T) + forces = zeros(SVector{3, T}, length(at)) + virial = zero(SMatrix{3, 3, T}) + + for i in domain + Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) + v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, ps, st) + energy += v + for α = 1:length(Js) + forces[Js[α]] -= dv[α] + forces[i] += dv[α] + end + virial += JuLIP.Potentials.site_virial(dv, Rs) + release!(Js); release!(Rs); release!(Zs) + end + return (energy = energy, forces = forces, virial = virial) +end + + function energy_forces_virial( at, V::ACEPotential{<: ACEModel}, ps, st; domain = 1:length(at), @@ -93,8 +125,10 @@ function energy_forces_virial( Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, ps, st) energy += v * energy_unit(V) - forces[Js] -= (dv * force_unit(V)) - forces[i] += sum(dv) * force_unit(V) + for α = 1:length(Js) + forces[Js[α]] -= dv[α] * force_unit(V) + forces[i] += dv[α] * force_unit(V) + end virial += JuLIP.Potentials.site_virial(dv, Rs) * energy_unit(V) release!(Js); release!(Rs); release!(Zs) end diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 15077733..97d5ba64 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -9,7 +9,8 @@ using ACEbase.Testing: print_tf, println_slim using ACEpotentials M = ACEpotentials.Models -using Optimisers, ForwardDiff +using Optimisers, ForwardDiff, Unitful +import AtomsCalculators using Random, LuxCore, StaticArrays, LinearAlgebra rng = Random.MersenneTwister(1234) @@ -27,18 +28,66 @@ model = M.ace_model(; elements = elements, order = order, Ytype = :solid, ps, st = LuxCore.setup(rng, model) -calc = M.ACEPotential(model) +calc = M.ACEPotential(model, ps, st) ## -at = bulk(:Si, cubic=true) * 2 -evf = M.energy_forces_virial(at, calc, ps, st) +@info("Testing correctness of potential energy") +for ntest = 1:20 + at = rattle!(bulk(:Si, cubic=true) * 2, 0.1) + at_flex = AtomsBase.FlexibleSystem(at) + nlist = JuLIP.neighbourlist(at, ustrip(M.cutoff_radius(calc))) + E = 0.0 + for i = 1:length(at) + Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) + z0 = at.Z[i] + E += M.evaluate(calc.model, Rs, Zs, z0, ps, st)[1] + end + efv = M.energy_forces_virial(at, calc, ps, st) + E2 = AtomsCalculators.potential_energy(at_flex, calc) + print_tf(@test abs(E - ustrip(efv.energy))/abs(E) < 1e-12) + print_tf(@test abs(E - ustrip(E2)) / abs(E) < 1e-12) +end + +## + +@info("Testing correctness of forces ") +@info(" .... TODO TEST VIRIALS ..... ") + +at = rattle!(bulk(:Si, cubic=true), 0.1) +at_flex = AtomsBase.FlexibleSystem(at) + +@info(" consistency local vs EmpiricalPotentials implementation") +@info("this currently fails due to a bug in EmpiricalPotentials") +# efv1 = M.energy_forces_virial(at, calc, ps, st) +# efv2 = AtomsCalculators.energy_forces_virial(at_flex, calc) +# efv3 = M.energy_forces_virial_serial(at, calc, ps, st) +# print_tf(@test efv1.energy ≈ efv2.energy) +# print_tf(@test all(efv1.forces .≈ efv2.force)) +# print_tf(@test efv1.virial ≈ efv1.virial) +# print_tf(@test efv1.energy ≈ efv3.energy) +# print_tf(@test all(efv1.forces .≈ efv3.forces)) +# print_tf(@test efv1.virial ≈ efv3.virial) + +## + +@info("test consistency of forces with energy") +@info(" TODO: write virial test!") +for ntest = 1:10 + at = rattle!(bulk(:Si, cubic=true), 0.1) + Us = randn(SVector{3, Float64}, length(at)) / length(at) + dF0 = - dot(Us, M.energy_forces_virial_serial(at, calc, ps, st).forces) + X0 = deepcopy(at.X) + F(t) = M.energy_forces_virial_serial(JuLIP.set_positions!(at, X0 + t * Us), + calc, ps, st).energy + print_tf( @test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false ) ) +end +println() ## # testing the AD through a loss function - at = rattle!(bulk(:Si, cubic=true), 0.1) using Unitful From c818ea3629255e8a6c614389a25691bb2e7711f2 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 12 May 2024 23:36:25 -0700 Subject: [PATCH 022/112] simple test to differentiate through loss --- src/models/calculators.jl | 2 +- test/models/test_calculator.jl | 26 +++++++++++++++++++------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 5c0a8b80..33ed4bbe 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -189,7 +189,7 @@ function pullback_EFV(Δefv, Δei = _ustrip(Δefv.energy) # them adjoint for dV needs combination of the virial and forces pullback - Δdi = [ _ustrip.(Δefv.virial * rj) for rj in Rs ] + Δdi = [ - _ustrip.(Δefv.virial * rj) for rj in Rs ] for α = 1:length(Js) # F[Js[α]] -= dV[α], F[i] += dV[α] # ∂_dvj { Δf[Js[α]] * F[Js[α]] } -> diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 97d5ba64..9c348890 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -75,6 +75,7 @@ at_flex = AtomsBase.FlexibleSystem(at) @info(" TODO: write virial test!") for ntest = 1:10 at = rattle!(bulk(:Si, cubic=true), 0.1) + at.Z[[3,6,8]] .= 8 Us = randn(SVector{3, Float64}, length(at)) / length(at) dF0 = - dot(Us, M.energy_forces_virial_serial(at, calc, ps, st).forces) X0 = deepcopy(at.X) @@ -88,11 +89,16 @@ println() ## # testing the AD through a loss function -at = rattle!(bulk(:Si, cubic=true), 0.1) +using Zygote using Unitful using Unitful: ustrip +# random structure +at = rattle!(bulk(:Si, cubic=true), 0.1) +at.Z[[3,6,8]] .= 8 + +# need to make sure that the weights in the loss remove the units! wE = 1.0 / u"eV" wV = 1.0 / u"eV" wF = 0.33 / u"eV/Å" @@ -100,14 +106,20 @@ wF = 0.33 / u"eV/Å" function loss(at, calc, ps, st) efv = M.energy_forces_virial(at, calc, ps, st) _norm_sq(f) = sum(abs2, f) - return ( wE^2 * efv.energy^2 / length(at) - + wV^2 * sum(abs2, efv.virial) / length(at) + return ( wE^2 * efv.energy^2 / length(at) + + wV^2 * sum(abs2, efv.virial) / length(at) + wF^2 * sum(_norm_sq, efv.forces) ) end -## -using Zygote -Zygote.refresh() +g = Zygote.gradient(ps -> loss(at, calc, ps, st), ps)[1] + +p_vec, _restruct = destructure(ps) +g_vec = destructure(g)[1] +u = randn(length(p_vec)) / length(p_vec) +dot(g_vec, u) +_ps(t) = _restruct(p_vec + t * u) +F(t) = loss(at, calc, _ps(t), st) +dF0 = dot(g_vec, u) -Zygote.gradient(ps -> loss(at, calc, ps, st), ps)[1] +ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=true) \ No newline at end of file From de58353ae2135232bdd053413b0917caeedee499 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 13 May 2024 07:43:38 -0700 Subject: [PATCH 023/112] improved loss diff test --- test/models/test_calculator.jl | 55 +++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 9c348890..3e3a65ab 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -88,38 +88,45 @@ println() ## # testing the AD through a loss function - +@info("Testing Zygote-AD through a loss function") using Zygote using Unitful using Unitful: ustrip -# random structure -at = rattle!(bulk(:Si, cubic=true), 0.1) -at.Z[[3,6,8]] .= 8 - # need to make sure that the weights in the loss remove the units! -wE = 1.0 / u"eV" -wV = 1.0 / u"eV" -wF = 0.33 / u"eV/Å" +for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), + (0.0 / u"eV", 1.0 / u"eV", 0.0 / u"eV/Å"), + (0.0 / u"eV", 0.0 / u"eV", 1.0 / u"eV/Å"), + (1.0 / u"eV", 0.1 / u"eV", 0.1 / u"eV/Å") ] + # random structure + at = rattle!(bulk(:Si, cubic=true), 0.1) + at.Z[[3,6,8]] .= 8 -function loss(at, calc, ps, st) - efv = M.energy_forces_virial(at, calc, ps, st) - _norm_sq(f) = sum(abs2, f) - return ( wE^2 * efv.energy^2 / length(at) - + wV^2 * sum(abs2, efv.virial) / length(at) - + wF^2 * sum(_norm_sq, efv.forces) ) -end + # wE = 1.0 / u"eV" + # wV = 1.0 / u"eV" + # wF = 0.33 / u"eV/Å" + + function loss(at, calc, ps, st) + efv = M.energy_forces_virial(at, calc, ps, st) + _norm_sq(f) = sum(abs2, f) + return ( wE^2 * efv.energy^2 / length(at) + + wV^2 * sum(abs2, efv.virial) / length(at) + + wF^2 * sum(_norm_sq, efv.forces) ) + end -g = Zygote.gradient(ps -> loss(at, calc, ps, st), ps)[1] + g = Zygote.gradient(ps -> loss(at, calc, ps, st), ps)[1] -p_vec, _restruct = destructure(ps) -g_vec = destructure(g)[1] -u = randn(length(p_vec)) / length(p_vec) -dot(g_vec, u) -_ps(t) = _restruct(p_vec + t * u) -F(t) = loss(at, calc, _ps(t), st) -dF0 = dot(g_vec, u) + p_vec, _restruct = destructure(ps) + g_vec = destructure(g)[1] + u = randn(length(p_vec)) / length(p_vec) + dot(g_vec, u) + _ps(t) = _restruct(p_vec + t * u) + F(t) = loss(at, calc, _ps(t), st) + dF0 = dot(g_vec, u) -ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=true) \ No newline at end of file + @info("(wE, wV, wF) = ($wE, $wV, $wF)") + FDTEST = ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=true) + println(@test FDTEST) +end \ No newline at end of file From ded8e406a4d8e01743b4fe20c3c45ef5179bb47b Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 14 May 2024 15:27:25 -0700 Subject: [PATCH 024/112] remove JuLIP dependence of new calculators --- Project.toml | 2 +- src/models/calculators.jl | 42 ++++++++++++---------------------- test/models/test_calculator.jl | 35 ++++++++++++++-------------- 3 files changed, 33 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index 245671fb..0c4e8c55 100644 --- a/Project.toml +++ b/Project.toml @@ -9,10 +9,10 @@ ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" +AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" EquivariantModels = "73ee3e68-46fd-466f-9c56-451dc0291ebc" ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478" diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 33ed4bbe..b69bcf6d 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -15,8 +15,6 @@ import AtomsCalculators: energy_forces_virial using Folds, ChunkSplitters, Unitful, NeighbourLists, Optimisers, LuxCore, ChainRulesCore -using ComponentArrays: ComponentArray - import ChainRulesCore: rrule, NoTangent, ZeroTangent using ObjectPools: release! @@ -60,18 +58,24 @@ end # manual implementation allowing parameters # but basically copied from the EmpiricalPotentials implementation -import JuLIP import AtomsBase +using EmpiricalPotentials: PairList, get_neighbours, site_virial + using Unitful: ustrip _ustrip(x) = ustrip(x) _ustrip(x::ZeroTangent) = x -AtomsBase.atomic_number(at::JuLIP.Atoms, iat::Integer) = at.Z[iat] +_site_virial(dV::AbstractVector{SVector{3, T1}}, + Rs::AbstractVector{SVector{3, T2}}) where {T1, T2} = + ( + - sum( dVi * Ri' for (dVi, Ri) in zip(dV, Rs); + init = zero(SMatrix{3, 3, promote_type(T1, T2)}) ) + ) function energy_forces_virial_serial( at, V::ACEPotential{<: ACEModel}, ps, st; domain = 1:length(at), - nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + nlist = PairList(at, cutoff_radius(V)), ) T = fl_type(V.model) # this is ACE specific @@ -87,7 +91,7 @@ function energy_forces_virial_serial( forces[Js[α]] -= dv[α] forces[i] += dv[α] end - virial += JuLIP.Potentials.site_virial(dv, Rs) + virial += _site_virial(dv, Rs) release!(Js); release!(Rs); release!(Zs) end return (energy = energy, forces = forces, virial = virial) @@ -99,7 +103,7 @@ function energy_forces_virial( domain = 1:length(at), executor = ThreadedEx(), ntasks = Threads.nthreads(), - nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + nlist = PairList(at, cutoff_radius(V)), kwargs... ) @@ -129,7 +133,7 @@ function energy_forces_virial( forces[Js[α]] -= dv[α] * force_unit(V) forces[i] += dv[α] * force_unit(V) end - virial += JuLIP.Potentials.site_virial(dv, Rs) * energy_unit(V) + virial += _site_virial(dv, Rs) * energy_unit(V) release!(Js); release!(Rs); release!(Zs) end [energy, forces, virial] @@ -138,30 +142,12 @@ function energy_forces_virial( end -# this implements the pullback of the energy_forces_virial function -# w.r.t. to the parameters only!! -# we should implement similar pullback helpers for forces and remove them -# from the function below to be re-used broadly. - -# function site_virial(dV::AbstractVector{SVector{3, T1}}, -# Rs::AbstractVector{SVector{3, T2}}) where {T1, T2} -# T = promote_type(T1, T2) -# return sum( dVj * rj' for (dVj, rj) in zip(dV, Rs), -# init = zero(SMatrix{3, 3, T}) ) -# end - -# function pullback_sitevirial_dV(Δ, Rs) -# # Δ : virial = ∑_j dVj' * Δ * rj -# # ∂_dVj (Δ : virial) = Δ * rj -# return [ Δ * rj for rj in Rs ] -# end - function pullback_EFV(Δefv, at, V::ACEPotential{<: ACEModel}, ps, st; domain = 1:length(at), executor = ThreadedEx(), ntasks = Threads.nthreads(), - nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + nlist = PairList(at, cutoff_radius(V)), kwargs... ) @@ -227,7 +213,7 @@ function rrule(::typeof(energy_forces_virial), domain = 1:length(at), executor = ThreadedEx(), ntasks = Threads.nthreads(), - nlist = JuLIP.neighbourlist(at, cutoff_radius(V)/distance_unit(V)), + nlist = PairList(at, cutoff_radius(V)), kwargs... ) diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 3e3a65ab..ab8424b3 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -12,6 +12,11 @@ M = ACEpotentials.Models using Optimisers, ForwardDiff, Unitful import AtomsCalculators +using AtomsBuilder, EmpiricalPotentials +using AtomsBuilder: bulk, rattle! +using EmpiricalPotentials: get_neighbours + + using Random, LuxCore, StaticArrays, LinearAlgebra rng = Random.MersenneTwister(1234) @@ -32,19 +37,18 @@ calc = M.ACEPotential(model, ps, st) ## + @info("Testing correctness of potential energy") for ntest = 1:20 at = rattle!(bulk(:Si, cubic=true) * 2, 0.1) - at_flex = AtomsBase.FlexibleSystem(at) - nlist = JuLIP.neighbourlist(at, ustrip(M.cutoff_radius(calc))) + nlist = PairList(at, M.cutoff_radius(calc)) E = 0.0 for i = 1:length(at) - Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) - z0 = at.Z[i] + Js, Rs, Zs, z0 = get_neighbours(at, calc, nlist, i) E += M.evaluate(calc.model, Rs, Zs, z0, ps, st)[1] end efv = M.energy_forces_virial(at, calc, ps, st) - E2 = AtomsCalculators.potential_energy(at_flex, calc) + E2 = AtomsCalculators.potential_energy(at, calc) print_tf(@test abs(E - ustrip(efv.energy))/abs(E) < 1e-12) print_tf(@test abs(E - ustrip(E2)) / abs(E) < 1e-12) end @@ -55,7 +59,6 @@ end @info(" .... TODO TEST VIRIALS ..... ") at = rattle!(bulk(:Si, cubic=true), 0.1) -at_flex = AtomsBase.FlexibleSystem(at) @info(" consistency local vs EmpiricalPotentials implementation") @info("this currently fails due to a bug in EmpiricalPotentials") @@ -75,17 +78,19 @@ at_flex = AtomsBase.FlexibleSystem(at) @info(" TODO: write virial test!") for ntest = 1:10 at = rattle!(bulk(:Si, cubic=true), 0.1) - at.Z[[3,6,8]] .= 8 + Z = AtomsBuilder._get_atomic_numbers(at) + Z[[3,6,8]] .= 8 + at = AtomsBuilder._set_atomic_numbers(at, Z) Us = randn(SVector{3, Float64}, length(at)) / length(at) dF0 = - dot(Us, M.energy_forces_virial_serial(at, calc, ps, st).forces) - X0 = deepcopy(at.X) - F(t) = M.energy_forces_virial_serial(JuLIP.set_positions!(at, X0 + t * Us), - calc, ps, st).energy + X0 = AtomsBuilder._get_positions(at) + F(t) = M.energy_forces_virial_serial( + AtomsBuilder._set_positions(at, X0 + (t * u"Å") * Us), + calc, ps, st).energy print_tf( @test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false ) ) end println() - ## # testing the AD through a loss function @info("Testing Zygote-AD through a loss function") @@ -101,11 +106,8 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), (1.0 / u"eV", 0.1 / u"eV", 0.1 / u"eV/Å") ] # random structure at = rattle!(bulk(:Si, cubic=true), 0.1) - at.Z[[3,6,8]] .= 8 - - # wE = 1.0 / u"eV" - # wV = 1.0 / u"eV" - # wF = 0.33 / u"eV/Å" + Z = AtomsBuilder._get_atomic_numbers(at) + Z[[3,6,8]] .= 8 function loss(at, calc, ps, st) efv = M.energy_forces_virial(at, calc, ps, st) @@ -115,7 +117,6 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), + wF^2 * sum(_norm_sq, efv.forces) ) end - g = Zygote.gradient(ps -> loss(at, calc, ps, st), ps)[1] p_vec, _restruct = destructure(ps) From 3ed974acc04372e3fb66e601abc69595f06674ed Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 14 May 2024 16:12:46 -0700 Subject: [PATCH 025/112] switch params to tensors --- src/models/Rnl_basis.jl | 8 ++++---- src/models/Rnl_learnable.jl | 34 +++++++++++++++++++++------------- src/models/ace.jl | 36 ++++++++++++++++++++++++------------ test/models/test_models.jl | 0 4 files changed, 49 insertions(+), 29 deletions(-) delete mode 100644 test/models/test_models.jl diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 2a2ad1b2..3366438a 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -33,7 +33,7 @@ struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractExplicitLayer transforms::SMatrix{NZ, NZ, TT} envelopes::SMatrix{NZ, NZ, TENV} # -------------- - weights::SMatrix{NZ, NZ, TW} # learnable weights, `nothing` when using Lux + weights::Array{TW, 4} # learnable weights, `nothing` when using Lux rin0cuts::SMatrix{NZ, NZ, NT_RIN0CUTS{T}} # matrix of (rin, rout, rcut) spec::Vector{NT_NL_SPEC} # -------------- @@ -47,7 +47,7 @@ function set_params(basis::LearnableRnlrzzBasis, ps) basis.transforms, basis.envelopes, # --------------- - _make_smatrix(ps.Wnlq, _get_nz(basis)), + ps.Wnlq, basis.rin0cuts, basis.spec, # --------------- @@ -95,7 +95,7 @@ function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) NZ = _get_nz(basis) T = _get_T(basis) - LEN = size(basis.weights[1, 1], 1) + LEN = size(basis.weights, 1) _splines = Matrix{SPL_OF_SVEC{LEN, T}}(undef, (NZ, NZ)) x_nodes = range(-1.0, 1.0, length = nnodes) polys = basis.polys @@ -105,7 +105,7 @@ function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) rin, rcut = rin0cut.rin, rin0cut.rcut Tij = basis.transforms[iz0, iz1] - Wnlq_ij = basis.weights[iz0, iz1] + Wnlq_ij = @view basis.weights[:, :, iz0, iz1] Rnl = [ SVector{LEN}( Wnlq_ij * Polynomials4ML.evaluate(polys, x) ) for x in x_nodes ] diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 8229785f..d8702220 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -10,13 +10,16 @@ function LearnableRnlrzzBasis( spec::AbstractVector{NT_NL_SPEC}; weights=nothing, meta=Dict{String, Any}()) - NZ = length(zlist) + NZ = length(zlist) + if isnothing(weights) + weights = fill(nothing, (1,1,1,1)) + end LearnableRnlrzzBasis(_convert_zlist(zlist), polys, _make_smatrix(transforms, NZ), _make_smatrix(envelopes, NZ), # -------------- - _make_smatrix(weights, NZ), + weights, _make_smatrix(rin0cuts, NZ), collect(spec), meta) @@ -30,20 +33,26 @@ function initialparameters(rng::AbstractRNG, len_nl = length(basis) len_q = length(basis.polys) - function _W() - W = randn(rng, len_nl, len_q) - W = W ./ sqrt.(sum(W.^2, dims = 2)) + Wnlq = zeros(len_nl, len_q, NZ, NZ) + for i = 1:NZ, j = 1:NZ + Wnlq[:, :, i, j] .= glorot_normal(rng, Float64, len_nl, len_q) end - return (Wnlq = [ _W() for i = 1:NZ, j = 1:NZ ], ) + return (Wnlq = Wnlq, ) end function initialstates(rng::AbstractRNG, basis::LearnableRnlrzzBasis) return NamedTuple() end - + +function parameterlength(basis::LearnableRnlrzzBasis) + NZ = _get_nz(basis) + len_nl = length(basis) + len_q = length(basis.polys) + return len_nl * len_q * NZ * NZ +end # ------------------------------------------------------------ @@ -57,7 +66,7 @@ import Polynomials4ML function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) iz = _z2i(basis, Zi) jz = _z2i(basis, Zj) - Wij = ps.Wnlq[iz, jz] + Wij = @view ps.Wnlq[:, :, iz, jz] trans_ij = basis.transforms[iz, jz] x = trans_ij(r) P = Polynomials4ML.evaluate(basis.polys, x) @@ -70,7 +79,7 @@ end function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) iz = _z2i(basis, Zi) jz = _z2i(basis, Zj) - Wij = ps.Wnlq[iz, jz] + Wij = @view ps.Wnlq[:, :, iz, jz] trans_ij = basis.transforms[iz, jz] x = trans_ij(r) P = Polynomials4ML.evaluate(basis.polys, x) @@ -98,7 +107,7 @@ function evaluate_batched(basis::LearnableRnlrzzBasis, env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, x) P = Polynomials4ML.evaluate(basis.polys, x) .* e - Rnl[j, :] = ps.Wnlq[iz, jz] * P + Rnl[j, :] = (@view ps.Wnlq[:, :, iz, jz]) * P end return Rnl, st @@ -169,8 +178,7 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, # output storage for the gradients T_∂Wnlq = promote_type(eltype(Δ), eltype(rs)) NZ = _get_nz(basis) - ∂Wnlq = [ zeros(T_∂Wnlq, size(ps.Wnlq[i,j])) - for i = 1:NZ, j = 1:NZ ] + ∂Wnlq = zeros(T_∂Wnlq, size(ps.Wnlq)) # then evaluate the rest in-place for j = 1:length(rs) @@ -186,7 +194,7 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, # TODO: ... and obviously this part here needs to be moved # to a SIMD loop. - ∂Wnlq[iz, jz][:, :] .+= Δ[j, :] * P' + ∂Wnlq[:, :, iz, jz] .+= Δ[j, :] * P' end return (Wnql = ∂Wnlq,) diff --git a/src/models/ace.jl b/src/models/ace.jl index 64d24667..3d449598 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -1,8 +1,10 @@ -using LuxCore: AbstractExplicitLayer, +import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, - initialstates + initialstates, + parameterlength + using Lux: glorot_normal @@ -31,8 +33,8 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer{(:rb # -------------- # we can add a nonlinear embedding here # -------------- - bparams::NTuple{NZ, Vector{T}} - aaparams::NTuple{NZ, Vector{T}} + bparams::Matrix{T} # : x NZ matrix of B parameters + # aaparams::NTuple{NZ, Vector{T}} # -------------- meta::Dict{String, Any} end @@ -126,8 +128,8 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, NZ = _get_nz(rbasis) n_B_params, n_AA_params = size(AA2BB_map) return ACEModel(rbasis._i2z, rbasis, ybasis, a_basis, aa_basis, AA2BB_map, - ntuple(_ -> zeros(n_B_params), NZ), - ntuple(_ -> zeros(n_AA_params), NZ), + zeros(n_B_params, NZ), + # ntuple(_ -> zeros(n_AA_params), NZ), Dict{String, Any}() ) end @@ -162,7 +164,12 @@ function initialparameters(rng::AbstractRNG, error("unknown `init_WB` = $(model.meta["init_WB"])") end - return (WB = [ winit(Float64, n_B_params) for _=1:NZ ], + WB = zeros(n_B_params, NZ) + for iz = 1:NZ + WB[:, iz] .= winit(rng, Float64, n_B_params) + end + + return (WB = WB, rbasis = initialparameters(rng, model.rbasis), ) end @@ -173,6 +180,11 @@ end (l::ACEModel)(args...) = evaluate(l, args...) +function parameterlength(model::ACEModel) + NZ = _get_nz(model) + n_B_params, n_AA_params = size(model.A2Bmap) + return NZ * n_B_params + parameterlength(model.rbasis) +end # ------------------------------------------------------------ # Model Evaluation @@ -217,7 +229,7 @@ function evaluate(model::ACEModel, # contract with params i_z0 = _z2i(model.rbasis, Z0) - val = dot(B, ps.WB[i_z0]) + val = dot(B, (@view ps.WB[:, i_z0])) return val, st end @@ -264,12 +276,12 @@ function evaluate_ed(model::ACEModel, # contract with params # (here we can insert another nonlinearity instead of the simple dot) i_z0 = _z2i(model.rbasis, Z0) - Ei = dot(B, ps.WB[i_z0]) + Ei = dot(B, (@view ps.WB[:, i_z0])) # ---------- BACKWARD PASS ------------ # ∂Ei / ∂B = WB[i_z0] - ∂B = ps.WB[i_z0] + ∂B = @view ps.WB[:, i_z0] # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap ∂AA = model.A2Bmap' * ∂B # TODO: make this in-place @@ -340,14 +352,14 @@ function grad_params(model::ACEModel, # contract with params # (here we can insert another nonlinearity instead of the simple dot) i_z0 = _z2i(model.rbasis, Z0) - Ei = dot(B, ps.WB[i_z0]) + Ei = dot(B, (@view ps.WB[:, i_z0])) # ---------- BACKWARD PASS ------------ # we need ∂WB = ∂Ei/∂WB -> this goes into the gradient # but we also need ∂B = ∂Ei / ∂B = WB[i_z0] to backpropagate ∂WB_i = B - ∂B = ps.WB[i_z0] + ∂B = @view ps.WB[:, i_z0] # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap ∂AA = model.A2Bmap' * ∂B # TODO: make this in-place diff --git a/test/models/test_models.jl b/test/models/test_models.jl deleted file mode 100644 index e69de29b..00000000 From d23e3d21c137e4774c2c0decb04688d3e66beaca Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 14 May 2024 21:46:16 -0700 Subject: [PATCH 026/112] cleanup of new tests --- Project.toml | 1 - src/models/Rnl_basis.jl | 3 ++- src/models/Rnl_learnable.jl | 12 ++++----- src/models/Rnl_splines.jl | 2 +- src/models/ace.jl | 15 ++++++----- src/models/radial_envelopes.jl | 15 ++++++----- test/models/test_Rnl.jl | 19 +++++++++++--- test/models/test_ace.jl | 37 +++++++++++++++++++-------- test/models/test_calculator.jl | 9 ++++++- test/models/test_learnable_Rnl.jl | 4 ++- test/models/test_models.jl | 8 ++++++ test/models/test_radial_transforms.jl | 4 +-- test/runtests.jl | 2 ++ 13 files changed, 92 insertions(+), 39 deletions(-) create mode 100644 test/models/test_models.jl diff --git a/Project.toml b/Project.toml index 0c4e8c55..c6a14d6f 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RepLieGroups = "f07d36f2-91c4-427a-b67b-965fe5ebe1d2" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" -RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 3366438a..c90b4e0a 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -130,4 +130,5 @@ function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) # easily reconstructed. return spl_basis -end \ No newline at end of file +end + diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index d8702220..e1ac97ff 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -1,4 +1,4 @@ - +import LuxCore # ------------------------------------------------------------ @@ -47,7 +47,7 @@ function initialstates(rng::AbstractRNG, end -function parameterlength(basis::LearnableRnlrzzBasis) +function LuxCore.parameterlength(basis::LearnableRnlrzzBasis) NZ = _get_nz(basis) len_nl = length(basis) len_q = length(basis.polys) @@ -71,7 +71,7 @@ function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) x = trans_ij(r) P = Polynomials4ML.evaluate(basis.polys, x) env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) + e = evaluate(env_ij, r, x) Rnl[:] .= Wij * (P .* e) return Rnl, st end @@ -84,7 +84,7 @@ function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) x = trans_ij(r) P = Polynomials4ML.evaluate(basis.polys, x) env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) + e = evaluate(env_ij, r, x) return Wij * (P .* e), st end @@ -105,7 +105,7 @@ function evaluate_batched(basis::LearnableRnlrzzBasis, trans_ij = basis.transforms[iz, jz] x = trans_ij(rs[j]) env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) + e = evaluate(env_ij, rs[j], x) P = Polynomials4ML.evaluate(basis.polys, x) .* e Rnl[j, :] = (@view ps.Wnlq[:, :, iz, jz]) * P end @@ -187,7 +187,7 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, trans_ij = basis.transforms[iz, jz] x = trans_ij(rs[j]) env_ij = basis.envelopes[iz, jz] - e = evaluate(env_ij, x) + e = evaluate(env_ij, rs[j], x) P = Polynomials4ML.evaluate(basis.polys, x) .* e # TODO: the P shouuld be stored inside a closure in the # forward pass and then resused. diff --git a/src/models/Rnl_splines.jl b/src/models/Rnl_splines.jl index 12592897..a9dc9970 100644 --- a/src/models/Rnl_splines.jl +++ b/src/models/Rnl_splines.jl @@ -39,7 +39,7 @@ function evaluate(basis::SplineRnlrzzBasis, r::Real, Zi, Zj, ps, st) spl_ij = basis.splines[iz, jz] x_ij = T_ij(r) - e_ij = evaluate(env_ij, x_ij) + e_ij = evaluate(env_ij, r, x_ij) return spl_ij(x_ij) * e_ij, st end diff --git a/src/models/ace.jl b/src/models/ace.jl index 3d449598..7070a6af 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -2,9 +2,7 @@ import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, - initialstates, - parameterlength - + initialstates using Lux: glorot_normal @@ -25,6 +23,7 @@ import Polynomials4ML struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer{(:rbasis,)} _i2z::NTuple{NZ, Int} + # -------------- rbasis::TRAD ybasis::TY abasis::TA @@ -34,7 +33,11 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer{(:rb # we can add a nonlinear embedding here # -------------- bparams::Matrix{T} # : x NZ matrix of B parameters - # aaparams::NTuple{NZ, Vector{T}} + # aaparams::NTuple{NZ, Vector{T}} (not used right now) + # -------------- + # pair potential + # pairbasis::TPAIR + # pairparams::Matrix{T} # -------------- meta::Dict{String, Any} end @@ -180,10 +183,10 @@ end (l::ACEModel)(args...) = evaluate(l, args...) -function parameterlength(model::ACEModel) +function LuxCore.parameterlength(model::ACEModel) NZ = _get_nz(model) n_B_params, n_AA_params = size(model.A2Bmap) - return NZ * n_B_params + parameterlength(model.rbasis) + return NZ * n_B_params end # ------------------------------------------------------------ diff --git a/src/models/radial_envelopes.jl b/src/models/radial_envelopes.jl index 1a3435a9..5f7fa61a 100644 --- a/src/models/radial_envelopes.jl +++ b/src/models/radial_envelopes.jl @@ -12,7 +12,7 @@ end PolyEnvelope1sR(rcut, p) = PolyEnvelope1sR(rcut, p, Dict{String, Any}()) -function evaluate(env::PolyEnvelope1sR, r::T) where T +function evaluate(env::PolyEnvelope1sR, r::T, x::T) where T if r >= env.rcut return zero(T) end @@ -21,8 +21,10 @@ function evaluate(env::PolyEnvelope1sR, r::T) where T return ( (r/env.rcut)^(-p) - 1.0) * (1 - r / env.rcut) end -evaluate_d(env::PolyEnvelope1sR, r) = - ForwardDiff.derivative(x -> evaluate(env, x), r) +evaluate_d(env::PolyEnvelope1sR, r::T, x::T) where {T} = + (ForwardDiff.derivative(x -> evaluate(env, x), r), + zero(T),) + # ---------------------------- @@ -50,7 +52,7 @@ function PolyEnvelope2sX(x1, x2, p1, p2) end -function evaluate(env::PolyEnvelope2sX, x::T) where T +function evaluate(env::PolyEnvelope2sX, r::T, x::T) where T x1, x2 = env.x1, env.x2 p1, p2 = env.p1, env.p2 s = env.s @@ -63,7 +65,8 @@ function evaluate(env::PolyEnvelope2sX, x::T) where T end -evaluate_d(env::PolyEnvelope2sX, x::T) where T = - ForwardDiff.derivative(x -> evaluate(env, x), x) +evaluate_d(env::PolyEnvelope2sX, r::T, x::T) where T = + (zero(T), ForwardDiff.derivative(x -> evaluate(env, x), x)) + diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl index 08a0af62..d3cef0b4 100644 --- a/test/models/test_Rnl.jl +++ b/test/models/test_Rnl.jl @@ -1,7 +1,7 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using ACEpotentials @@ -31,6 +31,7 @@ Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) @info("Test derivatives of LearnableRnlrzzBasis") for ntest = 1:20 + local r, Zi, Zj, U, F, dF r = 2.0 + rand() Zi = rand(basis._i2z) Zj = rand(basis._i2z) @@ -46,12 +47,14 @@ println() @info("LearnableRnlrzz : Consistency of single and batched evaluation") for ntest = 1:20 + local Rs, Rnl, Zs, Z0, Nat, st1, ∇Rnl + Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(basis, Nat) rs = norm.(Rs) Rnl = [ M.evaluate(basis, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] - Rnl_b, st = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) + Rnl_b, st1 = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) @@ -61,6 +64,7 @@ for ntest = 1:20 print_tf(@test Rnl_b ≈ Rnl_b2) print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) end +println() ## @@ -70,6 +74,8 @@ basis_p = M.set_params(basis, ps) @info("Testing SplineRnlrzzBasis consistency via splinify") for ntest = 1:30 + local Nat, Rs, Zs, Zi, r, Zj, Rnl + Nat = 1 Rs, Zs, Zi = M.rand_atenv(basis, Nat) r = norm(Rs[1]) @@ -78,6 +84,7 @@ for ntest = 1:30 Rnl, _ = basis(r, Zi, Zj, ps, st) for (nnodes, tol) in [(30, 1e-3), (100, 1e-5), (1000, 1e-8)] + local basis_spl, ps_spl, st_spl, Rnl_spl basis_spl = M.splinify(basis_p; nnodes = nnodes) ps_spl, st_spl = LuxCore.setup(rng, basis_spl) @@ -88,6 +95,7 @@ for ntest = 1:30 # @show norm(rel_err, 1) / length(Rnl) end end +println() ## @@ -97,6 +105,8 @@ basis_p = M.set_params(basis, ps) basis_spl = M.splinify(basis_p; nnodes = 100) for ntest = 1:20 + local Rs, Zs, Zi, Zj, r, Rnl, U, F, dF + Rs, Zs, Zi = M.rand_atenv(basis_spl, 1) r = norm(Rs[1]); Zj = Zs[1] Rnl = basis_spl(r, Zi, Zj, ps, st)[1] @@ -115,12 +125,14 @@ basis_p = M.set_params(basis, ps) basis_spl = M.splinify(basis_p; nnodes = 100) for ntest = 1:20 + local Rnl, Rs, Zs, Z0, Nat, st1, ∇Rnl + Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(basis_spl, Nat) rs = norm.(Rs) Rnl = [ M.evaluate(basis_spl, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] - Rnl_b, st = M.evaluate_batched(basis_spl, rs, Z0, Zs, ps, st) + Rnl_b, st1 = M.evaluate_batched(basis_spl, rs, Z0, Zs, ps, st) print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis_spl, rs, Z0, Zs, ps, st) @@ -130,3 +142,4 @@ for ntest = 1:20 print_tf(@test Rnl_b ≈ Rnl_b2) print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) end +println() \ No newline at end of file diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 6e48564a..cf26f3fe 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using Test, ACEbase @@ -26,23 +26,22 @@ model = M.ace_model(; elements = elements, order = order, Ytype = :solid, ps, st = LuxCore.setup(rng, model) -# TODO: the number of parameters is completely off, so something is -# likely wrong here. - ## @info("Test Rotation-Invariance of the Model") for ntest = 1:50 + local st1, Nat, Rs, Zs, Z0 + Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(model, Nat) - val, st = M.evaluate(model, Rs, Zs, Z0, ps, st) + val, st1 = M.evaluate(model, Rs, Zs, Z0, ps, st) p = shuffle(1:Nat) Rs1 = Ref(M.rand_iso()) .* Rs[p] Zs1 = Zs[p] - val1, st = M.evaluate(model, Rs1, Zs1, Z0, ps, st) + val1, st1 = M.evaluate(model, Rs1, Zs1, Z0, ps, st) print_tf(@test abs(val - val1) < 1e-10) end @@ -57,6 +56,8 @@ Ei1, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) println_slim(@test Ei ≈ Ei1) for ntest = 1:20 + local Nat, Rs, Zs, z0, Us, F, dF + Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) Us = randn(SVector{3, Float64}, Nat) @@ -76,6 +77,8 @@ Ei1, ∇Ei, st = M.grad_params(model, Rs, Zs, z0, ps, st) println_slim(@test Ei ≈ Ei1) for ntest = 1:20 + local Nat, Rs, Zs, z0, pvec, uvec, F, dF0 + Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) pvec, _restruct = destructure(ps) @@ -84,11 +87,15 @@ for ntest = 1:20 dF0 = dot( destructure( M.grad_params(model, Rs, Zs, z0, ps, st)[2] )[1], uvec ) print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose = false)) end +println() ## @info("Test second mixed derivatives reverse-over-reverse") for ntest = 1:20 + local Nat, Rs, Zs, Us, Ei, ∂Ei, ∂2_Ei, + ps_vec, vs_vec, F, dF0, z0 + Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) Us = randn(SVector{3, Float64}, Nat) @@ -109,33 +116,38 @@ for ntest = 1:20 dF0 = dot(∂2_∇Ei_vec, vs_vec) print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false)) end - +println() ## @info("Test basis implementation") for ntest = 1:30 + local Nat, Rs, Zs, z0, Ei, B, θ, st1 , ∇Ei + Nat = 15 Rs, Zs, z0 = M.rand_atenv(model, Nat) i_z0 = M._z2i(model, z0) - Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) - B, st = M.evaluate_basis(model, Rs, Zs, z0, ps, st) + Ei, st1 = M.evaluate(model, Rs, Zs, z0, ps, st) + B, st1 = M.evaluate_basis(model, Rs, Zs, z0, ps, st) θ = vcat(ps.WB...) print_tf(@test Ei ≈ dot(B, θ)) - Ei, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) - B, ∇B, st = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) + Ei, ∇Ei, st1 = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + B, ∇B, st1 = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) θ = vcat(ps.WB...) print_tf(@test Ei ≈ dot(B, θ)) print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) end +println() ## @info("Test the full mixed jacobian") for ntest = 1:30 + local Nat, Rs, Zs, z0, Ei, ∇Ei, ∂∂Ei, Us, F, dF0 + Nat = 15 Rs, Zs, z0 = M.rand_atenv(model, Nat) Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat) @@ -144,10 +156,12 @@ for ntest = 1:30 ∂∂Ei = M.jacobian_grad_params(model, Rs, Zs, z0, ps, st)[3] print_tf(@test dF0 ≈ transpose.(∂∂Ei) * Us) end +println() ## +#= @info("Basic performance benchmarks") # first test shows the performance is not at all awful even without any # optimizations and reductions in memory allocations. @@ -168,3 +182,4 @@ print(" evaluate_basis : "); @btime M.evaluate_basis($model, $Rs, $Zs, $z0, print(" evaluate_basis_ed : "); @btime M.evaluate_basis_ed($model, $Rs, $Zs, $z0, $ps, $st) print("jacobian_grad_params : "); @btime M.jacobian_grad_params($model, $Rs, $Zs, $z0, $ps, $st) +=# \ No newline at end of file diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index ab8424b3..8ab48aac 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -1,6 +1,6 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using Test, ACEbase @@ -40,6 +40,8 @@ calc = M.ACEPotential(model, ps, st) @info("Testing correctness of potential energy") for ntest = 1:20 + local Rs, Zs, z0, at + at = rattle!(bulk(:Si, cubic=true) * 2, 0.1) nlist = PairList(at, M.cutoff_radius(calc)) E = 0.0 @@ -52,6 +54,7 @@ for ntest = 1:20 print_tf(@test abs(E - ustrip(efv.energy))/abs(E) < 1e-12) print_tf(@test abs(E - ustrip(E2)) / abs(E) < 1e-12) end +println() ## @@ -77,6 +80,8 @@ at = rattle!(bulk(:Si, cubic=true), 0.1) @info("test consistency of forces with energy") @info(" TODO: write virial test!") for ntest = 1:10 + local at, Us, dF0, X0, F, Z + at = rattle!(bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) Z[[3,6,8]] .= 8 @@ -104,6 +109,8 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), (0.0 / u"eV", 1.0 / u"eV", 0.0 / u"eV/Å"), (0.0 / u"eV", 0.0 / u"eV", 1.0 / u"eV/Å"), (1.0 / u"eV", 0.1 / u"eV", 0.1 / u"eV/Å") ] + local at + # random structure at = rattle!(bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 9477a60e..f9fa57f4 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -1,7 +1,7 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using ACEpotentials @@ -31,6 +31,7 @@ Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) @info("Test derivatives of Rnlrzz basis") for ntest = 1:20 + global ps, st r = 2.0 + rand() Zi = rand(basis._i2z) Zj = rand(basis._i2z) @@ -46,6 +47,7 @@ println() @info("LearnableRnlrzz : Consistency of single and batched evaluation") for ntest = 1:20 + global ps, st Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(basis, Nat) rs = norm.(Rs) diff --git a/test/models/test_models.jl b/test/models/test_models.jl new file mode 100644 index 00000000..69b35002 --- /dev/null +++ b/test/models/test_models.jl @@ -0,0 +1,8 @@ + +@testset "Models" begin + @testset "Radial Envelopes" begin; include("test_radial_envelopes.jl"); end + @testset "Radial Transforms" begin; include("test_radial_transforms.jl"); end + @testset "Rnlrzz Basis" begin; include("test_Rnl.jl"); end + @testset "ACE Model" begin; include("test_ace.jl"); end + @testset "ACE Calculator" begin; include("test_calculator.jl"); end +end diff --git a/test/models/test_radial_transforms.jl b/test/models/test_radial_transforms.jl index e61626e7..0cb3952e 100644 --- a/test/models/test_radial_transforms.jl +++ b/test/models/test_radial_transforms.jl @@ -1,6 +1,6 @@ -using Pkg; Pkg.activate("."); -using TestEnv; TestEnv.activate(); +# using Pkg; Pkg.activate("."); +# using TestEnv; TestEnv.activate(); using ACEpotentials, Test diff --git a/test/runtests.jl b/test/runtests.jl index 2010ec7c..8bb0522a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ using ACEpotentials, Test, LazyArtifacts # experimental @testset "UF_ACE" begin include("test_uface.jl") end + include("models/test_models.jl") + # outdated @testset "Read data" begin include("outdated/test_data.jl") end @testset "Basis" begin include("outdated/test_basis.jl") end From 40fb1887d9232164959bb3a9f4e602904422604f Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 15 May 2024 14:09:47 -0700 Subject: [PATCH 027/112] most of the way to include pair potentials --- src/models/Rnl_basis.jl | 5 +- src/models/ace.jl | 119 +++++++++++++++++++++++++++------ src/models/ace_heuristics.jl | 45 ++++++++++++- test/models/test_ace.jl | 7 +- test/models/test_models.jl | 1 + test/models/test_pair_basis.jl | 80 ++++++++++++++++++++++ 6 files changed, 227 insertions(+), 30 deletions(-) create mode 100644 test/models/test_pair_basis.jl diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index c90b4e0a..2e61bd2d 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -81,6 +81,9 @@ _transform_zz(obj, zi, zj) = obj.transforms[_z2i(obj, zi), _z2i(obj, zj)] _get_T(basis::LearnableRnlrzzBasis) = typeof(basis.rin0cuts[1,1].rin) +splinify(basis::SplineRnlrzzBasis; kwargs...) = basis + + function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) # transform : r ∈ [rin, rcut] -> x @@ -105,7 +108,7 @@ function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) rin, rcut = rin0cut.rin, rin0cut.rcut Tij = basis.transforms[iz0, iz1] - Wnlq_ij = @view basis.weights[:, :, iz0, iz1] + Wnlq_ij = @view basis.weights[:, :, iz0, iz1] Rnl = [ SVector{LEN}( Wnlq_ij * Polynomials4ML.evaluate(polys, x) ) for x in x_nodes ] diff --git a/src/models/ace.jl b/src/models/ace.jl index 7070a6af..d395a550 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -21,11 +21,14 @@ import Polynomials4ML # ACE MODEL SPECIFICATION -struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer{(:rbasis,)} +struct ACEModel{NZ, TRAD, TY, TA, TAA, T, TPAIR} <: AbstractExplicitContainerLayer{(:rbasis,)} _i2z::NTuple{NZ, Int} # -------------- + # embeddings of the particles rbasis::TRAD ybasis::TY + # -------------- + # the tensor format abasis::TA aabasis::TAA A2Bmap::SparseMatrixCSC{T, Int} @@ -36,8 +39,8 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T} <: AbstractExplicitContainerLayer{(:rb # aaparams::NTuple{NZ, Vector{T}} (not used right now) # -------------- # pair potential - # pairbasis::TPAIR - # pairparams::Matrix{T} + pairbasis::TPAIR + pairparams::Matrix{T} # -------------- meta::Dict{String, Any} end @@ -88,7 +91,8 @@ end function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, - level = TotalDegree()) + level = TotalDegree(), + pair_basis = nothing ) # generate the coupling coefficients cgen = EquivariantModels.Rot3DCoeffs_real(0) AA2BB_map = EquivariantModels._rpi_A2B_matrix(cgen, AA_spec) @@ -128,10 +132,14 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, aa_basis = Polynomials4ML.SparseSymmProdDAG(AA_spec_idx) aa_basis.meta["AA_spec"] = AA_spec # (also store the human-readable spec) - NZ = _get_nz(rbasis) - n_B_params, n_AA_params = size(AA2BB_map) - return ACEModel(rbasis._i2z, rbasis, ybasis, a_basis, aa_basis, AA2BB_map, - zeros(n_B_params, NZ), + # NZ = _get_nz(rbasis) + # n_B_params, n_AA_params = size(AA2BB_map) + + + + return ACEModel(rbasis._i2z, rbasis, ybasis, + a_basis, aa_basis, AA2BB_map, zeros(0,0), + pair_basis, zeros(0,0), # ntuple(_ -> zeros(n_AA_params), NZ), Dict{String, Any}() ) end @@ -139,8 +147,8 @@ end # TODO: it is not entirely clear that the `level` is really needed here # since it is implicitly already encoded in AA_spec. We need a # function `auto_level` that generates level automagically from AA_spec. -function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level) - return _generate_ace_model(rbasis, Ytype, AA_spec, level) +function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level, pair_basis) + return _generate_ace_model(rbasis, Ytype, AA_spec, level, pair_basis) end # NOTE : a nicer convenience constructor is also provided in `ace_heuristics.jl` @@ -150,6 +158,16 @@ end # ------------------------------------------------------------ # Lux stuff +function _W_init(str) + if str == "zeros" + return (rng, T, args...) -> zeros(T, args...) + elseif str == "glorot_normal" + return glorot_normal + else + error("unknown `init_WB` = $str") + end +end + function initialparameters(rng::AbstractRNG, model::ACEModel) NZ = _get_nz(model) @@ -159,26 +177,30 @@ function initialparameters(rng::AbstractRNG, # via the B params. # there are different ways to initialize parameters - if model.meta["init_WB"] == "zeros" - winit = zeros - elseif model.meta["init_WB"] == "glorot_normal" - winit = glorot_normal - else - error("unknown `init_WB` = $(model.meta["init_WB"])") - end - + winit = _W_init(model.meta["init_WB"]) WB = zeros(n_B_params, NZ) for iz = 1:NZ WB[:, iz] .= winit(rng, Float64, n_B_params) end - return (WB = WB, - rbasis = initialparameters(rng, model.rbasis), ) + # generate pair basis parameters + n_pair = length(model.pairbasis) + Wpair = zeros(n_pair, NZ) + winit_pair = _W_init(model.meta["init_Wpair"]) + + for iz = 1:NZ + Wpair[:, iz] .= winit_pair(rng, Float64, n_pair) + end + + return (WB = WB, Wpair = Wpair, + rbasis = initialparameters(rng, model.rbasis), + pairbasis = initialparameters(rng, model.pairbasis), ) end function initialstates(rng::AbstractRNG, model::ACEModel) - return ( rbasis = initialstates(rng, model.rbasis), ) + return ( rbasis = initialstates(rng, model.rbasis), + pairbasis = initialstates(rng, model.pairbasis), ) end (l::ACEModel)(args...) = evaluate(l, args...) @@ -233,6 +255,16 @@ function evaluate(model::ACEModel, # contract with params i_z0 = _z2i(model.rbasis, Z0) val = dot(B, (@view ps.WB[:, i_z0])) + + # ------------------- + # pair potential + if model.pairbasis != nothing + Rpair, _ = evaluate_batched(model.pairbasis, rs, Z0, Zs, + ps.pairbasis, st.pairbasis) + Apair = sum(Rpair, dims=1)[:] + val += dot(Apair, (@view ps.Wpair[:, i_z0])) + end + # ------------------- return val, st end @@ -281,6 +313,7 @@ function evaluate_ed(model::ACEModel, i_z0 = _z2i(model.rbasis, Z0) Ei = dot(B, (@view ps.WB[:, i_z0])) + # ---------- BACKWARD PASS ------------ # ∂Ei / ∂B = WB[i_z0] @@ -311,6 +344,27 @@ function evaluate_ed(model::ACEModel, sum(∂Ylm[j, :] .* dYlm[j, :]) end + + # ------------------- + # pair potential + if model.pairbasis != nothing + Rpair, dRpair, _ = evaluate_ed_batched(model.pairbasis, rs, Z0, Zs, + ps.pairbasis, st.pairbasis) + Apair = sum(Rpair, dims=1)[:] + Wp_i = @view ps.Wpair[:, i_z0] + Ei += dot(Apair, Wp_i) + + # pullback --- I'm now assuming that the pair basis is not learnable. + if !( ps.pairbasis == NamedTuple() ) + error("I'm currently assuming the pair basis is not learnable.") + end + + for j = 1:length(Rs) + ∇Ei[j] += dot(Wp_i, (@view dRpair[j, :])) * (Rs[j] / rs[j]) + end + end + # ------------------- + return Ei, ∇Ei, st end @@ -391,7 +445,28 @@ function grad_params(model::ACEModel, # = pullback(∂Rnl, rbasis, args...) _, _, _, _, ∂Wqnl, _ = pb_Rnl(∂Rnl) # this should be a named tuple already. - return Ei, (WB = ∂WB, rbasis = ∂Wqnl), st + + # ------------------- + # pair potential + if model.pairbasis != nothing + Rpair, _ = evaluate_batched(model.pairbasis, rs, Z0, Zs, + ps.pairbasis, st.pairbasis) + Apair = sum(Rpair, dims=1)[:] + Wp_i = @view ps.Wpair[:, i_z0] + Ei += dot(Apair, Wp_i) + + # pullback --- I'm now assuming that the pair basis is not learnable. + if !( ps.pairbasis == NamedTuple() ) + error("I'm currently assuming the pair basis is not learnable.") + end + + ∂Wpair = zeros(eltype(Apair), size(ps.Wpair)) + ∂Wpair[:, i_z0] = Apair + end + # ------------------- + + + return Ei, (WB = ∂WB, Wpair = ∂Wpair, rbasis = ∂Wqnl), st end diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index d97e4408..60a6f0f9 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -1,4 +1,4 @@ - +import Random # -------------------------------------------------- # different notions of "level" / total degree. @@ -42,7 +42,7 @@ function ace_learnable_Rnlrzz(; rin0cuts = _default_rin0cuts(elements), transforms = agnesi_transform.(rin0cuts, 2, 2), polys = :legendre, - envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) + envelopes = :poly2sx ) if elements == nothing error("elements must be specified!") @@ -52,6 +52,7 @@ function ace_learnable_Rnlrzz(; end zlist =_convert_zlist(elements) + NZ = length(zlist) if spec == nothing spec = [ (n = n, l = l) for n = 1:maxn, l = 0:maxl @@ -69,6 +70,19 @@ function ace_learnable_Rnlrzz(; end end + if transforms isa Tuple && transforms[1] == :agnesi + p = transforms[2] + q = transforms[3] + transforms = agnesi_transform.(rin0cuts, p, q) + end + + if envelopes == :poly2sx + envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) + elseif envelopes == :poly1sr + envelopes = [ PolyEnvelope1sR(rin0cuts[iz, jz].rcut, 1) + for iz = 1:NZ, jz = 1:NZ ] + end + if actual_maxn > length(polys) error("actual_maxn > length of polynomial basis") end @@ -90,6 +104,11 @@ function ace_model(; elements = nothing, level = nothing, max_level = nothing, init_WB = :zeros, + # pair basis + pair_maxn = nothing, + pair_basis = :auto, + init_Wpair = :zeros, + rng = Random.default_rng(), ) # construct an rbasis if needed if isnothing(rbasis) @@ -102,11 +121,31 @@ function ace_model(; elements = nothing, end end + # construct a pair basis if needed + if pair_basis == :auto + @assert pair_maxn isa Integer + @show pair_maxn + + pair_basis = ace_learnable_Rnlrzz(; + elements = rbasis._i2z, + level = TotalDegree(), + max_level = pair_maxn, + maxl = 0, + maxn = pair_maxn, + rin0cuts = rbasis.rin0cuts, + transforms = (:agnesi, 1, 4), + envelopes = :poly1sr ) + end + + ps_pair = initialparameters(rng, pair_basis) + pair_basis_spl = splinify(set_params(pair_basis, ps_pair)) + AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, level = level, max_level = max_level) - model = ace_model(rbasis, Ytype, AA_spec, level) + model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis_spl) model.meta["init_WB"] = String(init_WB) + model.meta["init_Wpair"] = String(init_Wpair) return model end diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index cf26f3fe..d4df7c6f 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -1,5 +1,5 @@ -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using Test, ACEbase @@ -22,11 +22,11 @@ order = 3 model = M.ace_model(; elements = elements, order = order, Ytype = :solid, level = level, max_level = max_level, maxl = 8, - init_WB = :glorot_normal) + pair_maxn = 15, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) ps, st = LuxCore.setup(rng, model) - ## @info("Test Rotation-Invariance of the Model") @@ -57,7 +57,6 @@ println_slim(@test Ei ≈ Ei1) for ntest = 1:20 local Nat, Rs, Zs, z0, Us, F, dF - Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) Us = randn(SVector{3, Float64}, Nat) diff --git a/test/models/test_models.jl b/test/models/test_models.jl index 69b35002..4479d46e 100644 --- a/test/models/test_models.jl +++ b/test/models/test_models.jl @@ -3,6 +3,7 @@ @testset "Radial Envelopes" begin; include("test_radial_envelopes.jl"); end @testset "Radial Transforms" begin; include("test_radial_transforms.jl"); end @testset "Rnlrzz Basis" begin; include("test_Rnl.jl"); end + @testset "Pair Basis" begin; include("test_pair_basis.jl"); end @testset "ACE Model" begin; include("test_ace.jl"); end @testset "ACE Calculator" begin; include("test_calculator.jl"); end end diff --git a/test/models/test_pair_basis.jl b/test/models/test_pair_basis.jl new file mode 100644 index 00000000..d3f145a2 --- /dev/null +++ b/test/models/test_pair_basis.jl @@ -0,0 +1,80 @@ + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); + +using ACEpotentials +M = ACEpotentials.Models + +using Random, LuxCore, Test, ACEbase, LinearAlgebra +using ACEbase.Testing: print_tf +rng = Random.MersenneTwister(1234) + +## + +max_level = 16 +level = M.TotalDegree() +maxl = 0; maxn = max_level; +elements = (:Si, :O) +basis = M.ace_learnable_Rnlrzz(; level=level, max_level=max_level, + maxl = maxl, maxn = maxn, + elements = elements, + transforms = (:agnesi, 1, 4), + envelopes = :poly1sr ) + +ps, st = LuxCore.setup(rng, basis) + +r = 3.0 +Zi = basis._i2z[1] +Zj = basis._i2z[2] +Rnl1, st1 = basis(r, Zi, Zj, ps, st) +Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) + +basis_p = M.set_params(basis, ps) +basis_spl = M.splinify(basis_p) +ps_spl, st_spl = LuxCore.setup(rng, basis_spl) + +Rnl2, _ = M.evaluate(basis_spl, r, Zi, Zj, ps_spl, st_spl) +Rnl2, Rnl_d2, _ = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) + +## +# inspect the basis visually + +# using Plots + +# rr = range(0.1, 6.0, length=300) +# Zs = fill(14, length(rr)) +# z0 = 14 +# Rnl, _ = M.evaluate_batched(basis_spl, rr, z0, Zs, ps, st) +# env_rr = M.evaluate.(Ref(basis_spl.envelopes[1,1]), rr, 0.0) + +# plt1 = plot(; ylims = (-2.0, 5.0), ) +# plt2 = plot(; ylims = (-3.0, 3.0), ) +# for n = 1:5 +# plot!(plt1, rr, Rnl[:, n], label = "n=$n") +# plot!(plt2, rr, Rnl[:, n] ./ env_rr, label ="") +# end +# vline!(plt1, [basis_spl.rin0cuts[1,1].r0], label = "r0") +# vline!(plt2, [basis_spl.rin0cuts[1,1].r0], label = "") + +# plot(plt1, plt2, layout = (2,1)) + + +## + + +@info("Test derivatives of Spline Rnl Basis for Pairpot") + +for ntest = 1:20 + local r, Zi, Zj, U, F, dF + Zi = rand(basis_spl._i2z) + Zj = rand(basis_spl._i2z) + r = 2.0 + rand() + U = randn(length(basis_spl)) + F(t) = dot(U, basis_spl(r + t, Zi, Zj, ps, st)[1]) + dF(t) = dot(U, M.evaluate_ed(basis_spl, r + t, Zi, Zj, ps, st)[2]) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) +end +println() + +## + From 6f7f54c92e3c814a806f19f6730cd32d2f91a96a Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 15 May 2024 14:20:39 -0700 Subject: [PATCH 028/112] pair basis done --- src/models/ace.jl | 34 +++++++++++++++++++++++++++++++++- test/models/test_ace.jl | 3 +-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index d395a550..138a1f77 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -513,6 +513,29 @@ function get_basis_inds(model::ACEModel, Z) return (i_z - 1) * len_Bi .+ (1:len_Bi) end +function get_pairbasis_inds(model::ACEModel, Z) + len_Bi = size(model.A2Bmap, 1) + NZ = _get_nz(model) + len_B = NZ * len_Bi + + len_pair = length(model.pairbasis) + i_z = _z2i(model, Z) + return (len_B + (i_z - 1) * len_pair) .+ (1:len_pair) +end + +function len_basis(model::ACEModel) + len_Bi = size(model.A2Bmap, 1) + len_pair = length(model.pairbasis) + NZ = _get_nz(model) + return (len_Bi + len_pair) * NZ +end + + +function get_basis_params(model::ACEModel, ps, ) + # this is magically given by the basis ordering we picked + return vcat(ps.WB[:], ps.Wpair[:]) +end + function evaluate_basis(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} @@ -543,9 +566,18 @@ function evaluate_basis(model::ACEModel, # evaluate the coupling coefficients # TODO: use Bumper and do it in-place Bi = model.A2Bmap * AA - B = zeros(eltype(Bi), length(Bi) * _get_nz(model)) + B = zeros(eltype(Bi), len_basis(model)) B[get_basis_inds(model, Z0)] .= Bi + # ------------------- + # pair potential + if model.pairbasis != nothing + Rpair, _ = evaluate_batched(model.pairbasis, rs, Z0, Zs, + ps.pairbasis, st.pairbasis) + Apair = sum(Rpair, dims=1)[:] + B[get_pairbasis_inds(model, Z0)] .= Apair + end + return B, st end diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index d4df7c6f..9a3b18f9 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -129,12 +129,11 @@ for ntest = 1:30 i_z0 = M._z2i(model, z0) Ei, st1 = M.evaluate(model, Rs, Zs, z0, ps, st) B, st1 = M.evaluate_basis(model, Rs, Zs, z0, ps, st) - θ = vcat(ps.WB...) + θ = M.get_basis_params(model, ps) print_tf(@test Ei ≈ dot(B, θ)) Ei, ∇Ei, st1 = M.evaluate_ed(model, Rs, Zs, z0, ps, st) B, ∇B, st1 = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) - θ = vcat(ps.WB...) print_tf(@test Ei ≈ dot(B, θ)) print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) end From 1fa798c6edbc4776527ccd17835826f2a57449c2 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 15 May 2024 15:10:30 -0700 Subject: [PATCH 029/112] some more test cleanup --- Project.toml | 3 ++- src/models/ace_heuristics.jl | 1 - src/models/calculators.jl | 4 +++- test/Project.toml | 17 +++++++++++++++++ test/models/test_Rnl.jl | 4 ++-- test/models/test_ace.jl | 4 ++-- test/models/test_calculator.jl | 19 ++++++++++--------- test/models/test_pair_basis.jl | 2 +- 8 files changed, 37 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index c6a14d6f..601ba7e5 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ACE1 = "0.12" +ACE1 = "0.12.2" ACE1x = "0.1.8" ACEfit = "0.1.4" ACEmd = "0.1.6" @@ -59,6 +59,7 @@ StaticArrays = "1" UltraFastACE = "0.0.2" YAML = "0.4" julia = "1.9" +Interpolations = "0.14.7" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 60a6f0f9..af7ac7e4 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -124,7 +124,6 @@ function ace_model(; elements = nothing, # construct a pair basis if needed if pair_basis == :auto @assert pair_maxn isa Integer - @show pair_maxn pair_basis = ace_learnable_Rnlrzz(; elements = rbasis._i2z, diff --git a/src/models/calculators.jl b/src/models/calculators.jl index b69bcf6d..d0516e93 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -94,7 +94,9 @@ function energy_forces_virial_serial( virial += _site_virial(dv, Rs) release!(Js); release!(Rs); release!(Zs) end - return (energy = energy, forces = forces, virial = virial) + return (energy = energy * energy_unit(V), + forces = forces * force_unit(V), + virial = virial * energy_unit(V) ) end diff --git a/test/Project.toml b/test/Project.toml index c13e4e29..354eea3e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,27 @@ [deps] ACE1 = "e3f9bc04-086e-409a-ba78-e9769fe067bb" ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" +ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" +AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JuLIP = "945c410c-986d-556a-acb1-167a618e0462" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[compat] +ACE1 = "0.12.2" +ACE1x = "0.1.8" +JuLIP = "0.13.9, 0.14.2" +StaticArrays = "1" +Interpolations = "0.14.7" \ No newline at end of file diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl index d3cef0b4..40635aba 100644 --- a/test/models/test_Rnl.jl +++ b/test/models/test_Rnl.jl @@ -47,7 +47,7 @@ println() @info("LearnableRnlrzz : Consistency of single and batched evaluation") for ntest = 1:20 - local Rs, Rnl, Zs, Z0, Nat, st1, ∇Rnl + local Rs, Rnl, Zs, Z0, Nat, st1, ∇Rnl, rs Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(basis, Nat) @@ -125,7 +125,7 @@ basis_p = M.set_params(basis, ps) basis_spl = M.splinify(basis_p; nnodes = 100) for ntest = 1:20 - local Rnl, Rs, Zs, Z0, Nat, st1, ∇Rnl + local Rnl, Rs, Zs, Z0, Nat, st1, ∇Rnl, rs Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(basis_spl, Nat) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 9a3b18f9..f5e556b5 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using Test, ACEbase @@ -32,7 +32,7 @@ ps, st = LuxCore.setup(rng, model) @info("Test Rotation-Invariance of the Model") for ntest = 1:50 - local st1, Nat, Rs, Zs, Z0 + local st1, Nat, Rs, Zs, Z0, val Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(model, Nat) diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 8ab48aac..a769a938 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -29,7 +29,8 @@ order = 3 model = M.ace_model(; elements = elements, order = order, Ytype = :solid, level = level, max_level = max_level, maxl = 8, - init_WB = :glorot_normal) + pair_maxn = 15, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) ps, st = LuxCore.setup(rng, model) @@ -64,16 +65,16 @@ println() at = rattle!(bulk(:Si, cubic=true), 0.1) @info(" consistency local vs EmpiricalPotentials implementation") -@info("this currently fails due to a bug in EmpiricalPotentials") -# efv1 = M.energy_forces_virial(at, calc, ps, st) -# efv2 = AtomsCalculators.energy_forces_virial(at_flex, calc) -# efv3 = M.energy_forces_virial_serial(at, calc, ps, st) -# print_tf(@test efv1.energy ≈ efv2.energy) +@info("the force test currently fails due to a bug in EmpiricalPotentials") +efv1 = M.energy_forces_virial(at, calc, ps, st) +efv2 = AtomsCalculators.energy_forces_virial(at, calc) +efv3 = M.energy_forces_virial_serial(at, calc, ps, st) +print_tf(@test efv1.energy ≈ efv2.energy) # print_tf(@test all(efv1.forces .≈ efv2.force)) -# print_tf(@test efv1.virial ≈ efv1.virial) -# print_tf(@test efv1.energy ≈ efv3.energy) +print_tf(@test efv1.virial ≈ efv1.virial) +print_tf(@test efv1.energy ≈ efv3.energy) # print_tf(@test all(efv1.forces .≈ efv3.forces)) -# print_tf(@test efv1.virial ≈ efv3.virial) +print_tf(@test efv1.virial ≈ efv3.virial) ## diff --git a/test/models/test_pair_basis.jl b/test/models/test_pair_basis.jl index d3f145a2..d7b44f7a 100644 --- a/test/models/test_pair_basis.jl +++ b/test/models/test_pair_basis.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using ACEpotentials From 929ed11759adfccb30b022f5bb9f2c7f6743f707 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 15 May 2024 15:40:24 -0700 Subject: [PATCH 030/112] towards a tutorial for the new models --- docs/src/newkernels/Project.toml | 8 +++ docs/src/newkernels/newkernels.jl | 89 +++++++++++++++++++++++++++++++ test/models/test_calculator.jl | 5 +- 3 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 docs/src/newkernels/Project.toml create mode 100644 docs/src/newkernels/newkernels.jl diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml new file mode 100644 index 00000000..a852b2a7 --- /dev/null +++ b/docs/src/newkernels/Project.toml @@ -0,0 +1,8 @@ +[deps] +ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" +AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl new file mode 100644 index 00000000..62d4a86e --- /dev/null +++ b/docs/src/newkernels/newkernels.jl @@ -0,0 +1,89 @@ +# This script is to roughly document how to use the new model implementations +# I'll try to explain what can be done and what is missing along the way. +# I am + +using ACEpotentials, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Random + +bulk = AtomsBuilder.bulk +rattle! = AtomsBuilder.rattle! + +# because the new implementation is experimental, it is not exported, +# so I create a little shortcut to have easy access. + +M = ACEpotentials.Models + +# The new implementation tries to follow Lux rules, which likes to be +# disciplined and explicit about random numbers + +rng = Random.MersenneTwister(1234) + + +# I'll create a new model for a simple alloy and then generate a model. +# this generates a trace-like ACE model with a random radial basis. + +elements = (:Al, :Ti) + +model = M.ace_model(; elements = elements, + order = 3, # correlation order + Ytype = :solid, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = 15, # maximum level of the basis functions + pair_maxn = 15, # maximum number of basis functions for the pair potential + init_WB = :glorot_normal, # how to initialize the ACE basis parmeters + init_Wpair = :glorot_normal # how to initialize the pair potential parameters + ) + +# the radial basis specification can be looked at explicitly via + +display(model.rbasis.spec) + +# we can see that it is defined as (n, l) pairs. Each `n` specifies an invariant +# channel coupled to an `l` channel. Each `Rnl` radial basis function is defined +# by `Rnl(r, Zi, Zj) = ∑_q W_nlq(Zi, Zj) * P_q(r)`. + +# some things that are missing: +# - reweighting the basis via a smoothness prior. +# - allow initialization of pair potential basis with one-hot embedding params +# right now the pair potential basis uses trace-like radials +# - convenient ways to inspect the many-body basis specification. + +# Lux wants us to call a setup function to generate the parameters and state +# for the model. + +ps, st = Lux.setup(rng, model) + +# From the model we generate a calculator. This step should probably be integrated. +# into `ace_model`, we can discuss it. + +calc = M.ACEPotential(model, ps, st) + +# We can now treat `calc` as a nonlinear parameterized site potential model. +# - generate a random Al, Ti structure +# - calculate the energy, forces, and virial +# An important point to note is that AtomsBase enforces the use of units. + +function rand_AlTi(nrep, rattle) + # Al : 13; Ti : 22 + at = rattle!(bulk(:Al, cubic=true) * 2, 0.1) + Z = AtomsBuilder._get_atomic_numbers(at) + Z[rand(1:length(at), length(at) ÷ 2)] .= 22 + return AtomsBuilder._set_atomic_numbers(at, Z) +end + + +at = rand_AlTi(2, 0.1) + +efv = M.energy_forces_virial(at, calc, ps, st) + +@info("Energy") +display(efv.energy) + +@info("Virial") +display(efv.virial) + +@info("Forces (on atoms 1..5)") +display(efv.forces[1:5]) + + + diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index a769a938..f68781fc 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -65,15 +65,14 @@ println() at = rattle!(bulk(:Si, cubic=true), 0.1) @info(" consistency local vs EmpiricalPotentials implementation") -@info("the force test currently fails due to a bug in EmpiricalPotentials") efv1 = M.energy_forces_virial(at, calc, ps, st) efv2 = AtomsCalculators.energy_forces_virial(at, calc) efv3 = M.energy_forces_virial_serial(at, calc, ps, st) print_tf(@test efv1.energy ≈ efv2.energy) -# print_tf(@test all(efv1.forces .≈ efv2.force)) +print_tf(@test all(efv1.forces .≈ efv2.forces)) print_tf(@test efv1.virial ≈ efv1.virial) print_tf(@test efv1.energy ≈ efv3.energy) -# print_tf(@test all(efv1.forces .≈ efv3.forces)) +print_tf(@test all(efv1.forces .≈ efv3.forces)) print_tf(@test efv1.virial ≈ efv3.virial) ## From de815de8bb6d8983a8fddc75236d693f7eb09a73 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 15 May 2024 16:25:32 -0700 Subject: [PATCH 031/112] more on the tutorial --- docs/src/newkernels/Project.toml | 4 ++ docs/src/newkernels/newkernels.jl | 85 ++++++++++++++++++++++++++++++- src/models/ace.jl | 3 +- 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index a852b2a7..111198c7 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -1,8 +1,12 @@ [deps] ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index 62d4a86e..76656160 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -3,7 +3,7 @@ # I am using ACEpotentials, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, - Unitful, Random + Unitful, Random, Zygote, Optimisers bulk = AtomsBuilder.bulk rattle! = AtomsBuilder.rattle! @@ -85,5 +85,88 @@ display(efv.virial) @info("Forces (on atoms 1..5)") display(efv.forces[1:5]) +# we can incorporate the parameters and the state into the model struct +# but for now we ignore this possibility and focus on how to train a model. + +# we load our example dataset and convert it to AtomsBase + +data, _, meta = ACEpotentials.example_dataset("TiAl_tutorial") +train_data = FlexibleSystem.(data[1:5:end]) + +# to set up training we specify data keys and training weights. +# to get a unitless loss we need to specify the weights to have inverse +# units to the data. The local loss function is the loss applied to +# a single training structure. This follows the Lux training API +# loss(model, ps, st, data) + +loss = let data_keys = (E_key = :energy, F_key = :force, V_key = :virial), + weights = (wE = 1.0/u"eV", wF = 0.1 / u"eV/Å", wV = 0.1/u"eV") + + function(calc, ps, st, at) + efv = M.energy_forces_virial(at, calc, ps, st) + _norm_sq(f) = sum(abs2, f) + E_dft, F_dft, V_dft = Zygote.ignore() do # Zygote doesn't have an adjoint for creating units :( + ( at.data[data_keys.E_key] * u"eV", + at.data[data_keys.F_key] * u"eV/Å", + at.data[data_keys.V_key] * u"eV" ) + end + return ( weights[:wE]^2 * (efv.energy - E_dft)^2 / length(at) + + weights[:wV]^2 * sum(abs2, efv.virial - V_dft) / length(at) + + weights[:wF]^2 * sum(_norm_sq, efv.forces - F_dft) + ), st, () + end +end + +loss(calc, ps, st, at)[1] + + +# Zygote should now be able to differentiate this loss with respect to parameters +# the gradient is provided in the same format as the parameters, i.e. a NamedTuple. + +at1 = train_data[1] +g = Zygote.gradient(ps -> loss(calc, ps, st, at1)[1], ps)[1] + +@show typeof(g) + +# both parameters and gradients can be serialized into a vector and that +# allows us use of arbitrary optimizers + +ps_vec, _restruct = destructure(ps) +g_vec = destructure(g)[1] + +# Let's now try to optimize the model. Here I'm a bit hazy how to do this +# properly. I'm just modifying a Lux tutorial. There are probably better ways. +# https://github.com/LuxDL/Lux.jl/blob/main/examples/PolynomialFitting/main.jl + +using ADTypes, Printf +vjp_rule = AutoZygote() +opt = Optimisers.Adam() +opt_state = Optimisers.setup(opt, ps) +tstate = Lux.Experimental.TrainState(rng, calc.model, opt) + +function main(tstate, vjp, data, epochs) + for epoch in 1:epochs + grads, loss_val, stats, tstate = Lux.Experimental.compute_gradients( + vjp, loss, data, tstate) + if epoch % 10 == 1 || epoch == epochs + @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss_val + end + tstate = Lux.Experimental.apply_gradients!(tstate, grads) + end + return tstate +end + +main(tstate, vjp_rule, train_data, 100) + + +# the alternative might be to optimize using Optim.jl +using Optim + +adam = Optim.Adam() + +function total_loss(p_vec) + return sum( loss(calc, _restruct(p_vec), st, at)[1] for at in train_data ) +end +result = Optim.optimize(total_loss, ps_vec, adam) diff --git a/src/models/ace.jl b/src/models/ace.jl index 138a1f77..2685d066 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -466,7 +466,8 @@ function grad_params(model::ACEModel, # ------------------- - return Ei, (WB = ∂WB, Wpair = ∂Wpair, rbasis = ∂Wqnl), st + return Ei, (WB = ∂WB, Wpair = ∂Wpair, rbasis = ∂Wqnl, + pairbasis = NamedTuple()), st end From ce3e2921498e48051b37ced815478f20d1444147 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 15 May 2024 21:09:34 -0700 Subject: [PATCH 032/112] nonlinear regression example --- docs/src/newkernels/newkernels.jl | 68 +++++++++++++++++++++---------- src/models/calculators.jl | 2 + 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index 76656160..c9cb4b12 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -138,35 +138,59 @@ g_vec = destructure(g)[1] # properly. I'm just modifying a Lux tutorial. There are probably better ways. # https://github.com/LuxDL/Lux.jl/blob/main/examples/PolynomialFitting/main.jl -using ADTypes, Printf -vjp_rule = AutoZygote() -opt = Optimisers.Adam() -opt_state = Optimisers.setup(opt, ps) -tstate = Lux.Experimental.TrainState(rng, calc.model, opt) - -function main(tstate, vjp, data, epochs) - for epoch in 1:epochs - grads, loss_val, stats, tstate = Lux.Experimental.compute_gradients( - vjp, loss, data, tstate) - if epoch % 10 == 1 || epoch == epochs - @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss_val - end - tstate = Lux.Experimental.apply_gradients!(tstate, grads) - end - return tstate -end +# This is the Lux approach, which I couldn't get to work. + +# using ADTypes, Printf +# vjp_rule = AutoZygote() +# opt = Optimisers.Adam() +# opt_state = Optimisers.setup(opt, ps) +# tstate = Lux.Experimental.TrainState(rng, calc.model, opt) -main(tstate, vjp_rule, train_data, 100) +# function main(tstate, vjp, data, epochs) +# for epoch in 1:epochs +# grads, loss_val, stats, tstate = Lux.Experimental.compute_gradients( +# vjp, loss, data, tstate) +# if epoch % 10 == 1 || epoch == epochs +# @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss_val +# end +# tstate = Lux.Experimental.apply_gradients!(tstate, grads) +# end +# return tstate +# end + +# main(tstate, vjp_rule, train_data, 100) # the alternative might be to optimize using Optim.jl using Optim -adam = Optim.Adam() - function total_loss(p_vec) - return sum( loss(calc, _restruct(p_vec), st, at)[1] for at in train_data ) + return sum( loss(calc, _restruct(p_vec), st, at)[1] + for at in train_data ) end -result = Optim.optimize(total_loss, ps_vec, adam) +function total_loss_grad!(g, p_vec) + g[:] = Zygote.gradient(ps -> total_loss(ps), p_vec)[1] + return g +end + +@time total_loss(ps_vec) +@time total_loss(ps_vec) +@time total_loss_grad!(zeros(length(ps_vec)), ps_vec) +@time total_loss_grad!(zeros(length(ps_vec)), ps_vec) + +result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; + method = Optim.Adam(), + show_trace = true, + iterations = 100) + + +# Now that we've optimized the entire model a little bit +# we can think that the radial basis functions are sufficiently +# optimized. This is of course not true in this case since we didn't +# use enough iterations. But suppose we had converged the nonlinear +# optimization to get a really good radial basis. +# Then, in a second step we can freeze the radial basis and +# optimize the ACE basis coefficients via linear regression. + diff --git a/src/models/calculators.jl b/src/models/calculators.jl index d0516e93..9a35f491 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -34,6 +34,8 @@ distance_unit(::ACEPotential) = 1.0u"Å" force_unit(V) = energy_unit(V) / distance_unit(V) Base.zero(V::ACEPotential) = zero(energy_unit(V)) +initialparameters(rng::AbstractRNG, V::ACEPotential) = initialparameters(rng, V.model) +initialstates(rng::AbstractRNG, V::ACEPotential) = initialstates(rng, V.model) # --------------------------------------------------------------- # EmpiricalPotentials / SitePotential based implementation From e4e4f87fa6ec28eddf72ce7a5987f2d07b7be107 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 16 May 2024 09:56:50 -0700 Subject: [PATCH 033/112] remove params from models to strictly follow Lux; splinify convenience functions --- docs/src/newkernels/newkernels.jl | 25 +++++++++++++++++ src/models/Rnl_basis.jl | 26 +++++------------- src/models/Rnl_learnable.jl | 10 +++---- src/models/ace.jl | 21 ++++++++++++--- src/models/ace_heuristics.jl | 2 +- src/models/calculators.jl | 7 +++++ test/models/test_Rnl.jl | 10 +++---- test/models/test_ace.jl | 26 ++++++++++++++++-- test/models/test_calculator.jl | 39 ++++++++++++++++++++++----- test/models/test_pair_basis.jl | 3 +-- test/models/test_radial_envelopes.jl | 4 +-- test/models/test_radial_transforms.jl | 5 ++-- 12 files changed, 128 insertions(+), 50 deletions(-) diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index c9cb4b12..ce0573c3 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -194,3 +194,28 @@ result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; # Then, in a second step we can freeze the radial basis and # optimize the ACE basis coefficients via linear regression. +# as a first step, we replace the learnable radials with +# splined radials + +ps1_vec = result.minimizer +ps1 = _restruct(ps1_vec) + +rbasis_p = M.set_params(calc.model.rbasis, ps1.rbasis) +rbasis_spl = M.splinify(rbasis_p) + +# next we create a new ACE model with the splined radial basis +# this step should be moved into ACEpotentials.Models and +# automated. + +linmodel = M.ACEModel(calc.model._i2z, + rbasis_spl, + calc.model.ybasis, + calc.model.abasis, + calc.model.aabasis, + calc.model.A2Bmap, + calc.model.bparams, + calc.model.pairbasis, + calc.model.pairparams, + calc.model.meta) +lincalc = M.ACEPotential(linmodel) + diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 2e61bd2d..0303cddb 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -27,13 +27,13 @@ const SPL_OF_SVEC{DIM, T} = } -struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractExplicitLayer +struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, T} <: AbstractExplicitLayer _i2z::NTuple{NZ, Int} polys::TPOLY transforms::SMatrix{NZ, NZ, TT} envelopes::SMatrix{NZ, NZ, TENV} # -------------- - weights::Array{TW, 4} # learnable weights, `nothing` when using Lux + # weights::Array{TW, 4} # learnable weights, `nothing` when using Lux rin0cuts::SMatrix{NZ, NZ, NT_RIN0CUTS{T}} # matrix of (rin, rout, rcut) spec::Vector{NT_NL_SPEC} # -------------- @@ -41,19 +41,6 @@ struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, TW, T} <: AbstractExplicitLayer meta::Dict{String, Any} end -function set_params(basis::LearnableRnlrzzBasis, ps) - return LearnableRnlrzzBasis(basis._i2z, - basis.polys, - basis.transforms, - basis.envelopes, - # --------------- - ps.Wnlq, - basis.rin0cuts, - basis.spec, - # --------------- - basis.meta) -end - struct SplineRnlrzzBasis{NZ, TT, TENV, LEN, T} <: AbstractExplicitLayer _i2z::NTuple{NZ, Int} @@ -81,10 +68,11 @@ _transform_zz(obj, zi, zj) = obj.transforms[_z2i(obj, zi), _z2i(obj, zj)] _get_T(basis::LearnableRnlrzzBasis) = typeof(basis.rin0cuts[1,1].rin) -splinify(basis::SplineRnlrzzBasis; kwargs...) = basis + +splinify(basis::SplineRnlrzzBasis, ps; kwargs...) = basis -function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) +function splinify(basis::LearnableRnlrzzBasis, ps; nnodes = 100) # transform : r ∈ [rin, rcut] -> x # and then Rnl = Wnl_q * Pq(x) * env(x) gives the basis. @@ -98,7 +86,7 @@ function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) NZ = _get_nz(basis) T = _get_T(basis) - LEN = size(basis.weights, 1) + LEN = size(ps.Wnlq, 1) _splines = Matrix{SPL_OF_SVEC{LEN, T}}(undef, (NZ, NZ)) x_nodes = range(-1.0, 1.0, length = nnodes) polys = basis.polys @@ -108,7 +96,7 @@ function splinify(basis::LearnableRnlrzzBasis; nnodes = 100) rin, rcut = rin0cut.rin, rin0cut.rcut Tij = basis.transforms[iz0, iz1] - Wnlq_ij = @view basis.weights[:, :, iz0, iz1] + Wnlq_ij = @view ps.Wnlq[:, :, iz0, iz1] Rnl = [ SVector{LEN}( Wnlq_ij * Polynomials4ML.evaluate(polys, x) ) for x in x_nodes ] diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index e1ac97ff..921cf4ec 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -8,18 +8,18 @@ import LuxCore function LearnableRnlrzzBasis( zlist, polys, transforms, envelopes, rin0cuts, spec::AbstractVector{NT_NL_SPEC}; - weights=nothing, + # weights=nothing, meta=Dict{String, Any}()) NZ = length(zlist) - if isnothing(weights) - weights = fill(nothing, (1,1,1,1)) - end + # if isnothing(weights) + # weights = fill(nothing, (1,1,1,1)) + # end LearnableRnlrzzBasis(_convert_zlist(zlist), polys, _make_smatrix(transforms, NZ), _make_smatrix(envelopes, NZ), # -------------- - weights, + # weights, _make_smatrix(rin0cuts, NZ), collect(spec), meta) diff --git a/src/models/ace.jl b/src/models/ace.jl index 2685d066..f5fbba3f 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -35,12 +35,12 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T, TPAIR} <: AbstractExplicitContainerLay # -------------- # we can add a nonlinear embedding here # -------------- - bparams::Matrix{T} # : x NZ matrix of B parameters + # bparams::Matrix{T} # : x NZ matrix of B parameters # aaparams::NTuple{NZ, Vector{T}} (not used right now) # -------------- # pair potential pairbasis::TPAIR - pairparams::Matrix{T} + # pairparams::Matrix{T} # -------------- meta::Dict{String, Any} end @@ -138,8 +138,8 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, return ACEModel(rbasis._i2z, rbasis, ybasis, - a_basis, aa_basis, AA2BB_map, zeros(0,0), - pair_basis, zeros(0,0), + a_basis, aa_basis, AA2BB_map, # zeros(0,0), + pair_basis, # zeros(0,0), # ntuple(_ -> zeros(n_AA_params), NZ), Dict{String, Any}() ) end @@ -211,6 +211,19 @@ function LuxCore.parameterlength(model::ACEModel) return NZ * n_B_params end +function splinify(model::ACEModel, ps::NamedTuple) + rbasis_spl = splinify(model.rbasis, ps.rbasis) + pairbasis_spl = splinify(model.pairbasis, ps.pairbasis) + return ACEModel(model._i2z, + rbasis_spl, + model.ybasis, + model.abasis, + model.aabasis, + model.A2Bmap, + pairbasis_spl, + model.meta) +end + # ------------------------------------------------------------ # Model Evaluation # this should possibly be moved to a separate file once it diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index af7ac7e4..a52bf808 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -137,7 +137,7 @@ function ace_model(; elements = nothing, end ps_pair = initialparameters(rng, pair_basis) - pair_basis_spl = splinify(set_params(pair_basis, ps_pair)) + pair_basis_spl = splinify(pair_basis, ps_pair) AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, level = level, max_level = max_level) diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 9a35f491..fb1fb8e0 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -37,6 +37,13 @@ Base.zero(V::ACEPotential) = zero(energy_unit(V)) initialparameters(rng::AbstractRNG, V::ACEPotential) = initialparameters(rng, V.model) initialstates(rng::AbstractRNG, V::ACEPotential) = initialstates(rng, V.model) +set_parameters!(V::ACEPotential, ps) = (V.ps = ps; V) +set_states!(V::ACEPotential, st) = (V.st = st; V) +set_psst!(V::ACEPotential, ps, st) = (V.ps = ps; V.st = st; V) + +splinify(V::ACEPotential) = splinify(V, V.ps) +splinify(V::ACEPotential, ps) = ACEPotential(splinify(V.model, ps), nothing, nothing) + # --------------------------------------------------------------- # EmpiricalPotentials / SitePotential based implementation # diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl index 40635aba..4ab211a0 100644 --- a/test/models/test_Rnl.jl +++ b/test/models/test_Rnl.jl @@ -68,8 +68,6 @@ println() ## -basis_p = M.set_params(basis, ps) - @info("Testing SplineRnlrzzBasis consistency via splinify") @@ -86,7 +84,7 @@ for ntest = 1:30 for (nnodes, tol) in [(30, 1e-3), (100, 1e-5), (1000, 1e-8)] local basis_spl, ps_spl, st_spl, Rnl_spl - basis_spl = M.splinify(basis_p; nnodes = nnodes) + basis_spl = M.splinify(basis, ps; nnodes = nnodes) ps_spl, st_spl = LuxCore.setup(rng, basis_spl) Rnl_spl, _ = basis_spl(r, Zi, Zj, ps_spl, st_spl) rel_err = (Rnl - Rnl_spl) ./ (1 .+ abs.(Rnl)) @@ -101,8 +99,7 @@ println() @info("Test derivatives of SplineRnlrzzBasis") -basis_p = M.set_params(basis, ps) -basis_spl = M.splinify(basis_p; nnodes = 100) +basis_spl = M.splinify(basis, ps; nnodes = 100) for ntest = 1:20 local Rs, Zs, Zi, Zj, r, Rnl, U, F, dF @@ -121,8 +118,7 @@ println() @info("SplineRnlrzz : Consistency of single and batched evaluation") -basis_p = M.set_params(basis, ps) -basis_spl = M.splinify(basis_p; nnodes = 100) +basis_spl = M.splinify(basis, ps; nnodes = 100) for ntest = 1:20 local Rnl, Rs, Zs, Z0, Nat, st1, ∇Rnl, rs diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index f5e556b5..e4cfec20 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -157,6 +157,29 @@ end println() +## + +@info("check splinification") +lin_ace = M.splinify(model, ps) +ps_lin, st_lin = LuxCore.setup(rng, lin_ace) +ps_lin.WB[:] .= ps.WB[:] +ps_lin.Wpair[:] .= ps.Wpair[:] + +for ntest = 1:10 + local len, Nat, Rs, Zs, z0, Ei + len = 10 + mae = sum(1:len) do _ + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st)[1] + Ei_lin = M.evaluate(lin_ace, Rs, Zs, z0, ps_lin, st_lin)[1] + abs(Ei - Ei_lin) + end + mae /= len + print_tf(@test mae < 0.01) +end +println() + ## #= @@ -179,5 +202,4 @@ print(" reverse^2 : "); @btime M.pullback_2_mixed(rand(), $Us, $model, $Rs, $Zs print(" evaluate_basis : "); @btime M.evaluate_basis($model, $Rs, $Zs, $z0, $ps, $st) print(" evaluate_basis_ed : "); @btime M.evaluate_basis_ed($model, $Rs, $Zs, $z0, $ps, $st) print("jacobian_grad_params : "); @btime M.jacobian_grad_params($model, $Rs, $Zs, $z0, $ps, $st) - -=# \ No newline at end of file +=# diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index f68781fc..b0394df7 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -86,16 +86,42 @@ for ntest = 1:10 Z = AtomsBuilder._get_atomic_numbers(at) Z[[3,6,8]] .= 8 at = AtomsBuilder._set_atomic_numbers(at, Z) - Us = randn(SVector{3, Float64}, length(at)) / length(at) + Us = randn(SVector{3, Float64}, length(at)) / length(at) * u"Å" dF0 = - dot(Us, M.energy_forces_virial_serial(at, calc, ps, st).forces) X0 = AtomsBuilder._get_positions(at) F(t) = M.energy_forces_virial_serial( - AtomsBuilder._set_positions(at, X0 + (t * u"Å") * Us), - calc, ps, st).energy - print_tf( @test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false ) ) + AtomsBuilder._set_positions(at, X0 + t * Us), + calc, ps, st).energy |> ustrip + print_tf( @test ACEbase.Testing.fdtest(F, t -> ustrip(dF0), 0.0; verbose=false ) ) +end +println() + +## + +@info("check splinification of calculator") + +lin_calc = M.splinify(calc, ps) +ps_lin, st_lin = LuxCore.setup(rng, lin_calc) +ps_lin.WB[:] .= ps.WB[:] +ps_lin.Wpair[:] .= ps.Wpair[:] + +for ntest = 1:10 + len = 10 + mae = sum(1:len) do _ + at = rattle!(bulk(:Si, cubic=true), 0.1) + Z = AtomsBuilder._get_atomic_numbers(at) + Z[[3,6,8]] .= 8 + E = M.energy_forces_virial(at, calc, ps, st).energy + E_lin = M.energy_forces_virial(at, lin_calc, ps_lin, st_lin).energy + abs(E - E_lin) / (abs(E) + abs(E_lin)) + end + mae /= len + print_tf(@test mae < 1e-3) end println() + + ## # testing the AD through a loss function @info("Testing Zygote-AD through a loss function") @@ -109,7 +135,7 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), (0.0 / u"eV", 1.0 / u"eV", 0.0 / u"eV/Å"), (0.0 / u"eV", 0.0 / u"eV", 1.0 / u"eV/Å"), (1.0 / u"eV", 0.1 / u"eV", 0.1 / u"eV/Å") ] - local at + local at, Z, dF0 # random structure at = rattle!(bulk(:Si, cubic=true), 0.1) @@ -137,4 +163,5 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), @info("(wE, wV, wF) = ($wE, $wV, $wF)") FDTEST = ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=true) println(@test FDTEST) -end \ No newline at end of file +end + diff --git a/test/models/test_pair_basis.jl b/test/models/test_pair_basis.jl index d7b44f7a..78c4a95a 100644 --- a/test/models/test_pair_basis.jl +++ b/test/models/test_pair_basis.jl @@ -29,8 +29,7 @@ Zj = basis._i2z[2] Rnl1, st1 = basis(r, Zi, Zj, ps, st) Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) -basis_p = M.set_params(basis, ps) -basis_spl = M.splinify(basis_p) +basis_spl = M.splinify(basis, ps) ps_spl, st_spl = LuxCore.setup(rng, basis_spl) Rnl2, _ = M.evaluate(basis_spl, r, Zi, Zj, ps_spl, st_spl) diff --git a/test/models/test_radial_envelopes.jl b/test/models/test_radial_envelopes.jl index 07705cf3..0d24a5cc 100644 --- a/test/models/test_radial_envelopes.jl +++ b/test/models/test_radial_envelopes.jl @@ -16,10 +16,10 @@ using Plots rcut = 2.0 envpair = ACEpotentials.Models.PolyEnvelope1sR(rcut, 1) rr = range(0.0001, rcut+0.5, length=200) -y2 = ACEpotentials.Models.evaluate.(Ref(envpair), rr) +y2 = ACEpotentials.Models.evaluate.(Ref(envpair), rr, 0*rr) envmb = ACEpotentials.Models.PolyEnvelope2sX(0.0, 1.0, 2, 2) -ymb = ACEpotentials.Models.evaluate.(Ref(envmb), rr) +ymb = ACEpotentials.Models.evaluate.(Ref(envmb), 0*rr, rr) plot(rr, y2, label="pair envelope", lw=2, legend=:topleft, ylims = (-1.0, 3.0)) plot!(rr, ymb, label="mb envelope", lw=2, ) diff --git a/test/models/test_radial_transforms.jl b/test/models/test_radial_transforms.jl index 0cb3952e..cb3aea82 100644 --- a/test/models/test_radial_transforms.jl +++ b/test/models/test_radial_transforms.jl @@ -3,6 +3,7 @@ # using TestEnv; TestEnv.activate(); using ACEpotentials, Test +using ACEbase.Testing: print_tf, println_slim # there are no real tests for envelopes yet. The only thing we have is # a plot of the envelopes to inspect manually. @@ -44,7 +45,7 @@ plot(plt1, plt2, layout=(2,1), size = (600, 800)) ## - +@info("Testing agnesi transforms") rcut = 6.5 r0 = 2.3 @@ -53,6 +54,6 @@ trans_2_4 = ACEpotentials.Models.agnesi_transform(r0, rcut, 2, 4) trans_1_3 = ACEpotentials.Models.agnesi_transform(r0, rcut, 1, 3) for trans in [trans_2_2, trans_2_4, trans_1_3] - @test ACEpotentials.Models.test_normalized_transform(trans_2_2) + println_slim( @test ACEpotentials.Models.test_normalized_transform(trans_2_2) ) end From c8b2d5ec750f806263af2ded786bb635f9dc0ea7 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 16 May 2024 14:56:25 -0700 Subject: [PATCH 034/112] update to newkernels tutorial --- docs/src/newkernels/newkernels.jl | 74 ++++++++++++++++--------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index ce0573c3..5759f6e5 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -5,6 +5,7 @@ using ACEpotentials, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, Unitful, Random, Zygote, Optimisers +# JuLIP (via ACEpotentials) also exports the same functions bulk = AtomsBuilder.bulk rattle! = AtomsBuilder.rattle! @@ -36,7 +37,10 @@ model = M.ace_model(; elements = elements, # the radial basis specification can be looked at explicitly via -display(model.rbasis.spec) +@info("Subset of radial basis specification") +display(model.rbasis.spec[1:10:end]) +@info("Subset of pair basis specification") +display(model.pairbasis.spec[1:2:end]) # we can see that it is defined as (n, l) pairs. Each `n` specifies an invariant # channel coupled to an `l` channel. Each `Rnl` radial basis function is defined @@ -53,10 +57,10 @@ display(model.rbasis.spec) ps, st = Lux.setup(rng, model) -# From the model we generate a calculator. This step should probably be integrated. -# into `ace_model`, we can discuss it. +# From the model we generate a calculator. This step should probably be +# integrated into `ace_model`, we can discuss it. -calc = M.ACEPotential(model, ps, st) +calc = M.ACEPotential(model) # We can now treat `calc` as a nonlinear parameterized site potential model. # - generate a random Al, Ti structure @@ -117,13 +121,13 @@ loss = let data_keys = (E_key = :energy, F_key = :force, V_key = :virial), end end -loss(calc, ps, st, at)[1] +at1 = train_data[1] +loss(calc, ps, st, at1)[1] # Zygote should now be able to differentiate this loss with respect to parameters # the gradient is provided in the same format as the parameters, i.e. a NamedTuple. -at1 = train_data[1] g = Zygote.gradient(ps -> loss(calc, ps, st, at1)[1], ps)[1] @show typeof(g) @@ -132,13 +136,15 @@ g = Zygote.gradient(ps -> loss(calc, ps, st, at1)[1], ps)[1] # allows us use of arbitrary optimizers ps_vec, _restruct = destructure(ps) -g_vec = destructure(g)[1] +g_vec, _ = destructure(g) # the restructure is the same as for the params # Let's now try to optimize the model. Here I'm a bit hazy how to do this -# properly. I'm just modifying a Lux tutorial. There are probably better ways. -# https://github.com/LuxDL/Lux.jl/blob/main/examples/PolynomialFitting/main.jl +# properly. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # This is the Lux approach, which I couldn't get to work. +# I'm just modifying a Lux tutorial. There are probably better ways. +# https://github.com/LuxDL/Lux.jl/blob/main/examples/PolynomialFitting/main.jl # using ADTypes, Printf # vjp_rule = AutoZygote() @@ -159,27 +165,34 @@ g_vec = destructure(g)[1] # end # main(tstate, vjp_rule, train_data, 100) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # the alternative might be to optimize using Optim.jl +# this would allow us to use a wide range of different optimizers, +# including BFGS and friends. +# the total_loss and total_loss_grad! should be re-implemented properly +# with options to assemble the loss multi-threaded or distributed using Optim -function total_loss(p_vec) +function total_loss(p_vec) return sum( loss(calc, _restruct(p_vec), st, at)[1] - for at in train_data ) + for at in train_data ) end - + function total_loss_grad!(g, p_vec) g[:] = Zygote.gradient(ps -> total_loss(ps), p_vec)[1] return g end +@info("Timing for total loss and loss-grad") @time total_loss(ps_vec) @time total_loss(ps_vec) @time total_loss_grad!(zeros(length(ps_vec)), ps_vec) @time total_loss_grad!(zeros(length(ps_vec)), ps_vec) +@info("Start the optimization") result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; method = Optim.Adam(), show_trace = true, @@ -195,27 +208,18 @@ result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; # optimize the ACE basis coefficients via linear regression. # as a first step, we replace the learnable radials with -# splined radials - -ps1_vec = result.minimizer -ps1 = _restruct(ps1_vec) - -rbasis_p = M.set_params(calc.model.rbasis, ps1.rbasis) -rbasis_spl = M.splinify(rbasis_p) - -# next we create a new ACE model with the splined radial basis -# this step should be moved into ACEpotentials.Models and -# automated. - -linmodel = M.ACEModel(calc.model._i2z, - rbasis_spl, - calc.model.ybasis, - calc.model.abasis, - calc.model.aabasis, - calc.model.A2Bmap, - calc.model.bparams, - calc.model.pairbasis, - calc.model.pairparams, - calc.model.meta) -lincalc = M.ACEPotential(linmodel) +# splined radials. This is not technically needed, but I want to make it +# the default that once we have fixed the radials, we splinify them +# so that we fit to exactly what we export. + +ps1 = _restruct(result.minimizer) +lin_calc = M.splinify(calc, ps1) + +# The next point is that I propose a change to the interface for evaluating +# the basis (as opposed to the model), i.e. replacing +# energy(basis) with energy_basis(model) and similar. +# With this in mind we can now assemble the linear regression problem. + + + From 517ef8d74038b30ee93e05a699621ebc0ffa1814 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 17 May 2024 10:16:48 -0700 Subject: [PATCH 035/112] extend basis eval to calculator --- docs/src/newkernels/newkernels.jl | 35 ++++++++++++++++++++++++-- src/models/calculators.jl | 41 +++++++++++++++++++++++++++++++ test/models/test_calculator.jl | 18 ++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index 5759f6e5..3911656e 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -103,8 +103,10 @@ train_data = FlexibleSystem.(data[1:5:end]) # a single training structure. This follows the Lux training API # loss(model, ps, st, data) -loss = let data_keys = (E_key = :energy, F_key = :force, V_key = :virial), - weights = (wE = 1.0/u"eV", wF = 0.1 / u"eV/Å", wV = 0.1/u"eV") +data_keys = (E_key = :energy, F_key = :force, V_key = :virial) +weights = (wE = 1.0/u"eV", wF = 0.1 / u"eV/Å", wV = 0.1/u"eV") + +loss = let data_keys = data_keys, weights = weights function(calc, ps, st, at) efv = M.energy_forces_virial(at, calc, ps, st) @@ -220,6 +222,35 @@ lin_calc = M.splinify(calc, ps1) # energy(basis) with energy_basis(model) and similar. # With this in mind we can now assemble the linear regression problem. +function local_lsqsys(calc, at, ps, st, weights, keys) + efv = M.energy_forces_virial_basis(at, calc, ps, st) + + # energy + wE = weights[:wE] + E_dft = at.data[data_keys.E_key] * u"eV" + y_E = wE * E_dft + A_E = wE * efv.energy' + + # forces + wF = weights[:wF] + F_dft = at.data[data_keys.F_key] * u"eV/Å" + y_F = wF * reinterpret(eltype(F_dft[1]), F_dft) + A_F = wF * reinterpret(eltype(efv.forces[1]), efv.forces) + + # virial + wV = weights[:wV] + V_dft = at.data[data_keys.V_key] * u"eV" + y_V = wV * V_dft[:] + A_V = wV * reshape(reinterpret(eltype(efv.virial), efv.virial), 9, :) + + return vcat(A_E, A_F, A_V), vcat(y_E, y_F, y_V) +end + +function assemble_lsq(calc, data) + +end + + diff --git a/src/models/calculators.jl b/src/models/calculators.jl index fb1fb8e0..0eddbf76 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -245,3 +245,44 @@ function rrule(::typeof(energy_forces_virial), nlist = nlist, kwargs...), NoTangent() ) end + + +# -------------------------------------------------------- +# Basis evaluation + + +function energy_forces_virial_basis( + at, calc::ACEPotential{<: ACEModel}, ps, st; + domain = 1:length(at), + executor = ThreadedEx(), + ntasks = Threads.nthreads(), + nlist = PairList(at, cutoff_radius(calc)), + kwargs... + ) + + Js, Rs, Zs, z0 = get_neighbours(at, calc, nlist, 1) + E1, _ = evaluate_basis(calc.model, Rs, Zs, z0, ps, st) + N_basis = length(E1) + T = fl_type(calc.model) # this is ACE specific + + E = fill(zero(T) * energy_unit(calc), N_basis) + F = fill(zero(SVector{3, T}) * force_unit(calc), length(at), N_basis) + V = fill(zero(SMatrix{3, 3, T}) * energy_unit(calc), N_basis) + + for i in domain + Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) + v, dv, _ = evaluate_basis_ed(calc.model, Rs, Zs, z0, ps, st) + + for k = 1:N_basis + E[k] += v[k] * energy_unit(calc) + for α = 1:length(Js) + F[Js[α], k] -= dv[k, α] * force_unit(calc) + F[i, k] += dv[k, α] * force_unit(calc) + end + V[k] += _site_virial(dv[k, :], Rs) * energy_unit(calc) + end + release!(Js); release!(Rs); release!(Zs) + end + + return (energy = E, forces = F, virial = V) +end diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index b0394df7..875d45a1 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -120,7 +120,25 @@ for ntest = 1:10 end println() +## + +@info("Test splinified calculator basis usage") + +for ntest = 1:10 + ps_lin, st_lin = LuxCore.setup(rng, lin_calc) + at = rattle!(bulk(:Si, cubic=true), 0.1) + Z = AtomsBuilder._get_atomic_numbers(at) + Z[[3,6,8]] .= 8 + + efv = M.energy_forces_virial(at, lin_calc, ps_lin, st_lin) + efv_b = M.energy_forces_virial_basis(at, lin_calc, ps_lin, st_lin) + ps_vec, _restruct = destructure(ps_lin) + print_tf(@test dot(efv_b.energy, ps_vec) ≈ efv.energy ) + print_tf(@test all(efv_b.forces * ps_vec .≈ efv.forces) ) + print_tf(@test sum(ps_vec .* efv_b.virial) ≈ efv.virial ) +end +println() ## # testing the AD through a loss function From 476e61db7471c798bed489e57548184af420436c Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 17 May 2024 16:50:37 -0700 Subject: [PATCH 036/112] tutorial draft done --- docs/src/newkernels/Project.toml | 6 ++ docs/src/newkernels/newkernels.jl | 124 +++++++++++++++++++++++++++--- 2 files changed, 120 insertions(+), 10 deletions(-) diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index 111198c7..a9be4da4 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -2,10 +2,16 @@ ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" +AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" +EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" +Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" +GeometryOptimization = "673bf261-a53d-43b9-876f-d3c1fc8329c2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Molly = "aa0f7f06-fcc0-5ec4-a7f3-a573f33f9c4c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index 3911656e..ca38c7df 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -22,6 +22,7 @@ rng = Random.MersenneTwister(1234) # I'll create a new model for a simple alloy and then generate a model. # this generates a trace-like ACE model with a random radial basis. +# we probably have to put significant work into initializing better. elements = (:Al, :Ti) @@ -69,7 +70,7 @@ calc = M.ACEPotential(model) function rand_AlTi(nrep, rattle) # Al : 13; Ti : 22 - at = rattle!(bulk(:Al, cubic=true) * 2, 0.1) + at = rattle!(bulk(:Al, cubic=true) * nrep, 0.1) Z = AtomsBuilder._get_atomic_numbers(at) Z[rand(1:length(at), length(at) ÷ 2)] .= 22 return AtomsBuilder._set_atomic_numbers(at, Z) @@ -95,7 +96,9 @@ display(efv.forces[1:5]) # we load our example dataset and convert it to AtomsBase data, _, meta = ACEpotentials.example_dataset("TiAl_tutorial") -train_data = FlexibleSystem.(data[1:5:end]) +data = FlexibleSystem.(data) +train_data = data[1:5:end] +test_data = data[2:5:end] # to set up training we specify data keys and training weights. # to get a unitless loss we need to specify the weights to have inverse @@ -198,14 +201,11 @@ end result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; method = Optim.Adam(), show_trace = true, - iterations = 100) + iterations = 30) # obviously this needs more iterations -# Now that we've optimized the entire model a little bit -# we can think that the radial basis functions are sufficiently -# optimized. This is of course not true in this case since we didn't -# use enough iterations. But suppose we had converged the nonlinear -# optimization to get a really good radial basis. +# We didn't use enough iterations to do anything useful here. But suppose we +# had converged the nonlinear optimization to get a good radial basis. # Then, in a second step we can freeze the radial basis and # optimize the ACE basis coefficients via linear regression. @@ -216,6 +216,7 @@ result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; ps1 = _restruct(result.minimizer) lin_calc = M.splinify(calc, ps1) +lin_ps, lin_st = Lux.setup(rng, lin_calc) # The next point is that I propose a change to the interface for evaluating # the basis (as opposed to the model), i.e. replacing @@ -241,16 +242,119 @@ function local_lsqsys(calc, at, ps, st, weights, keys) wV = weights[:wV] V_dft = at.data[data_keys.V_key] * u"eV" y_V = wV * V_dft[:] - A_V = wV * reshape(reinterpret(eltype(efv.virial), efv.virial), 9, :) + # display( reinterpret(eltype(efv.virial), efv.virial) ) + A_V = wV * reshape(reinterpret(eltype(efv.virial[1]), efv.virial), 9, :) return vcat(A_E, A_F, A_V), vcat(y_E, y_F, y_V) end +# this line just checks that the local assembly makes sense +A1, y1 = local_lsqsys(lin_calc, at1, lin_ps, lin_st, weights, data_keys) -function assemble_lsq(calc, data) +@assert size(A1, 1) == length(y1) == 1 + 3 * length(at1) + 9 +@assert size(A1, 2) == length(destructure(lin_ps)[1]) +# we convert this to a global assembly routine. I thought this version would +# be multi-threaded but something seems to be wrong with it. + +using Folds + +function assemble_lsq(calc, data, weights, data_keys; + rng = Random.GLOBAL_RNG, + executor = Folds.ThreadedEx()) + ps, st = Lux.setup(rng, calc) + blocks = Folds.map(at -> local_lsqsys(lin_calc, at, ps, st, + weights, data_keys), + train_data, executor) + A = reduce(vcat, [b[1] for b in blocks]) + y = reduce(vcat, [b[2] for b in blocks]) + return A, y end +A, y = assemble_lsq(lin_calc, train_data, weights, data_keys) + +# estimate the parameters + +solver = ACEpotentials.ACEfit.BLR() +result = ACEpotentials.ACEfit.solve(solver, A, y) + +# a little hack to turn it into real parameters and a fully parameterized model +# this needs another convenience function provided within ACEpotentials. + +ps, st = Lux.setup(rng, lin_calc) +_, _restruct = destructure(ps) +fit_ps = _restruct(result["C"]) +fit_calc = M.ACEPotential(lin_calc.model, fit_ps, st) + +# can this do anything useful? +# first of all, because we have now specified the parameters, we no longer need +# to drag them around and can use a higher-level interface to evaluate the +# model. For example ... + +using AtomsCalculators +using AtomsCalculators: energy_forces_virial, forces, potential_energy + +at1 = rand(train_data) +efv1 = M.energy_forces_virial(at1, fit_calc, fit_calc.ps, fit_calc.st) +efv2 = energy_forces_virial(at1, fit_calc) +efv1.energy ≈ efv2.energy +all(efv1.forces .≈ efv2.forces) +efv1.virial ≈ efv2.virial +potential_energy(at1, fit_calc) ≈ efv1.energy +AtomsCalculators.virial(at1, fit_calc) ≈ efv1.virial +ef = AtomsCalculators.energy_forces(at1, fit_calc) +ef.energy ≈ efv1.energy +all(ef.forces .≈ efv1.forces) +all(efv1.forces .≈ forces(at1, fit_calc)) + +# Checking accuracy (it is terrible, so lots to fix ...) + +E_err(at) = abs(ustrip(potential_energy(at, fit_calc)) - at.data[data_keys.E_key]) / length(at) +mae_train = sum(E_err, train_data) / length(train_data) +mae_test = sum(E_err, test_data) / length(test_data) + +@info("MAE(train) = $mae_train") +@info("MAE(test) = $mae_test") + + +# Trying some simple geometry optimization +# This seems to run but doesn't update the structure, there is also no +# documentation how to extract information from the result to do so. +# I am probably still doing something wrong here ... + +using GeometryOptimization, OptimizationOptimJL + +@info("Short geometry optimization") +at = rand_AlTi(2, 0.001) +@show potential_energy(at, fit_calc) +solver = OptimizationOptimJL.LBFGS() +optim_options = (f_tol=1e-4, g_tol=1e-4, iterations=30, show_trace=false) +results = minimize_energy!(at, fit_calc; solver, optim_options...) +@show potential_energy(at, fit_calc) +at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, results.u) * u"Å") +@show potential_energy(at, fit_calc) + + +# The last step is to run a simple MD simulation for just a 100 steps. +# this currently doesn't work because Molly doesn't allow arbitrary +# units. (WTF?!) And I can't be bothered to write the wrappers +# needed to convert. + +# import Molly +# at = rand_AlTi(3, 0.01) +# sys_md = Molly.System(at) + +# sys_md = Molly.System(sys_md; +# velocities = random_velocities(sys_md, 298.0u"K"), +# loggers=(temp=TemperatureLogger(100),) +# ) + +# simulator = VelocityVerlet( +# dt=1.0u"fs", +# coupling=AndersenThermostat(temp, 1.0u"ps"), +# ) + +# simulate!(sys, simulator, 100) From 0dcfe3f1011226a5650e98f8b8465cc3b0b4e764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20J=C3=A4rvinen?= Date: Fri, 17 May 2024 19:12:02 -0700 Subject: [PATCH 037/112] fix Molly example --- docs/src/newkernels/Project.toml | 1 + docs/src/newkernels/newkernels.jl | 54 ++++++++++++++++++++----------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index a9be4da4..857f2978 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -1,6 +1,7 @@ [deps] ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index ca38c7df..a80e5e01 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -2,7 +2,7 @@ # I'll try to explain what can be done and what is missing along the way. # I am -using ACEpotentials, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, Unitful, Random, Zygote, Optimisers # JuLIP (via ACEpotentials) also exports the same functions @@ -71,9 +71,12 @@ calc = M.ACEPotential(model) function rand_AlTi(nrep, rattle) # Al : 13; Ti : 22 at = rattle!(bulk(:Al, cubic=true) * nrep, 0.1) - Z = AtomsBuilder._get_atomic_numbers(at) - Z[rand(1:length(at), length(at) ÷ 2)] .= 22 - return AtomsBuilder._set_atomic_numbers(at, Z) + + # swap odd atoms to Ti + particles = map( enumerate(at) ) do (i,atom) + isodd(i) ? AtomsBase.Atom(22, position(atom)) : atom + end + return FlexibleSystem(particles, bounding_box(at), boundary_conditions(at)) end @@ -333,8 +336,18 @@ solver = OptimizationOptimJL.LBFGS() optim_options = (f_tol=1e-4, g_tol=1e-4, iterations=30, show_trace=false) results = minimize_energy!(at, fit_calc; solver, optim_options...) @show potential_energy(at, fit_calc) -at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, results.u) * u"Å") -@show potential_energy(at, fit_calc) +#at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, results.u) * u"Å") +at_new = FastSystem( + bounding_box(at), + boundary_conditions(at), + reinterpret(SVector{3, Float64}, results.u) * u"Å", + atomic_symbol(at), + atomic_number(at), + atomic_mass(at) +) +@show potential_energy(at_new, fit_calc) + + # The last step is to run a simple MD simulation for just a 100 steps. @@ -342,19 +355,24 @@ at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, result # units. (WTF?!) And I can't be bothered to write the wrappers # needed to convert. -# import Molly -# at = rand_AlTi(3, 0.01) -# sys_md = Molly.System(at) +import Molly +at = rand_AlTi(3, 0.01) +# Tell Molly what units are used +sys_md = Molly.System(at; force_units=u"eV/Å", energy_units=u"eV") + +temp = 298.0u"K" -# sys_md = Molly.System(sys_md; -# velocities = random_velocities(sys_md, 298.0u"K"), -# loggers=(temp=TemperatureLogger(100),) -# ) +sys_md = Molly.System( + sys_md; + general_inters = (fit_calc,), + velocities = Molly.random_velocities(sys_md, temp), + loggers=(temp=Molly.TemperatureLogger(100),) +) -# simulator = VelocityVerlet( -# dt=1.0u"fs", -# coupling=AndersenThermostat(temp, 1.0u"ps"), -# ) +simulator = Molly.VelocityVerlet( + dt = 1.0u"fs", + coupling = Molly.AndersenThermostat(temp, 1.0u"ps"), +) -# simulate!(sys, simulator, 100) +Molly.simulate!(sys_md, simulator, 100) From 2dd015fdd30afe039527b5de919c7c50f64f406d Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 17 May 2024 19:35:15 -0700 Subject: [PATCH 038/112] a little bit of cleanup of the newkernels tutorial --- docs/src/newkernels/newkernels.jl | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index a80e5e01..7b8193d3 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -71,10 +71,9 @@ calc = M.ACEPotential(model) function rand_AlTi(nrep, rattle) # Al : 13; Ti : 22 at = rattle!(bulk(:Al, cubic=true) * nrep, 0.1) - - # swap odd atoms to Ti - particles = map( enumerate(at) ) do (i,atom) - isodd(i) ? AtomsBase.Atom(22, position(atom)) : atom + # swap random atoms to Ti + particles = map( enumerate(at) ) do (i, atom) + (rand() < 0.5) ? AtomsBase.Atom(22, position(atom)) : atom end return FlexibleSystem(particles, bounding_box(at), boundary_conditions(at)) end @@ -261,6 +260,8 @@ A1, y1 = local_lsqsys(lin_calc, at1, lin_ps, lin_st, weights, data_keys) # we convert this to a global assembly routine. I thought this version would # be multi-threaded but something seems to be wrong with it. +# this assembly is very slow because the current implementation of the +# basis is very inefficient using Folds @@ -283,8 +284,8 @@ A, y = assemble_lsq(lin_calc, train_data, weights, data_keys) solver = ACEpotentials.ACEfit.BLR() result = ACEpotentials.ACEfit.solve(solver, A, y) -# a little hack to turn it into real parameters and a fully parameterized model -# this needs another convenience function provided within ACEpotentials. +# a little hack to turn it into NamedTuple parameters and a fully parameterized +# model this needs another convenience function provided within ACEpotentials. ps, st = Lux.setup(rng, lin_calc) _, _restruct = destructure(ps) @@ -297,7 +298,7 @@ fit_calc = M.ACEPotential(lin_calc.model, fit_ps, st) # model. For example ... using AtomsCalculators -using AtomsCalculators: energy_forces_virial, forces, potential_energy +using AtomsCalculators: energy_forces_virial, forces, potential_energy, virial at1 = rand(train_data) efv1 = M.energy_forces_virial(at1, fit_calc, fit_calc.ps, fit_calc.st) @@ -306,7 +307,7 @@ efv1.energy ≈ efv2.energy all(efv1.forces .≈ efv2.forces) efv1.virial ≈ efv2.virial potential_energy(at1, fit_calc) ≈ efv1.energy -AtomsCalculators.virial(at1, fit_calc) ≈ efv1.virial +virial(at1, fit_calc) ≈ efv1.virial ef = AtomsCalculators.energy_forces(at1, fit_calc) ef.energy ≈ efv1.energy all(ef.forces .≈ efv1.forces) @@ -325,7 +326,7 @@ mae_test = sum(E_err, test_data) / length(test_data) # Trying some simple geometry optimization # This seems to run but doesn't update the structure, there is also no # documentation how to extract information from the result to do so. -# I am probably still doing something wrong here ... +# The following seems to work but this wants a PR in GeometryOptimization.jl using GeometryOptimization, OptimizationOptimJL @@ -335,8 +336,6 @@ at = rand_AlTi(2, 0.001) solver = OptimizationOptimJL.LBFGS() optim_options = (f_tol=1e-4, g_tol=1e-4, iterations=30, show_trace=false) results = minimize_energy!(at, fit_calc; solver, optim_options...) -@show potential_energy(at, fit_calc) -#at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, results.u) * u"Å") at_new = FastSystem( bounding_box(at), boundary_conditions(at), @@ -348,18 +347,13 @@ at_new = FastSystem( @show potential_energy(at_new, fit_calc) - - # The last step is to run a simple MD simulation for just a 100 steps. -# this currently doesn't work because Molly doesn't allow arbitrary -# units. (WTF?!) And I can't be bothered to write the wrappers -# needed to convert. +# Important: Tell Molly what units are used!! +# This seems to work ok now (Thank you, Teemu). import Molly at = rand_AlTi(3, 0.01) -# Tell Molly what units are used sys_md = Molly.System(at; force_units=u"eV/Å", energy_units=u"eV") - temp = 298.0u"K" sys_md = Molly.System( From 9cb610d643dc3336040168d09d5cf9f4a3c0329a Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 19 May 2024 14:32:34 -0700 Subject: [PATCH 039/112] E0s --- docs/src/newkernels/newkernels.jl | 20 ++++++++-------- src/models/ace.jl | 38 +++++++++++++++++++------------ test/models/test_ace.jl | 4 ++-- test/models/test_calculator.jl | 6 +++-- test/models/test_models.jl | 14 +++++------- test/runtests.jl | 2 +- 6 files changed, 46 insertions(+), 38 deletions(-) diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index 7b8193d3..10cabf13 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -228,10 +228,15 @@ lin_ps, lin_st = Lux.setup(rng, lin_calc) function local_lsqsys(calc, at, ps, st, weights, keys) efv = M.energy_forces_virial_basis(at, calc, ps, st) + # compute the E0s contribution. This needs to be done more + # elegantly and a stacked model would solve this problem. + E0 = sum( calc.model.E0s[M._z2i(calc.model, z)] + for z in AtomsBase.atomic_number(at) ) * u"eV" + # energy wE = weights[:wE] E_dft = at.data[data_keys.E_key] * u"eV" - y_E = wE * E_dft + y_E = wE * (E_dft - E0) A_E = wE * efv.energy' # forces @@ -296,6 +301,7 @@ fit_calc = M.ACEPotential(lin_calc.model, fit_ps, st) # first of all, because we have now specified the parameters, we no longer need # to drag them around and can use a higher-level interface to evaluate the # model. For example ... +# (this should really go into unit tests I think) using AtomsCalculators using AtomsCalculators: energy_forces_virial, forces, potential_energy, virial @@ -336,20 +342,12 @@ at = rand_AlTi(2, 0.001) solver = OptimizationOptimJL.LBFGS() optim_options = (f_tol=1e-4, g_tol=1e-4, iterations=30, show_trace=false) results = minimize_energy!(at, fit_calc; solver, optim_options...) -at_new = FastSystem( - bounding_box(at), - boundary_conditions(at), - reinterpret(SVector{3, Float64}, results.u) * u"Å", - atomic_symbol(at), - atomic_number(at), - atomic_mass(at) -) +at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, results.u) * u"Å") @show potential_energy(at_new, fit_calc) # The last step is to run a simple MD simulation for just a 100 steps. # Important: Tell Molly what units are used!! -# This seems to work ok now (Thank you, Teemu). import Molly at = rand_AlTi(3, 0.01) @@ -370,3 +368,5 @@ simulator = Molly.VelocityVerlet( Molly.simulate!(sys_md, simulator, 100) +@info("This simulation obviously crashed:") +@show sys_md.loggers.temp.history diff --git a/src/models/ace.jl b/src/models/ace.jl index f5fbba3f..fdef1389 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -35,12 +35,9 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T, TPAIR} <: AbstractExplicitContainerLay # -------------- # we can add a nonlinear embedding here # -------------- - # bparams::Matrix{T} # : x NZ matrix of B parameters - # aaparams::NTuple{NZ, Vector{T}} (not used right now) - # -------------- - # pair potential + # pair potential & Vref pairbasis::TPAIR - # pairparams::Matrix{T} + E0s::NTuple{NZ, T} # -------------- meta::Dict{String, Any} end @@ -92,7 +89,8 @@ end function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, level = TotalDegree(), - pair_basis = nothing ) + pair_basis = nothing, + E0s = nothing ) # generate the coupling coefficients cgen = EquivariantModels.Rot3DCoeffs_real(0) AA2BB_map = EquivariantModels._rpi_A2B_matrix(cgen, AA_spec) @@ -131,24 +129,24 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, # from this we can now generate the AA basis layer aa_basis = Polynomials4ML.SparseSymmProdDAG(AA_spec_idx) aa_basis.meta["AA_spec"] = AA_spec # (also store the human-readable spec) - - # NZ = _get_nz(rbasis) - # n_B_params, n_AA_params = size(AA2BB_map) - + if isnothing(E0s) + NZ = _get_nz(rbasis) + E0s = ntuple(i -> 0.0, NZ) + end return ACEModel(rbasis._i2z, rbasis, ybasis, - a_basis, aa_basis, AA2BB_map, # zeros(0,0), - pair_basis, # zeros(0,0), - # ntuple(_ -> zeros(n_AA_params), NZ), + a_basis, aa_basis, AA2BB_map, + pair_basis, E0s, Dict{String, Any}() ) end # TODO: it is not entirely clear that the `level` is really needed here # since it is implicitly already encoded in AA_spec. We need a # function `auto_level` that generates level automagically from AA_spec. -function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level, pair_basis) - return _generate_ace_model(rbasis, Ytype, AA_spec, level, pair_basis) +function ace_model(rbasis, Ytype, AA_spec::AbstractVector, level, + pair_basis, E0s = nothing) + return _generate_ace_model(rbasis, Ytype, AA_spec, level, pair_basis, E0s) end # NOTE : a nicer convenience constructor is also provided in `ace_heuristics.jl` @@ -221,6 +219,7 @@ function splinify(model::ACEModel, ps::NamedTuple) model.aabasis, model.A2Bmap, pairbasis_spl, + model.E0s, model.meta) end @@ -278,6 +277,9 @@ function evaluate(model::ACEModel, val += dot(Apair, (@view ps.Wpair[:, i_z0])) end # ------------------- + # E0s + val += model.E0s[i_z0] + # ------------------- return val, st end @@ -377,6 +379,9 @@ function evaluate_ed(model::ACEModel, end end # ------------------- + # E0s + Ei += model.E0s[i_z0] + # ------------------- return Ei, ∇Ei, st end @@ -477,6 +482,9 @@ function grad_params(model::ACEModel, ∂Wpair[:, i_z0] = Apair end # ------------------- + # E0s + Ei += model.E0s[i_z0] + # ------------------- return Ei, (WB = ∂WB, Wpair = ∂Wpair, rbasis = ∂Wqnl, diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index e4cfec20..948007c7 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -76,7 +76,7 @@ Ei1, ∇Ei, st = M.grad_params(model, Rs, Zs, z0, ps, st) println_slim(@test Ei ≈ Ei1) for ntest = 1:20 - local Nat, Rs, Zs, z0, pvec, uvec, F, dF0 + local Nat, Rs, Zs, z0, pvec, uvec, F, dF0, _restruct Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) @@ -93,7 +93,7 @@ println() @info("Test second mixed derivatives reverse-over-reverse") for ntest = 1:20 local Nat, Rs, Zs, Us, Ei, ∂Ei, ∂2_Ei, - ps_vec, vs_vec, F, dF0, z0 + ps_vec, vs_vec, F, dF0, z0, _restruct Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 875d45a1..d6ee686f 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -41,7 +41,7 @@ calc = M.ACEPotential(model, ps, st) @info("Testing correctness of potential energy") for ntest = 1:20 - local Rs, Zs, z0, at + local Rs, Zs, z0, at, efv at = rattle!(bulk(:Si, cubic=true) * 2, 0.1) nlist = PairList(at, M.cutoff_radius(calc)) @@ -125,6 +125,8 @@ println() @info("Test splinified calculator basis usage") for ntest = 1:10 + local ps_lin, st_lin, at, efv, _restruct + ps_lin, st_lin = LuxCore.setup(rng, lin_calc) at = rattle!(bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) @@ -153,7 +155,7 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), (0.0 / u"eV", 1.0 / u"eV", 0.0 / u"eV/Å"), (0.0 / u"eV", 0.0 / u"eV", 1.0 / u"eV/Å"), (1.0 / u"eV", 0.1 / u"eV", 0.1 / u"eV/Å") ] - local at, Z, dF0 + local at, Z, dF0, g, _restruct, g_vec # random structure at = rattle!(bulk(:Si, cubic=true), 0.1) diff --git a/test/models/test_models.jl b/test/models/test_models.jl index 4479d46e..43ef3948 100644 --- a/test/models/test_models.jl +++ b/test/models/test_models.jl @@ -1,9 +1,7 @@ -@testset "Models" begin - @testset "Radial Envelopes" begin; include("test_radial_envelopes.jl"); end - @testset "Radial Transforms" begin; include("test_radial_transforms.jl"); end - @testset "Rnlrzz Basis" begin; include("test_Rnl.jl"); end - @testset "Pair Basis" begin; include("test_pair_basis.jl"); end - @testset "ACE Model" begin; include("test_ace.jl"); end - @testset "ACE Calculator" begin; include("test_calculator.jl"); end -end +@testset "Radial Envelopes" begin; include("test_radial_envelopes.jl"); end +@testset "Radial Transforms" begin; include("test_radial_transforms.jl"); end +@testset "Rnlrzz Basis" begin; include("test_Rnl.jl"); end +@testset "Pair Basis" begin; include("test_pair_basis.jl"); end +@testset "ACE Model" begin; include("test_ace.jl"); end +@testset "ACE Calculator" begin; include("test_calculator.jl"); end diff --git a/test/runtests.jl b/test/runtests.jl index 8bb0522a..61570aa9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ using ACEpotentials, Test, LazyArtifacts # experimental @testset "UF_ACE" begin include("test_uface.jl") end - include("models/test_models.jl") + @testset "Models" begin include("models/test_models.jl") end # outdated @testset "Read data" begin include("outdated/test_data.jl") end From 8273fea2e0ddb35ba56df079d5462b8ac166c3db Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 19 May 2024 16:12:42 -0700 Subject: [PATCH 040/112] wrapped A, AA, B into a new tensor struct --- src/models/ace.jl | 145 ++++++++++--------------------------------- src/models/models.jl | 1 + src/models/sparse.jl | 59 ++++++++++++++++++ 3 files changed, 94 insertions(+), 111 deletions(-) create mode 100644 src/models/sparse.jl diff --git a/src/models/ace.jl b/src/models/ace.jl index fdef1389..729a8f87 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -7,7 +7,6 @@ import LuxCore: AbstractExplicitLayer, using Lux: glorot_normal using Random: AbstractRNG -using SparseArrays: SparseMatrixCSC using StaticArrays: SVector using LinearAlgebra: norm, dot @@ -20,8 +19,7 @@ import Polynomials4ML # ------------------------------------------------------------ # ACE MODEL SPECIFICATION - -struct ACEModel{NZ, TRAD, TY, TA, TAA, T, TPAIR} <: AbstractExplicitContainerLayer{(:rbasis,)} +struct ACEModel{NZ, TRAD, TY, TTEN, T, TPAIR} <: AbstractExplicitContainerLayer{(:rbasis,)} _i2z::NTuple{NZ, Int} # -------------- # embeddings of the particles @@ -29,9 +27,7 @@ struct ACEModel{NZ, TRAD, TY, TA, TAA, T, TPAIR} <: AbstractExplicitContainerLay ybasis::TY # -------------- # the tensor format - abasis::TA - aabasis::TAA - A2Bmap::SparseMatrixCSC{T, Int} + tensor::TTEN # -------------- # we can add a nonlinear embedding here # -------------- @@ -48,7 +44,7 @@ end # this is terrible : I'm assuming here that there is a unique # output type, which is of course not the case. It is needed temporarily # to make things work with AtomsCalculators and EmpiricalPotentials -fl_type(::ACEModel{NZ, TRAD, TY, TA, TAA, T}) where {NZ, TRAD, TY, TA, TAA, T} = T +fl_type(::ACEModel{NZ, TRAD, TY, TTEN, T, TPAIR}) where {NZ, TRAD, TY, TTEN, T, TPAIR} = T const NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} @@ -135,9 +131,11 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, E0s = ntuple(i -> 0.0, NZ) end + tensor = SparseEquivTensor(a_basis, aa_basis, AA2BB_map, + Dict{String, Any}()) + return ACEModel(rbasis._i2z, rbasis, ybasis, - a_basis, aa_basis, AA2BB_map, - pair_basis, E0s, + tensor, pair_basis, E0s, Dict{String, Any}() ) end @@ -169,7 +167,7 @@ end function initialparameters(rng::AbstractRNG, model::ACEModel) NZ = _get_nz(model) - n_B_params, n_AA_params = size(model.A2Bmap) + n_B_params = length(model.tensor) # only the B params are parameters, the AA params are uniquely defined # via the B params. @@ -215,9 +213,7 @@ function splinify(model::ACEModel, ps::NamedTuple) return ACEModel(model._i2z, rbasis_spl, model.ybasis, - model.abasis, - model.aabasis, - model.A2Bmap, + model.tensor, pairbasis_spl, model.E0s, model.meta) @@ -236,7 +232,10 @@ _length(ybasis::SolidHarmonics) = SpheriCart.sizeY(_getlmax(ybasis)) function evaluate(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, - ps, st) where {T} + ps, st) where + {T} + i_z0 = _z2i(model.rbasis, Z0) + # get the radii rs = [ norm(r) for r in Rs ] # use Bumper @@ -249,23 +248,10 @@ function evaluate(model::ACEModel, Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here SpheriCart.compute!(Ylm, model.ybasis, Rs) - # evaluate the A basis - TA = promote_type(T, eltype(Rnl)) - A = zeros(T, length(model.abasis)) - Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) - - # evaluate the AA basis - _AA = zeros(T, length(model.aabasis)) # use Bumper here - Polynomials4ML.evaluate!(_AA, model.aabasis, A) - # project to the actual AA basis - proj = model.aabasis.projection - AA = _AA[proj] # use Bumper here, or view; needs experimentation. - - # evaluate the coupling coefficients - B = model.A2Bmap * AA + # equivariant tensor product + B, _ = evaluate(model.tensor, Rnl, Ylm) # contract with params - i_z0 = _z2i(model.rbasis, Z0) val = dot(B, (@view ps.WB[:, i_z0])) # ------------------- @@ -290,6 +276,8 @@ function evaluate_ed(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} + i_z0 = _z2i(model.rbasis, Z0) + # ---------- EMBEDDINGS ------------ # (these are done in forward mode, so not part of the fwd, bwd passes) @@ -305,48 +293,20 @@ function evaluate_ed(model::ACEModel, dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) - # ---------- FORWARD PASS ------------ - - # evaluate the A basis - TA = promote_type(T, eltype(Rnl)) - A = zeros(T, length(model.abasis)) - Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) - - # evaluate the AA basis - _AA = zeros(T, length(model.aabasis)) # TODO: use Bumper here - Polynomials4ML.evaluate!(_AA, model.aabasis, A) - # project to the actual AA basis - proj = model.aabasis.projection - AA = _AA[proj] # TODO: use Bumper here, or view; needs experimentation. - - # evaluate the coupling coefficients - # TODO: use Bumper and do it in-place - B = model.A2Bmap * AA + # Forward Pass through the tensor + # keep intermediates to be used in backward pass + B, intermediates = evaluate(model.tensor, Rnl, Ylm) # contract with params # (here we can insert another nonlinearity instead of the simple dot) - i_z0 = _z2i(model.rbasis, Z0) Ei = dot(B, (@view ps.WB[:, i_z0])) - - # ---------- BACKWARD PASS ------------ - + # Start the backward pass # ∂Ei / ∂B = WB[i_z0] ∂B = @view ps.WB[:, i_z0] - - # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap - ∂AA = model.A2Bmap' * ∂B # TODO: make this in-place - _∂AA = zeros(T, length(_AA)) - _∂AA[proj] = ∂AA - - # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) - ∂A = zeros(T, length(model.abasis)) - Polynomials4ML.pullback_arg!(∂A, _∂AA, model.aabasis, _AA) - # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) - ∂Rnl = zeros(T, size(Rnl)) - ∂Ylm = zeros(T, size(Ylm)) - Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, model.abasis, (Rnl, Ylm)) + # backward pass through tensor + ∂Rnl, ∂Ylm = pullback_evaluate(∂B, model.tensor, Rnl, Ylm, intermediates) # ---------- ASSEMBLE DERIVATIVES ------------ # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm @@ -406,23 +366,9 @@ function grad_params(model::ACEModel, dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) - # ---------- FORWARD PASS ------------ - - # evaluate the A basis - TA = promote_type(T, eltype(Rnl)) - A = zeros(T, length(model.abasis)) - Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) - - # evaluate the AA basis - _AA = zeros(T, length(model.aabasis)) # TODO: use Bumper here - Polynomials4ML.evaluate!(_AA, model.aabasis, A) - # project to the actual AA basis - proj = model.aabasis.projection - AA = _AA[proj] # TODO: use Bumper here, or view; needs experimentation. - - # evaluate the coupling coefficients - # TODO: use Bumper and do it in-place - B = model.A2Bmap * AA + # Forward Pass through the tensor + # keep intermediates to be used in backward pass + B, intermediates = evaluate(model.tensor, Rnl, Ylm) # contract with params # (here we can insert another nonlinearity instead of the simple dot) @@ -436,19 +382,8 @@ function grad_params(model::ACEModel, ∂WB_i = B ∂B = @view ps.WB[:, i_z0] - # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap - ∂AA = model.A2Bmap' * ∂B # TODO: make this in-place - _∂AA = zeros(T, length(_AA)) - _∂AA[proj] = ∂AA - - # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) - ∂A = zeros(T, length(model.abasis)) - Polynomials4ML.pullback_arg!(∂A, _∂AA, model.aabasis, _AA) - - # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) - ∂Rnl = zeros(T, size(Rnl)) - ∂Ylm = zeros(T, size(Ylm)) # we could make this a black hole since we don't need it. - Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, model.abasis, (Rnl, Ylm)) + # backward pass through tensor + ∂Rnl, ∂Ylm = pullback_evaluate(∂B, model.tensor, Rnl, Ylm, intermediates) # ---------- ASSEMBLE DERIVATIVES ------------ # the first grad_param is ∂WB, which we already have but it needs to be @@ -530,13 +465,13 @@ end function get_basis_inds(model::ACEModel, Z) - len_Bi = size(model.A2Bmap, 1) + len_Bi = length(model.tensor) i_z = _z2i(model.rbasis, Z) return (i_z - 1) * len_Bi .+ (1:len_Bi) end function get_pairbasis_inds(model::ACEModel, Z) - len_Bi = size(model.A2Bmap, 1) + len_Bi = length(model.tensor) NZ = _get_nz(model) len_B = NZ * len_Bi @@ -546,7 +481,7 @@ function get_pairbasis_inds(model::ACEModel, Z) end function len_basis(model::ACEModel) - len_Bi = size(model.A2Bmap, 1) + len_Bi = length(model.tensor) len_pair = length(model.pairbasis) NZ = _get_nz(model) return (len_Bi + len_pair) * NZ @@ -573,21 +508,9 @@ function evaluate_basis(model::ACEModel, Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here SpheriCart.compute!(Ylm, model.ybasis, Rs) - # evaluate the A basis - TA = promote_type(T, eltype(Rnl)) - A = zeros(T, length(model.abasis)) - Polynomials4ML.evaluate!(A, model.abasis, (Rnl, Ylm)) - - # evaluate the AA basis - _AA = zeros(T, length(model.aabasis)) # use Bumper here - Polynomials4ML.evaluate!(_AA, model.aabasis, A) - # project to the actual AA basis - proj = model.aabasis.projection - AA = _AA[proj] # use Bumper here, or view; needs experimentation. - - # evaluate the coupling coefficients - # TODO: use Bumper and do it in-place - Bi = model.A2Bmap * AA + # equivariant tensor product + Bi, _ = evaluate(model.tensor, Rnl, Ylm) + B = zeros(eltype(Bi), len_basis(model)) B[get_basis_inds(model, Z0)] .= Bi diff --git a/src/models/models.jl b/src/models/models.jl index 3dc7883b..81387bd8 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -11,6 +11,7 @@ include("Rnl_basis.jl") include("Rnl_learnable.jl") include("Rnl_splines.jl") +include("sparse.jl") include("ace_heuristics.jl") include("ace.jl") diff --git a/src/models/sparse.jl b/src/models/sparse.jl new file mode 100644 index 00000000..7882ef36 --- /dev/null +++ b/src/models/sparse.jl @@ -0,0 +1,59 @@ +using SparseArrays: SparseMatrixCSC +import Polynomials4ML + + +struct SparseEquivTensor{T, TA, TAA} + abasis::TA + aabasis::TAA + A2Bmap::SparseMatrixCSC{T, Int} + # ------- + meta::Dict{String, Any} +end + +Base.length(tensor::SparseEquivTensor) = size(tensor.A2Bmap, 1) + + +function evaluate(tensor::SparseEquivTensor{T}, Rnl, Ylm) where {T} + # evaluate the A basis + TA = promote_type(T, eltype(Rnl), eltype(eltype(Ylm))) + A = zeros(TA, length(tensor.abasis)) + Polynomials4ML.evaluate!(A, tensor.abasis, (Rnl, Ylm)) + + # evaluate the AA basis + _AA = zeros(TA, length(tensor.aabasis)) # use Bumper here + Polynomials4ML.evaluate!(_AA, tensor.aabasis, A) + # project to the actual AA basis + proj = tensor.aabasis.projection + AA = _AA[proj] # use Bumper here, or view; needs experimentation. + + # evaluate the coupling coefficients + B = tensor.A2Bmap * AA + + return B, (_AA = _AA, ) +end + + +function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, + intermediates) where {T} + _AA = intermediates._AA + proj = tensor.aabasis.projection + + # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap + ∂AA = tensor.A2Bmap' * ∂B # TODO: make this in-place + _∂AA = zeros(T, length(_AA)) + _∂AA[proj] = ∂AA + + # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) + TA = promote_type(T, eltype(_AA), eltype(∂B), + eltype(Rnl), eltype(eltype(Ylm))) + ∂A = zeros(TA, length(tensor.abasis)) + Polynomials4ML.pullback_arg!(∂A, _∂AA, tensor.aabasis, _AA) + + # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) + ∂Rnl = zeros(TA, size(Rnl)) + ∂Ylm = zeros(TA, size(Ylm)) + Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, tensor.abasis, (Rnl, Ylm)) + + return ∂Rnl, ∂Ylm +end + From 76546d0d35eb91c16649d3cc7359b61863608cec Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 19 May 2024 20:26:52 -0700 Subject: [PATCH 041/112] fixed tests --- test/Project.toml | 4 +++- test/models/test_calculator.jl | 14 +++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 354eea3e..d3225fe8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ACE1 = "e3f9bc04-086e-409a-ba78-e9769fe067bb" ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707" ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" +AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -18,10 +19,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ACE1 = "0.12.2" ACE1x = "0.1.8" +Interpolations = "0.14.7" JuLIP = "0.13.9, 0.14.2" StaticArrays = "1" -Interpolations = "0.14.7" \ No newline at end of file diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index d6ee686f..1982dfd8 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -13,7 +13,7 @@ using Optimisers, ForwardDiff, Unitful import AtomsCalculators using AtomsBuilder, EmpiricalPotentials -using AtomsBuilder: bulk, rattle! +AB = AtomsBuilder using EmpiricalPotentials: get_neighbours @@ -43,7 +43,7 @@ calc = M.ACEPotential(model, ps, st) for ntest = 1:20 local Rs, Zs, z0, at, efv - at = rattle!(bulk(:Si, cubic=true) * 2, 0.1) + at = AB.rattle!(AB.bulk(:Si, cubic=true) * 2, 0.1) nlist = PairList(at, M.cutoff_radius(calc)) E = 0.0 for i = 1:length(at) @@ -62,7 +62,7 @@ println() @info("Testing correctness of forces ") @info(" .... TODO TEST VIRIALS ..... ") -at = rattle!(bulk(:Si, cubic=true), 0.1) +at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) @info(" consistency local vs EmpiricalPotentials implementation") efv1 = M.energy_forces_virial(at, calc, ps, st) @@ -82,7 +82,7 @@ print_tf(@test efv1.virial ≈ efv3.virial) for ntest = 1:10 local at, Us, dF0, X0, F, Z - at = rattle!(bulk(:Si, cubic=true), 0.1) + at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) Z[[3,6,8]] .= 8 at = AtomsBuilder._set_atomic_numbers(at, Z) @@ -108,7 +108,7 @@ ps_lin.Wpair[:] .= ps.Wpair[:] for ntest = 1:10 len = 10 mae = sum(1:len) do _ - at = rattle!(bulk(:Si, cubic=true), 0.1) + at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) Z[[3,6,8]] .= 8 E = M.energy_forces_virial(at, calc, ps, st).energy @@ -128,7 +128,7 @@ for ntest = 1:10 local ps_lin, st_lin, at, efv, _restruct ps_lin, st_lin = LuxCore.setup(rng, lin_calc) - at = rattle!(bulk(:Si, cubic=true), 0.1) + at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) Z[[3,6,8]] .= 8 @@ -158,7 +158,7 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), local at, Z, dF0, g, _restruct, g_vec # random structure - at = rattle!(bulk(:Si, cubic=true), 0.1) + at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) Z = AtomsBuilder._get_atomic_numbers(at) Z[[3,6,8]] .= 8 From d12f5247cea6eb7b8e6bde0127a326ae75d6195f Mon Sep 17 00:00:00 2001 From: James Kermode Date: Mon, 20 May 2024 16:48:20 +0100 Subject: [PATCH 042/112] add missing deps for newkernels example --- docs/src/newkernels/Project.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index 857f2978..59c08d12 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -1,7 +1,8 @@ [deps] +ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" +AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" @@ -15,5 +16,7 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From a59a8a31126d7f7f768dda4daa8ec3886fb37e7c Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 22 May 2024 08:25:28 -0700 Subject: [PATCH 043/112] add CondaPkg to gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1b4c933f..13b72b45 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ scratch /docs/src/literate_tutorials .vscode -.ipynb_checkpoints \ No newline at end of file +.ipynb_checkpoints +.CondaPkg From 31a273cf546cde3c00e76b6b0f447f17ae6e2a3a Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Fri, 31 May 2024 00:35:46 -0700 Subject: [PATCH 044/112] allow passing E0 --- src/models/ace_heuristics.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index a52bf808..a4a5dfc8 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -95,6 +95,7 @@ end function ace_model(; elements = nothing, order = nothing, Ytype = :solid, + E0s = nothing, # radial basis rbasis = nothing, rbasis_type = :learnable, @@ -142,7 +143,7 @@ function ace_model(; elements = nothing, AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, level = level, max_level = max_level) - model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis_spl) + model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis_spl, E0s) model.meta["init_WB"] = String(init_WB) model.meta["init_Wpair"] = String(init_Wpair) From 4e7266e431b6de77bd6571a08b12a6fc597399f6 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sat, 1 Jun 2024 01:59:27 -0700 Subject: [PATCH 045/112] passes r0incuts correctly --- src/models/ace_heuristics.jl | 12 +++++++++++- test/models/test_calculator.jl | 32 +++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index a4a5dfc8..101bb620 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -96,6 +96,7 @@ function ace_model(; elements = nothing, order = nothing, Ytype = :solid, E0s = nothing, + rin0cuts = nothing, # radial basis rbasis = nothing, rbasis_type = :learnable, @@ -111,12 +112,21 @@ function ace_model(; elements = nothing, init_Wpair = :zeros, rng = Random.default_rng(), ) + + if rin0cuts == nothing + rin0cuts = _default_rin0cuts(elements) + else + NZ = length(elements) + @assert rin0cuts isa SMatrix && size(rin0cuts) == (NZ, NZ) + end + # construct an rbasis if needed if isnothing(rbasis) if rbasis_type == :learnable rbasis = ace_learnable_Rnlrzz(; max_level = max_level, level = level, maxl = maxl, maxn = maxn, - elements = elements) + elements = elements, + rin0cuts = rin0cuts) else error("unknown rbasis_type = $rbasis_type") end diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 1982dfd8..ca492a38 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -26,11 +26,24 @@ elements = (:Si, :O) level = M.TotalDegree() max_level = 15 order = 3 +E0s = (-158.54496821, -2042.0330099956639) +NZ = length(elements) + +function make_rin0cut(zi, zj) + r0 = ACE1x.get_r0(zi, zj) + return (rin = 0.0, r0 = r0, rcut = 6.5) +end + +rin0cuts = SMatrix{NZ, NZ}([make_rin0cut(zi, zj) for zi in elements, zj in elements]) + model = M.ace_model(; elements = elements, order = order, Ytype = :solid, level = level, max_level = max_level, maxl = 8, pair_maxn = 15, - init_WB = :glorot_normal, init_Wpair = :glorot_normal) + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + E0s = E0s, + rin0cuts = rin0cuts + ) ps, st = LuxCore.setup(rng, model) @@ -38,6 +51,21 @@ calc = M.ACEPotential(model, ps, st) ## +@info("Testing correctness of E0s") +ps_vec, _restruct = destructure(ps) +ps_zero = _restruct(zero(ps_vec)) + +for ntest = 1:20 + local Rs, Zs, z0, at, efv + at = AB.rattle!(AB.bulk(:Si, cubic=true) * 2, 0.1) + nlist = PairList(at, M.cutoff_radius(calc)) + efv = M.energy_forces_virial(at, calc, ps_zero, st) + print_tf(@test norm(ustrip(efv.energy) - E0s[1] * length(at)) < 1e-10) +end + +println() + +## @info("Testing correctness of potential energy") for ntest = 1:20 @@ -57,6 +85,7 @@ for ntest = 1:20 end println() + ## @info("Testing correctness of forces ") @@ -185,3 +214,4 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), println(@test FDTEST) end +## \ No newline at end of file From 431f30fccdec128d3e93da17eec9634d37fbb696 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 1 Jun 2024 15:16:59 -0700 Subject: [PATCH 046/112] add randz to E0s test --- test/models/test_calculator.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index ca492a38..1c3740ba 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -57,10 +57,13 @@ ps_zero = _restruct(zero(ps_vec)) for ntest = 1:20 local Rs, Zs, z0, at, efv - at = AB.rattle!(AB.bulk(:Si, cubic=true) * 2, 0.1) + at = AB.randz!( AB.rattle!(AB.bulk(:Si, cubic=true) * 2, 0.1), + (:Si => 0.6, :O => 0.5), ) + n_Si = count(x -> x == 14, AtomsBase.atomic_number(at)) + n_O = count(x -> x == 8, AtomsBase.atomic_number(at)) nlist = PairList(at, M.cutoff_radius(calc)) efv = M.energy_forces_virial(at, calc, ps_zero, st) - print_tf(@test norm(ustrip(efv.energy) - E0s[1] * length(at)) < 1e-10) + print_tf(@test abs(ustrip(efv.energy) - E0s[1] * n_Si - E0s[2] * n_O) < 1e-10) end println() From 769d1bf7163b0d6d387208acc412863c1228549e Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 1 Jun 2024 15:26:01 -0700 Subject: [PATCH 047/112] fixed acemodel display --- src/models/ace.jl | 4 ++-- src/models/sparse.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index 729a8f87..673e7a3e 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -202,9 +202,9 @@ end (l::ACEModel)(args...) = evaluate(l, args...) function LuxCore.parameterlength(model::ACEModel) + # this layer stores the pair basis parameters and the B basis parameters NZ = _get_nz(model) - n_B_params, n_AA_params = size(model.A2Bmap) - return NZ * n_B_params + return NZ^2 * length(model.pairbasis) + NZ * length(model.tensor) end function splinify(model::ACEModel, ps::NamedTuple) diff --git a/src/models/sparse.jl b/src/models/sparse.jl index 7882ef36..d6bddacc 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -2,7 +2,7 @@ using SparseArrays: SparseMatrixCSC import Polynomials4ML -struct SparseEquivTensor{T, TA, TAA} +struct SparseEquivTensor{T, TA, TAA} abasis::TA aabasis::TAA A2Bmap::SparseMatrixCSC{T, Int} From a7a1a5afb8ead3839d58bba81ecc4d5120420a9d Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sat, 1 Jun 2024 22:15:21 -0700 Subject: [PATCH 048/112] restrict E0 to be a dictionary with Symbol and Uniful.Quantity --- src/models/ace.jl | 17 ++++++++++++++++- test/models/test_calculator.jl | 9 +++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index 673e7a3e..3913aa33 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -87,6 +87,10 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, level = TotalDegree(), pair_basis = nothing, E0s = nothing ) + + # storing E0s with unit + model_meta = Dict{String, Any}("E0s" => deepcopy(E0s)) + # generate the coupling coefficients cgen = EquivariantModels.Rot3DCoeffs_real(0) AA2BB_map = EquivariantModels._rpi_A2B_matrix(cgen, AA_spec) @@ -126,9 +130,20 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, aa_basis = Polynomials4ML.SparseSymmProdDAG(AA_spec_idx) aa_basis.meta["AA_spec"] = AA_spec # (also store the human-readable spec) + # process E0s and ustrip any units if isnothing(E0s) NZ = _get_nz(rbasis) E0s = ntuple(i -> 0.0, NZ) + elseif E0s isa Dict{Symbol, <: Quantity} + NZ = _get_nz(rbasis) + _E0s = zeros(NZ) + for sym in keys(E0s) + idx = findfirst(==(AtomicNumber(sym).z), rbasis._i2z) + _E0s[idx] = ustrip(E0s[sym]) + end + E0s = Tuple(_E0s) + else + error("E0s can either be nothing, or in form of a dictionary with keys 'Symbol' and values 'Uniful.Quantity'.") end tensor = SparseEquivTensor(a_basis, aa_basis, AA2BB_map, @@ -136,7 +151,7 @@ function _generate_ace_model(rbasis, Ytype::Symbol, AA_spec::AbstractVector, return ACEModel(rbasis._i2z, rbasis, ybasis, tensor, pair_basis, E0s, - Dict{String, Any}() ) + model_meta ) end # TODO: it is not entirely clear that the `level` is really needed here diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 1c3740ba..9787bc1c 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -25,8 +25,9 @@ rng = Random.MersenneTwister(1234) elements = (:Si, :O) level = M.TotalDegree() max_level = 15 -order = 3 -E0s = (-158.54496821, -2042.0330099956639) +order = 3 +E0s = Dict( :Si => -158.54496821u"eV", + :O => -2042.0330099956639u"eV") NZ = length(elements) function make_rin0cut(zi, zj) @@ -47,7 +48,7 @@ model = M.ace_model(; elements = elements, order = order, Ytype = :solid, ps, st = LuxCore.setup(rng, model) -calc = M.ACEPotential(model, ps, st) +calc = M.ACEPotential(model, ps, st) ## @@ -63,7 +64,7 @@ for ntest = 1:20 n_O = count(x -> x == 8, AtomsBase.atomic_number(at)) nlist = PairList(at, M.cutoff_radius(calc)) efv = M.energy_forces_virial(at, calc, ps_zero, st) - print_tf(@test abs(ustrip(efv.energy) - E0s[1] * n_Si - E0s[2] * n_O) < 1e-10) + print_tf(@test ustrip(abs(efv.energy - E0s[:Si] * n_Si - E0s[:O] * n_O)) < 1e-10) end println() From 1ea12403f0f4278023b4c82a3920b4c87fde1c79 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 2 Jun 2024 00:38:50 -0700 Subject: [PATCH 049/112] fix rSH for ACEModel --- src/models/ace.jl | 4 +- test/models/test_ace.jl | 264 ++++++++++++++++++++-------------------- 2 files changed, 135 insertions(+), 133 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index 3913aa33..76b14716 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -242,8 +242,8 @@ end import Zygote # these _getlmax and _length should be moved into SpheriCart -_getlmax(ybasis::SolidHarmonics{L}) where {L} = L -_length(ybasis::SolidHarmonics) = SpheriCart.sizeY(_getlmax(ybasis)) +_getlmax(ybasis::Union{SolidHarmonics{L}, SphericalHarmonics{L}}) where {L} = L +_length(ybasis::Union{SolidHarmonics, SphericalHarmonics}) = SpheriCart.sizeY(_getlmax(ybasis)) function evaluate(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 948007c7..a9cc9abb 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -20,166 +20,168 @@ level = M.TotalDegree() max_level = 15 order = 3 -model = M.ace_model(; elements = elements, order = order, Ytype = :solid, - level = level, max_level = max_level, maxl = 8, - pair_maxn = 15, - init_WB = :glorot_normal, init_Wpair = :glorot_normal) +for ybasis in [:spherical, :solid] + model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, + level = level, max_level = max_level, maxl = 8, + pair_maxn = 15, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) -ps, st = LuxCore.setup(rng, model) + ps, st = LuxCore.setup(rng, model) -## - -@info("Test Rotation-Invariance of the Model") - -for ntest = 1:50 - local st1, Nat, Rs, Zs, Z0, val + ## - Nat = rand(8:16) - Rs, Zs, Z0 = M.rand_atenv(model, Nat) - val, st1 = M.evaluate(model, Rs, Zs, Z0, ps, st) + @info("Test Rotation-Invariance of the Model") - p = shuffle(1:Nat) - Rs1 = Ref(M.rand_iso()) .* Rs[p] - Zs1 = Zs[p] - val1, st1 = M.evaluate(model, Rs1, Zs1, Z0, ps, st) + for ntest = 1:50 + local st1, Nat, Rs, Zs, Z0, val - print_tf(@test abs(val - val1) < 1e-10) -end -println() + Nat = rand(8:16) + Rs, Zs, Z0 = M.rand_atenv(model, Nat) + val, st1 = M.evaluate(model, Rs, Zs, Z0, ps, st) -## + p = shuffle(1:Nat) + Rs1 = Ref(M.rand_iso()) .* Rs[p] + Zs1 = Zs[p] + val1, st1 = M.evaluate(model, Rs1, Zs1, Z0, ps, st) -@info("Test derivatives w.r.t. positions") -Rs, Zs, z0 = M.rand_atenv(model, 16) -Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) -Ei1, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) -println_slim(@test Ei ≈ Ei1) + print_tf(@test abs(val - val1) < 1e-10) + end + println() -for ntest = 1:20 - local Nat, Rs, Zs, z0, Us, F, dF - Nat = rand(8:16) - Rs, Zs, z0 = M.rand_atenv(model, Nat) - Us = randn(SVector{3, Float64}, Nat) - F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] - dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) - print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) -end -println() + ## -## + @info("Test derivatives w.r.t. positions") + Rs, Zs, z0 = M.rand_atenv(model, 16) + Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei1, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + println_slim(@test Ei ≈ Ei1) -@info("Test derivatives w.r.t. parameters") -Nat = 15 -Rs, Zs, z0 = M.rand_atenv(model, Nat) -Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) -Ei1, ∇Ei, st = M.grad_params(model, Rs, Zs, z0, ps, st) -println_slim(@test Ei ≈ Ei1) + for ntest = 1:20 + local Nat, Rs, Zs, z0, Us, F, dF + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Us = randn(SVector{3, Float64}, Nat) + F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] + dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) + end + println() -for ntest = 1:20 - local Nat, Rs, Zs, z0, pvec, uvec, F, dF0, _restruct + ## - Nat = rand(8:16) + @info("Test derivatives w.r.t. parameters") + Nat = 15 Rs, Zs, z0 = M.rand_atenv(model, Nat) - pvec, _restruct = destructure(ps) - uvec = randn(length(pvec)) / sqrt(length(pvec)) - F(t) = M.evaluate(model, Rs, Zs, z0, _restruct(pvec + t * uvec), st)[1] - dF0 = dot( destructure( M.grad_params(model, Rs, Zs, z0, ps, st)[2] )[1], uvec ) - print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose = false)) -end -println() - -## + Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei1, ∇Ei, st = M.grad_params(model, Rs, Zs, z0, ps, st) + println_slim(@test Ei ≈ Ei1) -@info("Test second mixed derivatives reverse-over-reverse") -for ntest = 1:20 - local Nat, Rs, Zs, Us, Ei, ∂Ei, ∂2_Ei, - ps_vec, vs_vec, F, dF0, z0, _restruct + for ntest = 1:20 + local Nat, Rs, Zs, z0, pvec, uvec, F, dF0, _restruct - Nat = rand(8:16) - Rs, Zs, z0 = M.rand_atenv(model, Nat) - Us = randn(SVector{3, Float64}, Nat) - Ei = M.evaluate(model, Rs, Zs, z0, ps, st) - Ei, ∂Ei, _ = M.grad_params(model, Rs, Zs, z0, ps, st) - - # test partial derivative w.r.t. the Ei component - ∂2_Ei = M.pullback_2_mixed(1.0, 0*Us, model, Rs, Zs, z0, ps, st) - print_tf(@test destructure(∂2_Ei)[1] ≈ destructure(∂Ei)[1]) - - # test partial derivative w.r.t. the ∇Ei component - ∂2_∇Ei = M.pullback_2_mixed(0.0, Us, model, Rs, Zs, z0, ps, st) - ∂2_∇Ei_vec = destructure(∂2_∇Ei)[1] - - ps_vec, _restruct = destructure(ps) - vs_vec = randn(length(ps_vec)) / sqrt(length(ps_vec)) - F(t) = dot(Us, M.evaluate_ed(model, Rs, Zs, z0, _restruct(ps_vec + t * vs_vec), st)[2]) - dF0 = dot(∂2_∇Ei_vec, vs_vec) - print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false)) -end -println() - -## + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + pvec, _restruct = destructure(ps) + uvec = randn(length(pvec)) / sqrt(length(pvec)) + F(t) = M.evaluate(model, Rs, Zs, z0, _restruct(pvec + t * uvec), st)[1] + dF0 = dot( destructure( M.grad_params(model, Rs, Zs, z0, ps, st)[2] )[1], uvec ) + print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose = false)) + end + println() -@info("Test basis implementation") + ## -for ntest = 1:30 - local Nat, Rs, Zs, z0, Ei, B, θ, st1 , ∇Ei + @info("Test second mixed derivatives reverse-over-reverse") + for ntest = 1:20 + local Nat, Rs, Zs, Us, Ei, ∂Ei, ∂2_Ei, + ps_vec, vs_vec, F, dF0, z0, _restruct - Nat = 15 - Rs, Zs, z0 = M.rand_atenv(model, Nat) - i_z0 = M._z2i(model, z0) - Ei, st1 = M.evaluate(model, Rs, Zs, z0, ps, st) - B, st1 = M.evaluate_basis(model, Rs, Zs, z0, ps, st) - θ = M.get_basis_params(model, ps) - print_tf(@test Ei ≈ dot(B, θ)) - - Ei, ∇Ei, st1 = M.evaluate_ed(model, Rs, Zs, z0, ps, st) - B, ∇B, st1 = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) - print_tf(@test Ei ≈ dot(B, θ)) - print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) -end -println() + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Us = randn(SVector{3, Float64}, Nat) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei, ∂Ei, _ = M.grad_params(model, Rs, Zs, z0, ps, st) + + # test partial derivative w.r.t. the Ei component + ∂2_Ei = M.pullback_2_mixed(1.0, 0*Us, model, Rs, Zs, z0, ps, st) + print_tf(@test destructure(∂2_Ei)[1] ≈ destructure(∂Ei)[1]) + + # test partial derivative w.r.t. the ∇Ei component + ∂2_∇Ei = M.pullback_2_mixed(0.0, Us, model, Rs, Zs, z0, ps, st) + ∂2_∇Ei_vec = destructure(∂2_∇Ei)[1] + + ps_vec, _restruct = destructure(ps) + vs_vec = randn(length(ps_vec)) / sqrt(length(ps_vec)) + F(t) = dot(Us, M.evaluate_ed(model, Rs, Zs, z0, _restruct(ps_vec + t * vs_vec), st)[2]) + dF0 = dot(∂2_∇Ei_vec, vs_vec) + print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false)) + end + println() -## + ## -@info("Test the full mixed jacobian") + @info("Test basis implementation") -for ntest = 1:30 - local Nat, Rs, Zs, z0, Ei, ∇Ei, ∂∂Ei, Us, F, dF0 + for ntest = 1:30 + local Nat, Rs, Zs, z0, Ei, B, θ, st1 , ∇Ei - Nat = 15 - Rs, Zs, z0 = M.rand_atenv(model, Nat) - Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat) - F(t) = destructure( M.grad_params(model, Rs + t * Us, Zs, z0, ps, st)[2] )[1] - dF0 = ForwardDiff.derivative(F, 0.0) - ∂∂Ei = M.jacobian_grad_params(model, Rs, Zs, z0, ps, st)[3] - print_tf(@test dF0 ≈ transpose.(∂∂Ei) * Us) -end -println() + Nat = 15 + Rs, Zs, z0 = M.rand_atenv(model, Nat) + i_z0 = M._z2i(model, z0) + Ei, st1 = M.evaluate(model, Rs, Zs, z0, ps, st) + B, st1 = M.evaluate_basis(model, Rs, Zs, z0, ps, st) + θ = M.get_basis_params(model, ps) + print_tf(@test Ei ≈ dot(B, θ)) + + Ei, ∇Ei, st1 = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + B, ∇B, st1 = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) + print_tf(@test Ei ≈ dot(B, θ)) + print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) + end + println() + ## -## + @info("Test the full mixed jacobian") -@info("check splinification") -lin_ace = M.splinify(model, ps) -ps_lin, st_lin = LuxCore.setup(rng, lin_ace) -ps_lin.WB[:] .= ps.WB[:] -ps_lin.Wpair[:] .= ps.Wpair[:] + for ntest = 1:30 + local Nat, Rs, Zs, z0, Ei, ∇Ei, ∂∂Ei, Us, F, dF0 -for ntest = 1:10 - local len, Nat, Rs, Zs, z0, Ei - len = 10 - mae = sum(1:len) do _ - Nat = rand(8:16) + Nat = 15 Rs, Zs, z0 = M.rand_atenv(model, Nat) - Ei = M.evaluate(model, Rs, Zs, z0, ps, st)[1] - Ei_lin = M.evaluate(lin_ace, Rs, Zs, z0, ps_lin, st_lin)[1] - abs(Ei - Ei_lin) + Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat) + F(t) = destructure( M.grad_params(model, Rs + t * Us, Zs, z0, ps, st)[2] )[1] + dF0 = ForwardDiff.derivative(F, 0.0) + ∂∂Ei = M.jacobian_grad_params(model, Rs, Zs, z0, ps, st)[3] + print_tf(@test dF0 ≈ transpose.(∂∂Ei) * Us) + end + println() + + + ## + + @info("check splinification") + lin_ace = M.splinify(model, ps) + ps_lin, st_lin = LuxCore.setup(rng, lin_ace) + ps_lin.WB[:] .= ps.WB[:] + ps_lin.Wpair[:] .= ps.Wpair[:] + + for ntest = 1:10 + local len, Nat, Rs, Zs, z0, Ei + len = 10 + mae = sum(1:len) do _ + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st)[1] + Ei_lin = M.evaluate(lin_ace, Rs, Zs, z0, ps_lin, st_lin)[1] + abs(Ei - Ei_lin) + end + mae /= len + print_tf(@test mae < 0.01) end - mae /= len - print_tf(@test mae < 0.01) -end -println() + println() +end ## #= From 2f165c5fc1427ddeecbd96ac58042a79e93df4c9 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 2 Jun 2024 15:41:26 -0700 Subject: [PATCH 050/112] add ybasis tests --- test/models/test_ace.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index a9cc9abb..7072cc82 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -20,7 +20,39 @@ level = M.TotalDegree() max_level = 15 order = 3 +@info("Test ybasis of the Model is used correctly") +msolid = M.ace_model(; elements = elements, order = order, Ytype = :solid, +level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :glorot_normal) +mspherical = M.ace_model(; elements = elements, order = order, Ytype = :spherical, +level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :zero) +ps, st = LuxCore.setup(rng, msolid) + +# get first l ≠ 0 index +spec = msolid.tensor.aabasis.meta["AA_spec"] +getl(t) = t.l +idxBBl1 = findfirst(t -> any(getl.(t) .> 0), spec) + +# only set firstl1 as 1.0, coming from ∑m +ps.Wpair .= 0.0 +ps.WB .= 0.0 +ps.WB[idxBBl1, :] .= 1.0 + +# check sacling +Nat = 1 +for ntest = 1:20 + local Rs, Zs, Z0, Nal + Rs, Zs, Z0 = M.rand_atenv(msolid, Nat) + valsoild, _ = M.evaluate(msolid, Rs, Zs, Z0, ps, st) + valspherical, _ = M.evaluate(mspherical, Rs, Zs, Z0, ps, st) + print_tf(@test valsoild / valspherical ≈ norm(Rs[1])^(length(spec[idxBBl1]))) +end +println() + +## + for ybasis in [:spherical, :solid] + @info("=== Testing ybasis = $ybasis === ") + local ps, st, Nat model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, level = level, max_level = max_level, maxl = 8, pair_maxn = 15, From 4df33f6a297ccc5ac9a58015d0631ab191d1aac9 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 2 Jun 2024 15:43:43 -0700 Subject: [PATCH 051/112] fix typo --- test/models/test_ace.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 7072cc82..d49919d6 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -40,7 +40,7 @@ ps.WB[idxBBl1, :] .= 1.0 # check sacling Nat = 1 for ntest = 1:20 - local Rs, Zs, Z0, Nal + local Rs, Zs, Z0 Rs, Zs, Z0 = M.rand_atenv(msolid, Nat) valsoild, _ = M.evaluate(msolid, Rs, Zs, Z0, ps, st) valspherical, _ = M.evaluate(mspherical, Rs, Zs, Z0, ps, st) From f14a8c6d3d09246ce23e38df3f19f101fe42bd24 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 2 Jun 2024 17:32:21 -0700 Subject: [PATCH 052/112] change solid vs spher test --- test/models/test_ace.jl | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index d49919d6..a22b33ab 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -20,6 +20,8 @@ level = M.TotalDegree() max_level = 15 order = 3 +## + @info("Test ybasis of the Model is used correctly") msolid = M.ace_model(; elements = elements, order = order, Ytype = :solid, level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :glorot_normal) @@ -27,26 +29,13 @@ mspherical = M.ace_model(; elements = elements, order = order, Ytype = :spherica level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :zero) ps, st = LuxCore.setup(rng, msolid) -# get first l ≠ 0 index -spec = msolid.tensor.aabasis.meta["AA_spec"] -getl(t) = t.l -idxBBl1 = findfirst(t -> any(getl.(t) .> 0), spec) - -# only set firstl1 as 1.0, coming from ∑m -ps.Wpair .= 0.0 -ps.WB .= 0.0 -ps.WB[idxBBl1, :] .= 1.0 - -# check sacling -Nat = 1 -for ntest = 1:20 - local Rs, Zs, Z0 - Rs, Zs, Z0 = M.rand_atenv(msolid, Nat) - valsoild, _ = M.evaluate(msolid, Rs, Zs, Z0, ps, st) - valspherical, _ = M.evaluate(mspherical, Rs, Zs, Z0, ps, st) - print_tf(@test valsoild / valspherical ≈ norm(Rs[1])^(length(spec[idxBBl1]))) +for ntest = 1:30 + 𝐫 = randn(SVector{3, Float64}) + Ysolid = msolid.ybasis(𝐫) + Yspher = mspherical.ybasis(𝐫) + ll = [ M.SpheriCart.idx2lm(i)[1] for i in 1:length(Ysolid) ] + print_tf(@test (Yspher .* (norm(𝐫)).^ll) ≈ Ysolid) end -println() ## From 8557bca67cb127c5719dbc54a43b0621886381e4 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 2 Jun 2024 20:55:52 -0700 Subject: [PATCH 053/112] add get nl --- src/models/sparse.jl | 18 ++++++++++++++++++ src/models/utils.jl | 5 +++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/models/sparse.jl b/src/models/sparse.jl index d6bddacc..b5403d3a 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -57,3 +57,21 @@ function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, return ∂Rnl, ∂Ylm end + +_nl(bb) = [(n = b.n, l = b.l) for b in bb] + +function get_nl(tensor::SparseEquivTensor{T}) where {T} + # assume the new ACE model NEVER has the z channel + spec = tensor.aabasis.meta["AA_spec"] + nBB = size(tensor.A2Bmap, 1) + nnll_list = Vector{NT_NL_SPEC}[] + for i in 1:nBB + AAidx_nnz = tensor.A2Bmap[i, :].nzind + bbs = spec[AAidx_nnz] + @assert all([bb == _nl(bbs[1]) for bb in _nl.(bbs)]) + push!(nnll_list, _nl(bbs[1])) + end + @assert length(nnll_list) == nBB + return nnll_list +end + diff --git a/src/models/utils.jl b/src/models/utils.jl index de8d12da..b27588c5 100644 --- a/src/models/utils.jl +++ b/src/models/utils.jl @@ -78,8 +78,9 @@ function sparse_AA_spec(; order = nothing, return AA_spec_nlm end - - +function get_nl(model::ACEModel) + return get_nl(model.tensor) +end import ACE1 From a21059281d1fcaa0c3ca3d389d50d59d512c15e6 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 2 Jun 2024 21:53:48 -0700 Subject: [PATCH 054/112] renaming and minimal doc --- src/models/sparse.jl | 10 ++++++++-- src/models/utils.jl | 13 +++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/models/sparse.jl b/src/models/sparse.jl index b5403d3a..f4c187bd 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -58,9 +58,15 @@ function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, end -_nl(bb) = [(n = b.n, l = b.l) for b in bb] +""" +Get the specification of the BBbasis as a list (`Vector`) of vectors of `@NamedTuple{n::Int, l::Int}`. -function get_nl(tensor::SparseEquivTensor{T}) where {T} +### Parameters + +* `tensor` : a SparseEquivTensor, possibly from ACEModel +""" +function get_nnll_spec(tensor::SparseEquivTensor{T}) where {T} + _nl(bb) = [(n = b.n, l = b.l) for b in bb] # assume the new ACE model NEVER has the z channel spec = tensor.aabasis.meta["AA_spec"] nBB = size(tensor.A2Bmap, 1) diff --git a/src/models/utils.jl b/src/models/utils.jl index b27588c5..e33cf039 100644 --- a/src/models/utils.jl +++ b/src/models/utils.jl @@ -78,10 +78,19 @@ function sparse_AA_spec(; order = nothing, return AA_spec_nlm end -function get_nl(model::ACEModel) - return get_nl(model.tensor) + +""" +Get the specification of the BBbasis as a list (`Vector`) of vectors of `@NamedTuple{n::Int, l::Int}`. + +### Parameters + +* `model` : an ACEModel +""" +function get_nnll_spec(model::ACEModel) + return get_nnll_spec(model.tensor) end + import ACE1 rand_atenv(model::ACEModel, Nat) = rand_atenv(model.rbasis, Nat) From 04311a30474812ed007a7cdb545ffb9fe7ea2c3d Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 22 Jun 2024 00:52:17 -0700 Subject: [PATCH 055/112] update most tests --- Project.toml | 8 +++----- src/models/calculators.jl | 7 ------- src/models/sparse.jl | 4 ++-- test/models/test_calculator.jl | 20 +++++++++++--------- test/runtests.jl | 13 ++++++------- test/test_io.jl | 34 ++++++++++++++++------------------ 6 files changed, 38 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index 601ba7e5..665ec5d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ACEpotentials" uuid = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" -version = "0.6.6" +version = "0.8.0-dev" [deps] ACE1 = "e3f9bc04-086e-409a-ba78-e9769fe067bb" @@ -25,7 +25,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" NeighbourLists = "2fcf5ba9-9ed4-57cf-b73f-ff513e316b9c" -ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -51,15 +50,14 @@ ACE1x = "0.1.8" ACEfit = "0.1.4" ACEmd = "0.1.6" ExtXYZ = "0.1.14" +Interpolations = "0.14.7" JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" -SpheriCart = "0.0.3" +SpheriCart = "0.1.1" StaticArrays = "1" -UltraFastACE = "0.0.2" YAML = "0.4" julia = "1.9" -Interpolations = "0.14.7" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 0eddbf76..ec5649fe 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -17,8 +17,6 @@ using Folds, ChunkSplitters, Unitful, NeighbourLists, import ChainRulesCore: rrule, NoTangent, ZeroTangent -using ObjectPools: release! - struct ACEPotential{MOD} <: SitePotential model::MOD ps @@ -101,7 +99,6 @@ function energy_forces_virial_serial( forces[i] += dv[α] end virial += _site_virial(dv, Rs) - release!(Js); release!(Rs); release!(Zs) end return (energy = energy * energy_unit(V), forces = forces * force_unit(V), @@ -145,7 +142,6 @@ function energy_forces_virial( forces[i] += dv[α] * force_unit(V) end virial += _site_virial(dv, Rs) * energy_unit(V) - release!(Js); release!(Rs); release!(Zs) end [energy, forces, virial] end @@ -204,8 +200,6 @@ function pullback_EFV(Δefv, mult = one(TP) end - release!(Js); release!(Rs); release!(Zs) - # convert it back to a vector so we can accumulate it in the sum. # this is quite bad - in the call to pullback_2_mixed we just # converted it from a vector to a named tuple. We need to look into @@ -281,7 +275,6 @@ function energy_forces_virial_basis( end V[k] += _site_virial(dv[k, :], Rs) * energy_unit(calc) end - release!(Js); release!(Rs); release!(Zs) end return (energy = E, forces = F, virial = V) diff --git a/src/models/sparse.jl b/src/models/sparse.jl index f4c187bd..31369cab 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -47,12 +47,12 @@ function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, TA = promote_type(T, eltype(_AA), eltype(∂B), eltype(Rnl), eltype(eltype(Ylm))) ∂A = zeros(TA, length(tensor.abasis)) - Polynomials4ML.pullback_arg!(∂A, _∂AA, tensor.aabasis, _AA) + Polynomials4ML.unsafe_pullback!(∂A, _∂AA, tensor.aabasis, _AA) # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) ∂Rnl = zeros(TA, size(Rnl)) ∂Ylm = zeros(TA, size(Ylm)) - Polynomials4ML._pullback_evaluate!((∂Rnl, ∂Ylm), ∂A, tensor.abasis, (Rnl, Ylm)) + Polynomials4ML.pullback!((∂Rnl, ∂Ylm), ∂A, tensor.abasis, (Rnl, Ylm)) return ∂Rnl, ∂Ylm end diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 9787bc1c..d3bbc855 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -3,6 +3,8 @@ # using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); +## + using Test, ACEbase using ACEbase.Testing: print_tf, println_slim @@ -116,14 +118,14 @@ for ntest = 1:10 local at, Us, dF0, X0, F, Z at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) - Z = AtomsBuilder._get_atomic_numbers(at) + Z = AtomsBase.atomic_number(at) Z[[3,6,8]] .= 8 - at = AtomsBuilder._set_atomic_numbers(at, Z) + at = AtomsBuilder.set_elements(at, Z) Us = randn(SVector{3, Float64}, length(at)) / length(at) * u"Å" dF0 = - dot(Us, M.energy_forces_virial_serial(at, calc, ps, st).forces) - X0 = AtomsBuilder._get_positions(at) + X0 = AtomsBase.position(at) F(t) = M.energy_forces_virial_serial( - AtomsBuilder._set_positions(at, X0 + t * Us), + AtomsBuilder.set_positions(at, X0 + t * Us), calc, ps, st).energy |> ustrip print_tf( @test ACEbase.Testing.fdtest(F, t -> ustrip(dF0), 0.0; verbose=false ) ) end @@ -142,7 +144,7 @@ for ntest = 1:10 len = 10 mae = sum(1:len) do _ at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) - Z = AtomsBuilder._get_atomic_numbers(at) + Z = AtomsBase.atomic_number(at) Z[[3,6,8]] .= 8 E = M.energy_forces_virial(at, calc, ps, st).energy E_lin = M.energy_forces_virial(at, lin_calc, ps_lin, st_lin).energy @@ -162,7 +164,7 @@ for ntest = 1:10 ps_lin, st_lin = LuxCore.setup(rng, lin_calc) at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) - Z = AtomsBuilder._get_atomic_numbers(at) + Z = AtomsBase.atomic_number(at) Z[[3,6,8]] .= 8 efv = M.energy_forces_virial(at, lin_calc, ps_lin, st_lin) @@ -192,7 +194,7 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), # random structure at = AB.rattle!(AB.bulk(:Si, cubic=true), 0.1) - Z = AtomsBuilder._get_atomic_numbers(at) + Z = AtomsBase.atomic_number(at) Z[[3,6,8]] .= 8 function loss(at, calc, ps, st) @@ -214,8 +216,8 @@ for (wE, wV, wF) in [ (1.0 / u"eV", 0.0 / u"eV", 0.0 / u"eV/Å"), dF0 = dot(g_vec, u) @info("(wE, wV, wF) = ($wE, $wV, $wF)") - FDTEST = ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=true) + FDTEST = ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose=false) println(@test FDTEST) end -## \ No newline at end of file +## diff --git a/test/runtests.jl b/test/runtests.jl index 61570aa9..89c72f93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,11 +13,10 @@ using ACEpotentials, Test, LazyArtifacts @testset "Models" begin include("models/test_models.jl") end # outdated - @testset "Read data" begin include("outdated/test_data.jl") end - @testset "Basis" begin include("outdated/test_basis.jl") end - @testset "Solver" begin include("outdated/test_solver.jl") end - @testset "Fit ACE" begin include("outdated/test_fit.jl") end - @testset "Read params" begin include("outdated/test_read_params.jl") end - #@testset "Test ace_fit.jl script" begin include("outdated/test_ace_fit.jl") end - + # @testset "Read data" begin include("outdated/test_data.jl") end + # @testset "Basis" begin include("outdated/test_basis.jl") end + # @testset "Solver" begin include("outdated/test_solver.jl") end + # @testset "Fit ACE" begin include("outdated/test_fit.jl") end + # @testset "Read params" begin include("outdated/test_read_params.jl") end + # @testset "Test ace_fit.jl script" begin include("outdated/test_ace_fit.jl") end end diff --git a/test/test_io.jl b/test/test_io.jl index 99ff8602..83746802 100644 --- a/test/test_io.jl +++ b/test/test_io.jl @@ -16,22 +16,20 @@ weights = Dict("default" => Dict("E"=>30.0, "F"=>1.0, "V"=>1.0), "liq" => Dict("E"=>10.0, "F"=>0.66, "V"=>0.25)) -@testset "IO for new save and load" begin - acefit!(model, data; - data_keys..., - weights = weights, - solver = ACEfit.LSQR( - damp = 2e-2, - conlim = 1e12, - atol = 1e-7, - maxiter = 100000, - verbose = false +acefit!(model, data; + data_keys..., + weights = weights, + solver = ACEfit.LSQR( + damp = 2e-2, + conlim = 1e12, + atol = 1e-7, + maxiter = 100000, + verbose = false ) - ) - fname = tempname() * ".json" - pot = ACEpotential(model.potential.components) - @test_throws AssertionError save_potential(fname, model; meta="meta test") - save_potential(fname, model; meta=Dict("test"=>"meta test") ) - npot = load_potential(fname; new_format=true) - @test ace_energy(pot, data[1]) ≈ ace_energy(npot, data[1]) -end \ No newline at end of file + ) +fname = tempname() * ".json" +pot = ACEpotential(model.potential.components) +@test_throws AssertionError save_potential(fname, model; meta="meta test") +save_potential(fname, model; meta=Dict("test"=>"meta test") ) +npot = load_potential(fname; new_format=true) +@test ace_energy(pot, data[1]) ≈ ace_energy(npot, data[1]) From da1db057959a4fa1bc30613192af7ae4d178f3f8 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 22 Jun 2024 10:56:22 -0700 Subject: [PATCH 056/112] fixed tests for new code --- Project.toml | 1 - test/models/test_ace.jl | 7 ++++++- test/models/test_calculator.jl | 18 ++++++++++++++---- test/runtests.jl | 2 +- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 665ec5d5..4a0d7b75 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,6 @@ Interpolations = "0.14.7" JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" -SpheriCart = "0.1.1" StaticArrays = "1" YAML = "0.4" julia = "1.9" diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index a22b33ab..b525b38d 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -2,6 +2,8 @@ # using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); +## + using Test, ACEbase using ACEbase.Testing: print_tf, println_slim @@ -12,6 +14,7 @@ using Optimisers, ForwardDiff using Random, LuxCore, StaticArrays, LinearAlgebra rng = Random.MersenneTwister(1234) +Random.seed!(11) ## @@ -36,10 +39,12 @@ for ntest = 1:30 ll = [ M.SpheriCart.idx2lm(i)[1] for i in 1:length(Ysolid) ] print_tf(@test (Yspher .* (norm(𝐫)).^ll) ≈ Ysolid) end +println() ## for ybasis in [:spherical, :solid] + @info("=== Testing ybasis = $ybasis === ") local ps, st, Nat model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, @@ -198,7 +203,7 @@ for ybasis in [:spherical, :solid] abs(Ei - Ei_lin) end mae /= len - print_tf(@test mae < 0.01) + print_tf(@test mae < 0.02) end println() diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index d3bbc855..ce054456 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -18,10 +18,19 @@ using AtomsBuilder, EmpiricalPotentials AB = AtomsBuilder using EmpiricalPotentials: get_neighbours - using Random, LuxCore, StaticArrays, LinearAlgebra rng = Random.MersenneTwister(1234) +function _calc_E0s(at, calc) + out = 0.0 + Zs = AtomsBase.atomic_number(at) + for z in Zs + iz = M._z2i(calc.model, z) + out += calc.model.E0s[iz] + end + return out +end + ## elements = (:Si, :O) @@ -66,7 +75,9 @@ for ntest = 1:20 n_O = count(x -> x == 8, AtomsBase.atomic_number(at)) nlist = PairList(at, M.cutoff_radius(calc)) efv = M.energy_forces_virial(at, calc, ps_zero, st) - print_tf(@test ustrip(abs(efv.energy - E0s[:Si] * n_Si - E0s[:O] * n_O)) < 1e-10) + _E0 = E0s[:Si] * n_Si + E0s[:O] * n_O + print_tf(@test ustrip(abs(efv.energy - _E0)) < 1e-10) + print_tf(@test ustrip(_E0) ≈ _calc_E0s(at, calc)) end println() @@ -169,9 +180,8 @@ for ntest = 1:10 efv = M.energy_forces_virial(at, lin_calc, ps_lin, st_lin) efv_b = M.energy_forces_virial_basis(at, lin_calc, ps_lin, st_lin) - ps_vec, _restruct = destructure(ps_lin) - print_tf(@test dot(efv_b.energy, ps_vec) ≈ efv.energy ) + print_tf(@test dot(efv_b.energy, ps_vec) + _calc_E0s(at, lin_calc) * u"eV" ≈ efv.energy ) print_tf(@test all(efv_b.forces * ps_vec .≈ efv.forces) ) print_tf(@test sum(ps_vec .* efv_b.virial) ≈ efv.virial ) end diff --git a/test/runtests.jl b/test/runtests.jl index 89c72f93..f54407d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ using ACEpotentials, Test, LazyArtifacts # experimental @testset "UF_ACE" begin include("test_uface.jl") end - @testset "Models" begin include("models/test_models.jl") end + @testset "New Models" begin include("models/test_models.jl") end # outdated # @testset "Read data" begin include("outdated/test_data.jl") end From 414efe390cdaee6ab92dc8019426f90ec41f22ba Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 2 Jul 2024 14:16:35 +0200 Subject: [PATCH 057/112] many steps towards bumper --- Project.toml | 3 ++ src/models/Rnl_learnable.jl | 38 +++++++++++++-------- src/models/ace.jl | 68 ++++++++++++++++++++----------------- src/models/models.jl | 15 ++++++++ src/models/sparse.jl | 11 ++++-- test/models/test_ace.jl | 7 ++-- 6 files changed, 92 insertions(+), 50 deletions(-) diff --git a/Project.toml b/Project.toml index 4a0d7b75..90dcb80b 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ACEmd = "69e0c927-b120-467d-b2b3-5b6842148cf4" AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" +Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" @@ -39,8 +40,10 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b" UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +WithAlloc = "fb1aa66a-603c-4c1d-9bc4-66947c7b08dd" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 921cf4ec..f6b66dac 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -89,14 +89,16 @@ function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) end -function evaluate_batched(basis::LearnableRnlrzzBasis, +function evaluate_batched!(Rnl, + basis::LearnableRnlrzzBasis, rs, zi, zjs, ps, st) - @assert length(rs) == length(zjs) + @assert length(rs) == length(zjs) + @assert size(Rnl, 1) >= length(rs) + @assert size(Rnl, 2) >= length(basis) + # evaluate the first one to get the types and size Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) - # ... and then allocate storage - Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) # then evaluate the rest in-place for j = 1:length(rs) @@ -110,7 +112,15 @@ function evaluate_batched(basis::LearnableRnlrzzBasis, Rnl[j, :] = (@view ps.Wnlq[:, :, iz, jz]) * P end - return Rnl, st + return Rnl +end + +function whatalloc(::typeof(evaluate_batched!), + basis::LearnableRnlrzzBasis, + rs::AbstractVector{T}, zi, zjs, ps, st) where {T} + # Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + T1 = promote_type(eltype(ps.Wnlq), T) + return (T1, length(rs), length(basis)) end @@ -201,13 +211,13 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, end -function rrule(::typeof(evaluate_batched), - basis::LearnableRnlrzzBasis, - rs, zi, zjs, ps, st) - Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) +# function rrule(::typeof(evaluate_batched), +# basis::LearnableRnlrzzBasis, +# rs, zi, zjs, ps, st) +# Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) - return (Rnl, st), - Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), - pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), - NoTangent()) -end \ No newline at end of file +# return (Rnl, st), +# Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), +# pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), +# NoTangent()) +# end diff --git a/src/models/ace.jl b/src/models/ace.jl index 76b14716..e92cf18e 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -1,20 +1,14 @@ -import LuxCore: AbstractExplicitLayer, - AbstractExplicitContainerLayer, - initialparameters, - initialstates using Lux: glorot_normal -using Random: AbstractRNG using StaticArrays: SVector using LinearAlgebra: norm, dot +using Polynomials4ML: real_sphericalharmonics, real_solidharmonics -import SpheriCart -using SpheriCart: SolidHarmonics, SphericalHarmonics import RepLieGroups import EquivariantModels -import Polynomials4ML + # ------------------------------------------------------------ # ACE MODEL SPECIFICATION @@ -50,9 +44,9 @@ const NT_NLM = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} function _make_Y_basis(Ytype, lmax) if Ytype == :solid - return SolidHarmonics(lmax) + return real_solidharmonics(lmax) elseif Ytype == :spherical - return SphericalHarmonics(lmax) + return real_sphericalharmonics(lmax) end error("unknown `Ytype` = $Ytype - I don't know how to generate a spherical basis from this.") @@ -71,12 +65,12 @@ function _make_A_spec(AA_spec, level) return A_spec end -# this should go into sphericart or P4ML +# TODO: this should go into sphericart or P4ML function _make_Y_spec(maxl::Integer) NT_LM = NamedTuple{(:l, :m), Tuple{Int, Int}} y_spec = NT_LM[] - for i = 1:SpheriCart.sizeY(maxl) - l, m = SpheriCart.idx2lm(i) + for i = 1:P4ML.SpheriCart.sizeY(maxl) + l, m = P4ML.SpheriCart.idx2lm(i) push!(y_spec, (l = l, m = m)) end return y_spec @@ -234,41 +228,51 @@ function splinify(model::ACEModel, ps::NamedTuple) model.meta) end +# ------------------------------------------------------------ +# utilities + +function radii!(rs, Rs::AbstractVector{SVector{D, T}}) where {D, T <: Real} + @assert length(rs) >= length(Rs) + @inbounds for i = 1:length(Rs) + rs[i] = norm(Rs[i]) + end + return rs +end + +function whatalloc(::typeof(radii!), Rs::AbstractVector{SVector{D, T}}) where {D, T <: Real} + return (T, length(Rs)) +end + # ------------------------------------------------------------ # Model Evaluation # this should possibly be moved to a separate file once it # gets more complicated. -import Zygote - -# these _getlmax and _length should be moved into SpheriCart -_getlmax(ybasis::Union{SolidHarmonics{L}, SphericalHarmonics{L}}) where {L} = L -_length(ybasis::Union{SolidHarmonics, SphericalHarmonics}) = SpheriCart.sizeY(_getlmax(ybasis)) - function evaluate(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} i_z0 = _z2i(model.rbasis, Z0) + @no_escape begin + # get the radii - rs = [ norm(r) for r in Rs ] # use Bumper + rs = @withalloc radii!(Rs) # evaluate the radial basis - # use Bumper to pre-allocate - Rnl, _st = evaluate_batched(model.rbasis, rs, Z0, Zs, - ps.rbasis, st.rbasis) + Rnl = @withalloc evaluate_batched!(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here - SpheriCart.compute!(Ylm, model.ybasis, Rs) + Ylm = @withalloc P4ML.evaluate!(model.ybasis, Rs) # equivariant tensor product - B, _ = evaluate(model.tensor, Rnl, Ylm) + B, _ = @withalloc evaluate!(model.tensor, Rnl, Ylm) # contract with params val = dot(B, (@view ps.WB[:, i_z0])) + # ------------------- # pair potential if model.pairbasis != nothing @@ -281,6 +285,8 @@ function evaluate(model::ACEModel, # E0s val += model.E0s[i_z0] # ------------------- + + end return val, st end @@ -304,8 +310,8 @@ function evaluate_ed(model::ACEModel, Rnl, dRnl, _st = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), _length(model.ybasis)) # TODO: use Bumper - dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) + Ylm = zeros(T, length(Rs), length(model.ybasis)) # TODO: use Bumper + dYlm = zeros(SVector{3, T}, length(Rs), length(model.ybasis)) SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) # Forward Pass through the tensor @@ -377,8 +383,8 @@ function grad_params(model::ACEModel, (Rnl, _st), pb_Rnl = rrule(evaluate_batched, model.rbasis, rs, Z0, Zs, ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), _length(model.ybasis)) # TODO: use Bumper - dYlm = zeros(SVector{3, T}, length(Rs), _length(model.ybasis)) + Ylm = zeros(T, length(Rs), length(model.ybasis)) # TODO: use Bumper + dYlm = zeros(SVector{3, T}, length(Rs), length(model.ybasis)) SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) # Forward Pass through the tensor @@ -520,7 +526,7 @@ function evaluate_basis(model::ACEModel, ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), _length(model.ybasis)) # use Bumper here + Ylm = zeros(T, length(Rs), length(model.ybasis)) # use Bumper here SpheriCart.compute!(Ylm, model.ybasis, Rs) # equivariant tensor product diff --git a/src/models/models.jl b/src/models/models.jl index 81387bd8..8d53dc78 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -1,6 +1,21 @@ module Models +using Random: AbstractRNG + +using StrideArrays, Bumper, WithAlloc +import WithAlloc: whatalloc + +import Zygote + +import Polynomials4ML +const P4ML = Polynomials4ML + +import LuxCore: AbstractExplicitLayer, + AbstractExplicitContainerLayer, + initialparameters, + initialstates + include("elements.jl") include("radial_envelopes.jl") diff --git a/src/models/sparse.jl b/src/models/sparse.jl index 31369cab..19ae7799 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -13,7 +13,7 @@ end Base.length(tensor::SparseEquivTensor) = size(tensor.A2Bmap, 1) -function evaluate(tensor::SparseEquivTensor{T}, Rnl, Ylm) where {T} +function evaluate!(B, _AA, tensor::SparseEquivTensor{T}, Rnl, Ylm) where {T} # evaluate the A basis TA = promote_type(T, eltype(Rnl), eltype(eltype(Ylm))) A = zeros(TA, length(tensor.abasis)) @@ -27,11 +27,18 @@ function evaluate(tensor::SparseEquivTensor{T}, Rnl, Ylm) where {T} AA = _AA[proj] # use Bumper here, or view; needs experimentation. # evaluate the coupling coefficients - B = tensor.A2Bmap * AA + # B = tensor.A2Bmap * AA + mul!(B, tensor.A2Bmap, AA) return B, (_AA = _AA, ) end +function whatalloc(::typeof(evaluate!), tensor::SparseEquivTensor, Rnl, Ylm) + TA = promote_type(eltype(Rnl), eltype(eltype(Ylm))) + TB = promote_type(TA, eltype(tensor.A2Bmap)) + return (TB, size(tensor.A2Bmap, 1),), (TA, length(tensor.abasis),) +end + function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, intermediates) where {T} diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index b525b38d..d597795e 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -36,7 +36,7 @@ for ntest = 1:30 𝐫 = randn(SVector{3, Float64}) Ysolid = msolid.ybasis(𝐫) Yspher = mspherical.ybasis(𝐫) - ll = [ M.SpheriCart.idx2lm(i)[1] for i in 1:length(Ysolid) ] + ll = [ M.P4ML.SpheriCart.idx2lm(i)[1] for i in 1:length(Ysolid) ] print_tf(@test (Yspher .* (norm(𝐫)).^ll) ≈ Ysolid) end println() @@ -44,7 +44,7 @@ println() ## for ybasis in [:spherical, :solid] - + # ybasis = :spherical @info("=== Testing ybasis = $ybasis === ") local ps, st, Nat model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, @@ -73,8 +73,9 @@ for ybasis in [:spherical, :solid] print_tf(@test abs(val - val1) < 1e-10) end println() +end - ## +## @info("Test derivatives w.r.t. positions") Rs, Zs, z0 = M.rand_atenv(model, 16) From c4aed63b96ad8a3ac58139558fab8d82e9e77f3c Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 2 Jul 2024 16:16:38 +0200 Subject: [PATCH 058/112] fixing versions, which fixes lots of Lux warnings --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 90dcb80b..b5a1a031 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ ACE1x = "0.1.8" ACEfit = "0.1.4" ACEmd = "0.1.6" ExtXYZ = "0.1.14" -Interpolations = "0.14.7" +Interpolations = "0.14.7, 0.15" JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" From b0618cb848d39189176b47526d08e34ff7e08650 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 3 Jul 2024 09:34:32 +0200 Subject: [PATCH 059/112] steps towards withalloc in derivatives --- src/models/Rnl_learnable.jl | 44 ++++++++++++++++++++++++------------- src/models/ace.jl | 20 +++++++++-------- src/models/sparse.jl | 10 +++++++++ test/models/test_ace.jl | 5 ++--- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index f6b66dac..ac002f9a 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -73,7 +73,7 @@ function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, r, x) Rnl[:] .= Wij * (P .* e) - return Rnl, st + return Rnl end function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) @@ -85,7 +85,7 @@ function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) P = Polynomials4ML.evaluate(basis.polys, x) env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, r, x) - return Wij * (P .* e), st + return Wij * (P .* e) end @@ -98,7 +98,7 @@ function evaluate_batched!(Rnl, @assert size(Rnl, 2) >= length(basis) # evaluate the first one to get the types and size - Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) + Rnl_1 = evaluate(basis, rs[1], zi, zjs[1], ps, st) # then evaluate the rest in-place for j = 1:length(rs) @@ -118,13 +118,15 @@ end function whatalloc(::typeof(evaluate_batched!), basis::LearnableRnlrzzBasis, rs::AbstractVector{T}, zi, zjs, ps, st) where {T} - # Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) T1 = promote_type(eltype(ps.Wnlq), T) return (T1, length(rs), length(basis)) end - - +function evaluate_batched(basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + Rnl = zeros(whatalloc(evaluate_batched!, basis, rs, zi, zjs, ps, st)...) + return evaluate_batched!(Rnl, basis, rs, zi, zjs, ps, st) +end # ----- gradients # because the typical scenario is that we have few r, then moderately @@ -137,32 +139,44 @@ using ForwardDiff: Dual function evaluate_ed(basis::LearnableRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} d_r = Dual{T}(r, one(T)) - d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) + d_Rnl = evaluate(basis, d_r, Zi, Zj, ps, st) Rnl = ForwardDiff.value.(d_Rnl) Rnl_d = ForwardDiff.extract_derivative(T, d_Rnl) - return Rnl, Rnl_d, st + return Rnl, Rnl_d end -function evaluate_ed_batched(basis::LearnableRnlrzzBasis, +function evaluate_ed_batched!(Rnl, Rnl_d, + basis::LearnableRnlrzzBasis, rs::AbstractVector{T}, Zi, Zs, ps, st ) where {T <: Real} @assert length(rs) == length(Zs) - Rnl1, st = evaluate(basis, rs[1], Zi, Zs[1], ps, st) - Rnl = zeros(T, length(rs), length(Rnl1)) - Rnl_d = zeros(T, length(rs), length(Rnl1)) - for j = 1:length(rs) d_r = Dual{T}(rs[j], one(T)) - d_Rnl, st = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here + d_Rnl = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) end - return Rnl, Rnl_d, st + return Rnl, Rnl_d end +function whatalloc(::typeof(evaluate_ed_batched!), + basis::LearnableRnlrzzBasis, + rs::AbstractVector{T}, Zi, Zs, ps, st) where {T} + T1 = promote_type(eltype(ps.Wnlq), T) + return (T1, length(rs), length(basis)), (T1, length(rs), length(basis)) +end + +function evaluate_ed_batched(basis::LearnableRnlrzzBasis, + rs::AbstractVector{T}, Zi, Zs, ps, st + ) where {T <: Real} + allocinfo = whatalloc(evaluate_ed_batched!, basis, rs, Zi, Zs, ps, st) + Rnl = zeros(allocinfo[1]...) + Rnl_d = zeros(allocinfo[2]...) + return evaluate_ed_batched!(Rnl, Rnl_d, basis, rs, Zi, Zs, ps, st) +end diff --git a/src/models/ace.jl b/src/models/ace.jl index e92cf18e..6cba34ec 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -286,7 +286,7 @@ function evaluate(model::ACEModel, val += model.E0s[i_z0] # ------------------- - end + end # @no_escape return val, st end @@ -298,25 +298,25 @@ function evaluate_ed(model::ACEModel, ps, st) where {T} i_z0 = _z2i(model.rbasis, Z0) + + @no_escape begin # ---------- EMBEDDINGS ------------ # (these are done in forward mode, so not part of the fwd, bwd passes) # get the radii - rs = [ norm(r) for r in Rs ] # TODO: use Bumper + rs = @withalloc radii!(Rs) # evaluate the radial basis - # TODO: use Bumper to pre-allocate - Rnl, dRnl, _st = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, - ps.rbasis, st.rbasis) + # TODO: using @withalloc causes stack overflow + Rnl, dRnl = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), length(model.ybasis)) # TODO: use Bumper - dYlm = zeros(SVector{3, T}, length(Rs), length(model.ybasis)) - SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) + Ylm, dYlm = @withalloc P4ML.evaluate_ed!(model.ybasis, Rs) # Forward Pass through the tensor # keep intermediates to be used in backward pass - B, intermediates = evaluate(model.tensor, Rnl, Ylm) + B, intermediates = @withalloc evaluate!(model.tensor, Rnl, Ylm) # contract with params # (here we can insert another nonlinearity instead of the simple dot) @@ -364,6 +364,8 @@ function evaluate_ed(model::ACEModel, Ei += model.E0s[i_z0] # ------------------- + end # @no_escape + return Ei, ∇Ei, st end diff --git a/src/models/sparse.jl b/src/models/sparse.jl index 19ae7799..bf8d9c2d 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -39,6 +39,16 @@ function whatalloc(::typeof(evaluate!), tensor::SparseEquivTensor, Rnl, Ylm) return (TB, size(tensor.A2Bmap, 1),), (TA, length(tensor.abasis),) end +function evaluate(tensor::SparseEquivTensor, Rnl, Ylm) + allocinfo = whatalloc(evaluate!, tensor, Rnl, Ylm) + B = zeros(allocinfo[1]...) + AA = zeros(allocinfo[2]...) + return evaluate!(B, AA, tensor, Rnl, Ylm) +end + + +# --------- + function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, intermediates) where {T} diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index d597795e..7a8f6f92 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -44,7 +44,7 @@ println() ## for ybasis in [:spherical, :solid] - # ybasis = :spherical + ybasis = :spherical @info("=== Testing ybasis = $ybasis === ") local ps, st, Nat model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, @@ -73,9 +73,8 @@ for ybasis in [:spherical, :solid] print_tf(@test abs(val - val1) < 1e-10) end println() -end -## + ## @info("Test derivatives w.r.t. positions") Rs, Zs, z0 = M.rand_atenv(model, 16) From 8ceecf2ca9a52cb54830e3896106c191ecb3a9b3 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 09:33:56 +0200 Subject: [PATCH 060/112] intermediate --- src/models/ace.jl | 5 ++++- src/models/sparse.jl | 48 ++++++++++++++++++++++++++++++----------- test/models/test_ace.jl | 12 +++++------ 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index 6cba34ec..49c06291 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -309,8 +309,11 @@ function evaluate_ed(model::ACEModel, # evaluate the radial basis # TODO: using @withalloc causes stack overflow + # Rnl, dRnl = @withalloc evaluate_ed_batched!(model.rbasis, rs, Z0, Zs, + # ps.rbasis, st.rbasis) Rnl, dRnl = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, ps.rbasis, st.rbasis) + # evaluate the Y basis Ylm, dYlm = @withalloc P4ML.evaluate_ed!(model.ybasis, Rs) @@ -327,7 +330,7 @@ function evaluate_ed(model::ACEModel, ∂B = @view ps.WB[:, i_z0] # backward pass through tensor - ∂Rnl, ∂Ylm = pullback_evaluate(∂B, model.tensor, Rnl, Ylm, intermediates) + ∂Rnl, ∂Ylm = pullback(∂B, model.tensor, Rnl, Ylm, intermediates) # ---------- ASSEMBLE DERIVATIVES ------------ # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm diff --git a/src/models/sparse.jl b/src/models/sparse.jl index bf8d9c2d..3c141163 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -50,30 +50,54 @@ end # --------- -function pullback_evaluate(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, - intermediates) where {T} +function pullback!(∂Rnl, ∂Ylm, + ∂B, tensor::SparseEquivTensor, Rnl, Ylm, + intermediates) _AA = intermediates._AA proj = tensor.aabasis.projection + T_∂AA = promote_type(eltype(∂B), eltype(tensor.A2Bmap)) + T_∂A = promote_type(T_∂AA, eltype(_AA)) + + @no_escape begin # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap - ∂AA = tensor.A2Bmap' * ∂B # TODO: make this in-place - _∂AA = zeros(T, length(_AA)) - _∂AA[proj] = ∂AA + # ∂AA = tensor.A2Bmap' * ∂B + ∂AA = @alloc(T_∂AA, size(tensor.A2Bmap, 2)) + mul!(∂AA, tensor.A2Bmap', ∂B) + # _∂AA[proj] = ∂AA + _∂AA = @alloc(T_∂AA, length(_AA)) + _∂AA[proj] .= ∂AA # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) - TA = promote_type(T, eltype(_AA), eltype(∂B), - eltype(Rnl), eltype(eltype(Ylm))) - ∂A = zeros(TA, length(tensor.abasis)) - Polynomials4ML.unsafe_pullback!(∂A, _∂AA, tensor.aabasis, _AA) + ∂A = @alloc(T_∂A, length(tensor.abasis)) + P4ML.unsafe_pullback!(∂A, _∂AA, tensor.aabasis, _AA) # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) - ∂Rnl = zeros(TA, size(Rnl)) - ∂Ylm = zeros(TA, size(Ylm)) - Polynomials4ML.pullback!((∂Rnl, ∂Ylm), ∂A, tensor.abasis, (Rnl, Ylm)) + P4ML.pullback!((∂Rnl, ∂Ylm), ∂A, tensor.abasis, (Rnl, Ylm)) + + end # no_escape return ∂Rnl, ∂Ylm end +function whatalloc(::typeof(pullback!), + ∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, + intermediates) where {T} + TA = promote_type(T, eltype(intermediates._AA), eltype(∂B), + eltype(Rnl), eltype(eltype(Ylm))) + return (TA, size(Rnl)...), (TA, size(Ylm)...) +end + +function pullback(∂B, tensor::SparseEquivTensor{T}, Rnl, Ylm, + intermediates) where {T} + alc_∂Rnl, alc_∂Ylm = whatalloc(pullback!, ∂B, tensor, Rnl, Ylm, intermediates) + ∂Rnl = zeros(alc_∂Rnl...) + ∂Ylm = zeros(alc_∂Ylm...) + return pullback!(∂Rnl, ∂Ylm, ∂B, tensor, Rnl, Ylm, intermediates) +end + +# ---------------------------------------- +# utilities """ Get the specification of the BBbasis as a list (`Vector`) of vectors of `@NamedTuple{n::Int, l::Int}`. diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 7a8f6f92..24c8ffcb 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -1,5 +1,5 @@ -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); ## @@ -43,7 +43,7 @@ println() ## -for ybasis in [:spherical, :solid] +# for ybasis in [:spherical, :solid] ybasis = :spherical @info("=== Testing ybasis = $ybasis === ") local ps, st, Nat @@ -54,7 +54,7 @@ for ybasis in [:spherical, :solid] ps, st = LuxCore.setup(rng, model) - ## +## @info("Test Rotation-Invariance of the Model") @@ -74,7 +74,7 @@ for ybasis in [:spherical, :solid] end println() - ## +## @info("Test derivatives w.r.t. positions") Rs, Zs, z0 = M.rand_atenv(model, 16) @@ -89,11 +89,11 @@ for ybasis in [:spherical, :solid] Us = randn(SVector{3, Float64}, Nat) F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) - print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=true)) end println() - ## +## @info("Test derivatives w.r.t. parameters") Nat = 15 From 68d183fa5834ed7f40e98c2dd380644310c9fab6 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 09:43:13 +0200 Subject: [PATCH 061/112] bugfix --- src/models/sparse.jl | 6 ++++-- test/models/test_ace.jl | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/models/sparse.jl b/src/models/sparse.jl index 3c141163..59f080a7 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -59,14 +59,15 @@ function pullback!(∂Rnl, ∂Ylm, T_∂A = promote_type(T_∂AA, eltype(_AA)) @no_escape begin + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ∂Ei / ∂AA = ∂Ei / ∂B * ∂B / ∂AA = (WB[i_z0]) * A2Bmap # ∂AA = tensor.A2Bmap' * ∂B ∂AA = @alloc(T_∂AA, size(tensor.A2Bmap, 2)) mul!(∂AA, tensor.A2Bmap', ∂B) - # _∂AA[proj] = ∂AA _∂AA = @alloc(T_∂AA, length(_AA)) - _∂AA[proj] .= ∂AA + fill!(_∂AA, zero(T_∂AA)) + _∂AA[proj] = ∂AA # ∂Ei / ∂A = ∂Ei / ∂AA * ∂AA / ∂A = pullback(aabasis, ∂AA) ∂A = @alloc(T_∂A, length(tensor.abasis)) @@ -75,6 +76,7 @@ function pullback!(∂Rnl, ∂Ylm, # ∂Ei / ∂Rnl, ∂Ei / ∂Ylm = pullback(abasis, ∂A) P4ML.pullback!((∂Rnl, ∂Ylm), ∂A, tensor.abasis, (Rnl, Ylm)) + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ end # no_escape return ∂Rnl, ∂Ylm diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 24c8ffcb..dab6ff13 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -89,7 +89,7 @@ println() Us = randn(SVector{3, Float64}, Nat) F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) - print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=true)) + print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) end println() From 8d3a03a5b09ed68355e4ef3f99610dc8b4dd8a61 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 11:29:38 +0200 Subject: [PATCH 062/112] grad_params and rev2 --- src/models/Rnl_learnable.jl | 16 +++++++-- src/models/ace.jl | 66 ++++++++++++++++++++++++++++--------- src/models/sparse.jl | 8 ++--- test/models/test_ace.jl | 6 ++-- 4 files changed, 71 insertions(+), 25 deletions(-) diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index ac002f9a..373f41e8 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -155,8 +155,10 @@ function evaluate_ed_batched!(Rnl, Rnl_d, for j = 1:length(rs) d_r = Dual{T}(rs[j], one(T)) d_Rnl = evaluate(basis, d_r, Zi, Zs[j], ps, st) # should reuse memory here - map!(ForwardDiff.value, (@view Rnl[j, :]), d_Rnl) - map!(d -> ForwardDiff.extract_derivative(T, d), (@view Rnl_d[j, :]), d_Rnl) + for t = 1:size(Rnl, 2) + Rnl[j, t] = ForwardDiff.value(d_Rnl[t]) + Rnl_d[j, t] = ForwardDiff.extract_derivative(T, d_Rnl[t]) + end end return Rnl, Rnl_d @@ -235,3 +237,13 @@ end # pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), # NoTangent()) # end + +function rrule(::typeof(evaluate_batched), + basis::LearnableRnlrzzBasis, + rs, zi, zjs, ps, st) + Rnl = evaluate_batched(basis, rs, zi, zjs, ps, st) + + return Rnl, Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), + pullback_evaluate_batched(Δ, basis, rs, zi, zjs, ps, st), + NoTangent()) +end diff --git a/src/models/ace.jl b/src/models/ace.jl index 49c06291..f290bc9b 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -243,6 +243,22 @@ function whatalloc(::typeof(radii!), Rs::AbstractVector{SVector{D, T}}) where {D return (T, length(Rs)) end +function radii_ed!(rs, ∇rs, Rs::AbstractVector{SVector{D, T}}) where {D, T <: Real} + @assert length(rs) >= length(Rs) + @assert length(∇rs) >= length(Rs) + @inbounds for i = 1:length(Rs) + rs[i] = norm(Rs[i]) + ∇rs[i] = Rs[i] / rs[i] + end + return rs, ∇rs +end + +function whatalloc(::typeof(radii_ed!), Rs::AbstractVector{SVector{D, T}}) where {D, T <: Real} + return (T, length(Rs)), (SVector{D, T}, length(Rs)) +end + + + # ------------------------------------------------------------ # Model Evaluation # this should possibly be moved to a separate file once it @@ -300,19 +316,20 @@ function evaluate_ed(model::ACEModel, i_z0 = _z2i(model.rbasis, Z0) @no_escape begin + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ---------- EMBEDDINGS ------------ # (these are done in forward mode, so not part of the fwd, bwd passes) # get the radii - rs = @withalloc radii!(Rs) + rs, ∇rs = @withalloc radii_ed!(Rs) # evaluate the radial basis # TODO: using @withalloc causes stack overflow - # Rnl, dRnl = @withalloc evaluate_ed_batched!(model.rbasis, rs, Z0, Zs, - # ps.rbasis, st.rbasis) - Rnl, dRnl = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, + Rnl, dRnl = @withalloc evaluate_ed_batched!(model.rbasis, rs, Z0, Zs, ps.rbasis, st.rbasis) + # Rnl, dRnl = evaluate_ed_batched(model.rbasis, rs, Z0, Zs, + # ps.rbasis, st.rbasis) # evaluate the Y basis Ylm, dYlm = @withalloc P4ML.evaluate_ed!(model.ybasis, Rs) @@ -330,7 +347,7 @@ function evaluate_ed(model::ACEModel, ∂B = @view ps.WB[:, i_z0] # backward pass through tensor - ∂Rnl, ∂Ylm = pullback(∂B, model.tensor, Rnl, Ylm, intermediates) + ∂Rnl, ∂Ylm = @withalloc pullback!(∂B, model.tensor, Rnl, Ylm, intermediates) # ---------- ASSEMBLE DERIVATIVES ------------ # The ∂Ei / ∂𝐫ⱼ can now be obtained from the ∂Ei / ∂Rnl, ∂Ei / ∂Ylm @@ -338,11 +355,16 @@ function evaluate_ed(model::ACEModel, # ∂Ei / ∂𝐫ⱼ = ∑_nl ∂Ei / ∂Rnl[j] * ∂Rnl[j] / ∂𝐫ⱼ # + ∑_lm ∂Ei / ∂Ylm[j] * ∂Ylm[j] / ∂𝐫ⱼ ∇Ei = zeros(SVector{3, T}, length(Rs)) - for j = 1:length(Rs) - ∇Ei[j] = dot(∂Rnl[j, :], dRnl[j, :]) * (Rs[j] / rs[j]) + - sum(∂Ylm[j, :] .* dYlm[j, :]) + for t = 1:size(∂Rnl, 2) + for j = 1:size(∂Rnl, 1) + ∇Ei[j] += (∂Rnl[j, t] * dRnl[j, t]) * ∇rs[j] + end + end + for t = 1:size(∂Ylm, 2) + for j = 1:size(∂Ylm, 1) + ∇Ei[j] += ∂Ylm[j, t] * dYlm[j, t] + end end - # ------------------- # pair potential @@ -367,6 +389,7 @@ function evaluate_ed(model::ACEModel, Ei += model.E0s[i_z0] # ------------------- + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ end # @no_escape return Ei, ∇Ei, st @@ -377,24 +400,31 @@ function grad_params(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} + @no_escape begin + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # ---------- EMBEDDINGS ------------ # (these are done in forward mode, so not part of the fwd, bwd passes) # get the radii - rs = [ norm(r) for r in Rs ] # TODO: use Bumper + rs = @withalloc radii!(Rs) # evaluate the radial basis # TODO: use Bumper to pre-allocate - (Rnl, _st), pb_Rnl = rrule(evaluate_batched, model.rbasis, + Rnl, pb_Rnl = rrule(evaluate_batched, model.rbasis, rs, Z0, Zs, ps.rbasis, st.rbasis) + # evaluate the Y basis - Ylm = zeros(T, length(Rs), length(model.ybasis)) # TODO: use Bumper - dYlm = zeros(SVector{3, T}, length(Rs), length(model.ybasis)) - SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) + Ylm = @withalloc P4ML.evaluate!(model.ybasis, Rs) + + # Ylm = zeros(T, length(Rs), length(model.ybasis)) # TODO: use Bumper + # dYlm = zeros(SVector{3, T}, length(Rs), length(model.ybasis)) + # SpheriCart.compute_with_gradients!(Ylm, dYlm, model.ybasis, Rs) # Forward Pass through the tensor # keep intermediates to be used in backward pass - B, intermediates = evaluate(model.tensor, Rnl, Ylm) + # B, intermediates = evaluate(model.tensor, Rnl, Ylm) + B, intermediates = @withalloc evaluate!(model.tensor, Rnl, Ylm) # contract with params # (here we can insert another nonlinearity instead of the simple dot) @@ -409,7 +439,8 @@ function grad_params(model::ACEModel, ∂B = @view ps.WB[:, i_z0] # backward pass through tensor - ∂Rnl, ∂Ylm = pullback_evaluate(∂B, model.tensor, Rnl, Ylm, intermediates) + # ∂Rnl, ∂Ylm = pullback_evaluate(∂B, model.tensor, Rnl, Ylm, intermediates) + ∂Rnl, ∂Ylm = @withalloc pullback!(∂B, model.tensor, Rnl, Ylm, intermediates) # ---------- ASSEMBLE DERIVATIVES ------------ # the first grad_param is ∂WB, which we already have but it needs to be @@ -448,6 +479,9 @@ function grad_params(model::ACEModel, # ------------------- + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + end # @no_escape + return Ei, (WB = ∂WB, Wpair = ∂Wpair, rbasis = ∂Wqnl, pairbasis = NamedTuple()), st end diff --git a/src/models/sparse.jl b/src/models/sparse.jl index 59f080a7..e3a3e69c 100644 --- a/src/models/sparse.jl +++ b/src/models/sparse.jl @@ -17,11 +17,11 @@ function evaluate!(B, _AA, tensor::SparseEquivTensor{T}, Rnl, Ylm) where {T} # evaluate the A basis TA = promote_type(T, eltype(Rnl), eltype(eltype(Ylm))) A = zeros(TA, length(tensor.abasis)) - Polynomials4ML.evaluate!(A, tensor.abasis, (Rnl, Ylm)) + P4ML.evaluate!(A, tensor.abasis, (Rnl, Ylm)) # evaluate the AA basis - _AA = zeros(TA, length(tensor.aabasis)) # use Bumper here - Polynomials4ML.evaluate!(_AA, tensor.aabasis, A) + # _AA = zeros(TA, length(tensor.aabasis)) # use Bumper here + P4ML.evaluate!(_AA, tensor.aabasis, A) # project to the actual AA basis proj = tensor.aabasis.projection AA = _AA[proj] # use Bumper here, or view; needs experimentation. @@ -36,7 +36,7 @@ end function whatalloc(::typeof(evaluate!), tensor::SparseEquivTensor, Rnl, Ylm) TA = promote_type(eltype(Rnl), eltype(eltype(Ylm))) TB = promote_type(TA, eltype(tensor.A2Bmap)) - return (TB, size(tensor.A2Bmap, 1),), (TA, length(tensor.abasis),) + return (TB, length(tensor),), (TA, length(tensor.aabasis),) end function evaluate(tensor::SparseEquivTensor, Rnl, Ylm) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index dab6ff13..3fb3c9a0 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -115,7 +115,7 @@ println() end println() - ## +## @info("Test second mixed derivatives reverse-over-reverse") for ntest = 1:20 @@ -144,7 +144,7 @@ println() end println() - ## +## @info("Test basis implementation") @@ -166,7 +166,7 @@ println() end println() - ## +## @info("Test the full mixed jacobian") From 771a718ff7e7245d358de800ea41c09b8a295b13 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 13:13:25 +0200 Subject: [PATCH 063/112] fix basis evaluation --- src/models/ace.jl | 24 +++++++++++------------- test/models/test_ace.jl | 29 ++++++++++++++--------------- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/models/ace.jl b/src/models/ace.jl index f290bc9b..b366845e 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -304,7 +304,7 @@ function evaluate(model::ACEModel, end # @no_escape - return val, st + return val end @@ -392,7 +392,7 @@ function evaluate_ed(model::ACEModel, # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ end # @no_escape - return Ei, ∇Ei, st + return Ei, ∇Ei end @@ -557,19 +557,17 @@ function evaluate_basis(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} # get the radii - rs = [ norm(r) for r in Rs ] # use Bumper + rs = @withalloc radii!(Rs) # evaluate the radial basis - # use Bumper to pre-allocate - Rnl, _st = evaluate_batched(model.rbasis, rs, Z0, Zs, - ps.rbasis, st.rbasis) + Rnl = evaluate_batched(model.rbasis, rs, Z0, Zs, + ps.rbasis, st.rbasis) # evaluate the Y basis - Ylm = zeros(T, length(Rs), length(model.ybasis)) # use Bumper here - SpheriCart.compute!(Ylm, model.ybasis, Rs) + Ylm = @withalloc P4ML.evaluate!(model.ybasis, Rs) # equivariant tensor product - Bi, _ = evaluate(model.tensor, Rnl, Ylm) + Bi, _ = @withalloc evaluate!(model.tensor, Rnl, Ylm) B = zeros(eltype(Bi), len_basis(model)) B[get_basis_inds(model, Z0)] .= Bi @@ -583,7 +581,7 @@ function evaluate_basis(model::ACEModel, B[get_pairbasis_inds(model, Z0)] .= Apair end - return B, st + return B end __vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs) @@ -593,16 +591,16 @@ function evaluate_basis_ed(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} - B, st = evaluate_basis(model, Rs, Zs, Z0, ps, st) + B = evaluate_basis(model, Rs, Zs, Z0, ps, st) dB_vec = ForwardDiff.jacobian( - _Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st)[1], + _Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st), __vec(Rs)) dB1 = __svecs(collect(dB_vec')[:]) dB = collect( permutedims( reshape(dB1, length(Rs), length(B)), (2, 1) ) ) - return B, dB, st + return B, dB end diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 3fb3c9a0..5be9e472 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -63,12 +63,12 @@ println() Nat = rand(8:16) Rs, Zs, Z0 = M.rand_atenv(model, Nat) - val, st1 = M.evaluate(model, Rs, Zs, Z0, ps, st) + val = M.evaluate(model, Rs, Zs, Z0, ps, st) p = shuffle(1:Nat) Rs1 = Ref(M.rand_iso()) .* Rs[p] Zs1 = Zs[p] - val1, st1 = M.evaluate(model, Rs1, Zs1, Z0, ps, st) + val1 = M.evaluate(model, Rs1, Zs1, Z0, ps, st) print_tf(@test abs(val - val1) < 1e-10) end @@ -78,8 +78,8 @@ println() @info("Test derivatives w.r.t. positions") Rs, Zs, z0 = M.rand_atenv(model, 16) - Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) - Ei1, ∇Ei, st = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei1, ∇Ei = M.evaluate_ed(model, Rs, Zs, z0, ps, st) println_slim(@test Ei ≈ Ei1) for ntest = 1:20 @@ -87,7 +87,7 @@ println() Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) Us = randn(SVector{3, Float64}, Nat) - F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st)[1] + F(t) = M.evaluate(model, Rs + t * Us, Zs, z0, ps, st) dF(t) = dot(M.evaluate_ed(model, Rs + t * Us, Zs, z0, ps, st)[2], Us) print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) end @@ -98,8 +98,8 @@ println() @info("Test derivatives w.r.t. parameters") Nat = 15 Rs, Zs, z0 = M.rand_atenv(model, Nat) - Ei, st = M.evaluate(model, Rs, Zs, z0, ps, st) - Ei1, ∇Ei, st = M.grad_params(model, Rs, Zs, z0, ps, st) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei1, ∇Ei = M.grad_params(model, Rs, Zs, z0, ps, st) println_slim(@test Ei ≈ Ei1) for ntest = 1:20 @@ -109,7 +109,7 @@ println() Rs, Zs, z0 = M.rand_atenv(model, Nat) pvec, _restruct = destructure(ps) uvec = randn(length(pvec)) / sqrt(length(pvec)) - F(t) = M.evaluate(model, Rs, Zs, z0, _restruct(pvec + t * uvec), st)[1] + F(t) = M.evaluate(model, Rs, Zs, z0, _restruct(pvec + t * uvec), st) dF0 = dot( destructure( M.grad_params(model, Rs, Zs, z0, ps, st)[2] )[1], uvec ) print_tf(@test ACEbase.Testing.fdtest(F, t -> dF0, 0.0; verbose = false)) end @@ -126,7 +126,7 @@ println() Rs, Zs, z0 = M.rand_atenv(model, Nat) Us = randn(SVector{3, Float64}, Nat) Ei = M.evaluate(model, Rs, Zs, z0, ps, st) - Ei, ∂Ei, _ = M.grad_params(model, Rs, Zs, z0, ps, st) + Ei, ∂Ei = M.grad_params(model, Rs, Zs, z0, ps, st) # test partial derivative w.r.t. the Ei component ∂2_Ei = M.pullback_2_mixed(1.0, 0*Us, model, Rs, Zs, z0, ps, st) @@ -154,14 +154,13 @@ println() Nat = 15 Rs, Zs, z0 = M.rand_atenv(model, Nat) i_z0 = M._z2i(model, z0) - Ei, st1 = M.evaluate(model, Rs, Zs, z0, ps, st) - B, st1 = M.evaluate_basis(model, Rs, Zs, z0, ps, st) + Ei = M.evaluate(model, Rs, Zs, z0, ps, st) + B = M.evaluate_basis(model, Rs, Zs, z0, ps, st) θ = M.get_basis_params(model, ps) print_tf(@test Ei ≈ dot(B, θ)) - Ei, ∇Ei, st1 = M.evaluate_ed(model, Rs, Zs, z0, ps, st) - B, ∇B, st1 = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) - print_tf(@test Ei ≈ dot(B, θ)) + Ei, ∇Ei = M.evaluate_ed(model, Rs, Zs, z0, ps, st) + B, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) end println() @@ -184,7 +183,7 @@ println() println() - ## +## @info("check splinification") lin_ace = M.splinify(model, ps) From 27aeee779bc375fbc857ec7930d4eac9fcd216fb Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 16:25:42 +0200 Subject: [PATCH 064/112] update splined radials --- src/models/Rnl_splines.jl | 50 ++++++++++++++++++++++++------- src/models/ace.jl | 8 ++--- test/models/test_Rnl.jl | 33 ++++++++++---------- test/models/test_ace.jl | 11 +++---- test/models/test_learnable_Rnl.jl | 31 +++++++++---------- test/models/test_pair_basis.jl | 10 +++---- 6 files changed, 87 insertions(+), 56 deletions(-) diff --git a/src/models/Rnl_splines.jl b/src/models/Rnl_splines.jl index a9dc9970..822e8562 100644 --- a/src/models/Rnl_splines.jl +++ b/src/models/Rnl_splines.jl @@ -41,28 +41,41 @@ function evaluate(basis::SplineRnlrzzBasis, r::Real, Zi, Zj, ps, st) x_ij = T_ij(r) e_ij = evaluate(env_ij, r, x_ij) - return spl_ij(x_ij) * e_ij, st + return spl_ij(x_ij) * e_ij end -function evaluate_batched(basis::SplineRnlrzzBasis, +function evaluate_batched!(Rnl, basis::SplineRnlrzzBasis, rs, zi, zjs, ps, st) @assert length(rs) == length(zjs) # evaluate the first one to get the types and size - Rnl_1, st = evaluate(basis, rs[1], zi, zjs[1], ps, st) + Rnl_1 = evaluate(basis, rs[1], zi, zjs[1], ps, st) # ... and then allocate storage - Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) + # Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) Rnl[1, :] .= Rnl_1 # then evaluate the rest in-place for j = 2:length(rs) - Rnl[j, :], st = evaluate(basis, rs[j], zi, zjs[j], ps, st) + Rnl[j, :] = evaluate(basis, rs[j], zi, zjs[j], ps, st) end - return Rnl, st + return Rnl end +function whatalloc(::typeof(evaluate_batched!), + basis::SplineRnlrzzBasis, + rs, zi, zjs, ps, st) + T = eltype(rs) + return (T, length(rs), length(basis)) +end + + +function evaluate_batched(basis::SplineRnlrzzBasis, + rs, zi, zjs, ps, st) + Rnl = zeros(whatalloc(evaluate_batched!, basis, rs, zi, zjs, ps, st)...) + return evaluate_batched!(Rnl, basis, rs, zi, zjs, ps, st) +end # ----- gradients # because the typical scenario is that we have few r, then moderately @@ -74,10 +87,10 @@ using ForwardDiff: Dual function evaluate_ed(basis::SplineRnlrzzBasis, r::T, Zi, Zj, ps, st) where {T <: Real} d_r = Dual{T}(r, one(T)) - d_Rnl, st = evaluate(basis, d_r, Zi, Zj, ps, st) + d_Rnl = evaluate(basis, d_r, Zi, Zj, ps, st) Rnl = ForwardDiff.value.(d_Rnl) Rnl_d = ForwardDiff.extract_derivative(T, d_Rnl) - return Rnl, Rnl_d, st + return Rnl, Rnl_d end @@ -87,22 +100,37 @@ function evaluate_ed_batched(basis::SplineRnlrzzBasis, ) where {T <: Real} @assert length(rs) == length(Zs) - Rnl1, ∇Rnl1, st = evaluate_ed(basis, rs[1], Zi, Zs[1], ps, st) + Rnl1, ∇Rnl1 = evaluate_ed(basis, rs[1], Zi, Zs[1], ps, st) Rnl = zeros(T, length(rs), length(Rnl1)) Rnl_d = zeros(T, length(rs), length(Rnl1)) Rnl[1, :] .= Rnl1 Rnl_d[1, :] .= ∇Rnl1 for j = 1:length(rs) - Rnl_j, ∇Rnl_j, st = evaluate_ed(basis, rs[j], Zi, Zs[j], ps, st) + Rnl_j, ∇Rnl_j = evaluate_ed(basis, rs[j], Zi, Zs[j], ps, st) Rnl[j, :] = Rnl_j Rnl_d[j, :] = ∇Rnl_j end - return Rnl, Rnl_d, st + return Rnl, Rnl_d end +function whatalloc(::typeof(evaluate_ed_batched!), + basis::SplineRnlrzzBasis, + rs::AbstractVector, Zi, Zs, ps, st) + T = eltype(rs) + return (T, length(rs), length(basis)), (T, length(rs), length(basis)) +end + + +function evaluate_ed_batched(basis::SplineRnlrzzBasis, + rs::AbstractVector, Zi, Zs, ps, st) + alc_Rnl, alc_Rnl_d = whatalloc(evaluate_ed_batched!, basis, rs, Zi, Zs, ps, st) + Rnl = zeros(alc_Rnl...) + Rnl_d = zeros(alc_Rnl_d...) + return evaluate_ed_batched!(Rnl, Rnl_d, basis, rs, Zi, Zs, ps, st) +end function rrule(::typeof(evaluate_batched), diff --git a/src/models/ace.jl b/src/models/ace.jl index b366845e..2fe52a08 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -292,7 +292,7 @@ function evaluate(model::ACEModel, # ------------------- # pair potential if model.pairbasis != nothing - Rpair, _ = evaluate_batched(model.pairbasis, rs, Z0, Zs, + Rpair = evaluate_batched(model.pairbasis, rs, Z0, Zs, ps.pairbasis, st.pairbasis) Apair = sum(Rpair, dims=1)[:] val += dot(Apair, (@view ps.Wpair[:, i_z0])) @@ -369,7 +369,7 @@ function evaluate_ed(model::ACEModel, # ------------------- # pair potential if model.pairbasis != nothing - Rpair, dRpair, _ = evaluate_ed_batched(model.pairbasis, rs, Z0, Zs, + Rpair, dRpair = evaluate_ed_batched(model.pairbasis, rs, Z0, Zs, ps.pairbasis, st.pairbasis) Apair = sum(Rpair, dims=1)[:] Wp_i = @view ps.Wpair[:, i_z0] @@ -459,7 +459,7 @@ function grad_params(model::ACEModel, # ------------------- # pair potential if model.pairbasis != nothing - Rpair, _ = evaluate_batched(model.pairbasis, rs, Z0, Zs, + Rpair = evaluate_batched(model.pairbasis, rs, Z0, Zs, ps.pairbasis, st.pairbasis) Apair = sum(Rpair, dims=1)[:] Wp_i = @view ps.Wpair[:, i_z0] @@ -575,7 +575,7 @@ function evaluate_basis(model::ACEModel, # ------------------- # pair potential if model.pairbasis != nothing - Rpair, _ = evaluate_batched(model.pairbasis, rs, Z0, Zs, + Rpair = evaluate_batched(model.pairbasis, rs, Z0, Zs, ps.pairbasis, st.pairbasis) Apair = sum(Rpair, dims=1)[:] B[get_pairbasis_inds(model, Z0)] .= Apair diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl index 4ab211a0..9cd6a333 100644 --- a/test/models/test_Rnl.jl +++ b/test/models/test_Rnl.jl @@ -1,7 +1,7 @@ -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using ACEpotentials @@ -10,6 +10,7 @@ M = ACEpotentials.Models using Random, LuxCore, Test, ACEbase, LinearAlgebra using ACEbase.Testing: print_tf rng = Random.MersenneTwister(1234) +Random.seed!(1234) ## @@ -25,8 +26,8 @@ ps, st = LuxCore.setup(rng, basis) r = 3.0 Zi = basis._i2z[1] Zj = basis._i2z[2] -Rnl, st1 = basis(r, Zi, Zj, ps, st) -Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) +Rnl = basis(r, Zi, Zj, ps, st) +Rnl, Rnl_d = M.evaluate_ed(basis, r, Zi, Zj, ps, st) @info("Test derivatives of LearnableRnlrzzBasis") @@ -36,7 +37,7 @@ for ntest = 1:20 Zi = rand(basis._i2z) Zj = rand(basis._i2z) U = randn(eltype(Rnl), length(Rnl)) - F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)[1]) + F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)) dF(t) = dot(U, M.evaluate_ed(basis, r + t, Zi, Zj, ps, st)[2]) print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) end @@ -53,13 +54,13 @@ for ntest = 1:20 Rs, Zs, Z0 = M.rand_atenv(basis, Nat) rs = norm.(Rs) - Rnl = [ M.evaluate(basis, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] - Rnl_b, st1 = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) + Rnl = [ M.evaluate(basis, r, Z0, z, ps, st) for (r, z) in zip(rs, Zs) ] + Rnl_b = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) - Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) + Rnl_b2, ∇Rnl_b = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) ∇Rnl = [ M.evaluate_ed(basis, r, Z0, z, ps, st)[2] - for (r, z) in zip(rs, Zs) ] + for (r, z) in zip(rs, Zs) ] print_tf(@test Rnl_b ≈ Rnl_b2) print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) @@ -79,14 +80,14 @@ for ntest = 1:30 r = norm(Rs[1]) Zj = Zs[1] - Rnl, _ = basis(r, Zi, Zj, ps, st) + Rnl = basis(r, Zi, Zj, ps, st) for (nnodes, tol) in [(30, 1e-3), (100, 1e-5), (1000, 1e-8)] local basis_spl, ps_spl, st_spl, Rnl_spl basis_spl = M.splinify(basis, ps; nnodes = nnodes) ps_spl, st_spl = LuxCore.setup(rng, basis_spl) - Rnl_spl, _ = basis_spl(r, Zi, Zj, ps_spl, st_spl) + Rnl_spl = basis_spl(r, Zi, Zj, ps_spl, st_spl) rel_err = (Rnl - Rnl_spl) ./ (1 .+ abs.(Rnl)) # use 1-norm here to not stress about small outliers print_tf(@test norm(rel_err, 1)/length(Rnl) < tol) @@ -106,9 +107,9 @@ for ntest = 1:20 Rs, Zs, Zi = M.rand_atenv(basis_spl, 1) r = norm(Rs[1]); Zj = Zs[1] - Rnl = basis_spl(r, Zi, Zj, ps, st)[1] + Rnl = basis_spl(r, Zi, Zj, ps, st) U = randn(eltype(Rnl), length(Rnl)) - F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)[1]) + F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)) dF(t) = dot(U, M.evaluate_ed(basis, r + t, Zi, Zj, ps, st)[2]) print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) end @@ -127,15 +128,15 @@ for ntest = 1:20 Rs, Zs, Z0 = M.rand_atenv(basis_spl, Nat) rs = norm.(Rs) - Rnl = [ M.evaluate(basis_spl, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] - Rnl_b, st1 = M.evaluate_batched(basis_spl, rs, Z0, Zs, ps, st) + Rnl = [ M.evaluate(basis_spl, r, Z0, z, ps, st) for (r, z) in zip(rs, Zs) ] + Rnl_b = M.evaluate_batched(basis_spl, rs, Z0, Zs, ps, st) print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) - Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis_spl, rs, Z0, Zs, ps, st) + Rnl_b2, ∇Rnl_b = M.evaluate_ed_batched(basis_spl, rs, Z0, Zs, ps, st) ∇Rnl = [ M.evaluate_ed(basis_spl, r, Z0, z, ps, st)[2] for (r, z) in zip(rs, Zs) ] print_tf(@test Rnl_b ≈ Rnl_b2) print_tf(@test all(∇Rnl .≈ [∇Rnl_b[j, :] for j = 1:Nat ])) end -println() \ No newline at end of file +println() diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 5be9e472..d8be7642 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -43,8 +43,8 @@ println() ## -# for ybasis in [:spherical, :solid] - ybasis = :spherical +for ybasis in [:spherical, :solid] + # ybasis = :solid @info("=== Testing ybasis = $ybasis === ") local ps, st, Nat model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, @@ -193,12 +193,12 @@ println() for ntest = 1:10 local len, Nat, Rs, Zs, z0, Ei - len = 10 + len = 100 mae = sum(1:len) do _ Nat = rand(8:16) Rs, Zs, z0 = M.rand_atenv(model, Nat) - Ei = M.evaluate(model, Rs, Zs, z0, ps, st)[1] - Ei_lin = M.evaluate(lin_ace, Rs, Zs, z0, ps_lin, st_lin)[1] + Ei = M.evaluate(model, Rs, Zs, z0, ps, st) + Ei_lin = M.evaluate(lin_ace, Rs, Zs, z0, ps_lin, st_lin) abs(Ei - Ei_lin) end mae /= len @@ -207,6 +207,7 @@ println() println() end + ## #= diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index f9fa57f4..cdb5b0f2 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -11,6 +11,8 @@ using Random, LuxCore, Test, ACEbase, LinearAlgebra using ACEbase.Testing: print_tf rng = Random.MersenneTwister(1234) +Random.seed!(1234) + ## max_level = 8 @@ -26,7 +28,7 @@ r = 3.0 Zi = basis._i2z[1] Zj = basis._i2z[2] Rnl, st1 = basis(r, Zi, Zj, ps, st) -Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) +Rnl, Rnl_d = M.evaluate_ed(basis, r, Zi, Zj, ps, st) @info("Test derivatives of Rnlrzz basis") @@ -36,7 +38,7 @@ for ntest = 1:20 Zi = rand(basis._i2z) Zj = rand(basis._i2z) U = randn(eltype(Rnl), length(Rnl)) - F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)[1]) + F(t) = dot(U, basis(r + t, Zi, Zj, ps, st)) dF(t) = dot(U, M.evaluate_ed(basis, r + t, Zi, Zj, ps, st)[2]) print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) end @@ -52,11 +54,11 @@ for ntest = 1:20 Rs, Zs, Z0 = M.rand_atenv(basis, Nat) rs = norm.(Rs) - Rnl = [ M.evaluate(basis, r, Z0, z, ps, st)[1] for (r, z) in zip(rs, Zs) ] - Rnl_b, st = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) + Rnl = [ M.evaluate(basis, r, Z0, z, ps, st) for (r, z) in zip(rs, Zs) ] + Rnl_b = M.evaluate_batched(basis, rs, Z0, Zs, ps, st) print_tf(@test all([Rnl_b[j, :] for j = 1:Nat] .≈ Rnl)) - Rnl_b2, ∇Rnl_b, _ = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) + Rnl_b2, ∇Rnl_b = M.evaluate_ed_batched(basis, rs, Z0, Zs, ps, st) ∇Rnl = [ M.evaluate_ed(basis, r, Z0, z, ps, st)[2] for (r, z) in zip(rs, Zs) ] @@ -66,19 +68,18 @@ end ## -basis_p = M.set_params(basis, ps) - -basis_spl = M.splinify(basis_p; nnodes = 30) +@info("quick splinification check") +basis_spl = M.splinify(basis, ps; nnodes = 100) ps_spl, st_spl = LuxCore.setup(rng, basis_spl) -Rnl, _ = basis(r, Zi, Zj, ps, st) -Rnl_spl, _ = basis_spl(r, Zi, Zj, ps_spl, st_spl) +Rnl = basis(r, Zi, Zj, ps, st) +Rnl_spl = basis_spl(r, Zi, Zj, ps_spl, st_spl) -norm(Rnl - Rnl_spl, Inf) +println_slim(@test norm(Rnl - Rnl_spl, Inf) < 1e-4) -Rnl, ∇Rnl, _ = M.evaluate_ed(basis, r, Zi, Zj, ps, st) -Rnl_spl, ∇Rnl_spl, _ = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) +Rnl, ∇Rnl = M.evaluate_ed(basis, r, Zi, Zj, ps, st) +Rnl_spl, ∇Rnl_spl = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) -norm(Rnl - Rnl_spl, Inf) -norm(∇Rnl - ∇Rnl_spl, Inf) +println_slim(@test norm(Rnl - Rnl_spl, Inf) < 1e-4 ) +println_slim(@test norm(∇Rnl - ∇Rnl_spl, Inf) < 1e-2 ) diff --git a/test/models/test_pair_basis.jl b/test/models/test_pair_basis.jl index 78c4a95a..8c9a7f0d 100644 --- a/test/models/test_pair_basis.jl +++ b/test/models/test_pair_basis.jl @@ -26,14 +26,14 @@ ps, st = LuxCore.setup(rng, basis) r = 3.0 Zi = basis._i2z[1] Zj = basis._i2z[2] -Rnl1, st1 = basis(r, Zi, Zj, ps, st) -Rnl, Rnl_d, st1 = M.evaluate_ed(basis, r, Zi, Zj, ps, st) +Rnl1 = basis(r, Zi, Zj, ps, st) +Rnl, Rnl_d = M.evaluate_ed(basis, r, Zi, Zj, ps, st) basis_spl = M.splinify(basis, ps) ps_spl, st_spl = LuxCore.setup(rng, basis_spl) -Rnl2, _ = M.evaluate(basis_spl, r, Zi, Zj, ps_spl, st_spl) -Rnl2, Rnl_d2, _ = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) +Rnl2 = M.evaluate(basis_spl, r, Zi, Zj, ps_spl, st_spl) +Rnl2, Rnl_d2 = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) ## # inspect the basis visually @@ -69,7 +69,7 @@ for ntest = 1:20 Zj = rand(basis_spl._i2z) r = 2.0 + rand() U = randn(length(basis_spl)) - F(t) = dot(U, basis_spl(r + t, Zi, Zj, ps, st)[1]) + F(t) = dot(U, basis_spl(r + t, Zi, Zj, ps, st)) dF(t) = dot(U, M.evaluate_ed(basis_spl, r + t, Zi, Zj, ps, st)[2]) print_tf(@test ACEbase.Testing.fdtest(F, dF, 0.0; verbose=false)) end From b05e35530a19b3134c6c2671953847cd3200c632 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 16:33:19 +0200 Subject: [PATCH 065/112] fix calculators --- src/models/Rnl_splines.jl | 7 ++----- src/models/calculators.jl | 12 ++++++------ test/models/test_ace.jl | 2 +- test/models/test_calculator.jl | 4 ++-- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/models/Rnl_splines.jl b/src/models/Rnl_splines.jl index 822e8562..4e9fb17e 100644 --- a/src/models/Rnl_splines.jl +++ b/src/models/Rnl_splines.jl @@ -95,16 +95,13 @@ end -function evaluate_ed_batched(basis::SplineRnlrzzBasis, +function evaluate_ed_batched!(Rnl, Rnl_d, + basis::SplineRnlrzzBasis, rs::AbstractVector{T}, Zi, Zs, ps, st ) where {T <: Real} @assert length(rs) == length(Zs) Rnl1, ∇Rnl1 = evaluate_ed(basis, rs[1], Zi, Zs[1], ps, st) - Rnl = zeros(T, length(rs), length(Rnl1)) - Rnl_d = zeros(T, length(rs), length(Rnl1)) - Rnl[1, :] .= Rnl1 - Rnl_d[1, :] .= ∇Rnl1 for j = 1:length(rs) Rnl_j, ∇Rnl_j = evaluate_ed(basis, rs[j], Zi, Zs[j], ps, st) diff --git a/src/models/calculators.jl b/src/models/calculators.jl index ec5649fe..9d10fc0c 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -53,10 +53,10 @@ cutoff_radius(V::ACEPotential{<: ACEModel}) = maximum(x.rcut for x in V.model.rbasis.rin0cuts) * distance_unit(V) eval_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) = - evaluate(V.model, Rs, Zs, z0, V.ps, V.st)[1] + evaluate(V.model, Rs, Zs, z0, V.ps, V.st) function eval_grad_site(V::ACEPotential{<: ACEModel}, Rs, Zs, z0) - v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, V.ps, V.st) + v, dv = evaluate_ed(V.model, Rs, Zs, z0, V.ps, V.st) return v, dv end @@ -92,7 +92,7 @@ function energy_forces_virial_serial( for i in domain Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) - v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, ps, st) + v, dv = evaluate_ed(V.model, Rs, Zs, z0, ps, st) energy += v for α = 1:length(Js) forces[Js[α]] -= dv[α] @@ -135,7 +135,7 @@ function energy_forces_virial( for i in sub_domain Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) - v, dv, st = evaluate_ed(V.model, Rs, Zs, z0, ps, st) + v, dv = evaluate_ed(V.model, Rs, Zs, z0, ps, st) energy += v * energy_unit(V) for α = 1:length(Js) forces[Js[α]] -= dv[α] * force_unit(V) @@ -255,7 +255,7 @@ function energy_forces_virial_basis( ) Js, Rs, Zs, z0 = get_neighbours(at, calc, nlist, 1) - E1, _ = evaluate_basis(calc.model, Rs, Zs, z0, ps, st) + E1 = evaluate_basis(calc.model, Rs, Zs, z0, ps, st) N_basis = length(E1) T = fl_type(calc.model) # this is ACE specific @@ -265,7 +265,7 @@ function energy_forces_virial_basis( for i in domain Js, Rs, Zs, z0 = get_neighbours(at, V, nlist, i) - v, dv, _ = evaluate_basis_ed(calc.model, Rs, Zs, z0, ps, st) + v, dv = evaluate_basis_ed(calc.model, Rs, Zs, z0, ps, st) for k = 1:N_basis E[k] += v[k] * energy_unit(calc) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index d8be7642..e99e5b5c 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); ## diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index ce054456..1c191a8e 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -1,6 +1,6 @@ -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); ## @@ -93,7 +93,7 @@ for ntest = 1:20 E = 0.0 for i = 1:length(at) Js, Rs, Zs, z0 = get_neighbours(at, calc, nlist, i) - E += M.evaluate(calc.model, Rs, Zs, z0, ps, st)[1] + E += M.evaluate(calc.model, Rs, Zs, z0, ps, st) end efv = M.energy_forces_virial(at, calc, ps, st) E2 = AtomsCalculators.potential_energy(at, calc) From 4d9b2011474deb56c91f2c990d0ec4d42c02f18b Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 17:23:56 +0200 Subject: [PATCH 066/112] all non-sklearn tests pass --- test/models/test_Rnl.jl | 2 +- test/models/test_calculator.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl index 9cd6a333..0fb3d2ea 100644 --- a/test/models/test_Rnl.jl +++ b/test/models/test_Rnl.jl @@ -1,7 +1,7 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); using ACEpotentials diff --git a/test/models/test_calculator.jl b/test/models/test_calculator.jl index 1c191a8e..26530fbc 100644 --- a/test/models/test_calculator.jl +++ b/test/models/test_calculator.jl @@ -1,6 +1,6 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); ## From 6843bb539359af7d43f16359dc930f4630cb26a5 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 4 Jul 2024 22:20:02 +0200 Subject: [PATCH 067/112] replace GeometryOptimisation in example script --- docs/src/newkernels/Project.toml | 3 +-- docs/src/newkernels/newkernels.jl | 13 +++---------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index 59c08d12..9c81af63 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -7,13 +7,12 @@ AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" -GeometryOptimization = "673bf261-a53d-43b9-876f-d3c1fc8329c2" +GeomOpt = "ca147568-c688-4a55-a13d-dbd284330f4b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Molly = "aa0f7f06-fcc0-5ec4-a7f3-a573f33f9c4c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/docs/src/newkernels/newkernels.jl b/docs/src/newkernels/newkernels.jl index 10cabf13..1beab92d 100644 --- a/docs/src/newkernels/newkernels.jl +++ b/docs/src/newkernels/newkernels.jl @@ -330,20 +330,13 @@ mae_test = sum(E_err, test_data) / length(test_data) # Trying some simple geometry optimization -# This seems to run but doesn't update the structure, there is also no -# documentation how to extract information from the result to do so. -# The following seems to work but this wants a PR in GeometryOptimization.jl -using GeometryOptimization, OptimizationOptimJL +using GeomOpt -@info("Short geometry optimization") at = rand_AlTi(2, 0.001) @show potential_energy(at, fit_calc) -solver = OptimizationOptimJL.LBFGS() -optim_options = (f_tol=1e-4, g_tol=1e-4, iterations=30, show_trace=false) -results = minimize_energy!(at, fit_calc; solver, optim_options...) -at_new = AtomsBuilder._set_positions(at, reinterpret(SVector{3, Float64}, results.u) * u"Å") -@show potential_energy(at_new, fit_calc) +_at_opt, info = GeomOpt.minimise(at, fit_calc; g_tol = 1e-4, g_calls_limit = 30 ) +@show potential_energy(_at_opt, fit_calc) # The last step is to run a simple MD simulation for just a 100 steps. From cfd6a042873b21b5ffedbbf6c955e399784be6ea Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 20 Jul 2024 10:44:15 -0700 Subject: [PATCH 068/112] fix version bounds --- Project.toml | 3 ++- test/Project.toml | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index b5a1a031..d38b838b 100644 --- a/Project.toml +++ b/Project.toml @@ -52,8 +52,9 @@ ACE1 = "0.12.2" ACE1x = "0.1.8" ACEfit = "0.1.4" ACEmd = "0.1.6" +AtomsCalculators = "0.1" ExtXYZ = "0.1.14" -Interpolations = "0.14.7, 0.15" +Interpolations = "0.15" JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" diff --git a/test/Project.toml b/test/Project.toml index d3225fe8..facb7542 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -24,6 +24,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ACE1 = "0.12.2" ACE1x = "0.1.8" -Interpolations = "0.14.7" -JuLIP = "0.13.9, 0.14.2" StaticArrays = "1" From 6708d6ab490457fde66c60c2df2be7e76ef54dfc Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 20 Jul 2024 12:18:15 -0700 Subject: [PATCH 069/112] towards linear ACE2 --- src/models/Rnl_learnable.jl | 47 +++++++++++++++++++++++++------ src/models/ace_heuristics.jl | 11 ++++++-- test/models/test_radialweights.jl | 30 ++++++++++++++++++++ 3 files changed, 77 insertions(+), 11 deletions(-) create mode 100644 test/models/test_radialweights.jl diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 373f41e8..333703cd 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -8,18 +8,15 @@ import LuxCore function LearnableRnlrzzBasis( zlist, polys, transforms, envelopes, rin0cuts, spec::AbstractVector{NT_NL_SPEC}; - # weights=nothing, + Winit = :glorot_normal, meta=Dict{String, Any}()) NZ = length(zlist) - # if isnothing(weights) - # weights = fill(nothing, (1,1,1,1)) - # end + meta["Winit"] = string(Winit) LearnableRnlrzzBasis(_convert_zlist(zlist), polys, _make_smatrix(transforms, NZ), _make_smatrix(envelopes, NZ), # -------------- - # weights, _make_smatrix(rin0cuts, NZ), collect(spec), meta) @@ -34,11 +31,26 @@ function initialparameters(rng::AbstractRNG, len_q = length(basis.polys) Wnlq = zeros(len_nl, len_q, NZ, NZ) - for i = 1:NZ, j = 1:NZ - Wnlq[:, :, i, j] .= glorot_normal(rng, Float64, len_nl, len_q) + ps = (Wnlq = Wnlq, ) + + if !haskey(basis.meta, "Winit") + @warn("No key Winit found for radial basis, use glorot_normal to initialize.") + basis.meta["Winit"] = "glorot_normal" end + + if basis.meta["Winit"] == "glorot_normal" + for i = 1:NZ, j = 1:NZ + Wnlq[:, :, i, j] .= glorot_normal(rng, Float64, len_nl, len_q) + end + + elseif basis.meta["Winit"] == "linear" + set_I_weights!(basis, ps) + + else + error("Unknown key Winit = $(basis.meta["Winit"]) to initialize radial basis weights.") + end - return (Wnlq = Wnlq, ) + return ps end function initialstates(rng::AbstractRNG, @@ -54,6 +66,25 @@ function LuxCore.parameterlength(basis::LearnableRnlrzzBasis) return len_nl * len_q * NZ * NZ end +""" +Set the radial weights as they would be in a linear ACE model. +""" +function set_I_weights!(rbasis::LearnableRnlrzzBasis, ps) + NZ = _get_nz(rbasis) + if NZ != 1 + error("set_I_weights! is currently only implemented for NZ = 1") + end + ps.Wnlq[:] .= 0 + for i = 1:NZ, j = 1:NZ + for (i_nl, nl) in enumerate(rbasis.spec) + if nl.n <= size(ps.Wnlq, 2) + ps.Wnlq[i_nl, nl.n, i, j] = 1 + end + end + end + return ps +end + # ------------------------------------------------------------ # EVALUATION INTERFACE diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 101bb620..daa502f1 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -42,7 +42,8 @@ function ace_learnable_Rnlrzz(; rin0cuts = _default_rin0cuts(elements), transforms = agnesi_transform.(rin0cuts, 2, 2), polys = :legendre, - envelopes = :poly2sx + envelopes = :poly2sx, + Winit = :glorot_normal, ) if elements == nothing error("elements must be specified!") @@ -87,7 +88,9 @@ function ace_learnable_Rnlrzz(; error("actual_maxn > length of polynomial basis") end - return LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, rin0cuts, spec) + return LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, + rin0cuts, spec; + Winit = Winit) end @@ -102,6 +105,7 @@ function ace_model(; elements = nothing, rbasis_type = :learnable, maxl = 30, # maxl, max are fairly high defaults maxn = 50, # that we will likely never reach + init_Wradial = :glorot_normal, # basis size parameters level = nothing, max_level = nothing, @@ -126,7 +130,8 @@ function ace_model(; elements = nothing, rbasis = ace_learnable_Rnlrzz(; max_level = max_level, level = level, maxl = maxl, maxn = maxn, elements = elements, - rin0cuts = rin0cuts) + rin0cuts = rin0cuts, + Winit = init_Wradial) else error("unknown rbasis_type = $rbasis_type") end diff --git a/test/models/test_radialweights.jl b/test/models/test_radialweights.jl new file mode 100644 index 00000000..cbb0291f --- /dev/null +++ b/test/models/test_radialweights.jl @@ -0,0 +1,30 @@ +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) + +## + +using ACEpotentials +M = ACEpotentials.Models + +using Random, LuxCore, Test, ACEbase, LinearAlgebra +using ACEbase.Testing: print_tf +rng = Random.MersenneTwister(1234) + +## + +max_level = 8 +level = M.TotalDegree() +maxl = 3; maxn = max_level; +elements = (:Si, ) +order = 3 + +model = M.ace_model(; elements = elements, order = order, Ytype = :solid, + level = level, max_level = max_level, maxl = 8, pair_maxn = 15, + init_WB = :zeros, + init_Wpair = :zeros, + init_Wradial = :linear) + +ps, st = LuxCore.setup(rng, model) + + +## + From 24d6f01d4916fd1d9a2e9233b254852b33e5ea04 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 21 Jul 2024 03:24:16 -0700 Subject: [PATCH 070/112] linear ACE weights for multi-species --- src/models/Rnl_learnable.jl | 5 ++--- test/models/test_radialweights.jl | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 333703cd..1980b66c 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -70,10 +70,9 @@ end Set the radial weights as they would be in a linear ACE model. """ function set_I_weights!(rbasis::LearnableRnlrzzBasis, ps) + # Rnl(r, Z1, Z2) = ∑_q W[(nl), q, Z1, Z2] * P_q(r) + # For linear models this becomes Rnl(r, Z1, Z2) = Pn(r) NZ = _get_nz(rbasis) - if NZ != 1 - error("set_I_weights! is currently only implemented for NZ = 1") - end ps.Wnlq[:] .= 0 for i = 1:NZ, j = 1:NZ for (i_nl, nl) in enumerate(rbasis.spec) diff --git a/test/models/test_radialweights.jl b/test/models/test_radialweights.jl index cbb0291f..76758029 100644 --- a/test/models/test_radialweights.jl +++ b/test/models/test_radialweights.jl @@ -28,3 +28,20 @@ ps, st = LuxCore.setup(rng, model) ## +max_level = 8 +level = M.TotalDegree() +maxl = 3; maxn = max_level + 4; +elements = (:Si, :O) +order = 3 + +model = M.ace_model(; elements = elements, order = order, Ytype = :solid, + level = level, max_level = max_level, maxl = 8, maxn = maxn, + pair_maxn = 15, + init_WB = :zeros, + init_Wpair = :zeros, + init_Wradial = :linear) + +ps, st = LuxCore.setup(rng, model) + +display(ps.rbasis.Wnlq[:,:,1,1]) +size(ps.rbasis.Wnlq) \ No newline at end of file From 2f4c84b6db49de553e662dc6cdc2b2e9b70c9c03 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 21 Jul 2024 03:34:18 -0700 Subject: [PATCH 071/112] allow maxq > maxn --- src/models/ace_heuristics.jl | 19 ++++++++++++++++--- test/models/test_ace.jl | 4 ++-- test/models/test_radialweights.jl | 5 ++++- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index daa502f1..c42e99db 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -37,6 +37,8 @@ function ace_learnable_Rnlrzz(; level = nothing, maxl = nothing, maxn = nothing, + maxq_fact = 1.5, + maxq = :auto, elements = nothing, spec = nothing, rin0cuts = _default_rin0cuts(elements), @@ -62,10 +64,18 @@ function ace_learnable_Rnlrzz(; # now the actual maxn is the maximum n in the spec actual_maxn = maximum([ s.n for s in spec ]) + + if maxq == :auto + maxq = ceil(Int, actual_maxn * maxq_fact) + end + + if maxq < actual_maxn + @warn("maxq < actual_maxn; this results in linear dependence") + end if polys isa Symbol if polys == :legendre - polys = Polynomials4ML.legendre_basis(actual_maxn) + polys = Polynomials4ML.legendre_basis(maxq) else error("unknown polynomial type : $polys") end @@ -99,12 +109,14 @@ function ace_model(; elements = nothing, order = nothing, Ytype = :solid, E0s = nothing, - rin0cuts = nothing, + rin0cuts = :auto, # radial basis rbasis = nothing, rbasis_type = :learnable, maxl = 30, # maxl, max are fairly high defaults maxn = 50, # that we will likely never reach + maxq_fact = 1.5, + maxq = :auto, init_Wradial = :glorot_normal, # basis size parameters level = nothing, @@ -117,7 +129,7 @@ function ace_model(; elements = nothing, rng = Random.default_rng(), ) - if rin0cuts == nothing + if rin0cuts == :auto rin0cuts = _default_rin0cuts(elements) else NZ = length(elements) @@ -129,6 +141,7 @@ function ace_model(; elements = nothing, if rbasis_type == :learnable rbasis = ace_learnable_Rnlrzz(; max_level = max_level, level = level, maxl = maxl, maxn = maxn, + maxq_fact = maxq_fact, maxq = maxq, elements = elements, rin0cuts = rin0cuts, Winit = init_Wradial) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index e99e5b5c..fbccb7a9 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -27,9 +27,9 @@ order = 3 @info("Test ybasis of the Model is used correctly") msolid = M.ace_model(; elements = elements, order = order, Ytype = :solid, -level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :glorot_normal) + level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :glorot_normal) mspherical = M.ace_model(; elements = elements, order = order, Ytype = :spherical, -level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :zero) + level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :zero) ps, st = LuxCore.setup(rng, msolid) for ntest = 1:30 diff --git a/test/models/test_radialweights.jl b/test/models/test_radialweights.jl index 76758029..7cb8bd0b 100644 --- a/test/models/test_radialweights.jl +++ b/test/models/test_radialweights.jl @@ -13,18 +13,21 @@ rng = Random.MersenneTwister(1234) max_level = 8 level = M.TotalDegree() -maxl = 3; maxn = max_level; +maxl = 3; maxn = max_level; maxq_fact = 2; elements = (:Si, ) order = 3 model = M.ace_model(; elements = elements, order = order, Ytype = :solid, level = level, max_level = max_level, maxl = 8, pair_maxn = 15, + maxq_fact = maxq_fact, init_WB = :zeros, init_Wpair = :zeros, init_Wradial = :linear) ps, st = LuxCore.setup(rng, model) +@show size(ps.rbasis.Wnlq) +display(ps.rbasis.Wnlq[:, :, 1, 1]) ## From 5d8de53f8f81b28c4d475410653fcd65825b1dc1 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 24 Jul 2024 08:23:26 -0700 Subject: [PATCH 072/112] backup --- Project.toml | 1 + docs/src/newkernels/Project.toml | 3 -- docs/src/newkernels/linear.jl | 59 ++++++++++++++++++++++++++++++++ test/test_silicon.jl | 2 +- 4 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 docs/src/newkernels/linear.jl diff --git a/Project.toml b/Project.toml index d38b838b..83b1c210 100644 --- a/Project.toml +++ b/Project.toml @@ -61,6 +61,7 @@ Reexport = "1" StaticArrays = "1" YAML = "0.4" julia = "1.9" +EquivariantModels = "0.0.4" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index 9c81af63..c6a407ae 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -5,12 +5,9 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" -EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" GeomOpt = "ca147568-c688-4a55-a13d-dbd284330f4b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -Molly = "aa0f7f06-fcc0-5ec4-a7f3-a573f33f9c4c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl new file mode 100644 index 00000000..f3b692b9 --- /dev/null +++ b/docs/src/newkernels/linear.jl @@ -0,0 +1,59 @@ +# This script is to roughly document how to use the new model implementations +# I'll try to explain what can be done and what is missing along the way. +# I am + +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Zygote, Optimisers +import Random: seed!, MersenneTwister + +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +Z0 = :Si +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") + +# because the new implementation is experimental, it is not exported, +# so I create a little shortcut to have easy access. + +M = ACEpotentials.Models + +# The new implementation tries to follow Lux rules, which likes to be +# disciplined and explicit about random numbers + +rng = MersenneTwister(1234) + +# First we create an ACE1 style potential with some standard parameters + +elements = (Z0,) +order = 3 +totaldegree = 6 +rcut = 5.5 + +@profview begin + model1 = acemodel(elements = elements, + order = order, + totaldegree = totaldegree, + rcut = rcut, ) +end + +# now we create an ACE2 style model that should behave similarly + +model2 = M.ace_model(; elements = elements, + order = order, # correlation order + Ytype = :solid, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = :zeros, # how to initialize the pair potential parameters + init_Wradial = :linear + ) + +# Example dataset + +data, _, meta = ACEpotentials.example_dataset("TiAl_tutorial") +data = FlexibleSystem.(data) +train_data = data[1:5:end] +test_data = data[2:5:end] +data_keys = (E_key = :energy, F_key = :force, V_key = :virial) +weights = (wE = 1.0/u"eV", wF = 0.1 / u"eV/Å", wV = 0.1/u"eV") diff --git a/test/test_silicon.jl b/test/test_silicon.jl index 371ede77..6d2351e4 100644 --- a/test/test_silicon.jl +++ b/test/test_silicon.jl @@ -4,7 +4,7 @@ using LazyArtifacts using PythonCall using Test -### ----- setup ----- +## ----- setup ----- @warn "test_silicon not fully converted yet." model = acemodel(elements = [:Si], From f86835743aa23e4e273965eee91f4e0f408dd958 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 24 Jul 2024 21:59:16 -0700 Subject: [PATCH 073/112] example script for linear fit --- docs/src/newkernels/linear.jl | 179 ++++++++++++++++++++++++++++------ 1 file changed, 149 insertions(+), 30 deletions(-) diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index f3b692b9..8064a7b1 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -2,58 +2,177 @@ # I'll try to explain what can be done and what is missing along the way. # I am +using Random using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, - Unitful, Zygote, Optimisers -import Random: seed!, MersenneTwister + Unitful, Zygote, Optimisers, Folds # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset Z0 = :Si train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:3:end] # because the new implementation is experimental, it is not exported, # so I create a little shortcut to have easy access. M = ACEpotentials.Models -# The new implementation tries to follow Lux rules, which likes to be -# disciplined and explicit about random numbers - -rng = MersenneTwister(1234) - # First we create an ACE1 style potential with some standard parameters -elements = (Z0,) +elements = [Z0,] order = 3 -totaldegree = 6 +totaldegree = 10 rcut = 5.5 -@profview begin - model1 = acemodel(elements = elements, +model1 = acemodel(elements = elements, order = order, totaldegree = totaldegree, rcut = rcut, ) -end # now we create an ACE2 style model that should behave similarly +# this essentially reproduces the rcut = 5.5, we may want a nicer way to +# achieve this. + +rin0cuts = M._default_rin0cuts(elements; rcutfactor = 2.3) + model2 = M.ace_model(; elements = elements, - order = order, # correlation order - Ytype = :solid, # solid vs spherical harmonics - level = M.TotalDegree(), # how to calculate the weights to give to a basis function - max_level = totaldegree, # maximum level of the basis functions - pair_maxn = totaldegree, # maximum number of basis functions for the pair potential - init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = :zeros, # how to initialize the pair potential parameters - init_Wradial = :linear - ) - -# Example dataset - -data, _, meta = ACEpotentials.example_dataset("TiAl_tutorial") -data = FlexibleSystem.(data) -train_data = data[1:5:end] -test_data = data[2:5:end] -data_keys = (E_key = :energy, F_key = :force, V_key = :virial) -weights = (wE = 1.0/u"eV", wF = 0.1 / u"eV/Å", wV = 0.1/u"eV") + order = order, # correlation order + Ytype = :solid, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = :zeros, # how to initialize the pair potential parameters + init_Wradial = :linear, + rin0cuts = rin0cuts, + ) + +# wrap the model into a calculator, which turns it into a potential... + +calc_model2 = M.ACEPotential(model2) + +# Fit the ACE1 model + +# set weights for energy, forces virials +weights = Dict("default" => Dict("E" => 30.0, "F" => 1.0 , "V" => 1.0 ),); +# specify a solver +solver=ACEfit.TruncatedSVD(; rtol = 1e-8) + +acefit!(model1, train; solver=solver) + + + +# Fit the ACE2 model - this still needs a bit of hacking to convert everything +# to the new framework. +# - convert the data to AtomsBase +# - use a different interface to specify data weights and keys +# (this needs to be brough in line with the ACEpotentials framework) +# - rewrite the assembly for the LSQ system from scratch (but this is easy) + +train2 = FlexibleSystem.(train) +test2 = FlexibleSystem.(test) +data_keys = (E_key = :energy, F_key = :force, ) +weights = (wE = 30.0/u"eV", wF = 1.0 / u"eV/Å", ) + +function local_lsqsys(calc, at, ps, st, weights, keys) + efv = M.energy_forces_virial_basis(at, calc, ps, st) + + # There are no E0s in this dataset! + # # compute the E0s contribution. This needs to be done more + # # elegantly and a stacked model would solve this problem. + # E0 = sum( calc.model.E0s[M._z2i(calc.model, z)] + # for z in AtomsBase.atomic_number(at) ) * u"eV" + + # energy + wE = weights[:wE] + E_dft = at.data[data_keys.E_key] * u"eV" + y_E = wE * E_dft # (E_dft - E0) + A_E = wE * efv.energy' + + # forces + wF = weights[:wF] + F_dft = at.data[data_keys.F_key] * u"eV/Å" + y_F = wF * reinterpret(eltype(F_dft[1]), F_dft) + A_F = wF * reinterpret(eltype(efv.forces[1]), efv.forces) + + # # virial + # wV = weights[:wV] + # V_dft = at.data[data_keys.V_key] * u"eV" + # y_V = wV * V_dft[:] + # # display( reinterpret(eltype(efv.virial), efv.virial) ) + # A_V = wV * reshape(reinterpret(eltype(efv.virial[1]), efv.virial), 9, :) + + return vcat(A_E, A_F), vcat(y_E, y_F) +end + + +function assemble_lsq(calc, data, weights, data_keys; + rng = Random.MersenneTwister(1234), + executor = Folds.ThreadedEx()) + ps, st = Lux.setup(rng, calc) + blocks = Folds.map(at -> local_lsqsys(calc, at, ps, st, + weights, data_keys), + data, executor) + A = reduce(vcat, [b[1] for b in blocks]) + y = reduce(vcat, [b[2] for b in blocks]) + return A, y +end + + +A, y = assemble_lsq(calc_model2, train2[1:10], weights, data_keys) + +θ = ACEfit.trunc_svd(svd(A), y, 1e-8) +ps, st = Lux.setup(rng, calc_model2) + +# the next step is a hack. This should be automatable, probably using Lux.freeze. +# But I couldn't quite figure out how to use that. +# Here I'm manually constructing a parameters NamedTuple with rbasis removed. +# then I'm using the destructure / restructure method from Optimizers to +# convert θ into a namedtuple. + +ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = ps.pairbasis, rbasis = NamedTuple()) +_, restruct = destructure(ps_lin) +ps_lin_fit = restruct(θ) +ps_fit = deepcopy(ps) +ps_fit.WB[:] = ps_lin_fit.WB[:] +ps_fit.Wpair[:] = ps_lin_fit.Wpair[:] +calc_model2_fit = M.ACEPotential(model2, ps_fit, st) + + +# Now we can compare errors? +# to make sure we are comparing exactly the same thing, we implement this +# from scratch here ... + +function EF_err(sys::JuLIP.Atoms, calc) + E = JuLIP.energy(calc, sys) + F = JuLIP.forces(calc, sys) + E_ref = JuLIP.get_data(sys, "energy") + F_ref = JuLIP.get_data(sys, "force") + return abs(E - E_ref) / length(sys), norm.(F - F_ref) +end + +function EF_err(sys::AtomsBase.AbstractSystem, calc) + efv = M.energy_forces_virial(sys, calc_model2_fit) + F_ustrip = [ ustrip.(f) for f in efv.forces ] + E_ref = sys.data[:energy] + F_ref = sys.data[:force] + return abs(ustrip(efv.energy) - E_ref) / length(sys), norm.(F_ustrip - F_ref) +end + +function rmse(test, calc) + E_errs = Float64[] + F_errs = Float64[] + for sys in test + E_err, F_err = EF_err(sys, calc) + push!(E_errs, E_err) + append!(F_errs, F_err) + end + return norm(E_errs) / sqrt(length(E_errs)), + norm(F_errs) / sqrt(length(F_errs)) +end + + +E_rmse_1, F_rmse_1 = rmse(test, model1.potential) +E_rmse_2, F_rmse_2 = rmse(test2, calc_model2_fit) From c53d0a2bdcd4009b55c57a6a1a1aeed58dd2b5dc Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 25 Jul 2024 13:26:23 -0700 Subject: [PATCH 074/112] add rbasis_analysis script --- docs/src/newkernels/linear.jl | 9 +- docs/src/newkernels/rbasis_analysis.jl | 111 +++++++++++++++++++++++++ src/models/ace_heuristics.jl | 4 + 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 docs/src/newkernels/rbasis_analysis.jl diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index 8064a7b1..4edb990b 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -4,7 +4,7 @@ using Random using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, - Unitful, Zygote, Optimisers, Folds + Unitful, Zygote, Optimisers, Folds, Printf # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset @@ -28,6 +28,7 @@ rcut = 5.5 model1 = acemodel(elements = elements, order = order, totaldegree = totaldegree, + pure = false, pure2b = false, rcut = rcut, ) # now we create an ACE2 style model that should behave similarly @@ -176,3 +177,9 @@ end E_rmse_1, F_rmse_1 = rmse(test, model1.potential) E_rmse_2, F_rmse_2 = rmse(test2, calc_model2_fit) + + +@printf("Model | E | F \n") +@printf(" ACE1 | %.2e | %.2e \n", E_rmse_1, F_rmse_1) +@printf(" ACE2 | %.2e | %.2e \n", E_rmse_2, F_rmse_2) + diff --git a/docs/src/newkernels/rbasis_analysis.jl b/docs/src/newkernels/rbasis_analysis.jl new file mode 100644 index 00000000..8252b328 --- /dev/null +++ b/docs/src/newkernels/rbasis_analysis.jl @@ -0,0 +1,111 @@ +# This script is to explore the differences between the ACE1 models and the new +# models. This is to help bring the two to feature parity so that ACE1 +# can be retired. + +using Random +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Zygote, Optimisers, Folds, Plots + +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +Z0 = :Si +z1 = AtomicNumber(Z0) +z2 = Int(z1) + +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:3:end] + +# because the new implementation is experimental, it is not exported, +# so I create a little shortcut to have easy access. + +M = ACEpotentials.Models + +# First we create an ACE1 style potential with some standard parameters + +elements = [Z0,] +order = 3 +totaldegree = 10 +rcut = 5.5 + +model1 = acemodel(elements = elements, + order = order, + totaldegree = totaldegree, + pure = false, + pure2b = false, + rcut = rcut, ) + +# now we create an ACE2 style model that should behave similarly + +# this essentially reproduces the rcut = 5.5, we may want a nicer way to +# achieve this. + +rin0cuts = M._default_rin0cuts(elements; rcutfactor = 2.3) + + +model2 = M.ace_model(; elements = elements, + order = order, # correlation order + Ytype = :solid, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = :zeros, # how to initialize the pair potential parameters + init_Wradial = :linear, + rin0cuts = rin0cuts, + ) + +ps, st = Lux.setup(rng, model2) +ps_r = ps.rbasis +st_r = st.rbasis + +# wrap the model into a calculator, which turns it into a potential... + +calc_model2 = M.ACEPotential(model2) + +# extrac the radial basis +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = model2.rbasis + + +## + +rr = range(0.001, rcut + 0.5, length=200) +R1 = reduce(hcat, [ JuLIP.evaluate(rbasis1, r, z1, z1) for r in rr ]) +R2 = reduce(hcat, [ rbasis2(r, z2, z2, ps_r, st_r)[1:10] for r in rr]) + +# normalize +for n = 1:10 + R1[n, :] = R1[n, :] / maximum(abs, R1[n, :]) + R2[n, :] = R2[n, :] / maximum(abs, R2[n, :]) +end + +plt = plot() +for n = 1:4 + plot!(plt, rr, R1[n, :], c = n, label="R1_$n") + plot!(plt, rr, R2[n, :], c = n, ls = :dash, label="R2_$n") +end +plt + + +## + +nmax = 8 + +pairb1 = model1.basis.BB[1].J[1] +P1 = reduce(hcat, [ JuLIP.evaluate(pairb1, r, z1, z1)[1:nmax] for r in rr ]) + +pairb2 = model2.pairbasis +P2 = reduce(hcat, [ pairb2(r, z2, z2, NamedTuple(), NamedTuple())[1:nmax] for r in rr ]) + +# truncate +P1 = min.(max.(P1, -100), 100) +P2 = min.(max.(P2, -100), 100) + +plt = plot(; ylims = (-1.0, 3.0)) +for n = 1:4 + plot!(plt, rr, P1[n, :], c = n, label="P1_$n") + plot!(plt, rr, P2[n, :], c = n, ls = :dash, label="P2_$n") +end +plt + diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index c42e99db..56900274 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -125,6 +125,7 @@ function ace_model(; elements = nothing, # pair basis pair_maxn = nothing, pair_basis = :auto, + pair_learnable = false, init_Wpair = :zeros, rng = Random.default_rng(), ) @@ -165,6 +166,9 @@ function ace_model(; elements = nothing, envelopes = :poly1sr ) end + if !pair_learnable + pair_basis.meta["Winit"] = "linear" + end ps_pair = initialparameters(rng, pair_basis) pair_basis_spl = splinify(pair_basis, ps_pair) From 62422334d9f07bb6266f03e77c42e8d67eb028cc Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 25 Jul 2024 13:35:46 -0700 Subject: [PATCH 075/112] fix lux bug in linear.jl --- docs/src/newkernels/linear.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index 4edb990b..98601787 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -5,11 +5,12 @@ using Random using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, Unitful, Zygote, Optimisers, Folds, Printf +rng = Random.GLOBAL_RNG # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset -Z0 = :Si +Z0 = :Cu train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") train = train[1:3:end] From 00ff95bf5cc8f884e53e4166e6431ad0901e0390 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 28 Jul 2024 11:51:26 -0700 Subject: [PATCH 076/112] perfect math of pair and radial --- docs/src/newkernels/linear.jl | 33 ++++++++-- docs/src/newkernels/rbasis_analysis.jl | 84 ++++++++++++++++++++++---- src/models/Rnl_basis.jl | 4 +- src/models/Rnl_learnable.jl | 4 ++ src/models/ace.jl | 10 +-- src/models/ace_heuristics.jl | 18 +++--- 6 files changed, 120 insertions(+), 33 deletions(-) diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index 98601787..8b5e32c1 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -10,7 +10,7 @@ rng = Random.GLOBAL_RNG # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset -Z0 = :Cu +Z0 = :Si train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") train = train[1:3:end] @@ -28,8 +28,12 @@ rcut = 5.5 model1 = acemodel(elements = elements, order = order, + pin = 2, pcut = 2, + transform = (:agnesi, 2, 2), totaldegree = totaldegree, - pure = false, pure2b = false, + pure = false, + pure2b = false, + pair_envelope = (:r, 1), rcut = rcut, ) # now we create an ACE2 style model that should behave similarly @@ -37,7 +41,8 @@ model1 = acemodel(elements = elements, # this essentially reproduces the rcut = 5.5, we may want a nicer way to # achieve this. -rin0cuts = M._default_rin0cuts(elements; rcutfactor = 2.3) +rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) +rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) model2 = M.ace_model(; elements = elements, order = order, # correlation order @@ -46,11 +51,27 @@ model2 = M.ace_model(; elements = elements, max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = :zeros, # how to initialize the pair potential parameters + init_Wpair = "linear", # how to initialize the pair potential parameters init_Wradial = :linear, + pair_transform = (:agnesi, 1, 3), + pair_learnable = true, rin0cuts = rin0cuts, ) +ps, st = Lux.setup(rng, model2) +ps_r = ps.rbasis +st_r = st.rbasis + +# extract the radial basis +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = model2.rbasis + +k = length(rbasis1.J.A) +rbasis1.J.A[:] .= rbasis2.polys.A[1:k] +rbasis1.J.B[:] .= rbasis2.polys.B[1:k] +rbasis1.J.C[:] .= rbasis2.polys.C[1:k] + + # wrap the model into a calculator, which turns it into a potential... calc_model2 = M.ACEPotential(model2) @@ -134,8 +155,8 @@ ps, st = Lux.setup(rng, calc_model2) # then I'm using the destructure / restructure method from Optimizers to # convert θ into a namedtuple. -ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = ps.pairbasis, rbasis = NamedTuple()) -_, restruct = destructure(ps_lin) +ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = NamedTuple(), rbasis = NamedTuple()) +_θ, restruct = destructure(ps_lin) ps_lin_fit = restruct(θ) ps_fit = deepcopy(ps) ps_fit.WB[:] = ps_lin_fit.WB[:] diff --git a/docs/src/newkernels/rbasis_analysis.jl b/docs/src/newkernels/rbasis_analysis.jl index 8252b328..3a1dc505 100644 --- a/docs/src/newkernels/rbasis_analysis.jl +++ b/docs/src/newkernels/rbasis_analysis.jl @@ -5,6 +5,7 @@ using Random using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, Unitful, Zygote, Optimisers, Folds, Plots +rng = Random.GLOBAL_RNG # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset @@ -30,9 +31,11 @@ rcut = 5.5 model1 = acemodel(elements = elements, order = order, + transform = (:agnesi, 2, 2), totaldegree = totaldegree, pure = false, pure2b = false, + pair_envelope = (:r, 1), rcut = rcut, ) # now we create an ACE2 style model that should behave similarly @@ -40,8 +43,8 @@ model1 = acemodel(elements = elements, # this essentially reproduces the rcut = 5.5, we may want a nicer way to # achieve this. -rin0cuts = M._default_rin0cuts(elements; rcutfactor = 2.3) - +rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) +rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) model2 = M.ace_model(; elements = elements, order = order, # correlation order @@ -50,8 +53,10 @@ model2 = M.ace_model(; elements = elements, max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = :zeros, # how to initialize the pair potential parameters + init_Wpair = "linear", # how to initialize the pair potential parameters init_Wradial = :linear, + pair_transform = (:agnesi, 1, 3), + pair_learnable = true, rin0cuts = rin0cuts, ) @@ -59,21 +64,56 @@ ps, st = Lux.setup(rng, model2) ps_r = ps.rbasis st_r = st.rbasis +# extract the radial basis +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = model2.rbasis +k = length(rbasis1_.J.A) + +# transform old coefficients to new coefficients to make them match + +rbasis1.J.A[:] .= rbasis2.polys.A[1:k] +rbasis1.J.B[:] .= rbasis2.polys.B[1:k] +rbasis1.J.C[:] .= rbasis2.polys.C[1:k] +rbasis1.J.A[2] /= rbasis1.J.A[1] +rbasis1.J.B[2] /= rbasis1.J.A[1] + # wrap the model into a calculator, which turns it into a potential... calc_model2 = M.ACEPotential(model2) -# extrac the radial basis -rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -rbasis2 = model2.rbasis +## + +# sample points +rr = range(0.001, rcut, length=200) + +@info("check the transforms are identical") +t1 = ACE1.Transforms.transform.(Ref(rbasis1.trans), rr, z1, z1) +t2 = rbasis2.transforms[1].(rr) +@show t1 ≈ t2 + +## + +@info("Check the raw polynomials") +xx = range(-1, 1, length=200) +J1 = reduce(hcat, ACE1.evaluate.(Ref(_J), xx))' +J2 = rbasis2.polys(xx) +J1 ≈ J2[:, 1:size(J1, 2)] + +plt = plot() +for n = 1:5 + plot!(J1[:, n], c = n, label = "J1,$n") + plot!(J2[:, n], c = n, ls = :dash, label = "J2,$n") +end +plt ## -rr = range(0.001, rcut + 0.5, length=200) R1 = reduce(hcat, [ JuLIP.evaluate(rbasis1, r, z1, z1) for r in rr ]) R2 = reduce(hcat, [ rbasis2(r, z2, z2, ps_r, st_r)[1:10] for r in rr]) +# R1 = R1_ * Diagonal(R2[1,:]) + # normalize for n = 1:10 R1[n, :] = R1[n, :] / maximum(abs, R1[n, :]) @@ -90,13 +130,13 @@ plt ## -nmax = 8 - pairb1 = model1.basis.BB[1].J[1] -P1 = reduce(hcat, [ JuLIP.evaluate(pairb1, r, z1, z1)[1:nmax] for r in rr ]) - pairb2 = model2.pairbasis -P2 = reduce(hcat, [ pairb2(r, z2, z2, NamedTuple(), NamedTuple())[1:nmax] for r in rr ]) +ps_pair = ps.pairbasis +st_pair = st.pairbasis + +P1 = reduce(hcat, [ JuLIP.evaluate(pairb1, r, z1, z1)[1:nmax] for r in rr ]) +P2 = reduce(hcat, [ pairb2(r, z2, z2, ps_pair, st_pair)[1:nmax] for r in rr ]) # truncate P1 = min.(max.(P1, -100), 100) @@ -109,3 +149,23 @@ for n = 1:4 end plt + +## +@info(" Confirm that the pair bases span the same space ") +# they do to within almost machine precision +# ( in fact the transormation matrix C says they differ only +# up to a sign and scalar factor. ) + +rrr = range(1.0, 5.0, length=100) +P1 = reduce(hcat, [ JuLIP.evaluate(pairb1, r, z1, z1)[1:nmax] for r in rrr ]) +P2 = reduce(hcat, [ pairb2(r, z2, z2, ps_pair, st_pair)[1:nmax] for r in rrr ]) +C = P2' \ P1' +@show norm(P1' - P2' * C) + +@info("Confirm the radial bases span the same space") + +R1 = reduce(hcat, [ JuLIP.evaluate(rbasis1, r, z1, z1)[1:nmax] for r in rrr ]) +R2 = reduce(hcat, [ rbasis2(r, z2, z2, ps_r, st_r)[1:nmax] for r in rrr]) +C = R2' \ R1' +@show norm(R1' - R2' * C) + diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 0303cddb..12f1c0c9 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -27,7 +27,7 @@ const SPL_OF_SVEC{DIM, T} = } -struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, T} <: AbstractExplicitLayer +mutable struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, T} <: AbstractExplicitLayer _i2z::NTuple{NZ, Int} polys::TPOLY transforms::SMatrix{NZ, NZ, TT} @@ -42,7 +42,7 @@ struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, T} <: AbstractExplicitLayer end -struct SplineRnlrzzBasis{NZ, TT, TENV, LEN, T} <: AbstractExplicitLayer +mutable struct SplineRnlrzzBasis{NZ, TT, TENV, LEN, T} <: AbstractExplicitLayer _i2z::NTuple{NZ, Int} transforms::SMatrix{NZ, NZ, TT} envelopes::SMatrix{NZ, NZ, TENV} diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 1980b66c..622c0404 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -46,6 +46,10 @@ function initialparameters(rng::AbstractRNG, elseif basis.meta["Winit"] == "linear" set_I_weights!(basis, ps) + elseif basis.meta["Winit"] == "zeros" + @warn("Setting inner basis weights to zero.") + Wnlq[:] .= 0 + else error("Unknown key Winit = $(basis.meta["Winit"]) to initialize radial basis weights.") end diff --git a/src/models/ace.jl b/src/models/ace.jl index 2fe52a08..b8f1e52a 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -13,7 +13,7 @@ import EquivariantModels # ------------------------------------------------------------ # ACE MODEL SPECIFICATION -struct ACEModel{NZ, TRAD, TY, TTEN, T, TPAIR} <: AbstractExplicitContainerLayer{(:rbasis,)} +mutable struct ACEModel{NZ, TRAD, TY, TTEN, T, TPAIR} <: AbstractExplicitContainerLayer{(:rbasis,)} _i2z::NTuple{NZ, Int} # -------------- # embeddings of the particles @@ -191,7 +191,7 @@ function initialparameters(rng::AbstractRNG, # generate pair basis parameters n_pair = length(model.pairbasis) Wpair = zeros(n_pair, NZ) - winit_pair = _W_init(model.meta["init_Wpair"]) + winit_pair = _W_init(model.meta["init_WB"]) for iz = 1:NZ Wpair[:, iz] .= winit_pair(rng, Float64, n_pair) @@ -376,9 +376,9 @@ function evaluate_ed(model::ACEModel, Ei += dot(Apair, Wp_i) # pullback --- I'm now assuming that the pair basis is not learnable. - if !( ps.pairbasis == NamedTuple() ) - error("I'm currently assuming the pair basis is not learnable.") - end + # if !( ps.pairbasis == NamedTuple() ) + # error("I'm currently assuming the pair basis is not learnable.") + # end for j = 1:length(Rs) ∇Ei[j] += dot(Wp_i, (@view dRpair[j, :])) * (Rs[j] / rs[j]) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 56900274..6876ce60 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -126,7 +126,8 @@ function ace_model(; elements = nothing, pair_maxn = nothing, pair_basis = :auto, pair_learnable = false, - init_Wpair = :zeros, + pair_transform = (:agnesi, 1, 4), + init_Wpair = "linear", rng = Random.default_rng(), ) @@ -162,20 +163,21 @@ function ace_model(; elements = nothing, maxl = 0, maxn = pair_maxn, rin0cuts = rbasis.rin0cuts, - transforms = (:agnesi, 1, 4), + transforms = pair_transform, envelopes = :poly1sr ) end - if !pair_learnable - pair_basis.meta["Winit"] = "linear" - end - ps_pair = initialparameters(rng, pair_basis) - pair_basis_spl = splinify(pair_basis, ps_pair) + pair_basis.meta["Winit"] = init_Wpair + + if !pair_learnable + ps_pair = initialparameters(rng, pair_basis) + pair_basis = splinify(pair_basis, ps_pair) + end AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, level = level, max_level = max_level) - model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis_spl, E0s) + model = ace_model(rbasis, Ytype, AA_spec, level, pair_basis, E0s) model.meta["init_WB"] = String(init_WB) model.meta["init_Wpair"] = String(init_Wpair) From faaae9b31a4836aae1d8f0012299c58f4da293ea Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 28 Jul 2024 16:04:48 -0700 Subject: [PATCH 077/112] confirm Ylms span the same space --- docs/src/newkernels/ylm_analysis.jl | 104 ++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 docs/src/newkernels/ylm_analysis.jl diff --git a/docs/src/newkernels/ylm_analysis.jl b/docs/src/newkernels/ylm_analysis.jl new file mode 100644 index 00000000..fde0e675 --- /dev/null +++ b/docs/src/newkernels/ylm_analysis.jl @@ -0,0 +1,104 @@ +# This script is to explore the differences between the ACE1 models and the new +# models. This is to help bring the two to feature parity so that ACE1 +# can be retired. + +using Random +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Zygote, Optimisers, Folds, Plots +rng = Random.GLOBAL_RNG + +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +Z0 = :Si +z1 = AtomicNumber(Z0) +z2 = Int(z1) + +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:3:end] + +# because the new implementation is experimental, it is not exported, +# so I create a little shortcut to have easy access. + +M = ACEpotentials.Models + +# First we create an ACE1 style potential with some standard parameters + +elements = [Z0,] +order = 3 +totaldegree = 10 +rcut = 5.5 + +model1 = acemodel(elements = elements, + order = order, + transform = (:agnesi, 2, 2), + totaldegree = totaldegree, + pure = false, + pure2b = false, + pair_envelope = (:r, 1), + rcut = rcut, ) + +# now we create an ACE2 style model that should behave similarly + +# this essentially reproduces the rcut = 5.5, we may want a nicer way to +# achieve this. + +rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) +rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) + +model2 = M.ace_model(; elements = elements, + order = order, # correlation order + Ytype = :solid, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = "linear", # how to initialize the pair potential parameters + init_Wradial = :linear, + pair_transform = (:agnesi, 1, 3), + pair_learnable = true, + rin0cuts = rin0cuts, + ) + +ps, st = Lux.setup(rng, model2) +ps_r = ps.rbasis +st_r = st.rbasis + +# extract the radial basis +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = model2.rbasis +k = length(rbasis1.J.A) + +# transform old coefficients to new coefficients to make them match + +rbasis1.J.A[:] .= rbasis2.polys.A[1:k] +rbasis1.J.B[:] .= rbasis2.polys.B[1:k] +rbasis1.J.C[:] .= rbasis2.polys.C[1:k] +rbasis1.J.A[2] /= rbasis1.J.A[1] +rbasis1.J.B[2] /= rbasis1.J.A[1] + +# wrap the model into a calculator, which turns it into a potential... + +calc_model2 = M.ACEPotential(model2) + + +## + +ybasis1 = model1.basis.BB[2].pibasis.basis1p.SH +ybasis2 = model2.ybasis +maxk = length(ybasis2) + +X = [ (u = @SVector rand(3); u/norm(u)) for _ = 1:100 ] +Y1 = reduce(hcat, [ ACE1.evaluate(ybasis1, u)[1:maxk] for u in X ]) +Y1r = real.(Y1) +Y1i = imag.(Y1) +Y2 = reduce(hcat, [ ybasis2(u)[1:maxk] for u in X ]) + +@info("check span real/imag(Y1) = span Y2") +Cr = Y2' \ Y1r' +@show norm(Y1r' - Y2' * Cr) + +Ci = Y2' \ Y1i' +@show norm(Y1i' - Y2' * Ci) + + From 8d71f0d6646b6a963f952df5e7d23b9d54bc6deb Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 28 Jul 2024 16:14:34 -0700 Subject: [PATCH 078/112] angular coupled basis fcns are inconsistent --- docs/src/newkernels/acebasis_analysis.jl | 160 +++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 docs/src/newkernels/acebasis_analysis.jl diff --git a/docs/src/newkernels/acebasis_analysis.jl b/docs/src/newkernels/acebasis_analysis.jl new file mode 100644 index 00000000..7bf64143 --- /dev/null +++ b/docs/src/newkernels/acebasis_analysis.jl @@ -0,0 +1,160 @@ +# This script is to explore the differences between the ACE1 models and the new +# models. This is to help bring the two to feature parity so that ACE1 +# can be retired. + +using Random +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Zygote, Optimisers, Folds, Plots +rng = Random.GLOBAL_RNG + +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +Z0 = :Si +z1 = AtomicNumber(Z0) +z2 = Int(z1) + +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:3:end] + +# because the new implementation is experimental, it is not exported, +# so I create a little shortcut to have easy access. + +M = ACEpotentials.Models + +# First we create an ACE1 style potential with some standard parameters + +elements = [Z0,] +order = 3 +totaldegree = 10 +rcut = 5.5 + +model1 = acemodel(elements = elements, + order = order, + transform = (:agnesi, 2, 2), + totaldegree = totaldegree, + pure = false, + pure2b = false, + pair_envelope = (:r, 1), + rcut = rcut, ) + +# now we create an ACE2 style model that should behave similarly + +# this essentially reproduces the rcut = 5.5, we may want a nicer way to +# achieve this. + +rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) +rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) + +model2 = M.ace_model(; elements = elements, + order = order, # correlation order + Ytype = :solid, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = "linear", # how to initialize the pair potential parameters + init_Wradial = :linear, + pair_transform = (:agnesi, 1, 3), + pair_learnable = true, + rin0cuts = rin0cuts, + ) + +ps, st = Lux.setup(rng, model2) +ps_r = ps.rbasis +st_r = st.rbasis + +# extract the radial basis +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = model2.rbasis +k = length(rbasis1.J.A) + +# transform old coefficients to new coefficients to make them match + +rbasis1.J.A[:] .= rbasis2.polys.A[1:k] +rbasis1.J.B[:] .= rbasis2.polys.B[1:k] +rbasis1.J.C[:] .= rbasis2.polys.C[1:k] +rbasis1.J.A[2] /= rbasis1.J.A[1] +rbasis1.J.B[2] /= rbasis1.J.A[1] + +# wrap the model into a calculator, which turns it into a potential... + +calc_model2 = M.ACEPotential(model2) + + +## + +# look at the specifications +_spec1 = ACE1.get_nl(model1.basis.BB[2]) +spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] +spec2 = M.get_nnll_spec(model2.tensor) +spec1 = sort.(spec1) +spec2 = sort.(spec2) + +Nb = length(spec2) + +## + +@info("Set differences of spec suggest the bases are consistent") +spec_1diff2 = setdiff(spec1, spec2) +spec_2diff1 = setdiff(spec2, spec1) +@show length(spec_2diff1) +@show length(spec_1diff2) +@show length(spec1) - length(spec2) + +## + +idx2in1 = [ findfirst( Ref(bb) .== spec1 ) for bb in spec2 ] +@show length(idx2in1) == Nb + +# now we can check the span + +Nenv = 1000 +XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] +XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] + +B1 = [ ACE1.evaluate(model1.basis.BB[2], x...)[idx2in1] for x in XX1 ] + +I2mb = M.get_basis_inds(model2, z2) +B2 = [ M.evaluate_basis(model2, x..., ps, st)[I2mb] for x in XX2 ] + +A1 = reduce(hcat, B1) +A2 = reduce(hcat, B2) + +# see whether they span the same space +# for the full basis this is not even close to true. ... +C = A1' \ A2' +norm(A2' - A1' * C) + +# we can make a list of all basis functions that fail ... +@info("make a list of failed basis functions") +err = sum(abs, A2' - A1' * C, dims = (1,))[:] +idx_fail = findall(err .> 1e-8) + +spec_fail = spec2[I2mb[idx_fail]] +@info("List of failed basis functions: ") +display(spec_fail) + +@info("Compare with list of basis functions that have l > 0") +maxll = [ maximum(b.l for b in bb) for bb in spec2[I2mb] ] +idx_hasl = findall(maxll .> 0) +@show sort(idx_fail) == sort(idx_hasl) + +## ------------------------------------------------------------------ +## try for the 1-correlations: + +@info("Checking consistencyy of 1-correlations") +idx1 = 1:10 +C1 = A2[idx1, :] / A1[idx1,:] +@show norm(C1 * A2[idx1,:] - A1[idx1,:]) + +## try for 1- and 2-correlations + +idx21 = findall(length.(spec1[idx2in1]) .<= 2) +idx22 = findall(length.(spec2) .<= 2) +C2 = A2[idx22, :]' \ A1[idx21, :]' +@show norm(A1[idx21, :]' - A2[idx22,:]' * C2) +err = sum(abs, A1[idx21, :]' - A2[idx22,:]' * C2, dims = (1,)) +idx_err = findall(err[:] .> 1e-10) + + From f1ce95373441e98d478dcdf76da02e1e89a33df6 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 28 Jul 2024 16:15:16 -0700 Subject: [PATCH 079/112] small cleanup --- docs/src/newkernels/acebasis_analysis.jl | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/docs/src/newkernels/acebasis_analysis.jl b/docs/src/newkernels/acebasis_analysis.jl index 7bf64143..eb48e253 100644 --- a/docs/src/newkernels/acebasis_analysis.jl +++ b/docs/src/newkernels/acebasis_analysis.jl @@ -140,21 +140,3 @@ maxll = [ maximum(b.l for b in bb) for bb in spec2[I2mb] ] idx_hasl = findall(maxll .> 0) @show sort(idx_fail) == sort(idx_hasl) -## ------------------------------------------------------------------ -## try for the 1-correlations: - -@info("Checking consistencyy of 1-correlations") -idx1 = 1:10 -C1 = A2[idx1, :] / A1[idx1,:] -@show norm(C1 * A2[idx1,:] - A1[idx1,:]) - -## try for 1- and 2-correlations - -idx21 = findall(length.(spec1[idx2in1]) .<= 2) -idx22 = findall(length.(spec2) .<= 2) -C2 = A2[idx22, :]' \ A1[idx21, :]' -@show norm(A1[idx21, :]' - A2[idx22,:]' * C2) -err = sum(abs, A1[idx21, :]' - A2[idx22,:]' * C2, dims = (1,)) -idx_err = findall(err[:] .> 1e-10) - - From 8241d1a9a037f42755c36ea579bb1353ad1b2688 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 11:30:31 -0700 Subject: [PATCH 080/112] perfect basis match --- docs/src/newkernels/Project.toml | 1 + docs/src/newkernels/acebasis_analysis.jl | 17 +++++++---------- docs/src/newkernels/ylm_analysis.jl | 5 +++++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index c6a407ae..b4c8cc92 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -10,6 +10,7 @@ GeomOpt = "ca147568-c688-4a55-a13d-dbd284330f4b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/docs/src/newkernels/acebasis_analysis.jl b/docs/src/newkernels/acebasis_analysis.jl index eb48e253..f3ecb745 100644 --- a/docs/src/newkernels/acebasis_analysis.jl +++ b/docs/src/newkernels/acebasis_analysis.jl @@ -48,7 +48,7 @@ rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) model2 = M.ace_model(; elements = elements, order = order, # correlation order - Ytype = :solid, # solid vs spherical harmonics + Ytype = :spherical, # solid vs spherical harmonics level = M.TotalDegree(), # how to calculate the weights to give to a basis function max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential @@ -114,11 +114,14 @@ XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] B1 = [ ACE1.evaluate(model1.basis.BB[2], x...)[idx2in1] for x in XX1 ] +B1_all = [ ACE1.evaluate(model1.basis.BB[2], x...) for x in XX1 ] + I2mb = M.get_basis_inds(model2, z2) B2 = [ M.evaluate_basis(model2, x..., ps, st)[I2mb] for x in XX2 ] A1 = reduce(hcat, B1) +A1_all = reduce(hcat, B1_all) A2 = reduce(hcat, B2) # see whether they span the same space @@ -130,13 +133,7 @@ norm(A2' - A1' * C) @info("make a list of failed basis functions") err = sum(abs, A2' - A1' * C, dims = (1,))[:] idx_fail = findall(err .> 1e-8) +@show idx_fail +@show norm( abs.(C) - I ) -spec_fail = spec2[I2mb[idx_fail]] -@info("List of failed basis functions: ") -display(spec_fail) - -@info("Compare with list of basis functions that have l > 0") -maxll = [ maximum(b.l for b in bb) for bb in spec2[I2mb] ] -idx_hasl = findall(maxll .> 0) -@show sort(idx_fail) == sort(idx_hasl) - +# ALL ARE OK! WE ARE GOOD diff --git a/docs/src/newkernels/ylm_analysis.jl b/docs/src/newkernels/ylm_analysis.jl index fde0e675..807533ba 100644 --- a/docs/src/newkernels/ylm_analysis.jl +++ b/docs/src/newkernels/ylm_analysis.jl @@ -102,3 +102,8 @@ Ci = Y2' \ Y1i' @show norm(Y1i' - Y2' * Ci) +## + +cyp4ml = complex_sphericalharmonics(2) +Yp4 = reduce(hcat, [cyp4ml(u) for u in X] ) +Yp4 ≈ Y1 From a6d3b1b6a52bb0a939abcd5ce3a0679a0f9d4e79 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 11:45:51 -0700 Subject: [PATCH 081/112] regression is closer (no match yet) --- docs/src/newkernels/linear.jl | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index 8b5e32c1..f11ad26a 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -23,12 +23,11 @@ M = ACEpotentials.Models elements = [Z0,] order = 3 -totaldegree = 10 +totaldegree = 12 rcut = 5.5 model1 = acemodel(elements = elements, order = order, - pin = 2, pcut = 2, transform = (:agnesi, 2, 2), totaldegree = totaldegree, pure = false, @@ -46,9 +45,9 @@ rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) model2 = M.ace_model(; elements = elements, order = order, # correlation order - Ytype = :solid, # solid vs spherical harmonics + Ytype = :spherical, # solid vs spherical harmonics level = M.TotalDegree(), # how to calculate the weights to give to a basis function - max_level = totaldegree, # maximum level of the basis functions + max_level = totaldegree+1, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters init_Wpair = "linear", # how to initialize the pair potential parameters @@ -58,25 +57,28 @@ model2 = M.ace_model(; elements = elements, rin0cuts = rin0cuts, ) -ps, st = Lux.setup(rng, model2) +ps, st = Lux.setup(rng, model2) ps_r = ps.rbasis st_r = st.rbasis # extract the radial basis rbasis1 = model1.basis.BB[2].pibasis.basis1p.J rbasis2 = model2.rbasis - k = length(rbasis1.J.A) + +# transform old coefficients to new coefficients to make them match rbasis1.J.A[:] .= rbasis2.polys.A[1:k] rbasis1.J.B[:] .= rbasis2.polys.B[1:k] rbasis1.J.C[:] .= rbasis2.polys.C[1:k] - +rbasis1.J.A[2] /= rbasis1.J.A[1] +rbasis1.J.B[2] /= rbasis1.J.A[1] # wrap the model into a calculator, which turns it into a potential... calc_model2 = M.ACEPotential(model2) -# Fit the ACE1 model +## +#Fit the ACE1 model # set weights for energy, forces virials weights = Dict("default" => Dict("E" => 30.0, "F" => 1.0 , "V" => 1.0 ),); @@ -86,7 +88,7 @@ solver=ACEfit.TruncatedSVD(; rtol = 1e-8) acefit!(model1, train; solver=solver) - +## # Fit the ACE2 model - this still needs a bit of hacking to convert everything # to the new framework. # - convert the data to AtomsBase @@ -138,13 +140,14 @@ function assemble_lsq(calc, data, weights, data_keys; blocks = Folds.map(at -> local_lsqsys(calc, at, ps, st, weights, data_keys), data, executor) + A = reduce(vcat, [b[1] for b in blocks]) y = reduce(vcat, [b[2] for b in blocks]) return A, y end - -A, y = assemble_lsq(calc_model2, train2[1:10], weights, data_keys) +A, y = assemble_lsq(calc_model2, train2, weights, data_keys) +@show size(A) θ = ACEfit.trunc_svd(svd(A), y, 1e-8) ps, st = Lux.setup(rng, calc_model2) @@ -164,6 +167,7 @@ ps_fit.Wpair[:] = ps_lin_fit.Wpair[:] calc_model2_fit = M.ACEPotential(model2, ps_fit, st) +## # Now we can compare errors? # to make sure we are comparing exactly the same thing, we implement this # from scratch here ... From 3f8f13700281c006eb263a4a37491458bcad91a1 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 12:59:05 -0700 Subject: [PATCH 082/112] match the full pair+mb basis --- docs/src/newkernels/acebasis_analysis.jl | 31 +++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/docs/src/newkernels/acebasis_analysis.jl b/docs/src/newkernels/acebasis_analysis.jl index f3ecb745..8a3e4f90 100644 --- a/docs/src/newkernels/acebasis_analysis.jl +++ b/docs/src/newkernels/acebasis_analysis.jl @@ -136,4 +136,33 @@ idx_fail = findall(err .> 1e-8) @show idx_fail @show norm( abs.(C) - I ) -# ALL ARE OK! WE ARE GOOD +@info("Perfect match - we should be good!") + +## +# now we try to reduce the ACE1 basis to be identical to the +# ACE2 basis + +@info("Reduce the ACE1 basis to be identical to the ACE2 basis") +@show size(model1.basis.BB[2].A2Bmaps[1]) + +A1_all[idx2in1, :] == A1 +idx_del = setdiff((1:size(model1.basis.BB[2].A2Bmaps[1], 1)), idx2in1) +model1.basis.BB[2].A2Bmaps[1][idx_del, :] .= 0 +BB2 = ACE1.RPI.remove_zeros(ACE1._cleanup(model1.basis.BB[2])) +@show size(BB2.A2Bmaps[1]) + +## +basis1_red = deepcopy(model1.basis) +basis1_red.BB[2] = BB2 + +function _evaluate(basis::JuLIP.MLIPs.IPSuperBasis, + Rs, Zs, z0) + reduce(vcat, [ACE1.evaluate(B, Rs, Zs, z0) for B in basis.BB]) +end + +A1_ = reduce(hcat, [ _evaluate(basis1_red, x...) for x in XX1]) +A2_ = reduce(hcat, [M.evaluate_basis(model2, x..., ps, st) for x in XX2]) +A2_p = [A2_[end-9:end,:]; A2_[1:end-10,:]] + +C = A1_' \ A2_p' +@show norm(A2_p - C' * A1_) From e5bdd73122f25955e46cef1acf722acfa236bae3 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 13:23:31 -0700 Subject: [PATCH 083/112] simplified basis matching test --- docs/src/newkernels/acebasis_analysis.jl | 2 +- docs/src/newkernels/linear.jl | 7 +- docs/src/newkernels/match_bases.jl | 77 ++++++++++++++++++++++ docs/src/newkernels/test_matching_bases.jl | 49 ++++++++++++++ 4 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 docs/src/newkernels/match_bases.jl create mode 100644 docs/src/newkernels/test_matching_bases.jl diff --git a/docs/src/newkernels/acebasis_analysis.jl b/docs/src/newkernels/acebasis_analysis.jl index 8a3e4f90..d4856169 100644 --- a/docs/src/newkernels/acebasis_analysis.jl +++ b/docs/src/newkernels/acebasis_analysis.jl @@ -60,7 +60,7 @@ model2 = M.ace_model(; elements = elements, rin0cuts = rin0cuts, ) -ps, st = Lux.setup(rng, model2) +ps, st = Lux.setup(rng, model2) ps_r = ps.rbasis st_r = st.rbasis diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index f11ad26a..6ec49fb8 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -47,7 +47,7 @@ model2 = M.ace_model(; elements = elements, order = order, # correlation order Ytype = :spherical, # solid vs spherical harmonics level = M.TotalDegree(), # how to calculate the weights to give to a basis function - max_level = totaldegree+1, # maximum level of the basis functions + max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters init_Wpair = "linear", # how to initialize the pair potential parameters @@ -73,6 +73,11 @@ rbasis1.J.C[:] .= rbasis2.polys.C[1:k] rbasis1.J.A[2] /= rbasis1.J.A[1] rbasis1.J.B[2] /= rbasis1.J.A[1] +## +basis1_red = deepcopy(model1.basis) +basis1_red.BB[2] = BB2 + + # wrap the model into a calculator, which turns it into a potential... calc_model2 = M.ACEPotential(model2) diff --git a/docs/src/newkernels/match_bases.jl b/docs/src/newkernels/match_bases.jl new file mode 100644 index 00000000..828117a7 --- /dev/null +++ b/docs/src/newkernels/match_bases.jl @@ -0,0 +1,77 @@ +# This script is to explore the differences between the ACE1 models and the new +# models. This is to help bring the two to feature parity so that ACE1 +# can be retired. + +using Random +using ACEpotentials, Lux +M = ACEpotentials.Models + +function matching_bases(; Z = :Si, order = 3, totaldegree = 10, + rcut = 5.5) + + model1 = acemodel(elements = elements, + order = order, + transform = (:agnesi, 2, 2), + totaldegree = totaldegree, + pure = false, + pure2b = false, + pair_envelope = (:r, 1), + rcut = rcut, ) + + rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) + rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) + + model2 = M.ace_model(; elements = elements, + order = order, # correlation order + Ytype = :spherical, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = "linear", # how to initialize the pair potential parameters + init_Wradial = :linear, + pair_transform = (:agnesi, 1, 3), + pair_learnable = true, + rin0cuts = rin0cuts, + ) + + ps, st = Lux.setup(rng, model2) + ps_r = ps.rbasis + st_r = st.rbasis + + # extract the radial basis + rbasis1 = model1.basis.BB[2].pibasis.basis1p.J + rbasis2 = model2.rbasis + k = length(rbasis1.J.A) + + # transform old coefficients to new coefficients to make them match + rbasis1.J.A[:] .= rbasis2.polys.A[1:k] + rbasis1.J.B[:] .= rbasis2.polys.B[1:k] + rbasis1.J.C[:] .= rbasis2.polys.C[1:k] + rbasis1.J.A[2] /= rbasis1.J.A[1] + rbasis1.J.B[2] /= rbasis1.J.A[1] + + # fix the basis1 spec + _spec1 = ACE1.get_nl(model1.basis.BB[2]) + spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] + spec2 = M.get_nnll_spec(model2.tensor) + spec1 = sort.(spec1) + spec2 = sort.(spec2) + Nb = length(spec2) + idx2in1 = [ findfirst( Ref(bb) .== spec1 ) for bb in spec2 ] + @show length(idx2in1) == Nb + + idx_del = setdiff((1:size(model1.basis.BB[2].A2Bmaps[1], 1)), idx2in1) + model1.basis.BB[2].A2Bmaps[1][idx_del, :] .= 0 + BB2 = ACE1.RPI.remove_zeros(ACE1._cleanup(model1.basis.BB[2])) + model1.basis.BB[2] = BB2 + + # wrap the model into a calculator, which turns it into a potential... + calc_model2 = M.ACEPotential(model2) + + return model1, model2, calc_model2 +end + + + + diff --git a/docs/src/newkernels/test_matching_bases.jl b/docs/src/newkernels/test_matching_bases.jl new file mode 100644 index 00000000..eb3687dd --- /dev/null +++ b/docs/src/newkernels/test_matching_bases.jl @@ -0,0 +1,49 @@ + +include("match_bases.jl") +using LinearAlgebra, StaticArrays, Lux + +_evaluate(basis::JuLIP.MLIPs.IPSuperBasis, Rs, Zs, z0) = + reduce(vcat, [ACE1.evaluate(B, Rs, Zs, z0) for B in basis.BB]) + +## + +Z0 = :Si +z1 = AtomicNumber(Z0) +z2 = Int(z1) + +model1, model2, calc_model2 = matching_bases(; Z = Z0) +ps, st = Lux.setup(Random.GLOBAL_RNG, model2) + +## +# confirm match on atomic environments + +Nenv = 1000 +XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] +XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] + + +B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) +B2 = reduce(hcat, [M.evaluate_basis(model2, x..., ps, st) for x in XX2]) + +@info("Compute linear transform between bases to show match") +C = B1' \ B2' +@show norm(B2 - C' * B1) + +@info("Transform should be a permuted diagonal, but this is not quite true...") +@show size(C) +@show count(abs.(C) .> 1e-10) + +## + +@info("Check match on a dataset (Zuo)") +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:10] + +EB1 = reduce(hcat, [ ACE1.energy(model1.basis, sys) for sys in train ]) +EB2 = reduce(hcat, [ M.energy_forces_virial_basis(FlexibleSystem(sys), calc_model2, ps, st).energy for sys in train ]) + +@show norm(ustrip.(EB2) - C' * EB1) + From 4d1a0118676c01889cb568872478b4eecdd6284f Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 13:30:40 -0700 Subject: [PATCH 084/112] closely matching regression results --- docs/src/newkernels/linear2.jl | 156 +++++++++++++++++++++ docs/src/newkernels/test_matching_bases.jl | 2 +- 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 docs/src/newkernels/linear2.jl diff --git a/docs/src/newkernels/linear2.jl b/docs/src/newkernels/linear2.jl new file mode 100644 index 00000000..d6eec60a --- /dev/null +++ b/docs/src/newkernels/linear2.jl @@ -0,0 +1,156 @@ +# This script is to roughly document how to use the new model implementations +# I'll try to explain what can be done and what is missing along the way. +# I am + +using Random +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Zygote, Optimisers, Folds, Printf +rng = Random.GLOBAL_RNG +M = ACEpotentials.Models + +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +Z0 = :Si +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:3:end] +weights = Dict("default" => Dict("E" => 30.0, "F" => 1.0 , "V" => 1.0 ),); + +# Creating matching ACE1 and ACE2 models + +totaldegree = 10 +order = 3 +rcut = 5.5 + +model1, model2, calc_model2 = matching_bases(; + Z = Z0, rcut = rcut, order = order, totaldegree=totaldegree) + + +## +#Fit the ACE1 model + +solver=ACEfit.TruncatedSVD(; rtol = 1e-8) +acefit!(model1, train; solver=solver, weights=weights) + + +## +# Fit the ACE2 model - this still needs a bit of hacking to convert everything +# to the new framework. +# - convert the data to AtomsBase +# - use a different interface to specify data weights and keys +# (this needs to be brough in line with the ACEpotentials framework) +# - rewrite the assembly for the LSQ system from scratch (but this is easy) + +train2 = FlexibleSystem.(train) +test2 = FlexibleSystem.(test) +data_keys = (E_key = :energy, F_key = :force, ) +weights = (wE = 30.0/u"eV", wF = 1.0 / u"eV/Å", ) + +function local_lsqsys(calc, at, ps, st, weights, keys) + efv = M.energy_forces_virial_basis(at, calc, ps, st) + + # There are no E0s in this dataset! + # # compute the E0s contribution. This needs to be done more + # # elegantly and a stacked model would solve this problem. + # E0 = sum( calc.model.E0s[M._z2i(calc.model, z)] + # for z in AtomsBase.atomic_number(at) ) * u"eV" + + # energy + wE = weights[:wE] + E_dft = at.data[data_keys.E_key] * u"eV" + y_E = wE * E_dft / sqrt(length(at)) # (E_dft - E0) + A_E = wE * efv.energy' / sqrt(length(at)) + + # forces + wF = weights[:wF] + F_dft = at.data[data_keys.F_key] * u"eV/Å" + y_F = wF * reinterpret(eltype(F_dft[1]), F_dft) + A_F = wF * reinterpret(eltype(efv.forces[1]), efv.forces) + + # # virial + # wV = weights[:wV] + # V_dft = at.data[data_keys.V_key] * u"eV" + # y_V = wV * V_dft[:] + # # display( reinterpret(eltype(efv.virial), efv.virial) ) + # A_V = wV * reshape(reinterpret(eltype(efv.virial[1]), efv.virial), 9, :) + + return vcat(A_E, A_F), vcat(y_E, y_F) +end + + +function assemble_lsq(calc, data, weights, data_keys; + rng = Random.MersenneTwister(1234), + executor = Folds.ThreadedEx()) + ps, st = Lux.setup(rng, calc) + blocks = Folds.map(at -> local_lsqsys(calc, at, ps, st, + weights, data_keys), + data, executor) + + A = reduce(vcat, [b[1] for b in blocks]) + y = reduce(vcat, [b[2] for b in blocks]) + return A, y +end + +A, y = assemble_lsq(calc_model2, train2, weights, data_keys) +@show size(A) + +θ = ACEfit.trunc_svd(svd(A), y, 1e-8) +ps, st = Lux.setup(rng, calc_model2) + +# the next step is a hack. This should be automatable, probably using Lux.freeze. +# But I couldn't quite figure out how to use that. +# Here I'm manually constructing a parameters NamedTuple with rbasis removed. +# then I'm using the destructure / restructure method from Optimizers to +# convert θ into a namedtuple. + +ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = NamedTuple(), rbasis = NamedTuple()) +_θ, restruct = destructure(ps_lin) +ps_lin_fit = restruct(θ) +ps_fit = deepcopy(ps) +ps_fit.WB[:] = ps_lin_fit.WB[:] +ps_fit.Wpair[:] = ps_lin_fit.Wpair[:] +calc_model2_fit = M.ACEPotential(model2, ps_fit, st) + + +## +# Now we can compare errors? +# to make sure we are comparing exactly the same thing, we implement this +# from scratch here ... + +function EF_err(sys::JuLIP.Atoms, calc) + E = JuLIP.energy(calc, sys) + F = JuLIP.forces(calc, sys) + E_ref = JuLIP.get_data(sys, "energy") + F_ref = JuLIP.get_data(sys, "force") + return abs(E - E_ref) / length(sys), norm.(F - F_ref) +end + +function EF_err(sys::AtomsBase.AbstractSystem, calc) + efv = M.energy_forces_virial(sys, calc_model2_fit) + F_ustrip = [ ustrip.(f) for f in efv.forces ] + E_ref = sys.data[:energy] + F_ref = sys.data[:force] + return abs(ustrip(efv.energy) - E_ref) / length(sys), norm.(F_ustrip - F_ref) +end + +function rmse(test, calc) + E_errs = Float64[] + F_errs = Float64[] + for sys in test + E_err, F_err = EF_err(sys, calc) + push!(E_errs, E_err) + append!(F_errs, F_err) + end + return norm(E_errs) / sqrt(length(E_errs)), + norm(F_errs) / sqrt(length(F_errs)) +end + + +E_rmse_1, F_rmse_1 = rmse(test, model1.potential) +E_rmse_2, F_rmse_2 = rmse(test2, calc_model2_fit) + + +@printf("Model | E | F \n") +@printf(" ACE1 | %.2e | %.2e \n", E_rmse_1, F_rmse_1) +@printf(" ACE2 | %.2e | %.2e \n", E_rmse_2, F_rmse_2) + diff --git a/docs/src/newkernels/test_matching_bases.jl b/docs/src/newkernels/test_matching_bases.jl index eb3687dd..a8143d88 100644 --- a/docs/src/newkernels/test_matching_bases.jl +++ b/docs/src/newkernels/test_matching_bases.jl @@ -11,8 +11,8 @@ Z0 = :Si z1 = AtomicNumber(Z0) z2 = Int(z1) -model1, model2, calc_model2 = matching_bases(; Z = Z0) ps, st = Lux.setup(Random.GLOBAL_RNG, model2) +model1, model2, calc_model2 = matching_bases(; Z = Z0) ## # confirm match on atomic environments From 4aabf612871e8e9ee88dbc579ec484a16c0c472a Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 14:55:14 -0700 Subject: [PATCH 085/112] cleanup basic linear script --- docs/src/newkernels/linear.jl | 48 ++------ .../{linear2.jl => linear_match.jl} | 0 docs/src/newkernels/ylm_analysis.jl | 109 ------------------ 3 files changed, 11 insertions(+), 146 deletions(-) rename docs/src/newkernels/{linear2.jl => linear_match.jl} (100%) delete mode 100644 docs/src/newkernels/ylm_analysis.jl diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index 6ec49fb8..ca5cd629 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -6,19 +6,19 @@ using Random using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, Unitful, Zygote, Optimisers, Folds, Printf rng = Random.GLOBAL_RNG +M = ACEpotentials.Models +## # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset Z0 = :Si train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") train = train[1:3:end] +wE = 30.0; wF = 1.0; wV = 1.0 +weights = Dict("default" => Dict("E" => wE, "F" => wF , "V" => wV)) -# because the new implementation is experimental, it is not exported, -# so I create a little shortcut to have easy access. - -M = ACEpotentials.Models - +## # First we create an ACE1 style potential with some standard parameters elements = [Z0,] @@ -28,7 +28,6 @@ rcut = 5.5 model1 = acemodel(elements = elements, order = order, - transform = (:agnesi, 2, 2), totaldegree = totaldegree, pure = false, pure2b = false, @@ -37,17 +36,14 @@ model1 = acemodel(elements = elements, # now we create an ACE2 style model that should behave similarly -# this essentially reproduces the rcut = 5.5, we may want a nicer way to -# achieve this. - -rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) +rin0cuts = M._default_rin0cuts(elements) rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) model2 = M.ace_model(; elements = elements, order = order, # correlation order Ytype = :spherical, # solid vs spherical harmonics level = M.TotalDegree(), # how to calculate the weights to give to a basis function - max_level = totaldegree, # maximum level of the basis functions + max_level = totaldegree+1, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters init_Wpair = "linear", # how to initialize the pair potential parameters @@ -58,38 +54,16 @@ model2 = M.ace_model(; elements = elements, ) ps, st = Lux.setup(rng, model2) -ps_r = ps.rbasis -st_r = st.rbasis - -# extract the radial basis -rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -rbasis2 = model2.rbasis -k = length(rbasis1.J.A) - -# transform old coefficients to new coefficients to make them match -rbasis1.J.A[:] .= rbasis2.polys.A[1:k] -rbasis1.J.B[:] .= rbasis2.polys.B[1:k] -rbasis1.J.C[:] .= rbasis2.polys.C[1:k] -rbasis1.J.A[2] /= rbasis1.J.A[1] -rbasis1.J.B[2] /= rbasis1.J.A[1] - -## -basis1_red = deepcopy(model1.basis) -basis1_red.BB[2] = BB2 - # wrap the model into a calculator, which turns it into a potential... calc_model2 = M.ACEPotential(model2) ## -#Fit the ACE1 model +# Fit the ACE1 model -# set weights for energy, forces virials -weights = Dict("default" => Dict("E" => 30.0, "F" => 1.0 , "V" => 1.0 ),); # specify a solver solver=ACEfit.TruncatedSVD(; rtol = 1e-8) - acefit!(model1, train; solver=solver) @@ -104,7 +78,7 @@ acefit!(model1, train; solver=solver) train2 = FlexibleSystem.(train) test2 = FlexibleSystem.(test) data_keys = (E_key = :energy, F_key = :force, ) -weights = (wE = 30.0/u"eV", wF = 1.0 / u"eV/Å", ) +weights = (wE = wE/u"eV", wF = wF / u"eV/Å", ) function local_lsqsys(calc, at, ps, st, weights, keys) efv = M.energy_forces_virial_basis(at, calc, ps, st) @@ -118,8 +92,8 @@ function local_lsqsys(calc, at, ps, st, weights, keys) # energy wE = weights[:wE] E_dft = at.data[data_keys.E_key] * u"eV" - y_E = wE * E_dft # (E_dft - E0) - A_E = wE * efv.energy' + y_E = wE * E_dft / sqrt(length(at)) # (E_dft - E0) + A_E = wE * efv.energy' / sqrt(length(at)) # forces wF = weights[:wF] diff --git a/docs/src/newkernels/linear2.jl b/docs/src/newkernels/linear_match.jl similarity index 100% rename from docs/src/newkernels/linear2.jl rename to docs/src/newkernels/linear_match.jl diff --git a/docs/src/newkernels/ylm_analysis.jl b/docs/src/newkernels/ylm_analysis.jl deleted file mode 100644 index 807533ba..00000000 --- a/docs/src/newkernels/ylm_analysis.jl +++ /dev/null @@ -1,109 +0,0 @@ -# This script is to explore the differences between the ACE1 models and the new -# models. This is to help bring the two to feature parity so that ACE1 -# can be retired. - -using Random -using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, - Unitful, Zygote, Optimisers, Folds, Plots -rng = Random.GLOBAL_RNG - -# we will try this for a simple dataset, Zuo et al -# replace element with any of those available in that dataset - -Z0 = :Si -z1 = AtomicNumber(Z0) -z2 = Int(z1) - -train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") -train = train[1:3:end] - -# because the new implementation is experimental, it is not exported, -# so I create a little shortcut to have easy access. - -M = ACEpotentials.Models - -# First we create an ACE1 style potential with some standard parameters - -elements = [Z0,] -order = 3 -totaldegree = 10 -rcut = 5.5 - -model1 = acemodel(elements = elements, - order = order, - transform = (:agnesi, 2, 2), - totaldegree = totaldegree, - pure = false, - pure2b = false, - pair_envelope = (:r, 1), - rcut = rcut, ) - -# now we create an ACE2 style model that should behave similarly - -# this essentially reproduces the rcut = 5.5, we may want a nicer way to -# achieve this. - -rin0cuts = M._default_rin0cuts(elements) #; rcutfactor = 2.29167) -rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) - -model2 = M.ace_model(; elements = elements, - order = order, # correlation order - Ytype = :solid, # solid vs spherical harmonics - level = M.TotalDegree(), # how to calculate the weights to give to a basis function - max_level = totaldegree, # maximum level of the basis functions - pair_maxn = totaldegree, # maximum number of basis functions for the pair potential - init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = "linear", # how to initialize the pair potential parameters - init_Wradial = :linear, - pair_transform = (:agnesi, 1, 3), - pair_learnable = true, - rin0cuts = rin0cuts, - ) - -ps, st = Lux.setup(rng, model2) -ps_r = ps.rbasis -st_r = st.rbasis - -# extract the radial basis -rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -rbasis2 = model2.rbasis -k = length(rbasis1.J.A) - -# transform old coefficients to new coefficients to make them match - -rbasis1.J.A[:] .= rbasis2.polys.A[1:k] -rbasis1.J.B[:] .= rbasis2.polys.B[1:k] -rbasis1.J.C[:] .= rbasis2.polys.C[1:k] -rbasis1.J.A[2] /= rbasis1.J.A[1] -rbasis1.J.B[2] /= rbasis1.J.A[1] - -# wrap the model into a calculator, which turns it into a potential... - -calc_model2 = M.ACEPotential(model2) - - -## - -ybasis1 = model1.basis.BB[2].pibasis.basis1p.SH -ybasis2 = model2.ybasis -maxk = length(ybasis2) - -X = [ (u = @SVector rand(3); u/norm(u)) for _ = 1:100 ] -Y1 = reduce(hcat, [ ACE1.evaluate(ybasis1, u)[1:maxk] for u in X ]) -Y1r = real.(Y1) -Y1i = imag.(Y1) -Y2 = reduce(hcat, [ ybasis2(u)[1:maxk] for u in X ]) - -@info("check span real/imag(Y1) = span Y2") -Cr = Y2' \ Y1r' -@show norm(Y1r' - Y2' * Cr) - -Ci = Y2' \ Y1i' -@show norm(Y1i' - Y2' * Ci) - - -## - -cyp4ml = complex_sphericalharmonics(2) -Yp4 = reduce(hcat, [cyp4ml(u) for u in X] ) -Yp4 ≈ Y1 From cd54c517ad1ff80b64b46b4750f6cca5b48c83c1 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 15:31:57 -0700 Subject: [PATCH 086/112] streamlining and cleanup --- docs/src/newkernels/linear.jl | 137 ++------------------- docs/src/newkernels/linear_match.jl | 113 ++--------------- docs/src/newkernels/llsq.jl | 98 +++++++++++++++ docs/src/newkernels/match_bases.jl | 3 +- docs/src/newkernels/test_matching_bases.jl | 2 +- 5 files changed, 125 insertions(+), 228 deletions(-) create mode 100644 docs/src/newkernels/llsq.jl diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index ca5cd629..a51b292b 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -8,6 +8,8 @@ using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, rng = Random.GLOBAL_RNG M = ACEpotentials.Models +include(@__DIR__() * "/LLSQ.jl") + ## # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset @@ -16,7 +18,6 @@ Z0 = :Si train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") train = train[1:3:end] wE = 30.0; wF = 1.0; wV = 1.0 -weights = Dict("default" => Dict("E" => wE, "F" => wF , "V" => wV)) ## # First we create an ACE1 style potential with some standard parameters @@ -24,20 +25,6 @@ weights = Dict("default" => Dict("E" => wE, "F" => wF , "V" => wV)) elements = [Z0,] order = 3 totaldegree = 12 -rcut = 5.5 - -model1 = acemodel(elements = elements, - order = order, - totaldegree = totaldegree, - pure = false, - pure2b = false, - pair_envelope = (:r, 1), - rcut = rcut, ) - -# now we create an ACE2 style model that should behave similarly - -rin0cuts = M._default_rin0cuts(elements) -rin0cuts = SMatrix{1,1}((;rin0cuts[1]..., :rcut => 5.5)) model2 = M.ace_model(; elements = elements, order = order, # correlation order @@ -49,22 +36,14 @@ model2 = M.ace_model(; elements = elements, init_Wpair = "linear", # how to initialize the pair potential parameters init_Wradial = :linear, pair_transform = (:agnesi, 1, 3), - pair_learnable = true, - rin0cuts = rin0cuts, + pair_learnable = false, ) -ps, st = Lux.setup(rng, model2) # wrap the model into a calculator, which turns it into a potential... -calc_model2 = M.ACEPotential(model2) - -## -# Fit the ACE1 model - -# specify a solver -solver=ACEfit.TruncatedSVD(; rtol = 1e-8) -acefit!(model1, train; solver=solver) +ps, st = Lux.setup(rng, model2) +calc_model2 = M.ACEPotential(model2, ps, st) ## @@ -80,111 +59,17 @@ test2 = FlexibleSystem.(test) data_keys = (E_key = :energy, F_key = :force, ) weights = (wE = wE/u"eV", wF = wF / u"eV/Å", ) -function local_lsqsys(calc, at, ps, st, weights, keys) - efv = M.energy_forces_virial_basis(at, calc, ps, st) - - # There are no E0s in this dataset! - # # compute the E0s contribution. This needs to be done more - # # elegantly and a stacked model would solve this problem. - # E0 = sum( calc.model.E0s[M._z2i(calc.model, z)] - # for z in AtomsBase.atomic_number(at) ) * u"eV" - - # energy - wE = weights[:wE] - E_dft = at.data[data_keys.E_key] * u"eV" - y_E = wE * E_dft / sqrt(length(at)) # (E_dft - E0) - A_E = wE * efv.energy' / sqrt(length(at)) - - # forces - wF = weights[:wF] - F_dft = at.data[data_keys.F_key] * u"eV/Å" - y_F = wF * reinterpret(eltype(F_dft[1]), F_dft) - A_F = wF * reinterpret(eltype(efv.forces[1]), efv.forces) - - # # virial - # wV = weights[:wV] - # V_dft = at.data[data_keys.V_key] * u"eV" - # y_V = wV * V_dft[:] - # # display( reinterpret(eltype(efv.virial), efv.virial) ) - # A_V = wV * reshape(reinterpret(eltype(efv.virial[1]), efv.virial), 9, :) - - return vcat(A_E, A_F), vcat(y_E, y_F) -end - - -function assemble_lsq(calc, data, weights, data_keys; - rng = Random.MersenneTwister(1234), - executor = Folds.ThreadedEx()) - ps, st = Lux.setup(rng, calc) - blocks = Folds.map(at -> local_lsqsys(calc, at, ps, st, - weights, data_keys), - data, executor) - - A = reduce(vcat, [b[1] for b in blocks]) - y = reduce(vcat, [b[2] for b in blocks]) - return A, y -end - -A, y = assemble_lsq(calc_model2, train2, weights, data_keys) +A, y = LLSQ.assemble_lsq(calc_model2, train2, weights, data_keys) @show size(A) θ = ACEfit.trunc_svd(svd(A), y, 1e-8) -ps, st = Lux.setup(rng, calc_model2) - -# the next step is a hack. This should be automatable, probably using Lux.freeze. -# But I couldn't quite figure out how to use that. -# Here I'm manually constructing a parameters NamedTuple with rbasis removed. -# then I'm using the destructure / restructure method from Optimizers to -# convert θ into a namedtuple. +calc_model2_fit = LLSQ.set_linear_params(calc_model2, θ) -ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = NamedTuple(), rbasis = NamedTuple()) -_θ, restruct = destructure(ps_lin) -ps_lin_fit = restruct(θ) -ps_fit = deepcopy(ps) -ps_fit.WB[:] = ps_lin_fit.WB[:] -ps_fit.Wpair[:] = ps_lin_fit.Wpair[:] -calc_model2_fit = M.ACEPotential(model2, ps_fit, st) +## +# Look at errors +E_rmse_2, F_rmse_2 = LLSQ.rmse(test2, calc_model2_fit) -## -# Now we can compare errors? -# to make sure we are comparing exactly the same thing, we implement this -# from scratch here ... - -function EF_err(sys::JuLIP.Atoms, calc) - E = JuLIP.energy(calc, sys) - F = JuLIP.forces(calc, sys) - E_ref = JuLIP.get_data(sys, "energy") - F_ref = JuLIP.get_data(sys, "force") - return abs(E - E_ref) / length(sys), norm.(F - F_ref) -end - -function EF_err(sys::AtomsBase.AbstractSystem, calc) - efv = M.energy_forces_virial(sys, calc_model2_fit) - F_ustrip = [ ustrip.(f) for f in efv.forces ] - E_ref = sys.data[:energy] - F_ref = sys.data[:force] - return abs(ustrip(efv.energy) - E_ref) / length(sys), norm.(F_ustrip - F_ref) -end - -function rmse(test, calc) - E_errs = Float64[] - F_errs = Float64[] - for sys in test - E_err, F_err = EF_err(sys, calc) - push!(E_errs, E_err) - append!(F_errs, F_err) - end - return norm(E_errs) / sqrt(length(E_errs)), - norm(F_errs) / sqrt(length(F_errs)) -end - - -E_rmse_1, F_rmse_1 = rmse(test, model1.potential) -E_rmse_2, F_rmse_2 = rmse(test2, calc_model2_fit) - - -@printf("Model | E | F \n") -@printf(" ACE1 | %.2e | %.2e \n", E_rmse_1, F_rmse_1) +@printf("Model | E | F \n") @printf(" ACE2 | %.2e | %.2e \n", E_rmse_2, F_rmse_2) diff --git a/docs/src/newkernels/linear_match.jl b/docs/src/newkernels/linear_match.jl index d6eec60a..64c7d0eb 100644 --- a/docs/src/newkernels/linear_match.jl +++ b/docs/src/newkernels/linear_match.jl @@ -5,8 +5,11 @@ using Random using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, Unitful, Zygote, Optimisers, Folds, Printf -rng = Random.GLOBAL_RNG -M = ACEpotentials.Models +# rng = Random.GLOBAL_RNG +# M = ACEpotentials.Models + +include("match_bases.jl") +include("LLSQ.jl") # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset @@ -35,7 +38,8 @@ acefit!(model1, train; solver=solver, weights=weights) ## # Fit the ACE2 model - this still needs a bit of hacking to convert everything -# to the new framework. +# to the new framework. To be moved into ACEpotentials and hide from +# the user ... # - convert the data to AtomsBase # - use a different interface to specify data weights and keys # (this needs to be brough in line with the ACEpotentials framework) @@ -46,108 +50,17 @@ test2 = FlexibleSystem.(test) data_keys = (E_key = :energy, F_key = :force, ) weights = (wE = 30.0/u"eV", wF = 1.0 / u"eV/Å", ) -function local_lsqsys(calc, at, ps, st, weights, keys) - efv = M.energy_forces_virial_basis(at, calc, ps, st) - - # There are no E0s in this dataset! - # # compute the E0s contribution. This needs to be done more - # # elegantly and a stacked model would solve this problem. - # E0 = sum( calc.model.E0s[M._z2i(calc.model, z)] - # for z in AtomsBase.atomic_number(at) ) * u"eV" - - # energy - wE = weights[:wE] - E_dft = at.data[data_keys.E_key] * u"eV" - y_E = wE * E_dft / sqrt(length(at)) # (E_dft - E0) - A_E = wE * efv.energy' / sqrt(length(at)) - - # forces - wF = weights[:wF] - F_dft = at.data[data_keys.F_key] * u"eV/Å" - y_F = wF * reinterpret(eltype(F_dft[1]), F_dft) - A_F = wF * reinterpret(eltype(efv.forces[1]), efv.forces) - - # # virial - # wV = weights[:wV] - # V_dft = at.data[data_keys.V_key] * u"eV" - # y_V = wV * V_dft[:] - # # display( reinterpret(eltype(efv.virial), efv.virial) ) - # A_V = wV * reshape(reinterpret(eltype(efv.virial[1]), efv.virial), 9, :) - - return vcat(A_E, A_F), vcat(y_E, y_F) -end - - -function assemble_lsq(calc, data, weights, data_keys; - rng = Random.MersenneTwister(1234), - executor = Folds.ThreadedEx()) - ps, st = Lux.setup(rng, calc) - blocks = Folds.map(at -> local_lsqsys(calc, at, ps, st, - weights, data_keys), - data, executor) - - A = reduce(vcat, [b[1] for b in blocks]) - y = reduce(vcat, [b[2] for b in blocks]) - return A, y -end - -A, y = assemble_lsq(calc_model2, train2, weights, data_keys) +A, y = LLSQ.assemble_lsq(calc_model2, train2, weights, data_keys) @show size(A) θ = ACEfit.trunc_svd(svd(A), y, 1e-8) -ps, st = Lux.setup(rng, calc_model2) - -# the next step is a hack. This should be automatable, probably using Lux.freeze. -# But I couldn't quite figure out how to use that. -# Here I'm manually constructing a parameters NamedTuple with rbasis removed. -# then I'm using the destructure / restructure method from Optimizers to -# convert θ into a namedtuple. - -ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = NamedTuple(), rbasis = NamedTuple()) -_θ, restruct = destructure(ps_lin) -ps_lin_fit = restruct(θ) -ps_fit = deepcopy(ps) -ps_fit.WB[:] = ps_lin_fit.WB[:] -ps_fit.Wpair[:] = ps_lin_fit.Wpair[:] -calc_model2_fit = M.ACEPotential(model2, ps_fit, st) - +calc_model2_fit = LLSQ.set_linear_params(calc_model2, θ) ## -# Now we can compare errors? -# to make sure we are comparing exactly the same thing, we implement this -# from scratch here ... - -function EF_err(sys::JuLIP.Atoms, calc) - E = JuLIP.energy(calc, sys) - F = JuLIP.forces(calc, sys) - E_ref = JuLIP.get_data(sys, "energy") - F_ref = JuLIP.get_data(sys, "force") - return abs(E - E_ref) / length(sys), norm.(F - F_ref) -end - -function EF_err(sys::AtomsBase.AbstractSystem, calc) - efv = M.energy_forces_virial(sys, calc_model2_fit) - F_ustrip = [ ustrip.(f) for f in efv.forces ] - E_ref = sys.data[:energy] - F_ref = sys.data[:force] - return abs(ustrip(efv.energy) - E_ref) / length(sys), norm.(F_ustrip - F_ref) -end - -function rmse(test, calc) - E_errs = Float64[] - F_errs = Float64[] - for sys in test - E_err, F_err = EF_err(sys, calc) - push!(E_errs, E_err) - append!(F_errs, F_err) - end - return norm(E_errs) / sqrt(length(E_errs)), - norm(F_errs) / sqrt(length(F_errs)) -end - - -E_rmse_1, F_rmse_1 = rmse(test, model1.potential) -E_rmse_2, F_rmse_2 = rmse(test2, calc_model2_fit) +# compute the errors + +E_rmse_1, F_rmse_1 = LLSQ.rmse(test, model1.potential) +E_rmse_2, F_rmse_2 = LLSQ.rmse(test2, calc_model2_fit) @printf("Model | E | F \n") diff --git a/docs/src/newkernels/llsq.jl b/docs/src/newkernels/llsq.jl new file mode 100644 index 00000000..6fa1fe0f --- /dev/null +++ b/docs/src/newkernels/llsq.jl @@ -0,0 +1,98 @@ + +module LLSQ + +using Random, LinearAlgebra, Folds, Lux, Optimisers +using ACEpotentials +M = ACEpotentials.Models +using Random: MersenneTwister + + +function local_lsqsys(calc, at, ps, st, weights, keys) + efv = M.energy_forces_virial_basis(at, calc, ps, st) + + # There are no E0s in this dataset! + # # compute the E0s contribution. This needs to be done more + # # elegantly and a stacked model would solve this problem. + # E0 = sum( calc.model.E0s[M._z2i(calc.model, z)] + # for z in AtomsBase.atomic_number(at) ) * u"eV" + + # energy + wE = weights[:wE] + E_dft = at.data[keys.E_key] * u"eV" + y_E = wE * E_dft / sqrt(length(at)) # (E_dft - E0) + A_E = wE * efv.energy' / sqrt(length(at)) + + # forces + wF = weights[:wF] + F_dft = at.data[keys.F_key] * u"eV/Å" + y_F = wF * reinterpret(eltype(F_dft[1]), F_dft) + A_F = wF * reinterpret(eltype(efv.forces[1]), efv.forces) + + # # virial + # wV = weights[:wV] + # V_dft = at.data[keys.V_key] * u"eV" + # y_V = wV * V_dft[:] + # # display( reinterpret(eltype(efv.virial), efv.virial) ) + # A_V = wV * reshape(reinterpret(eltype(efv.virial[1]), efv.virial), 9, :) + + return vcat(A_E, A_F), vcat(y_E, y_F) +end + + +function assemble_lsq(calc, data, weights, data_keys; + rng = MersenneTwister(1234), + executor = Folds.ThreadedEx()) + ps, st = Lux.setup(rng, calc) + blocks = Folds.map(at -> local_lsqsys(calc, at, ps, st, + weights, data_keys), + data, executor) + + A = reduce(vcat, [b[1] for b in blocks]) + y = reduce(vcat, [b[2] for b in blocks]) + return A, y +end + +function EF_err(sys::JuLIP.Atoms, calc) + E = JuLIP.energy(calc, sys) + F = JuLIP.forces(calc, sys) + E_ref = JuLIP.get_data(sys, "energy") + F_ref = JuLIP.get_data(sys, "force") + return abs(E - E_ref) / length(sys), norm.(F - F_ref) +end + +function EF_err(sys::AtomsBase.AbstractSystem, calc) + efv = M.energy_forces_virial(sys, calc) + F_ustrip = [ ustrip.(f) for f in efv.forces ] + E_ref = sys.data[:energy] + F_ref = sys.data[:force] + return abs(ustrip(efv.energy) - E_ref) / length(sys), norm.(F_ustrip - F_ref) +end + +function rmse(test, calc) + E_errs = Float64[] + F_errs = Float64[] + for sys in test + E_err, F_err = EF_err(sys, calc) + push!(E_errs, E_err) + append!(F_errs, F_err) + end + return norm(E_errs) / sqrt(length(E_errs)), + norm(F_errs) / sqrt(length(F_errs)) +end + + +function set_linear_params(calc, θ) + # TODO: replace the first line with extracting the parameters + # from the calculator!! + ps, st = Lux.setup(MersenneTwister(1234), calc.model) + ps_lin = (WB = ps.WB, Wpair = ps.Wpair, pairbasis = NamedTuple(), rbasis = NamedTuple()) + _θ, restruct = destructure(ps_lin) + ps_lin_fit = restruct(θ) + ps_fit = deepcopy(ps) + ps_fit.WB[:] = ps_lin_fit.WB[:] + ps_fit.Wpair[:] = ps_lin_fit.Wpair[:] + calc_model_fit = M.ACEPotential(calc.model, ps_fit, st) + return calc_model_fit +end + +end \ No newline at end of file diff --git a/docs/src/newkernels/match_bases.jl b/docs/src/newkernels/match_bases.jl index 828117a7..6c7d68b5 100644 --- a/docs/src/newkernels/match_bases.jl +++ b/docs/src/newkernels/match_bases.jl @@ -9,6 +9,7 @@ M = ACEpotentials.Models function matching_bases(; Z = :Si, order = 3, totaldegree = 10, rcut = 5.5) + elements = [Z, ] model1 = acemodel(elements = elements, order = order, transform = (:agnesi, 2, 2), @@ -35,7 +36,7 @@ function matching_bases(; Z = :Si, order = 3, totaldegree = 10, rin0cuts = rin0cuts, ) - ps, st = Lux.setup(rng, model2) + ps, st = Lux.setup(Random.GLOBAL_RNG, model2) ps_r = ps.rbasis st_r = st.rbasis diff --git a/docs/src/newkernels/test_matching_bases.jl b/docs/src/newkernels/test_matching_bases.jl index a8143d88..eb3687dd 100644 --- a/docs/src/newkernels/test_matching_bases.jl +++ b/docs/src/newkernels/test_matching_bases.jl @@ -11,8 +11,8 @@ Z0 = :Si z1 = AtomicNumber(Z0) z2 = Int(z1) -ps, st = Lux.setup(Random.GLOBAL_RNG, model2) model1, model2, calc_model2 = matching_bases(; Z = Z0) +ps, st = Lux.setup(Random.GLOBAL_RNG, model2) ## # confirm match on atomic environments From 51bf4d985f2670cd2ceb6a6d56bbdfa45acabcaf Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 29 Jul 2024 16:44:52 -0700 Subject: [PATCH 087/112] linear-nonlinear optim example --- docs/src/newkernels/Project.toml | 1 + docs/src/newkernels/linear_nonlinear.jl | 169 ++++++++++++++++++++++++ src/models/calculators.jl | 14 +- 3 files changed, 178 insertions(+), 6 deletions(-) create mode 100644 docs/src/newkernels/linear_nonlinear.jl diff --git a/docs/src/newkernels/Project.toml b/docs/src/newkernels/Project.toml index b4c8cc92..7f83ce95 100644 --- a/docs/src/newkernels/Project.toml +++ b/docs/src/newkernels/Project.toml @@ -7,6 +7,7 @@ AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" GeomOpt = "ca147568-c688-4a55-a13d-dbd284330f4b" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/docs/src/newkernels/linear_nonlinear.jl b/docs/src/newkernels/linear_nonlinear.jl new file mode 100644 index 00000000..845ca32c --- /dev/null +++ b/docs/src/newkernels/linear_nonlinear.jl @@ -0,0 +1,169 @@ +# This script is to roughly document how to use the new model implementations +# I'll try to explain what can be done and what is missing along the way. +# I am + +using Random +using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, + Unitful, Zygote, Optimisers, Folds, Printf, Optim, LineSearches +rng = Random.GLOBAL_RNG +M = ACEpotentials.Models + +include(@__DIR__() * "/LLSQ.jl") + +## +# we will try this for a simple dataset, Zuo et al +# replace element with any of those available in that dataset + +Z0 = :Si +train, test, _ = ACEpotentials.example_dataset("Zuo20_$Z0") +train = train[1:3:end] +wE = 30.0; wF = 1.0; wV = 1.0 + +## +# First we create an ACE1 style potential with some standard parameters + +elements = [Z0,] +order = 3 +totaldegree = 12 + +model2 = M.ace_model(; elements = elements, + order = order, # correlation order + Ytype = :spherical, # solid vs spherical harmonics + level = M.TotalDegree(), # how to calculate the weights to give to a basis function + max_level = totaldegree+1, # maximum level of the basis functions + pair_maxn = totaldegree, # maximum number of basis functions for the pair potential + init_WB = :zeros, # how to initialize the ACE basis parmeters + init_Wpair = "linear", # how to initialize the pair potential parameters + init_Wradial = :linear, + pair_transform = (:agnesi, 1, 3), + pair_learnable = false, + ) + + +# wrap the model into a calculator, which turns it into a potential... + +ps, st = Lux.setup(rng, model2) +calc_model2 = M.ACEPotential(model2, ps, st) + + +## +# Fit the ACE2 model - this still needs a bit of hacking to convert everything +# to the new framework. +# - convert the data to AtomsBase +# - use a different interface to specify data weights and keys +# (this needs to be brough in line with the ACEpotentials framework) +# - rewrite the assembly for the LSQ system from scratch (but this is easy) + +train2 = FlexibleSystem.(train) +test2 = FlexibleSystem.(test) +data_keys = (E_key = :energy, F_key = :force, ) +weights = (wE = wE/u"eV", wF = wF / u"eV/Å", ) + +A, y = LLSQ.assemble_lsq(calc_model2, train2, weights, data_keys) +@show size(A) + +θ = ACEfit.trunc_svd(svd(A), y, 1e-8) +calc_model2_fit = LLSQ.set_linear_params(calc_model2, θ) + +## +# Look at errors + +E_train, F_train = LLSQ.rmse(train2, calc_model2_fit) +E_test, F_test = LLSQ.rmse(test2, calc_model2_fit) + +@printf(" | E | F \n") +@printf(" train | %.2e | %.2e \n", E_train, F_train) +@printf(" test | %.2e | %.2e \n", E_test, F_test) + + +## +# Now we can do some nonlinear iterations on the model + +# First we need to define a loss (dropping virials here...) + +loss = let data_keys = data_keys, weights = weights + + function(calc, ps, st, at) + efv = M.energy_forces_virial(at, calc, ps, st) + _norm_sq(f) = sum(abs2, f) + E_dft, F_dft = Zygote.ignore() do + ( at.data[data_keys.E_key] * u"eV", + at.data[data_keys.F_key] * u"eV/Å" ) + end + return ( weights[:wE]^2 * (efv.energy - E_dft)^2 / length(at) + + weights[:wF]^2 * sum(_norm_sq, efv.forces - F_dft) + ), st, () + end +end + + +at1 = train2[1] +calc = deepcopy(calc_model2_fit) +loss(calc, calc.ps, calc.st, at1) + +g = Zygote.gradient(ps -> loss(calc, ps, st, at1)[1], ps)[1] + + +ps_vec, _restruct = destructure(calc.ps) + +function total_loss(p_vec) + return sum( loss(calc, _restruct(p_vec), st, at)[1] + for at in train2 ) +end + +function total_loss_grad!(g, p_vec) + g[:] = Zygote.gradient(ps -> total_loss(ps), p_vec)[1] + return g +end + +total_loss_grad(p_vec) = total_loss_grad!(zeros(length(ps_vec)), ps_vec) + +# these are reasonably fast - check with these: +# @time total_loss(ps_vec) +# @time total_loss(ps_vec) +# @time total_loss_grad!(zeros(length(ps_vec)), ps_vec) +# @time total_loss_grad!(zeros(length(ps_vec)), ps_vec) + +g = total_loss_grad!(zeros(length(ps_vec)), ps_vec) + +@info("Start the optimization") +method = GradientDescent(; alphaguess = InitialHagerZhang(α0=1.0), + linesearch = LineSearches.BackTracking(; order=2),) + +result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; + method = method, + show_trace = true, + iterations = 30) # obviously this needs more iterations + +# this actually terminates unsuccessfull, without +# progress. not sure why it says it was successful ... + +## +# we can try something purely manual ... +ps_vec1 = deepcopy(ps_vec) + +## + +# the following look shows that it seems very very hard to +# reduce the loss further. Unclear what is going on here... + +α = 1e-8 +for n = 1:10 + l = total_loss(ps_vec1) + g = total_loss_grad(ps_vec1) + @printf(" %.2e | %.2e \n", l, norm(g, Inf)) + ps_vec1 -= α * g +end + +## + +# trying Adam now? This is a bit randomized and it shows +# that as soon as we perturb a bit, we get a much higher +# loss and gradient. It suggests that the LLSQ picks out +# an extremely unstable minimum. + +method = Optim.Adam() +result = Optim.optimize(total_loss, total_loss_grad!, ps_vec; + method = method, + show_trace = true, + iterations = 30) # obviously this needs more iterations diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 9d10fc0c..14a7ad62 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -181,13 +181,15 @@ function pullback_EFV(Δefv, Δei = _ustrip(Δefv.energy) - # them adjoint for dV needs combination of the virial and forces pullback + # the adjoint for dV needs combination of the virial and forces pullback Δdi = [ - _ustrip.(Δefv.virial * rj) for rj in Rs ] - for α = 1:length(Js) - # F[Js[α]] -= dV[α], F[i] += dV[α] - # ∂_dvj { Δf[Js[α]] * F[Js[α]] } -> - Δdi[α] -= _ustrip.( Δefv.forces[Js[α]] ) - Δdi[α] += _ustrip.( Δefv.forces[i] ) + if eltype(Δdi) != ZeroTangent + for α = 1:length(Js) + # F[Js[α]] -= dV[α], F[i] += dV[α] + # ∂_dvj { Δf[Js[α]] * F[Js[α]] } -> + Δdi[α] -= _ustrip.( Δefv.forces[Js[α]] ) + Δdi[α] += _ustrip.( Δefv.forces[i] ) + end end # now we can apply the pullback through evaluate_ed From 62960d9df0b0b1ee30acb873a33adf8c6864d2d8 Mon Sep 17 00:00:00 2001 From: CheukHinHoJerry Date: Mon, 29 Jul 2024 21:07:53 -0700 Subject: [PATCH 088/112] minor typo fix --- docs/src/newkernels/linear_match.jl | 2 +- docs/src/newkernels/linear_nonlinear.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/newkernels/linear_match.jl b/docs/src/newkernels/linear_match.jl index 64c7d0eb..0804737c 100644 --- a/docs/src/newkernels/linear_match.jl +++ b/docs/src/newkernels/linear_match.jl @@ -9,7 +9,7 @@ using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, # M = ACEpotentials.Models include("match_bases.jl") -include("LLSQ.jl") +include("llsq.jl") # we will try this for a simple dataset, Zuo et al # replace element with any of those available in that dataset diff --git a/docs/src/newkernels/linear_nonlinear.jl b/docs/src/newkernels/linear_nonlinear.jl index 845ca32c..aacf5680 100644 --- a/docs/src/newkernels/linear_nonlinear.jl +++ b/docs/src/newkernels/linear_nonlinear.jl @@ -8,7 +8,7 @@ using ACEpotentials, AtomsBase, AtomsBuilder, Lux, StaticArrays, LinearAlgebra, rng = Random.GLOBAL_RNG M = ACEpotentials.Models -include(@__DIR__() * "/LLSQ.jl") +include(@__DIR__() * "/llsq.jl") ## # we will try this for a simple dataset, Zuo et al From 6570ee222bdee0df8e649147ced366d9c1f3caf7 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Jul 2024 14:48:16 -0700 Subject: [PATCH 089/112] test for connor bug --- test/test_bugs.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 test/test_bugs.jl diff --git a/test/test_bugs.jl b/test/test_bugs.jl new file mode 100644 index 00000000..d95af5d3 --- /dev/null +++ b/test/test_bugs.jl @@ -0,0 +1,32 @@ + +using ACEpotentials, Test +using Random: seed! + +@info(" ============== Testing for ACEpotentials 208 ================") +@info(" On Julia 1.9 some energy computations were inconsistent. ") + +model = acemodel(elements = [:Ti, ], + order = 3, + totaldegree = 10, + rcut = 6.0, + Eref = [:Ti => -1586.0195, ]) + +# generate random parameters +seed!(1234) +params = randn(length(model.basis)) +# params = params ./ (1:length(params)).^2 # (this is not needed) +ACEpotentials._set_params!(model, params) + +function energy_per_at(pot, i) + at = bulk(:Ti) * i + return JuLIP.energy(pot, at) / length(at) +end + +E_per_at = [ energy_per_at(model.potential, i) for i = 1:10 ] + +maxdiff = maximum(abs(E_per_at[i] - E_per_at[j]) for i = 1:10, j = 1:10 ) +@show maxdiff + +@test maxdiff < 1e-12 + +@info(" ============================================================") From a0a97317f613b6c31e187d5f833a148b69783795 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Jul 2024 16:09:47 -0700 Subject: [PATCH 090/112] update test and julia bound --- .github/workflows/CI.yml | 3 +-- Project.toml | 2 +- test/runtests.jl | 3 +++ test/test_bugs.jl | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c8b936ab..057fbad1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,8 +19,7 @@ jobs: fail-fast: false matrix: julia-version: - - '1.9' - - '1' + - '1.10' - 'nightly' python-version: - '3.8' diff --git a/Project.toml b/Project.toml index d38b838b..657a0b66 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ PrettyTables = "1.3, 2.0" Reexport = "1" StaticArrays = "1" YAML = "0.4" -julia = "1.9" +julia = "~1.10" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index f54407d5..dae9c9fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,9 @@ using ACEpotentials, Test, LazyArtifacts @testset "New Models" begin include("models/test_models.jl") end + # weird stuff + @testset "Weird bugs" begin include("test_bugs.jl") end + # outdated # @testset "Read data" begin include("outdated/test_data.jl") end # @testset "Basis" begin include("outdated/test_basis.jl") end diff --git a/test/test_bugs.jl b/test/test_bugs.jl index d95af5d3..d350d7de 100644 --- a/test/test_bugs.jl +++ b/test/test_bugs.jl @@ -27,6 +27,6 @@ E_per_at = [ energy_per_at(model.potential, i) for i = 1:10 ] maxdiff = maximum(abs(E_per_at[i] - E_per_at[j]) for i = 1:10, j = 1:10 ) @show maxdiff -@test maxdiff < 1e-12 +@test maxdiff < 1e-9 @info(" ============================================================") From 0e5e2ad6a5d791ab29d21bd760d2d5a8b806dcdb Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Jul 2024 21:39:11 -0700 Subject: [PATCH 091/112] stop reexporting ACE1 --- src/ACEpotentials.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ACEpotentials.jl b/src/ACEpotentials.jl index 991328e9..1efec4cc 100644 --- a/src/ACEpotentials.jl +++ b/src/ACEpotentials.jl @@ -2,7 +2,8 @@ module ACEpotentials using Reexport @reexport using JuLIP -@reexport using ACE1 +using ACE1 +export ACE1 @reexport using ACE1x @reexport using ACEfit @reexport using ACEmd From ad48bd7a41c2eaca35f36d4f7796e43e5fb2f50b Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 30 Jul 2024 21:53:52 -0700 Subject: [PATCH 092/112] fix non-exports --- src/ACEpotentials.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ACEpotentials.jl b/src/ACEpotentials.jl index 1efec4cc..e1872483 100644 --- a/src/ACEpotentials.jl +++ b/src/ACEpotentials.jl @@ -2,9 +2,15 @@ module ACEpotentials using Reexport @reexport using JuLIP + using ACE1 export ACE1 -@reexport using ACE1x + +using ACE1x +export ACE1x +import ACE1x: ace_basis, smoothness_prior, ace_defaults, acemodel +export ace_basis, smoothness_prior, ace_defaults, acemodel + @reexport using ACEfit @reexport using ACEmd From b5952e670ae7f95aed79107104284fd84d429dd7 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 1 Aug 2024 08:00:16 -0700 Subject: [PATCH 093/112] start on ace1 compat --- src/models/ace1_compat.jl | 338 ++++++++++++++++++++++++++++++++ src/models/defaults.jl | 28 +++ src/models/smoothness_priors.jl | 37 ++++ test/test_ace1_compat.jl | 14 ++ 4 files changed, 417 insertions(+) create mode 100644 src/models/ace1_compat.jl create mode 100644 src/models/defaults.jl create mode 100644 src/models/smoothness_priors.jl create mode 100644 test/test_ace1_compat.jl diff --git a/src/models/ace1_compat.jl b/src/models/ace1_compat.jl new file mode 100644 index 00000000..eebb7e9f --- /dev/null +++ b/src/models/ace1_compat.jl @@ -0,0 +1,338 @@ +# This module is a translation of most ACE1x options to the new ACE +# kernels. It is used to provide compatibility with the old ACE1 / ACE1x. + +module ACE1compat + +ace1_defaults() = deepcopy(_kw_defaults) + +const _kw_defaults = Dict(:elements => nothing, + :order => nothing, + :totaldegree => nothing, + :wL => 1.5, + # + :rin => 0.0, + :r0 => :bondlen, + :rcut => (:bondlen, 2.5), + :transform => (:agnesi, 2, 4), + :envelope => (:x, 2, 2), + :rbasis => :legendre, + # + :pure2b => false, # TODO: ALLOW TRUE!!! + :delete2b => false, # here too + :pure => false, + # + :pair_rin => :rin, + :pair_rcut => :rcut, + :pair_degree => :totaldegree, + :pair_transform => (:agnesi, 1, 3), + :pair_basis => :legendre, + :pair_envelope => (:r, 2), + # + :Eref => missing, + #temporary variable to specify whether using the variable cutoffs or not + :variable_cutoffs => false, + ) + +const _kw_aliases = Dict( :N => :order, + :species => :elements, + :trans => :transform, + ) + + +function _clean_args(kwargs) + dargs = Dict{Symbol, Any}() + for key in keys(kwargs) + if haskey(_kw_aliases, key) + dargs[_kw_aliases[key]] = kwargs[key] + else + dargs[key] = kwargs[key] + end + end + for key in keys(_kw_defaults) + if !haskey(dargs, key) + dargs[key] = _kw_defaults[key] + end + end + + if dargs[:pair_rcut] == :rcut + dargs[:pair_rcut] = dargs[:rcut] + end + + return namedtuple(dargs) +end + +function _get_order(kwargs) + if haskey(kwargs, :order) + return kwargs[:order] + elseif haskey(kwargs, :bodyorder) + return kwargs[:bodyorder] - 1 + end + error("Cannot determine correlation order or body order of ACE basis from the arguments provided.") +end + +function _get_degrees(kwargs) + + if haskey(kwargs, :totaldegree) + deg = kwargs[:totaldegree] + cor_order = _get_order(kwargs) + + if deg isa Number + maxlevels = [deg for i in 1:cor_order] + elseif deg isa AbstractVector{<: Number} + @assert length(deg) == cor_order + maxlevels = deg + else + error("Cannot determine total degree of ACE basis from the arguments provided.") + end + + wL = kwargs[:wL] + basis_selector = BasisSelector(cor_order, maxlevels, + TotalDegree(1.0, wL)) + maxn = maximum(maxlevels) + + # return basis_selector, maxdeg, maxn + return basis_selector + end + + error("Cannot determine total degree of ACE basis from the arguments provided.") +end + +function _get_r0(kwargs, z1, z2) + if kwargs[:r0] == :bondlen + return DefaultHypers.bond_len(z1, z2) + elseif kwargs[:r0] isa Number + return kwargs[:r0] + elseif kwargs[:r0] isa Dict + return kwargs[:r0][(z1, z2)] + end + error("Unable to determine r0($z1, $z2) from the arguments provided.") +end + +function _get_elements(kwargs) + return [ kwargs[:elements]... ] +end + +function _get_all_r0(kwargs) + elements = _get_elements(kwargs) + r0 = Dict( [ (s1, s2) => _get_r0(kwargs, s1, s2) + for s1 in elements, s2 in elements]... ) +end + +function _get_rcut(kwargs, s1, s2; _rcut = kwargs[:rcut]) + if _rcut isa Tuple + if _rcut[1] == :bondlen # rcut = (:bondlen, rcut_factor) + return _rcut[2] * get_r0(s1, s2) + end + elseif _rcut isa Number # rcut = explicit value + return _rcut + elseif _rcut isa Dict # explicit values for each pair + return _rcut[(s1, s2)] + end + error("Unable to determine rcut($s1, $s2) from the arguments provided.") +end + +function _get_all_rcut(kwargs; _rcut = kwargs[:rcut]) + if _rcut isa Number + return _rcut + end + elements = _get_elements(kwargs) + rcut = Dict( [ (s1, s2) => _get_rcut(kwargs, s1, s2; _rcut = _rcut) + for s1 in elements, s2 in elements]... ) + return rcut +end + + +function _rin0cuts_rcut(zlist, cutoffs::Dict) + function rin0cut(zi, zj) + r0 = DefaultHypers.bond_len(zi, zj) + return (rin = 0.0, r0 = r0, rcut = cutoffs[zi, zj]) + end + NZ = length(zlist) + return SMatrix{NZ, NZ}([ rin0cut(zi, zj) for zi in zlist, zj in zlist ]) +end + + + +function _transform(kwargs; transform = kwargs[:transform]) + elements = _get_elements(kwargs) + + if transform isa Tuple + if transform[1] == :agnesi + if length(transform) != 3 + error("The ACE1 compatibility only supports (:agnesi, p, q) type transforms.") + end + + p = transform[2] + q = transform[3] + r0 = _get_all_r0(kwargs) + rcut = _get_all_rcut(kwargs) + if rcut isa Number || ! kwargs[:variable_cutoffs] + cutoffs = nothing + else + cutoffs = Dict([ (s1, s2) => (0.0, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...) + end + # rcut = maximum(values(rcut)) # multitransform wants a single cutoff. + + # construct the rin0cut structures + rin0cuts = _rin0cuts_rcut(elements, cutoffs) + transforms = agnesi_transform.(rin0cuts, p, q) + return transforms + + # transforms = Dict([ (s1, s2) => agnesi_transform(r0[(s1, s2)], p, q) + # for s1 in elements, s2 in elements]... ) + # trans_ace = multitransform(transforms; rin = 0.0, rcut = rcut, cutoffs=cutoffs) + # return trans_ace + end + end + + error("Unable to determine transform from the arguments provided.") +end + +#= + +function _radial_basis(kwargs) + rbasis = kwargs[:rbasis] + + if rbasis isa ACE1.ScalarBasis + return rbasis + + elseif rbasis == :legendre + Deg, maxdeg, maxn = _get_degrees(kwargs) + cor_order = _get_order(kwargs) + envelope = kwargs[:envelope] + if envelope isa Tuple && envelope[1] == :x + pin = envelope[2] + pcut = envelope[3] + if (kwargs[:pure2b] || kwargs[:pure]) + maxn += (pin + pcut) * (cor_order-1) + end + else + error("Cannot construct the radial basis automatically without knowing the envelope.") + end + + trans_ace = _transform(kwargs) + + Rn_basis = transformed_jacobi(maxn, trans_ace; pcut = pcut, pin = pin) + # println("pcut is", pcut, "pin is", pin, "trans_ace is", trans_ace) + # println(kwargs) + #Rn_basis = transformed_jacobi(maxn, trans_ace, kwargs[:rcut], kwargs[:rin];) + return Rn_basis + end + + error("Unable to determine the radial basis from the arguments provided.") +end + + + + +function _pair_basis(kwargs) + rbasis = kwargs[:pair_basis] + elements = _get_elements(kwargs) + #elements has to be sorted becuase PolyPairBasis (see end of function) assumes sorted. + if kwargs[:variable_cutoffs] + elements = [chemical_symbol(z) for z in JuLIP.Potentials.ZList(elements, static=true).list] + end + + if rbasis isa ACE1.ScalarBasis + return rbasis + + elseif rbasis == :legendre + if kwargs[:pair_degree] == :totaldegree + Deg, maxdeg, maxn = _get_degrees(kwargs) + elseif kwargs[:pair_degree] isa Integer + maxn = kwargs[:pair_degree] + else + error("Cannot determine `maxn` for pair basis from information provided.") + end + + allrcut = _get_all_rcut(kwargs; _rcut = kwargs[:pair_rcut]) + if allrcut isa Number + allrcut = Dict([(s1, s2) => allrcut for s1 in elements, s2 in elements]...) + end + + trans_pair = _transform(kwargs, transform = kwargs[:pair_transform]) + _s2i(s) = z2i(trans_pair.zlist, AtomicNumber(s)) + alltrans = Dict([(s1, s2) => trans_pair.transforms[_s2i(s1), _s2i(s2)].t + for s1 in elements, s2 in elements]...) + + allr0 = _get_all_r0(kwargs) + + function _r_basis(s1, s2, penv) + _env = ACE1.PolyEnvelope(penv, allr0[(s1, s2)], allrcut[(s1, s2)] ) + return transformed_jacobi_env(maxn, alltrans[(s1, s2)], _env, allrcut[(s1, s2)]) + end + + _x_basis(s1, s2, pin, pcut) = transformed_jacobi(maxn, alltrans[(s1, s2)], allrcut[(s1, s2)]; + pcut = pcut, pin = pin) + + envelope = kwargs[:pair_envelope] + if envelope isa Tuple + if envelope[1] == :x + pin = envelope[2] + pcut = envelope[3] + rbases = [ _x_basis(s1, s2, pin, pcut) for s1 in elements, s2 in elements ] + elseif envelope[1] == :r + penv = envelope[2] + rbases = [ _r_basis(s1, s2, penv) for s1 in elements, s2 in elements ] + end + end + end + + return PolyPairBasis(rbases, elements) +end + + + +function mb_ace_basis(kwargs) + elements = _get_elements(kwargs) + cor_order = _get_order(kwargs) + Deg, maxdeg, maxn = _get_degrees(kwargs) + rbasis = _radial_basis(kwargs) + pure2b = kwargs[:pure2b] + + if pure2b && kwargs[:pure] + # error("Cannot use both `pure2b` and `pure` options.") + @info("Option `pure = true` overrides `pure2b=true`") + pure2b = false + end + + if pure2b + rpibasis = Pure2b.pure2b_basis(species = AtomicNumber.(elements), + Rn=rbasis, + D=Deg, + maxdeg=maxdeg, + order=cor_order, + delete2b = kwargs[:delete2b]) + elseif kwargs[:pure] + dirtybasis = ACE1.ace_basis(species = AtomicNumber.(elements), + rbasis=rbasis, + D=Deg, + maxdeg=maxdeg, + N = cor_order, ) + _rem = kwargs[:delete2b] ? 1 : 0 + # remove all zero-basis functions that we might have accidentally created so that we purify less extra basis + dirtybasis = ACE1.RPI.remove_zeros(dirtybasis) + # and finally cleanup the rest of the basis + dirtybasis = ACE1._cleanup(dirtybasis) + # finally purify + rpibasis = ACE1x.Purify.pureRPIBasis(dirtybasis; remove = _rem) + else + rpibasis = ACE1.ace_basis(species = AtomicNumber.(elements), + rbasis=rbasis, + D=Deg, + maxdeg=maxdeg, + N = cor_order, ) + end + + return rpibasis +end + +function ace_basis(; kwargs...) + kwargs = _clean_args(kwargs) + rpiB = mb_ace_basis(kwargs) + pairB = _pair_basis(kwargs) + return JuLIP.MLIPs.IPSuperBasis([pairB, rpiB]); +end +=# + +end \ No newline at end of file diff --git a/src/models/defaults.jl b/src/models/defaults.jl new file mode 100644 index 00000000..6f932abe --- /dev/null +++ b/src/models/defaults.jl @@ -0,0 +1,28 @@ +module DefaultHypers + +import YAML + +# -------------- Bond-length heuristics + +_lengthscales_path = joinpath(@__DIR__, "..", "..", "data", + "length_scales_VASP_auto_length_scales.yaml") +_lengthscales = YAML.load_file(_lengthscales_path) + +bond_len(s::Symbol) = bond_len(AtomicNumber(s)) +bond_len(z::AtomicNumber) = bond_len(convert(Int, z)) + +function bond_len(z::Integer) + if haskey(_lengthscales, z) + return _lengthscales[z]["bond_len"][1] + elseif rnn(AtomicNumber(z)) > 0 + return rnn(AtomicNumber(z)) + end + error("No typical bond length for atomic number $z is known. Please specify manually.") +end + +bond_len(z1, z2) = (bond_len(z1) + bond_len(z2)) / 2 + + +# -------------- + +end \ No newline at end of file diff --git a/src/models/smoothness_priors.jl b/src/models/smoothness_priors.jl new file mode 100644 index 00000000..661935fc --- /dev/null +++ b/src/models/smoothness_priors.jl @@ -0,0 +1,37 @@ + + +# -------------------------------------------------- +# different notions of "level" / total degree. +# selecting the basis in this way is assumed smoothness of the target +# and is closely related to the choice of smoothness prior. + +abstract type AbstractLevel end +struct TotalDegree <: AbstractLevel + wn::Float64 + wl::Float64 +end + +TotalDegree() = TotalDegree(1.0, 2/3) + +(l::TotalDegree)(b::NamedTuple) = b.n/l.wn + b.l/l.wl +(l::TotalDegree)(bb::AbstractVector{<: NamedTuple}) = sum(l(b) for b in bb) + + +struct EuclideanDegree <: AbstractLevel + wn::Float64 + wl::Float64 +end + +EuclideanDegree() = EuclideanDegree(1.0, 2/3) + +(l::EuclideanDegree)(b::NamedTuple) = sqrt( (b.n/l.wn)^2 + (b.l/l.wl)^2 ) +(l::EuclideanDegree)(bb::AbstractVector{<: NamedTuple}) = sqrt( sum(l(b)^2 for b in bb) ) + + +struct BasisSelector + order::Int + maxlevels::AbstractVector{<: Number} + level +end + +# -------------------------------------------------- diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl new file mode 100644 index 00000000..e36b6ac9 --- /dev/null +++ b/test/test_ace1_compat.jl @@ -0,0 +1,14 @@ + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); + +## + +using Random, Test, ACEbase +using ACEbase.Testing: print_tf, println_slim + +using ACEpotentials +M = ACEpotentials.Models + +## + From c3588f94ef450bfd8c23bf25089f7c7a57d25bc2 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 1 Aug 2024 15:24:28 -0700 Subject: [PATCH 094/112] ace1-compat - match radials --- Project.toml | 3 +- data/README | 1 + ...length_scales_VASP_auto_length_scales.yaml | 304 ++++++++++++++++++ src/ACEpotentials.jl | 4 + src/{models => }/ace1_compat.jl | 118 +++++-- src/{models => }/defaults.jl | 3 +- src/models/Rnl_learnable.jl | 2 +- src/models/ace_heuristics.jl | 44 ++- src/models/elements.jl | 21 -- src/models/models.jl | 2 + src/models/smoothness_priors.jl | 16 +- test/test_ace1_compat.jl | 72 ++++- 12 files changed, 501 insertions(+), 89 deletions(-) create mode 100644 data/README create mode 100644 data/length_scales_VASP_auto_length_scales.yaml rename src/{models => }/ace1_compat.jl (75%) rename src/{models => }/defaults.jl (88%) diff --git a/Project.toml b/Project.toml index 691ebd2e..703c004d 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" NeighbourLists = "2fcf5ba9-9ed4-57cf-b73f-ff513e316b9c" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -53,6 +54,7 @@ ACE1x = "0.1.8" ACEfit = "0.1.4" ACEmd = "0.1.6" AtomsCalculators = "0.1" +EquivariantModels = "0.0.4" ExtXYZ = "0.1.14" Interpolations = "0.15" JuLIP = "0.13.9, 0.14.2" @@ -60,7 +62,6 @@ PrettyTables = "1.3, 2.0" Reexport = "1" StaticArrays = "1" YAML = "0.4" -EquivariantModels = "0.0.4" julia = "~1.10.0" [extras] diff --git a/data/README b/data/README new file mode 100644 index 00000000..4566ff11 --- /dev/null +++ b/data/README @@ -0,0 +1 @@ +The file length_scales_VASP_auto_length_scales.yaml is taken from https://github.com/libAtoms/universalSOAP/blob/main/calculations/length_scales_VASP_auto_length_scales.yaml \ No newline at end of file diff --git a/data/length_scales_VASP_auto_length_scales.yaml b/data/length_scales_VASP_auto_length_scales.yaml new file mode 100644 index 00000000..dcb31be2 --- /dev/null +++ b/data/length_scales_VASP_auto_length_scales.yaml @@ -0,0 +1,304 @@ +{1: {"bond_len": [1.2, "NB VASP auto_length_scale"], + "min_bond_len": [0.75, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [3.4, "NB VASP auto_length_scale"]}, + 11: {"bond_len": [3.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [35, "NB VASP auto_length_scale"]}, + 12: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [23, "NB VASP auto_length_scale"]}, + 13: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [17, "NB VASP auto_length_scale"]}, + 14: {"bond_len": [2.4, "NB VASP auto_length_scale"], + "min_bond_len": [2.3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [20, "NB VASP auto_length_scale"]}, + 15: {"bond_len": [2.5, "NB VASP auto_length_scale"], + "min_bond_len": [1.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [15, "NB VASP auto_length_scale"]}, + 16: {"bond_len": [2.3, "NB VASP auto_length_scale"], + "min_bond_len": [1.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [17, "NB VASP auto_length_scale"]}, + 17: {"bond_len": [2.6, "NB VASP auto_length_scale"], + "min_bond_len": [2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [36, "NB VASP auto_length_scale"]}, + 19: {"bond_len": [4.6, "NB VASP auto_length_scale"], + "min_bond_len": [3.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [71, "NB VASP auto_length_scale"]}, + 20: {"bond_len": [3.8, "NB VASP auto_length_scale"], + "min_bond_len": [3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [41, "NB VASP auto_length_scale"]}, + 21: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [22, "NB VASP auto_length_scale"]}, + 22: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.4, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [17, "NB VASP auto_length_scale"]}, + 23: {"bond_len": [2.6, "NB VASP auto_length_scale"], + "min_bond_len": [1.8, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [13, "NB VASP auto_length_scale"]}, + 24: {"bond_len": [2.5, "NB VASP auto_length_scale"], + "min_bond_len": [1.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [12, "NB VASP auto_length_scale"]}, + 26: {"bond_len": [2.5, "NB VASP auto_length_scale"], + "min_bond_len": [1.8, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [11, "NB VASP auto_length_scale"]}, + 27: {"bond_len": [2.5, "NB VASP auto_length_scale"], + "min_bond_len": [2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [11, "NB VASP auto_length_scale"]}, + 28: {"bond_len": [2.5, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [11, "NB VASP auto_length_scale"]}, + 29: {"bond_len": [2.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [12, "NB VASP auto_length_scale"]}, + 3: {"bond_len": [3, "NB VASP auto_length_scale"], + "min_bond_len": [2.4, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [20, "NB VASP auto_length_scale"]}, + 30: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [15, "NB VASP auto_length_scale"]}, + 31: {"bond_len": [3, "NB VASP auto_length_scale"], + "min_bond_len": [2.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [19, "NB VASP auto_length_scale"]}, + 32: {"bond_len": [2.5, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [24, "NB VASP auto_length_scale"]}, + 33: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [20, "NB VASP auto_length_scale"]}, + 34: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2.2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [22, "NB VASP auto_length_scale"]}, + 35: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2.3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [29, "NB VASP auto_length_scale"]}, + 37: {"bond_len": [4.7, "NB VASP auto_length_scale"], + "min_bond_len": [3.7, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [81, "NB VASP auto_length_scale"]}, + 38: {"bond_len": [4.2, "NB VASP auto_length_scale"], + "min_bond_len": [3.8, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [54, "NB VASP auto_length_scale"]}, + 39: {"bond_len": [3.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [33, "NB VASP auto_length_scale"]}, + 4: {"bond_len": [2.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [7.9, "NB VASP auto_length_scale"]}, + 40: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [23, "NB VASP auto_length_scale"]}, + 41: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [18, "NB VASP auto_length_scale"]}, + 42: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [16, "NB VASP auto_length_scale"]}, + 43: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [14, "NB VASP auto_length_scale"]}, + 44: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [14, "NB VASP auto_length_scale"]}, + 45: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [14, "NB VASP auto_length_scale"]}, + 46: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2.4, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [15, "NB VASP auto_length_scale"]}, + 47: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [18, "NB VASP auto_length_scale"]}, + 48: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [23, "NB VASP auto_length_scale"]}, + 49: {"bond_len": [3.4, "NB VASP auto_length_scale"], + "min_bond_len": [2.8, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [27, "NB VASP auto_length_scale"]}, + 5: {"bond_len": [1.8, "NB VASP auto_length_scale"], + "min_bond_len": [1.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [8.3, "NB VASP auto_length_scale"]}, + 50: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [36, "NB VASP auto_length_scale"]}, + 51: {"bond_len": [3.1, "NB VASP auto_length_scale"], + "min_bond_len": [2.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [29, "NB VASP auto_length_scale"]}, + 52: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [32, "NB VASP auto_length_scale"]}, + 53: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.7, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [39, "NB VASP auto_length_scale"]}, + 56: {"bond_len": [4.3, "NB VASP auto_length_scale"], + "min_bond_len": [3.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [59, "NB VASP auto_length_scale"]}, + 57: {"bond_len": [3.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [36, "NB VASP auto_length_scale"]}, + 58: {"bond_len": [3.4, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [27, "NB VASP auto_length_scale"]}, + 59: {"bond_len": [3.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.8, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [36, "NB VASP auto_length_scale"]}, + 6: {"bond_len": [1.4, "NB VASP auto_length_scale"], + "min_bond_len": [1.3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [5.7, "NB VASP auto_length_scale"]}, + 60: {"bond_len": [3.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [35, "NB VASP auto_length_scale"]}, + 61: {"bond_len": [3.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.7, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [34, "NB VASP auto_length_scale"]}, + 62: {"bond_len": [3.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.7, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [33, "NB VASP auto_length_scale"]}, + 64: {"bond_len": [3.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [33, "NB VASP auto_length_scale"]}, + 65: {"bond_len": [3.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [32, "NB VASP auto_length_scale"]}, + 66: {"bond_len": [3.5, "NB VASP auto_length_scale"], + "min_bond_len": [3.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [31, "NB VASP auto_length_scale"]}, + 69: {"bond_len": [3.5, "NB VASP auto_length_scale"], + "min_bond_len": [3.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [30, "NB VASP auto_length_scale"]}, + 7: {"bond_len": [1.6, "NB VASP auto_length_scale"], + "min_bond_len": [1.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [9.2, "NB VASP auto_length_scale"]}, + 70: {"bond_len": [3.8, "NB VASP auto_length_scale"], + "min_bond_len": [3.3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [39, "NB VASP auto_length_scale"]}, + 72: {"bond_len": [3.2, "NB VASP auto_length_scale"], + "min_bond_len": [2.6, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [22, "NB VASP auto_length_scale"]}, + 73: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [18, "NB VASP auto_length_scale"]}, + 74: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [16, "NB VASP auto_length_scale"]}, + 75: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [15, "NB VASP auto_length_scale"]}, + 76: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.1, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [14, "NB VASP auto_length_scale"]}, + 77: {"bond_len": [2.7, "NB VASP auto_length_scale"], + "min_bond_len": [2.2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [15, "NB VASP auto_length_scale"]}, + 78: {"bond_len": [2.8, "NB VASP auto_length_scale"], + "min_bond_len": [2.3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [16, "NB VASP auto_length_scale"]}, + 79: {"bond_len": [2.9, "NB VASP auto_length_scale"], + "min_bond_len": [2.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [18, "NB VASP auto_length_scale"]}, + 8: {"bond_len": [1.7, "NB VASP auto_length_scale"], + "min_bond_len": [1.2, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [11, "NB VASP auto_length_scale"]}, + 80: {"bond_len": [3.3, "NB VASP auto_length_scale"], + "min_bond_len": [3, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [28, "NB VASP auto_length_scale"]}, + 81: {"bond_len": [3.4, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [31, "NB VASP auto_length_scale"]}, + 82: {"bond_len": [3.5, "NB VASP auto_length_scale"], + "min_bond_len": [2.9, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [31, "NB VASP auto_length_scale"]}, + 83: {"bond_len": [3.3, "NB VASP auto_length_scale"], + "min_bond_len": [2.7, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [34, "NB VASP auto_length_scale"]}, + 84: {"bond_len": [3.3, "NB VASP auto_length_scale"], + "min_bond_len": [2.8, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [36, "NB VASP auto_length_scale"]}, + 89: {"bond_len": [4, "NB VASP auto_length_scale"], + "min_bond_len": [3.5, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [46, "NB VASP auto_length_scale"]}, + 9: {"bond_len": [2.1, "NB VASP auto_length_scale"], + "min_bond_len": [1.4, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [15, "NB VASP auto_length_scale"]}, + 90: {"bond_len": [3.6, "NB VASP auto_length_scale"], + "min_bond_len": [2.7, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [32, "NB VASP auto_length_scale"]}, + 91: {"bond_len": [3.3, "NB VASP auto_length_scale"], + "min_bond_len": [2.4, "NB VASP auto_length_scale"], + "other links": {}, + "vol_per_atom": [25, "NB VASP auto_length_scale"]}} diff --git a/src/ACEpotentials.jl b/src/ACEpotentials.jl index e1872483..0d802d06 100644 --- a/src/ACEpotentials.jl +++ b/src/ACEpotentials.jl @@ -14,6 +14,8 @@ export ace_basis, smoothness_prior, ace_defaults, acemodel @reexport using ACEfit @reexport using ACEmd +include("defaults.jl") + include("atoms_data.jl") include("model.jl") include("export.jl") @@ -28,6 +30,8 @@ include("analysis/dataset_analysis.jl") include("experimental.jl") include("models/models.jl") +include("ace1_compat.jl") + include("outdated/fit.jl") include("outdated/data.jl") include("outdated/basis.jl") diff --git a/src/models/ace1_compat.jl b/src/ace1_compat.jl similarity index 75% rename from src/models/ace1_compat.jl rename to src/ace1_compat.jl index eebb7e9f..5bfbe7b3 100644 --- a/src/models/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -3,6 +3,13 @@ module ACE1compat +using NamedTupleTools, StaticArrays +import ACEpotentials: DefaultHypers, Models + +using ACEpotentials.Models: agnesi_transform, + SplineRnlrzzBasis, + ace_learnable_Rnlrzz + ace1_defaults() = deepcopy(_kw_defaults) const _kw_defaults = Dict(:elements => nothing, @@ -29,8 +36,6 @@ const _kw_defaults = Dict(:elements => nothing, :pair_envelope => (:r, 2), # :Eref => missing, - #temporary variable to specify whether using the variable cutoffs or not - :variable_cutoffs => false, ) const _kw_aliases = Dict( :N => :order, @@ -58,6 +63,14 @@ function _clean_args(kwargs) dargs[:pair_rcut] = dargs[:rcut] end + if haskey(dargs, :variable_cutoffs) + @warn("variable_cutoffs argument is ignored") + end + + if kwargs[:pure2b] || kwargs[:pure] + error("ACE1compat current does not support `pure2b` or `pure` options.") + end + return namedtuple(dargs) end @@ -145,13 +158,29 @@ end function _rin0cuts_rcut(zlist, cutoffs::Dict) function rin0cut(zi, zj) r0 = DefaultHypers.bond_len(zi, zj) - return (rin = 0.0, r0 = r0, rcut = cutoffs[zi, zj]) + rin, rcut = cutoffs[zi, zj] + return (rin = rin, r0 = r0, rcut = rcut) end NZ = length(zlist) return SMatrix{NZ, NZ}([ rin0cut(zi, zj) for zi in zlist, zj in zlist ]) end +function _ace1_rin0cuts(kwargs) + elements = _get_elements(kwargs) + r0 = _get_all_r0(kwargs) + rcut = _get_all_rcut(kwargs) + if rcut isa Number + cutoffs = Dict([ (s1, s2) => (0.0, rcut) for s1 in elements, s2 in elements]...) + else + cutoffs = Dict([ (s1, s2) => (0.0, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...) + end + # rcut = maximum(values(rcut)) # multitransform wants a single cutoff. + + # construct the rin0cut structures + rin0cuts = _rin0cuts_rcut(elements, cutoffs) +end + function _transform(kwargs; transform = kwargs[:transform]) elements = _get_elements(kwargs) @@ -161,20 +190,9 @@ function _transform(kwargs; transform = kwargs[:transform]) if length(transform) != 3 error("The ACE1 compatibility only supports (:agnesi, p, q) type transforms.") end - p = transform[2] q = transform[3] - r0 = _get_all_r0(kwargs) - rcut = _get_all_rcut(kwargs) - if rcut isa Number || ! kwargs[:variable_cutoffs] - cutoffs = nothing - else - cutoffs = Dict([ (s1, s2) => (0.0, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...) - end - # rcut = maximum(values(rcut)) # multitransform wants a single cutoff. - - # construct the rin0cut structures - rin0cuts = _rin0cuts_rcut(elements, cutoffs) + rin0cuts = _ace1_rin0cuts(kwargs) transforms = agnesi_transform.(rin0cuts, p, q) return transforms @@ -188,41 +206,73 @@ function _transform(kwargs; transform = kwargs[:transform]) error("Unable to determine transform from the arguments provided.") end -#= + +function _get_Rnl_spec(kwargs) + maxdeg = maximum(kwargs[:totaldegree]) + wL = kwargs[:wL] + lvl = Models.TotalDegree(1.0, 1/wL) + return Models.oneparticle_spec(lvl, maxdeg) +end + function _radial_basis(kwargs) rbasis = kwargs[:rbasis] + elements = _get_elements(kwargs) - if rbasis isa ACE1.ScalarBasis + if rbasis isa SplineRnlrzzBasis return rbasis elseif rbasis == :legendre - Deg, maxdeg, maxn = _get_degrees(kwargs) - cor_order = _get_order(kwargs) - envelope = kwargs[:envelope] - if envelope isa Tuple && envelope[1] == :x - pin = envelope[2] - pcut = envelope[3] - if (kwargs[:pure2b] || kwargs[:pure]) - maxn += (pin + pcut) * (cor_order-1) - end - else - error("Cannot construct the radial basis automatically without knowing the envelope.") - end trans_ace = _transform(kwargs) + rin0cuts = _ace1_rin0cuts(kwargs) + Rnl_spec = _get_Rnl_spec(kwargs) - Rn_basis = transformed_jacobi(maxn, trans_ace; pcut = pcut, pin = pin) - # println("pcut is", pcut, "pin is", pin, "trans_ace is", trans_ace) - # println(kwargs) - #Rn_basis = transformed_jacobi(maxn, trans_ace, kwargs[:rcut], kwargs[:rin];) - return Rn_basis + envelope = kwargs[:envelope] + # this is the default envelope + # envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) + # just check that it hasn't been changed + if envelope != (:x, 2, 2) + error("The ACE1 compatibility only supports (:x, 2, 2) type envelopes for the radial basis") + end + + # finally we need to specify a polynomial basis. ACE1 incorporates + # the envelope into the orthogonality. This corresponds to + # ∫ Pq(x) Pq(x) env(x)^2 dx = δ_{pq} + # which results in a Jacobi basis + pin = envelope[2] + pcut = envelope[3] + polys = (:jacobi, Float64(2*pin), Float64(2*pcut)) + + + # This is to be revisited if we re-introduce pure2b + # if envelope isa Tuple && envelope[1] == :x + # if (kwargs[:pure2b] || kwargs[:pure]) + # maxn += (pin + pcut) * (cor_order-1) + # end + # else + # error("Cannot construct the radial basis automatically without knowing the envelope.") + # end + # Rn_basis = transformed_jacobi(maxn, trans_ace; pcut = pcut, pin = pin) + + Rn_basis = ace_learnable_Rnlrzz(; spec = Rnl_spec, + maxq = maximum(b.n for b in Rnl_spec), + elements = elements, + rin0cuts = rin0cuts, + transforms = trans_ace, + polys = polys, + Winit = "linear") + + ps_Rn = Models.initialparameters(nothing, Rn_basis) + Rn_spl = Models.splinify(Rn_basis, ps_Rn) + return Rn_spl end error("Unable to determine the radial basis from the arguments provided.") end +#= function _pair_basis(kwargs) diff --git a/src/models/defaults.jl b/src/defaults.jl similarity index 88% rename from src/models/defaults.jl rename to src/defaults.jl index 6f932abe..c53cc98a 100644 --- a/src/models/defaults.jl +++ b/src/defaults.jl @@ -1,10 +1,11 @@ module DefaultHypers import YAML +using JuLIP: AtomicNumber, rnn # -------------- Bond-length heuristics -_lengthscales_path = joinpath(@__DIR__, "..", "..", "data", +_lengthscales_path = joinpath(@__DIR__, "..", "data", "length_scales_VASP_auto_length_scales.yaml") _lengthscales = YAML.load_file(_lengthscales_path) diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 622c0404..396583ce 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -24,7 +24,7 @@ end Base.length(basis::LearnableRnlrzzBasis) = length(basis.spec) -function initialparameters(rng::AbstractRNG, +function initialparameters(rng::Union{AbstractRNG, Nothing}, basis::LearnableRnlrzzBasis) NZ = _get_nz(basis) len_nl = length(basis) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 6876ce60..666a80ac 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -1,30 +1,5 @@ import Random -# -------------------------------------------------- -# different notions of "level" / total degree. - -abstract type AbstractLevel end -struct TotalDegree <: AbstractLevel - wn::Float64 - wl::Float64 -end - -TotalDegree() = TotalDegree(1.0, 0.66) - -(l::TotalDegree)(b::NamedTuple) = b.n/l.wn + b.l/l.wl -(l::TotalDegree)(bb::AbstractVector{<: NamedTuple}) = sum(l(b) for b in bb) - - -struct EuclideanDegree <: AbstractLevel - wn::Float64 - wl::Float64 -end - -EuclideanDegree() = EuclideanDegree(1.0, 0.66) - -(l::EuclideanDegree)(b::NamedTuple) = sqrt( (b.n/l.wn)^2 + (b.l/l.wl)^2 ) -(l::EuclideanDegree)(bb::AbstractVector{<: NamedTuple}) = sqrt( sum(l(b)^2 for b in bb) ) - # ------------------------------------------------------- # construction of Rnlrzz bases with lots of defaults @@ -79,6 +54,14 @@ function ace_learnable_Rnlrzz(; else error("unknown polynomial type : $polys") end + elseif polys isa Tuple + if polys[1] == :jacobi + α = polys[2] + β = polys[3] + polys = Polynomials4ML.jacobi_basis(maxq, α, β) + else + error("unknown polynomial type : $polys") + end end if transforms isa Tuple && transforms[1] == :agnesi @@ -186,3 +169,14 @@ end # ------------------------------------------------------- + + +function _default_rin0cuts(zlist; rinfactor = 0.0, rcutfactor = 2.5) + function rin0cut(zi, zj) + r0 = ACE1x.get_r0(zi, zj) + return (rin = r0 * rinfactor, r0 = r0, rcut = r0 * rcutfactor) + end + NZ = length(zlist) + return SMatrix{NZ, NZ}([ rin0cut(zi, zj) for zi in zlist, zj in zlist ]) +end + diff --git a/src/models/elements.jl b/src/models/elements.jl index 0ec4354b..3e4fe018 100644 --- a/src/models/elements.jl +++ b/src/models/elements.jl @@ -30,14 +30,6 @@ function _convert_zlist(zlist) return ntuple(i -> convert(Int, zlist[i]), length(zlist)) end -function _default_rin0cuts(zlist; rinfactor = 0.0, rcutfactor = 2.5) - function rin0cut(zi, zj) - r0 = ACE1x.get_r0(zi, zj) - return (rin = r0 * rinfactor, r0 = r0, rcut = r0 * rcutfactor) - end - NZ = length(zlist) - return SMatrix{NZ, NZ}([ rin0cut(zi, zj) for zi in zlist, zj in zlist ]) -end """ Takes an object and converts it to an `SMatrix{NZ, NZ}` via the following rules: @@ -59,16 +51,3 @@ function _make_smatrix(obj, NZ) return SMatrix{NZ, NZ}(fill(obj, (NZ, NZ))) end - -# a one-hot embedding for the z variable. -# function embed_z(ace, Rs, Zs) -# TF = eltype(eltype(Rs)) -# Ez = acquire!(ace.pool, :Ez, (length(Zs), length(ace.rbasis)), TF) -# fill!(Ez, 0) -# for (j, z) in enumerate(Zs) -# iz = _z2i(ace.rbasis, z) -# Ez[j, iz] = 1 -# end -# return Ez -# end - diff --git a/src/models/models.jl b/src/models/models.jl index 8d53dc78..fa534c92 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -18,6 +18,8 @@ import LuxCore: AbstractExplicitLayer, include("elements.jl") +include("smoothness_priors.jl") + include("radial_envelopes.jl") include("radial_transforms.jl") diff --git a/src/models/smoothness_priors.jl b/src/models/smoothness_priors.jl index 661935fc..e873c8c2 100644 --- a/src/models/smoothness_priors.jl +++ b/src/models/smoothness_priors.jl @@ -28,10 +28,18 @@ EuclideanDegree() = EuclideanDegree(1.0, 2/3) (l::EuclideanDegree)(bb::AbstractVector{<: NamedTuple}) = sqrt( sum(l(b)^2 for b in bb) ) -struct BasisSelector - order::Int - maxlevels::AbstractVector{<: Number} - level +# struct SparseBasisSelector +# order::Int +# maxlevels::AbstractVector{<: Number} +# level +# end + +function oneparticle_spec(level::Union{TotalDegree, EuclideanDegree}, maxlevel) + maxn1 = ceil(Int, maxlevel * level.wn) + maxl1 = ceil(Int, maxlevel * level.wl) + spec = [ (n = n, l = l) for n = 1:maxn1, l = 0:maxl1 + if level((n = n, l = l)) <= maxlevel ] + return sort(spec; by = x -> (x.l, x.n)) end # -------------------------------------------------- diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl index e36b6ac9..22fb9299 100644 --- a/test/test_ace1_compat.jl +++ b/test/test_ace1_compat.jl @@ -1,14 +1,82 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) # using TestEnv; TestEnv.activate(); ## -using Random, Test, ACEbase +using Plots +using Random, Test, ACEbase, LinearAlgebra using ACEbase.Testing: print_tf, println_slim using ACEpotentials M = ACEpotentials.Models +ACE1compat = ACEpotentials.ACE1compat ## + + +params = ( elements = [:Si,], + order = 3, + transform = (:agnesi, 2, 2), + totaldegree = 8, + pure = false, + pure2b = false, + pair_envelope = (:r, 1), + rcut = 5.5, + ) + + +model1 = acemodel(; params...) + + +## + +@info("check the transform construction") + +params_clean = ACE1compat._clean_args(params) +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = ACE1compat._radial_basis(params_clean) + +trans1 = rbasis1.trans.transforms[1] +trans2 = rbasis2.transforms[1] + +rr = params.rcut * rand(200) +t1 = ACE1.Transforms.transform.(Ref(trans1), rr) +t2 = trans2.(rr) +err_t1_t2 = maximum(abs.(t1 .- t2)) +println_slim(@test err_t1_t2 < 1e-12) + +# the envelope - check that the "choices" are the same + +println_slim(@test rbasis1.envelope isa ACE1.OrthPolys.OneEnvelope) +println_slim(@test rbasis1.J.pl == rbasis1.J.pr == 2 ) +println_slim(@test rbasis2.envelopes[1].p1 == rbasis2.envelopes[1].p2 == 2) + +## + +@info("check full radial basis construction") +@info(" This error can be a bit larger since the jacobi basis used in ACE1 is constructed from a discrete measure") +@info("The first test checks Rn vs Rn0") +z1 = AtomicNumber(:Si) +z2 = Int(z1) +rp = range(0.0, params.rcut, length=200) +R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z1, z1) for r in rp ]) +R2 = reduce(hcat, [ rbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rp]) +maxn = size(R1, 1) +scal = [ maximum(R1[n,:]) / maximum(R2[n,:]) for n = 1:maxn ] +err = norm(R1 - Diagonal(scal) * R2[1:maxn, :], Inf) +@show err +println_slim(@test err < 0.001) + +@info("The remaining checks are for Rn0 = Rnl") +for i_nl = 1:size(R2, 1) + n = rbasis2.spec[i_nl].n + print_tf(@test R2[i_nl, :] ≈ R2[n, :]) +end +println() + +## + + + From 3ee6e7ec40ff8364580a1162a501f1257974bd95 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 2 Aug 2024 07:55:32 -0700 Subject: [PATCH 095/112] ace1-compat matching pair basis --- src/ace1_compat.jl | 92 +++++++++++++++++------------------- src/models/ace_heuristics.jl | 8 ++++ test/test_ace1_compat.jl | 16 +++++++ 3 files changed, 67 insertions(+), 49 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 5bfbe7b3..985e8c1e 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -166,10 +166,9 @@ function _rin0cuts_rcut(zlist, cutoffs::Dict) end -function _ace1_rin0cuts(kwargs) +function _ace1_rin0cuts(kwargs; rcutkey = :rcut) elements = _get_elements(kwargs) - r0 = _get_all_r0(kwargs) - rcut = _get_all_rcut(kwargs) + rcut = _get_all_rcut(kwargs; _rcut = kwargs[rcutkey]) if rcut isa Number cutoffs = Dict([ (s1, s2) => (0.0, rcut) for s1 in elements, s2 in elements]...) else @@ -182,7 +181,8 @@ function _ace1_rin0cuts(kwargs) end -function _transform(kwargs; transform = kwargs[:transform]) +function _transform(kwargs; transform = kwargs[:transform], + rcutkey = :rcut) elements = _get_elements(kwargs) if transform isa Tuple @@ -192,7 +192,7 @@ function _transform(kwargs; transform = kwargs[:transform]) end p = transform[2] q = transform[3] - rin0cuts = _ace1_rin0cuts(kwargs) + rin0cuts = _ace1_rin0cuts(kwargs; rcutkey = rcutkey) transforms = agnesi_transform.(rin0cuts, p, q) return transforms @@ -207,8 +207,8 @@ function _transform(kwargs; transform = kwargs[:transform]) end -function _get_Rnl_spec(kwargs) - maxdeg = maximum(kwargs[:totaldegree]) +function _get_Rnl_spec(kwargs, + maxdeg = maximum(kwargs[:totaldegree]) ) wL = kwargs[:wL] lvl = Models.TotalDegree(1.0, 1/wL) return Models.oneparticle_spec(lvl, maxdeg) @@ -272,66 +272,60 @@ function _radial_basis(kwargs) end -#= - - function _pair_basis(kwargs) rbasis = kwargs[:pair_basis] elements = _get_elements(kwargs) - #elements has to be sorted becuase PolyPairBasis (see end of function) assumes sorted. - if kwargs[:variable_cutoffs] - elements = [chemical_symbol(z) for z in JuLIP.Potentials.ZList(elements, static=true).list] - end + rin0cuts = _ace1_rin0cuts(kwargs; rcutkey = :pair_rcut) - if rbasis isa ACE1.ScalarBasis - return rbasis + if rbasis == :legendre - elseif rbasis == :legendre + # SPECIFICATION if kwargs[:pair_degree] == :totaldegree - Deg, maxdeg, maxn = _get_degrees(kwargs) + maxn = maximum(kwargs[:totaldegree]) elseif kwargs[:pair_degree] isa Integer maxn = kwargs[:pair_degree] else error("Cannot determine `maxn` for pair basis from information provided.") end + pair_spec = [ (n = n, l = 0) for n in 1:maxn ] - allrcut = _get_all_rcut(kwargs; _rcut = kwargs[:pair_rcut]) - if allrcut isa Number - allrcut = Dict([(s1, s2) => allrcut for s1 in elements, s2 in elements]...) - end - - trans_pair = _transform(kwargs, transform = kwargs[:pair_transform]) - _s2i(s) = z2i(trans_pair.zlist, AtomicNumber(s)) - alltrans = Dict([(s1, s2) => trans_pair.transforms[_s2i(s1), _s2i(s2)].t - for s1 in elements, s2 in elements]...) - - allr0 = _get_all_r0(kwargs) - - function _r_basis(s1, s2, penv) - _env = ACE1.PolyEnvelope(penv, allr0[(s1, s2)], allrcut[(s1, s2)] ) - return transformed_jacobi_env(maxn, alltrans[(s1, s2)], _env, allrcut[(s1, s2)]) - end - - _x_basis(s1, s2, pin, pcut) = transformed_jacobi(maxn, alltrans[(s1, s2)], allrcut[(s1, s2)]; - pcut = pcut, pin = pin) + # TRANSFORM + trans_pair = _transform(kwargs, transform = kwargs[:pair_transform], + rcutkey = :pair_rcut) + # ENVELOPE + # here we use the same convention, so this is fine envelope = kwargs[:pair_envelope] - if envelope isa Tuple - if envelope[1] == :x - pin = envelope[2] - pcut = envelope[3] - rbases = [ _x_basis(s1, s2, pin, pcut) for s1 in elements, s2 in elements ] - elseif envelope[1] == :r - penv = envelope[2] - rbases = [ _r_basis(s1, s2, penv) for s1 in elements, s2 in elements ] - end - end + + # ------ Here it is getting weird? + # _s2i(s) = z2i(trans_pair.zlist, AtomicNumber(s)) + # alltrans = Dict([(s1, s2) => trans_pair.transforms[_s2i(s1), _s2i(s2)].t + # for s1 in elements, s2 in elements]...) + # allr0 = _get_all_r0(kwargs) + # function _r_basis(s1, s2, penv) + # _env = ACE1.PolyEnvelope(penv, allr0[(s1, s2)], allrcut[(s1, s2)] ) + # return transformed_jacobi_env(maxn, alltrans[(s1, s2)], _env, allrcut[(s1, s2)]) + # end + # _x_basis(s1, s2, pin, pcut) = transformed_jacobi(maxn, alltrans[(s1, s2)], allrcut[(s1, s2)]; + # pcut = pcut, pin = pin) + pair_basis = ace_learnable_Rnlrzz(; spec = pair_spec, + maxq = maxn, + elements = elements, + rin0cuts = rin0cuts, + transforms = trans_pair, + envelopes = envelope, + polys = :legendre, + Winit = "linear" ) + ps_pair = Models.initialparameters(nothing, pair_basis) + pair_spl = Models.splinify(pair_basis, ps_pair) + return pair_spl end - - return PolyPairBasis(rbases, elements) + + error("Cannot determine the pair basis from the arguments provided.") end +#= function mb_ace_basis(kwargs) elements = _get_elements(kwargs) diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 666a80ac..355bf8d7 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -75,6 +75,14 @@ function ace_learnable_Rnlrzz(; elseif envelopes == :poly1sr envelopes = [ PolyEnvelope1sR(rin0cuts[iz, jz].rcut, 1) for iz = 1:NZ, jz = 1:NZ ] + elseif envelopes isa Tuple && envelopes[1] == :x + @assert length(envelopes) == 2 + envelopes = PolyEnvelope2sX(-1.0, 1.0, envelopes[2], envelopes[3]) + elseif envelopes isa Tuple && envelopes[1] == :r + envelopes = [ PolyEnvelope1sR(rin0cuts[iz, jz].rcut, envelopes[2]) + for iz = 1:NZ, jz = 1:NZ ] + else + error("cannot read envelope : $envelopes") end if actual_maxn > length(polys) diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl index 22fb9299..20e36c47 100644 --- a/test/test_ace1_compat.jl +++ b/test/test_ace1_compat.jl @@ -78,5 +78,21 @@ println() ## +@info("Check the pair basis construction") +pairbasis1 = model1.basis.BB[1] +pairbasis2 = ACE1compat._pair_basis(params_clean) + +rr = range(0.001, params.rcut, length=200) +P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z1, z1) for r in rr ]) +P2 = reduce(hcat, [ pairbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rr]) +println_slim(@test size(P1) == size(P2)) + +nmax = size(P1, 1) +scal = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] +P2 = Diagonal(scal) * P2 + +err = norm( (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1), Inf) +@show err +println_slim(@test err < 0.01) From 58aa8c0679c492b9378ce19933812dd17aba4b42 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 2 Aug 2024 09:10:26 -0700 Subject: [PATCH 096/112] ace1compat - mb matching --- src/ace1_compat.jl | 118 ++++++++++++++++++----------------- src/models/ace_heuristics.jl | 11 ++-- test/test_ace1_compat.jl | 55 +++++++++++++--- 3 files changed, 113 insertions(+), 71 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 985e8c1e..42e6a10f 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -99,12 +99,8 @@ function _get_degrees(kwargs) end wL = kwargs[:wL] - basis_selector = BasisSelector(cor_order, maxlevels, - TotalDegree(1.0, wL)) - maxn = maximum(maxlevels) - # return basis_selector, maxdeg, maxn - return basis_selector + return Models.TotalDegree(1.0, 1/wL), maxlevels end error("Cannot determine total degree of ACE basis from the arguments provided.") @@ -297,17 +293,6 @@ function _pair_basis(kwargs) # here we use the same convention, so this is fine envelope = kwargs[:pair_envelope] - # ------ Here it is getting weird? - # _s2i(s) = z2i(trans_pair.zlist, AtomicNumber(s)) - # alltrans = Dict([(s1, s2) => trans_pair.transforms[_s2i(s1), _s2i(s2)].t - # for s1 in elements, s2 in elements]...) - # allr0 = _get_all_r0(kwargs) - # function _r_basis(s1, s2, penv) - # _env = ACE1.PolyEnvelope(penv, allr0[(s1, s2)], allrcut[(s1, s2)] ) - # return transformed_jacobi_env(maxn, alltrans[(s1, s2)], _env, allrcut[(s1, s2)]) - # end - # _x_basis(s1, s2, pin, pcut) = transformed_jacobi(maxn, alltrans[(s1, s2)], allrcut[(s1, s2)]; - # pcut = pcut, pin = pin) pair_basis = ace_learnable_Rnlrzz(; spec = pair_spec, maxq = maxn, elements = elements, @@ -325,58 +310,75 @@ function _pair_basis(kwargs) end -#= +function ace1_model(; kwargs...) + + kwargs = _clean_args(kwargs) -function mb_ace_basis(kwargs) elements = _get_elements(kwargs) cor_order = _get_order(kwargs) - Deg, maxdeg, maxn = _get_degrees(kwargs) rbasis = _radial_basis(kwargs) - pure2b = kwargs[:pure2b] + pairbasis = _pair_basis(kwargs) + lvl, maxlvl = _get_degrees(kwargs) + if all(maxlvl .== maximum(maxlvl)) + maxlvl = maximum(maxlvl) + else + error("ACE1-compat only supports a single-number totaldegree at the moment.") + end - if pure2b && kwargs[:pure] + # if pure2b && kwargs[:pure] # error("Cannot use both `pure2b` and `pure` options.") - @info("Option `pure = true` overrides `pure2b=true`") - pure2b = false + # @info("Option `pure = true` overrides `pure2b=true`") + # pure2b = false + # end + + if kwargs[:pure2b] || kwargs[:pure] + error("ACE1compat does not yet support the `pure2b` or `pure` options.") end - if pure2b - rpibasis = Pure2b.pure2b_basis(species = AtomicNumber.(elements), - Rn=rbasis, - D=Deg, - maxdeg=maxdeg, - order=cor_order, - delete2b = kwargs[:delete2b]) - elseif kwargs[:pure] - dirtybasis = ACE1.ace_basis(species = AtomicNumber.(elements), - rbasis=rbasis, - D=Deg, - maxdeg=maxdeg, - N = cor_order, ) - _rem = kwargs[:delete2b] ? 1 : 0 - # remove all zero-basis functions that we might have accidentally created so that we purify less extra basis - dirtybasis = ACE1.RPI.remove_zeros(dirtybasis) - # and finally cleanup the rest of the basis - dirtybasis = ACE1._cleanup(dirtybasis) - # finally purify - rpibasis = ACE1x.Purify.pureRPIBasis(dirtybasis; remove = _rem) - else - rpibasis = ACE1.ace_basis(species = AtomicNumber.(elements), - rbasis=rbasis, - D=Deg, - maxdeg=maxdeg, - N = cor_order, ) + + # if pure2b + # rpibasis = Pure2b.pure2b_basis(species = AtomicNumber.(elements), + # Rn=rbasis, + # D=Deg, + # maxdeg=maxdeg, + # order=cor_order, + # delete2b = kwargs[:delete2b]) + # elseif kwargs[:pure] + # dirtybasis = ACE1.ace_basis(species = AtomicNumber.(elements), + # rbasis=rbasis, + # D=Deg, + # maxdeg=maxdeg, + # N = cor_order, ) + # _rem = kwargs[:delete2b] ? 1 : 0 + # # remove all zero-basis functions that we might have accidentally created so that we purify less extra basis + # dirtybasis = ACE1.RPI.remove_zeros(dirtybasis) + # # and finally cleanup the rest of the basis + # dirtybasis = ACE1._cleanup(dirtybasis) + # # finally purify + # rpibasis = ACE1x.Purify.pureRPIBasis(dirtybasis; remove = _rem) + # else + + if ismissing(kwargs[:Eref]) + E0s = nothing + else + E0s = kwargs[:Eref] end - return rpibasis + model = Models.ace_model(; elements=elements, + order = cor_order, + Ytype = :spherical, + E0s = E0s, + rbasis = rbasis, + pair_basis = pairbasis, + rin0cuts = rbasis.rin0cuts, + level = lvl, + max_level = maxlvl, + init_WB = :zeros,) + + return model end -function ace_basis(; kwargs...) - kwargs = _clean_args(kwargs) - rpiB = mb_ace_basis(kwargs) - pairB = _pair_basis(kwargs) - return JuLIP.MLIPs.IPSuperBasis([pairB, rpiB]); -end -=# -end \ No newline at end of file +end + + diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 355bf8d7..5359e568 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -156,15 +156,16 @@ function ace_model(; elements = nothing, rin0cuts = rbasis.rin0cuts, transforms = pair_transform, envelopes = :poly1sr ) - end - pair_basis.meta["Winit"] = init_Wpair + pair_basis.meta["Winit"] = init_Wpair - if !pair_learnable - ps_pair = initialparameters(rng, pair_basis) - pair_basis = splinify(pair_basis, ps_pair) + if !pair_learnable + ps_pair = initialparameters(rng, pair_basis) + pair_basis = splinify(pair_basis, ps_pair) + end end + AA_spec = sparse_AA_spec(; order = order, r_spec = rbasis.spec, level = level, max_level = max_level) diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl index 20e36c47..bd4c1366 100644 --- a/test/test_ace1_compat.jl +++ b/test/test_ace1_compat.jl @@ -5,17 +5,16 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) ## using Plots -using Random, Test, ACEbase, LinearAlgebra +using Random, Test, ACEbase, LinearAlgebra, Lux using ACEbase.Testing: print_tf, println_slim using ACEpotentials M = ACEpotentials.Models ACE1compat = ACEpotentials.ACE1compat +rng = Random.MersenneTwister(1234) ## - - params = ( elements = [:Si,], order = 3, transform = (:agnesi, 2, 2), @@ -28,15 +27,20 @@ params = ( elements = [:Si,], model1 = acemodel(; params...) +model2 = ACE1compat.ace1_model(; params...) +ps, st = Lux.setup(rng, model2) ## @info("check the transform construction") -params_clean = ACE1compat._clean_args(params) +# params_clean = ACE1compat._clean_args(params) +# rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +# rbasis2 = ACE1compat._radial_basis(params_clean) + rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -rbasis2 = ACE1compat._radial_basis(params_clean) +rbasis2 = model2.rbasis trans1 = rbasis1.trans.transforms[1] trans2 = rbasis2.transforms[1] @@ -80,7 +84,8 @@ println() @info("Check the pair basis construction") pairbasis1 = model1.basis.BB[1] -pairbasis2 = ACE1compat._pair_basis(params_clean) +# pairbasis2 = ACE1compat._pair_basis(params_clean) +pairbasis2 = model2.pairbasis rr = range(0.001, params.rcut, length=200) P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z1, z1) for r in rr ]) @@ -88,11 +93,45 @@ P2 = reduce(hcat, [ pairbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in r println_slim(@test size(P1) == size(P2)) nmax = size(P1, 1) -scal = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] -P2 = Diagonal(scal) * P2 +scal_pair = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] +P2 = Diagonal(scal_pair) * P2 +scal_err = -(extrema(abs.(scal_pair))...) +@show scal_err +println_slim(@test scal_err < 0.01) err = norm( (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1), Inf) @show err println_slim(@test err < 0.01) +## + +@info("Check the bases span the same space") +@info(" check spec matches") +_spec1 = ACE1.get_nl(model1.basis.BB[2]) +spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] +spec2 = M.get_nnll_spec(model2.tensor) +println_slim(@test sort(sort.(spec1)) == sort(sort.(spec2))) + +Nenv = 300 +XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] +XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] + +_evaluate(basis::JuLIP.MLIPs.IPSuperBasis, Rs, Zs, z0) = + reduce(vcat, [ACE1.evaluate(B, Rs, Zs, z0) for B in basis.BB]) + +B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) +B2 = reduce(hcat, [M.evaluate_basis(model2, x..., ps, st) for x in XX2]) + +Nmb = length(spec1) +B1_mb = B1[end-Nmb+1:end, :] +B2_mb = B2[1:Nmb, :] + +# Compute linear transform between bases to show match +C_mb = B1_mb' \ B2_mb' +@show norm(B2_mb - C_mb' * B1_mb, Inf) +B1_pair = B1[1:end-Nmb, :] +B2_pair = B2[Nmb+1:end, :] +C_pair = B1_pair' \ B2_pair' +@show norm(B2_pair - C_pair' * B1_pair, Inf) +norm(B2_pair - Diagonal(scal_pair) \ B1_pair) \ No newline at end of file From 95cfc2463d813a22c3972d45b09578773abe16a8 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 2 Aug 2024 09:20:56 -0700 Subject: [PATCH 097/112] full match --- src/models/utils.jl | 8 +++++++- test/test_ace1_compat.jl | 37 +++++++++++++++++-------------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/models/utils.jl b/src/models/utils.jl index e33cf039..6d55c6e7 100644 --- a/src/models/utils.jl +++ b/src/models/utils.jl @@ -95,7 +95,8 @@ import ACE1 rand_atenv(model::ACEModel, Nat) = rand_atenv(model.rbasis, Nat) -function rand_atenv(rbasis::Union{LearnableRnlrzzBasis, SplineRnlrzzBasis}, Nat) +function rand_atenv(rbasis::Union{LearnableRnlrzzBasis, SplineRnlrzzBasis}, Nat; + rin_fact = 0.9) z0 = rand(rbasis._i2z) zs = rand(rbasis._i2z, Nat) @@ -106,6 +107,11 @@ function rand_atenv(rbasis::Union{LearnableRnlrzzBasis, SplineRnlrzzBasis}, Nat) x = 2 * rand() - 1 t_ij = rbasis.transforms[iz0, izj] r_ij = inv_transform(t_ij, x) + r0_ij = rbasis.rin0cuts[iz0, izj].r0 + r_in = rin_fact * r0_ij + if r_ij < r_in + r_ij = r_in + rand() * (r0_ij - r_in) + end push!(rs, r_ij) end Rs = [ r * ACE1.Random.rand_sphere() for r in rs ] diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl index bd4c1366..c2944a99 100644 --- a/test/test_ace1_compat.jl +++ b/test/test_ace1_compat.jl @@ -27,18 +27,14 @@ params = ( elements = [:Si,], model1 = acemodel(; params...) + model2 = ACE1compat.ace1_model(; params...) ps, st = Lux.setup(rng, model2) - ## @info("check the transform construction") -# params_clean = ACE1compat._clean_args(params) -# rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -# rbasis2 = ACE1compat._radial_basis(params_clean) - rbasis1 = model1.basis.BB[2].pibasis.basis1p.J rbasis2 = model2.rbasis @@ -52,7 +48,6 @@ err_t1_t2 = maximum(abs.(t1 .- t2)) println_slim(@test err_t1_t2 < 1e-12) # the envelope - check that the "choices" are the same - println_slim(@test rbasis1.envelope isa ACE1.OrthPolys.OneEnvelope) println_slim(@test rbasis1.J.pl == rbasis1.J.pr == 2 ) println_slim(@test rbasis2.envelopes[1].p1 == rbasis2.envelopes[1].p2 == 2) @@ -84,7 +79,6 @@ println() @info("Check the pair basis construction") pairbasis1 = model1.basis.BB[1] -# pairbasis2 = ACE1compat._pair_basis(params_clean) pairbasis2 = model2.pairbasis rr = range(0.001, params.rcut, length=200) @@ -121,17 +115,20 @@ _evaluate(basis::JuLIP.MLIPs.IPSuperBasis, Rs, Zs, z0) = B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) B2 = reduce(hcat, [M.evaluate_basis(model2, x..., ps, st) for x in XX2]) - -Nmb = length(spec1) -B1_mb = B1[end-Nmb+1:end, :] -B2_mb = B2[1:Nmb, :] - # Compute linear transform between bases to show match -C_mb = B1_mb' \ B2_mb' -@show norm(B2_mb - C_mb' * B1_mb, Inf) - -B1_pair = B1[1:end-Nmb, :] -B2_pair = B2[Nmb+1:end, :] -C_pair = B1_pair' \ B2_pair' -@show norm(B2_pair - C_pair' * B1_pair, Inf) -norm(B2_pair - Diagonal(scal_pair) \ B1_pair) \ No newline at end of file +C = B1' \ B2' +@show norm(B2 - C' * B1, Inf) +println_slim(@test norm(B2 - C' * B1, Inf) < 1e-3) + +# Nmb = length(spec1) +# B1_mb = B1[end-Nmb+1:end, :] +# B2_mb = B2[1:Nmb, :] +# +# C_mb = B1_mb' \ B2_mb' +# @show norm(B2_mb - C_mb' * B1_mb, Inf) + +# B1_pair = B1[1:end-Nmb, :] +# B2_pair = B2[Nmb+1:end, :] +# C_pair = B1_pair' \ B2_pair' +# @show norm(B2_pair - C_pair' * B1_pair, Inf) + From b49757c6bbca3412be3859773cb363b239fd87ca Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 2 Aug 2024 11:08:30 -0700 Subject: [PATCH 098/112] ACE1-compat E0s and sitepot --- src/ace1_compat.jl | 7 ++++--- src/models/calculators.jl | 8 +++++++- test/test_ace1_compat.jl | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 42e6a10f..ba7b89be 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -3,7 +3,7 @@ module ACE1compat -using NamedTupleTools, StaticArrays +using NamedTupleTools, StaticArrays, Unitful import ACEpotentials: DefaultHypers, Models using ACEpotentials.Models: agnesi_transform, @@ -358,10 +358,11 @@ function ace1_model(; kwargs...) # rpibasis = ACE1x.Purify.pureRPIBasis(dirtybasis; remove = _rem) # else - if ismissing(kwargs[:Eref]) + Eref = kwargs[:Eref] + if ismissing(Eref) E0s = nothing else - E0s = kwargs[:Eref] + E0s = Dict([ key => val * u"eV" for (key, val) in Eref]...) end model = Models.ace_model(; elements=elements, diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 14a7ad62..fe3fa9be 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -17,7 +17,7 @@ using Folds, ChunkSplitters, Unitful, NeighbourLists, import ChainRulesCore: rrule, NoTangent, ZeroTangent -struct ACEPotential{MOD} <: SitePotential +mutable struct ACEPotential{MOD} <: SitePotential model::MOD ps st @@ -42,6 +42,12 @@ set_psst!(V::ACEPotential, ps, st) = (V.ps = ps; V.st = st; V) splinify(V::ACEPotential) = splinify(V, V.ps) splinify(V::ACEPotential, ps) = ACEPotential(splinify(V.model, ps), nothing, nothing) +function set_parameters!(V::ACEPotential, θ::AbstractVector{<: Number}) + ps_vec, _restruct = destructure(V.ps) + ps = _restruct(θ) + return set_parameters!(V, ps) +end + # --------------------------------------------------------------- # EmpiricalPotentials / SitePotential based implementation # diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl index c2944a99..b0c281df 100644 --- a/test/test_ace1_compat.jl +++ b/test/test_ace1_compat.jl @@ -23,6 +23,7 @@ params = ( elements = [:Si,], pure2b = false, pair_envelope = (:r, 1), rcut = 5.5, + Eref = [:Si => -1.234 ] ) @@ -106,7 +107,7 @@ spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] spec2 = M.get_nnll_spec(model2.tensor) println_slim(@test sort(sort.(spec1)) == sort(sort.(spec2))) -Nenv = 300 +Nenv = 1000 XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] @@ -132,3 +133,35 @@ println_slim(@test norm(B2 - C' * B1, Inf) < 1e-3) # C_pair = B1_pair' \ B2_pair' # @show norm(B2_pair - C_pair' * B1_pair, Inf) +## + +@info("Check E0s") +E0s1 = model1.Vref.E0[:Si] +E0s2 = model2.E0s[1] +println_slim(@test E0s1 ≈ E0s2) + +## + +@info("Set some random parameteres and check site energies") + +lenB = size(B1, 1) +θ2 = randn(lenB) ./ (1:lenB).^2 +θ1 = C * θ2 +ACE1x._set_params!(model1, θ1) + +calc2 = M.ACEPotential(model2, ps, st) +M.set_parameters!(calc2, θ2) + + +JuLIP.evaluate(V::JuLIP.OneBody, Rs, Zs, z0) = + V.E0[JuLIP.Chemistry.chemical_symbol(z0)] + +_evaluate(pot::JuLIP.MLIPs.SumIP, Rs, Zs, z0) = + sum(JuLIP.evaluate(V, Rs, Zs, z0) for V in pot.components) + +V1 = [ _evaluate(model1.potential, x...) for x in XX1 ] +V2 = [ M.evaluate(calc2.model, x..., calc2.ps, calc2.st) for x in XX2 ] + +err = norm(V1 - V2, Inf) +@show err +println_slim(@test err < 1e-4) From e09d5582c228bb993770a024178746728990a398 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 2 Aug 2024 21:06:46 -0700 Subject: [PATCH 099/112] ace1-compat, allow ace2 to be a superset --- src/ace1_compat.jl | 5 ----- src/models/ace_heuristics.jl | 3 ++- src/models/utils.jl | 10 +++++++++- test/test_ace1_compat.jl | 36 +++++++++++++++++++++--------------- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index ba7b89be..82a2c806 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -319,11 +319,6 @@ function ace1_model(; kwargs...) rbasis = _radial_basis(kwargs) pairbasis = _pair_basis(kwargs) lvl, maxlvl = _get_degrees(kwargs) - if all(maxlvl .== maximum(maxlvl)) - maxlvl = maximum(maxlvl) - else - error("ACE1-compat only supports a single-number totaldegree at the moment.") - end # if pure2b && kwargs[:pure] # error("Cannot use both `pure2b` and `pure` options.") diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 5359e568..5be3e39c 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -33,8 +33,9 @@ function ace_learnable_Rnlrzz(; NZ = length(zlist) if spec == nothing + _max_lvl = maximum(max_level) spec = [ (n = n, l = l) for n = 1:maxn, l = 0:maxl - if level((n = n, l = l)) <= max_level ] + if level((n = n, l = l)) <= _max_lvl ] end # now the actual maxn is the maximum n in the spec diff --git a/src/models/utils.jl b/src/models/utils.jl index 6d55c6e7..6adadfa3 100644 --- a/src/models/utils.jl +++ b/src/models/utils.jl @@ -30,6 +30,14 @@ function sparse_AA_spec(; order = nothing, r_spec = nothing, max_level = nothing, level = nothing, ) + + # convert the max_level to a list + if max_level isa Number + max_levels = fill(max_level, order) + else + max_levels = max_level + end + # compute the r levels r_level = [ level(b) for b in r_spec ] @@ -51,7 +59,7 @@ function sparse_AA_spec(; order = nothing, # generate the AA basis spec from the A basis spec tup2b = vv -> [ A_spec[v] for v in vv[vv .> 0] ] - admissible = bb -> (length(bb) == 0) || level(bb) <= max_level + admissible = bb -> (length(bb) == 0) || (level(bb) <= max_levels[length(bb)]) filter_ = EquivariantModels.RPE_filter_real(0) AA_spec = EquivariantModels.gensparse(; diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl index b0c281df..25d132bd 100644 --- a/test/test_ace1_compat.jl +++ b/test/test_ace1_compat.jl @@ -18,7 +18,7 @@ rng = Random.MersenneTwister(1234) params = ( elements = [:Si,], order = 3, transform = (:agnesi, 2, 2), - totaldegree = 8, + totaldegree = [12, 10, 8], pure = false, pure2b = false, pair_envelope = (:r, 1), @@ -29,6 +29,7 @@ params = ( elements = [:Si,], model1 = acemodel(; params...) +params2 = (; params..., totaldegree = params.totaldegree .+ 0.1) model2 = ACE1compat.ace1_model(; params...) ps, st = Lux.setup(rng, model2) @@ -65,9 +66,10 @@ R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z1, z1) for r in rp ]) R2 = reduce(hcat, [ rbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rp]) maxn = size(R1, 1) scal = [ maximum(R1[n,:]) / maximum(R2[n,:]) for n = 1:maxn ] -err = norm(R1 - Diagonal(scal) * R2[1:maxn, :], Inf) -@show err -println_slim(@test err < 0.001) +errs = maximum(abs, R1 - Diagonal(scal) * R2[1:maxn, :]; dims=2) +normalizederr = norm(errs ./ (1:length(errs)).^3, Inf) +@show normalizederr +println_slim(@test normalizederr < 1e-4) @info("The remaining checks are for Rn0 = Rnl") for i_nl = 1:size(R2, 1) @@ -90,13 +92,14 @@ println_slim(@test size(P1) == size(P2)) nmax = size(P1, 1) scal_pair = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] P2 = Diagonal(scal_pair) * P2 -scal_err = -(extrema(abs.(scal_pair))...) +scal_err = abs( -(extrema(abs.(scal_pair))...) ) @show scal_err println_slim(@test scal_err < 0.01) -err = norm( (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1), Inf) -@show err -println_slim(@test err < 0.01) +err = maximum(abs, (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1); dims=2) +normalizederr = norm(err ./ (1:length(err)).^3, Inf) +@show normalizederr +println_slim(@test normalizederr < 1e-4) ## @@ -105,7 +108,7 @@ println_slim(@test err < 0.01) _spec1 = ACE1.get_nl(model1.basis.BB[2]) spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] spec2 = M.get_nnll_spec(model2.tensor) -println_slim(@test sort(sort.(spec1)) == sort(sort.(spec2))) +println_slim(@test issubset(sort.(spec1), sort.(spec2))) Nenv = 1000 XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] @@ -117,9 +120,11 @@ _evaluate(basis::JuLIP.MLIPs.IPSuperBasis, Rs, Zs, z0) = B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) B2 = reduce(hcat, [M.evaluate_basis(model2, x..., ps, st) for x in XX2]) # Compute linear transform between bases to show match -C = B1' \ B2' -@show norm(B2 - C' * B1, Inf) -println_slim(@test norm(B2 - C' * B1, Inf) < 1e-3) +# We want full-rank C such that C * B2 = B1 +C = B2' \ B1' +basiserr = norm(B1 - C' * B2, Inf) +@show basiserr +println_slim(@test basiserr < 1e-3) # Nmb = length(spec1) # B1_mb = B1[end-Nmb+1:end, :] @@ -144,9 +149,10 @@ println_slim(@test E0s1 ≈ E0s2) @info("Set some random parameteres and check site energies") -lenB = size(B1, 1) -θ2 = randn(lenB) ./ (1:lenB).^2 -θ1 = C * θ2 +lenB1 = size(B1, 1) +θ1 = randn(lenB1) ./ (1:lenB1).^2 +θ2 = C * θ1 + ACE1x._set_params!(model1, θ1) calc2 = M.ACEPotential(model2, ps, st) From b57de27ad928cd6758ac1092819a724ec97a6682 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Fri, 2 Aug 2024 21:40:02 -0700 Subject: [PATCH 100/112] cleanup compat tests --- test/ace1/ace1_testutils.jl | 194 ++++++++++++++++++++++++++++++++++ test/ace1/test_ace1_compat.jl | 31 ++++++ test/test_ace1_compat.jl | 173 ------------------------------ 3 files changed, 225 insertions(+), 173 deletions(-) create mode 100644 test/ace1/ace1_testutils.jl create mode 100644 test/ace1/test_ace1_compat.jl delete mode 100644 test/test_ace1_compat.jl diff --git a/test/ace1/ace1_testutils.jl b/test/ace1/ace1_testutils.jl new file mode 100644 index 00000000..2939154b --- /dev/null +++ b/test/ace1/ace1_testutils.jl @@ -0,0 +1,194 @@ + +module ACE1_TestUtils + +using Random, Test, ACEbase, LinearAlgebra, Lux +using ACEbase.Testing: print_tf, println_slim +using ACE1, ACE1x, JuLIP + +using ACEpotentials +M = ACEpotentials.Models +ACE1compat = ACEpotentials.ACE1compat +rng = Random.MersenneTwister(1234) + +function JuLIP.cutoff(model::ACE1x.ACE1Model) + return maximum(JuLIP.cutoff.(model.basis.BB)) +end + +_evaluate(basis::JuLIP.MLIPs.IPSuperBasis, Rs, Zs, z0) = + reduce(vcat, [ACE1.evaluate(B, Rs, Zs, z0) for B in basis.BB]) + + +JuLIP.evaluate(V::JuLIP.OneBody, Rs, Zs, z0) = + V.E0[JuLIP.Chemistry.chemical_symbol(z0)] + +_evaluate(pot::JuLIP.MLIPs.SumIP, Rs, Zs, z0) = + sum(JuLIP.evaluate(V, Rs, Zs, z0) for V in pot.components) + +function get_rbasis(model::ACE1x.ACE1Model) + return model.basis.BB[2].pibasis.basis1p.J +end + +function get_rbasis(model::M.ACEModel) + return model.rbasis +end + +function check_rbasis_transforms(model1, model2) + @info("check the radial basis transforms and envelopes") + + NZ = M._get_nz(model2); @assert NZ == 1 + + rbasis1 = get_rbasis(model1) + rbasis2 = get_rbasis(model2) + + trans1 = rbasis1.trans.transforms[1] + trans2 = rbasis2.transforms[1] + + rcut = JuLIP.cutoff(model1) + rr = rcut * rand(200) + t1 = ACE1.Transforms.transform.(Ref(trans1), rr) + t2 = trans2.(rr) + err_t1_t2 = maximum(abs.(t1 .- t2)) + println_slim(@test err_t1_t2 < 1e-12) + + # the envelope - check that the "choices" are the same + println_slim(@test rbasis1.envelope isa ACE1.OrthPolys.OneEnvelope) + println_slim(@test rbasis1.J.pl == rbasis1.J.pr == 2 ) + println_slim(@test rbasis2.envelopes[1].p1 == rbasis2.envelopes[1].p2 == 2) + nothing +end + + +function check_rbasis(model1, model2) + @info("check radial basis") + @info(" error can be a bit larger since the jacobi basis used in ACE1 is constructed from a discrete measure") + @info("The first test checks Rn vs Rn0") + rbasis1 = get_rbasis(model1) + rbasis2 = get_rbasis(model2) + rcut = JuLIP.cutoff(model1) + rp = range(0.0, rcut, length=200) + R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z1, z1) for r in rp ]) + R2 = reduce(hcat, [ rbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rp]) + maxn = size(R1, 1) + scal = [ maximum(R1[n,:]) / maximum(R2[n,:]) for n = 1:maxn ] + errs = maximum(abs, R1 - Diagonal(scal) * R2[1:maxn, :]; dims=2) + normalizederr = norm(errs ./ (1:length(errs)).^3, Inf) + @show normalizederr + println_slim(@test normalizederr < 1e-4) + + @info("The remaining checks are for Rn0 = Rnl") + for i_nl = 1:size(R2, 1) + n = rbasis2.spec[i_nl].n + print_tf(@test R2[i_nl, :] ≈ R2[n, :]) + end + println() +end + + +function check_pairbasis(model1, model2) + @info("Check the pair basis") + pairbasis1 = model1.basis.BB[1] + pairbasis2 = model2.pairbasis + rcut = JuLIP.cutoff(model1) + z1 = AtomicNumber(:Si) + z2 = Int(z1) + rr = range(0.001, rcut, length=200) + P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z1, z1) for r in rr ]) + P2 = reduce(hcat, [ pairbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rr]) + println_slim(@test size(P1) == size(P2)) + + nmax = size(P1, 1) + scal_pair = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] + P2 = Diagonal(scal_pair) * P2 + scal_err = abs( -(extrema(abs.(scal_pair))...) ) + @show scal_err + println_slim(@test scal_err < 0.01) + + err = maximum(abs, (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1); dims=2) + normalizederr = norm(err ./ (1:length(err)).^3, Inf) + @show normalizederr + println_slim(@test normalizederr < 1e-4) +end + + +function check_basis(model1, model2; Nenv = :auto) + ps, st = Lux.setup(rng, model2) + + @info("Check the bases span the same space") + NZ = M._get_nz(model2); @assert NZ == 1 + if NZ == 1 + @info(" NZ == 1 >>> check spec matches") + _spec1 = ACE1.get_nl(model1.basis.BB[2]) + spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] + spec2 = M.get_nnll_spec(model2.tensor) + println_slim(@test issubset(sort.(spec1), sort.(spec2))) + end + + lenB1 = length(model1.basis) + if Nenv == :auto + Nenv = lenB1 * 10 + @show Nenv + end + + XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] + XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] + + B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) + B2 = reduce(hcat, [ M.evaluate_basis(model2, x..., ps, st) for x in XX2]) + @info("Compute linear transform between bases to show match") + # We want full-rank C such that C * B2 = B1 + # (note this allows B2 > B1) + C = B2' \ B1' + basiserr = norm(B1 - C' * B2, Inf) + @show basiserr + println_slim(@test basiserr < 1e-3) + + # # some more fine-grained checks for debugging + # Nmb = length(spec1) + # B1_mb = B1[end-Nmb+1:end, :] + # B2_mb = B2[1:Nmb, :] + # + # C_mb = B1_mb' \ B2_mb' + # @show norm(B2_mb - C_mb' * B1_mb, Inf) + + # B1_pair = B1[1:end-Nmb, :] + # B2_pair = B2[Nmb+1:end, :] + # C_pair = B1_pair' \ B2_pair' + # @show norm(B2_pair - C_pair' * B1_pair, Inf) + + @info("Set some random parameters and check site energies") + θ1 = randn(lenB1) ./ (1:lenB1).^2 + θ2 = C * θ1 + + ACE1x._set_params!(model1, θ1) + + calc2 = M.ACEPotential(model2, ps, st) + M.set_parameters!(calc2, θ2) + + V1 = [ _evaluate(model1.potential, x...) for x in XX1 ] + V2 = [ M.evaluate(calc2.model, x..., calc2.ps, calc2.st) for x in XX2 ] + + err = norm(V1 - V2, Inf) + @show err + println_slim(@test err < 1e-4) + nothing +end + + +function check_compat(params; deginc = 0.1) + model1 = acemodel(; params...) + params2 = (; params..., totaldegree = params.totaldegree .+ deginc) + model2 = ACE1compat.ace1_model(; params...) + + NZ = length(params.elements) + if NZ == 1 + @info("NZ == 1 >>> can do some extra checks") + check_rbasis_transforms(model1, model2) + check_rbasis(model1, model2) + check_pairbasis(model1, model2) + end + + check_basis(model1, model2) + nothing +end + +end \ No newline at end of file diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl new file mode 100644 index 00000000..02f21831 --- /dev/null +++ b/test/ace1/test_ace1_compat.jl @@ -0,0 +1,31 @@ + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) +# using TestEnv; TestEnv.activate(); + +## + +using Plots +using Random, Test, ACEbase, LinearAlgebra, Lux +using ACEbase.Testing: print_tf, println_slim + +using ACEpotentials +M = ACEpotentials.Models +ACE1compat = ACEpotentials.ACE1compat +rng = Random.MersenneTwister(1234) + +## + +params = ( elements = [:Si,], + order = 3, + transform = (:agnesi, 2, 2), + totaldegree = [12, 10, 8], + pure = false, + pure2b = false, + pair_envelope = (:r, 1), + rcut = 5.5, + Eref = [:Si => -1.234 ] + ) + +## + +ACE1_TestUtils.check_compat(params) diff --git a/test/test_ace1_compat.jl b/test/test_ace1_compat.jl deleted file mode 100644 index 25d132bd..00000000 --- a/test/test_ace1_compat.jl +++ /dev/null @@ -1,173 +0,0 @@ - -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) -# using TestEnv; TestEnv.activate(); - -## - -using Plots -using Random, Test, ACEbase, LinearAlgebra, Lux -using ACEbase.Testing: print_tf, println_slim - -using ACEpotentials -M = ACEpotentials.Models -ACE1compat = ACEpotentials.ACE1compat -rng = Random.MersenneTwister(1234) - -## - -params = ( elements = [:Si,], - order = 3, - transform = (:agnesi, 2, 2), - totaldegree = [12, 10, 8], - pure = false, - pure2b = false, - pair_envelope = (:r, 1), - rcut = 5.5, - Eref = [:Si => -1.234 ] - ) - - -model1 = acemodel(; params...) - -params2 = (; params..., totaldegree = params.totaldegree .+ 0.1) -model2 = ACE1compat.ace1_model(; params...) -ps, st = Lux.setup(rng, model2) - -## - -@info("check the transform construction") - -rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -rbasis2 = model2.rbasis - -trans1 = rbasis1.trans.transforms[1] -trans2 = rbasis2.transforms[1] - -rr = params.rcut * rand(200) -t1 = ACE1.Transforms.transform.(Ref(trans1), rr) -t2 = trans2.(rr) -err_t1_t2 = maximum(abs.(t1 .- t2)) -println_slim(@test err_t1_t2 < 1e-12) - -# the envelope - check that the "choices" are the same -println_slim(@test rbasis1.envelope isa ACE1.OrthPolys.OneEnvelope) -println_slim(@test rbasis1.J.pl == rbasis1.J.pr == 2 ) -println_slim(@test rbasis2.envelopes[1].p1 == rbasis2.envelopes[1].p2 == 2) - -## - -@info("check full radial basis construction") -@info(" This error can be a bit larger since the jacobi basis used in ACE1 is constructed from a discrete measure") -@info("The first test checks Rn vs Rn0") -z1 = AtomicNumber(:Si) -z2 = Int(z1) -rp = range(0.0, params.rcut, length=200) -R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z1, z1) for r in rp ]) -R2 = reduce(hcat, [ rbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rp]) -maxn = size(R1, 1) -scal = [ maximum(R1[n,:]) / maximum(R2[n,:]) for n = 1:maxn ] -errs = maximum(abs, R1 - Diagonal(scal) * R2[1:maxn, :]; dims=2) -normalizederr = norm(errs ./ (1:length(errs)).^3, Inf) -@show normalizederr -println_slim(@test normalizederr < 1e-4) - -@info("The remaining checks are for Rn0 = Rnl") -for i_nl = 1:size(R2, 1) - n = rbasis2.spec[i_nl].n - print_tf(@test R2[i_nl, :] ≈ R2[n, :]) -end -println() - -## - -@info("Check the pair basis construction") -pairbasis1 = model1.basis.BB[1] -pairbasis2 = model2.pairbasis - -rr = range(0.001, params.rcut, length=200) -P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z1, z1) for r in rr ]) -P2 = reduce(hcat, [ pairbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rr]) -println_slim(@test size(P1) == size(P2)) - -nmax = size(P1, 1) -scal_pair = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] -P2 = Diagonal(scal_pair) * P2 -scal_err = abs( -(extrema(abs.(scal_pair))...) ) -@show scal_err -println_slim(@test scal_err < 0.01) - -err = maximum(abs, (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1); dims=2) -normalizederr = norm(err ./ (1:length(err)).^3, Inf) -@show normalizederr -println_slim(@test normalizederr < 1e-4) - -## - -@info("Check the bases span the same space") -@info(" check spec matches") -_spec1 = ACE1.get_nl(model1.basis.BB[2]) -spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] -spec2 = M.get_nnll_spec(model2.tensor) -println_slim(@test issubset(sort.(spec1), sort.(spec2))) - -Nenv = 1000 -XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] -XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] - -_evaluate(basis::JuLIP.MLIPs.IPSuperBasis, Rs, Zs, z0) = - reduce(vcat, [ACE1.evaluate(B, Rs, Zs, z0) for B in basis.BB]) - -B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) -B2 = reduce(hcat, [M.evaluate_basis(model2, x..., ps, st) for x in XX2]) -# Compute linear transform between bases to show match -# We want full-rank C such that C * B2 = B1 -C = B2' \ B1' -basiserr = norm(B1 - C' * B2, Inf) -@show basiserr -println_slim(@test basiserr < 1e-3) - -# Nmb = length(spec1) -# B1_mb = B1[end-Nmb+1:end, :] -# B2_mb = B2[1:Nmb, :] -# -# C_mb = B1_mb' \ B2_mb' -# @show norm(B2_mb - C_mb' * B1_mb, Inf) - -# B1_pair = B1[1:end-Nmb, :] -# B2_pair = B2[Nmb+1:end, :] -# C_pair = B1_pair' \ B2_pair' -# @show norm(B2_pair - C_pair' * B1_pair, Inf) - -## - -@info("Check E0s") -E0s1 = model1.Vref.E0[:Si] -E0s2 = model2.E0s[1] -println_slim(@test E0s1 ≈ E0s2) - -## - -@info("Set some random parameteres and check site energies") - -lenB1 = size(B1, 1) -θ1 = randn(lenB1) ./ (1:lenB1).^2 -θ2 = C * θ1 - -ACE1x._set_params!(model1, θ1) - -calc2 = M.ACEPotential(model2, ps, st) -M.set_parameters!(calc2, θ2) - - -JuLIP.evaluate(V::JuLIP.OneBody, Rs, Zs, z0) = - V.E0[JuLIP.Chemistry.chemical_symbol(z0)] - -_evaluate(pot::JuLIP.MLIPs.SumIP, Rs, Zs, z0) = - sum(JuLIP.evaluate(V, Rs, Zs, z0) for V in pot.components) - -V1 = [ _evaluate(model1.potential, x...) for x in XX1 ] -V2 = [ M.evaluate(calc2.model, x..., calc2.ps, calc2.st) for x in XX2 ] - -err = norm(V1 - V2, Inf) -@show err -println_slim(@test err < 1e-4) From 837c8c439b52869d6314570dfbd18be18e62bc23 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 3 Aug 2024 09:41:32 -0700 Subject: [PATCH 101/112] ACE1 style r envelope --- src/ace1_compat.jl | 34 ++++++++++++++++++++++++++++++- src/models/ace_heuristics.jl | 9 +++++++-- test/ace1/ace1_testutils.jl | 6 ++++-- test/ace1/test_ace1_compat.jl | 38 ++++++++++++++++++++++++++++------- test/runtests.jl | 3 +++ 5 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 82a2c806..c46f1a95 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -290,8 +290,11 @@ function _pair_basis(kwargs) rcutkey = :pair_rcut) # ENVELOPE - # here we use the same convention, so this is fine + # here we use a similar convention, just need to convert to ace1-style envelope = kwargs[:pair_envelope] + if envelope isa Tuple && envelope[1] == :r + envelope = (:r_ace1, envelope[2]) + end pair_basis = ace_learnable_Rnlrzz(; spec = pair_spec, maxq = maxn, @@ -375,6 +378,35 @@ function ace1_model(; kwargs...) end +# ------------------------- + +import ACEpotentials + +""" +The pair basis radial envelope implemented in ACE1.jl +""" +struct ACE1_PolyEnvelope1sR{T} + rcut::T + r0::T + p::Int +end + + +ACE1_PolyEnvelope1sR(rcut, r0, p) = + ACE1_PolyEnvelope1sR(rcut, r0, p, Dict{String, Any}()) + +function ACEpotentials.Models.evaluate(env::ACE1_PolyEnvelope1sR, r::T, x::T) where T + p, r0, rcut = env.p, env.r0, env.rcut + if r > rcut; return zero(T); end + s = r/r0; scut = rcut/r0 + return s^(-p) - scut^(-p) + p * scut^(-p-1) * (s - scut) +end + +ACEpotentials.Models.evaluate_d(env::ACE1_PolyEnvelope1sR, r::T, x::T) where {T} = + (ForwardDiff.derivative(x -> evaluate(env, x), r), + zero(T),) + + end diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 5be3e39c..c6ff3be6 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -1,4 +1,5 @@ -import Random +import Random +import ACEpotentials: ACE1compat # ------------------------------------------------------- @@ -71,17 +72,21 @@ function ace_learnable_Rnlrzz(; transforms = agnesi_transform.(rin0cuts, p, q) end + @show envelopes if envelopes == :poly2sx envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) elseif envelopes == :poly1sr envelopes = [ PolyEnvelope1sR(rin0cuts[iz, jz].rcut, 1) for iz = 1:NZ, jz = 1:NZ ] elseif envelopes isa Tuple && envelopes[1] == :x - @assert length(envelopes) == 2 + @assert length(envelopes) == 3 envelopes = PolyEnvelope2sX(-1.0, 1.0, envelopes[2], envelopes[3]) elseif envelopes isa Tuple && envelopes[1] == :r envelopes = [ PolyEnvelope1sR(rin0cuts[iz, jz].rcut, envelopes[2]) for iz = 1:NZ, jz = 1:NZ ] + elseif envelopes isa Tuple && envelopes[1] == :r_ace1 + envelopes = [ ACE1compat.ACE1_PolyEnvelope1sR(rin0cuts[iz, jz].rcut, rin0cuts[iz, jz].r0, envelopes[2]) + for iz = 1:NZ, jz = 1:NZ ] else error("cannot read envelope : $envelopes") end diff --git a/test/ace1/ace1_testutils.jl b/test/ace1/ace1_testutils.jl index 2939154b..0a9dabc7 100644 --- a/test/ace1/ace1_testutils.jl +++ b/test/ace1/ace1_testutils.jl @@ -3,7 +3,7 @@ module ACE1_TestUtils using Random, Test, ACEbase, LinearAlgebra, Lux using ACEbase.Testing: print_tf, println_slim -using ACE1, ACE1x, JuLIP +import ACE1, ACE1x, JuLIP using ACEpotentials M = ACEpotentials.Models @@ -62,6 +62,8 @@ function check_rbasis(model1, model2) @info("check radial basis") @info(" error can be a bit larger since the jacobi basis used in ACE1 is constructed from a discrete measure") @info("The first test checks Rn vs Rn0") + z1 = AtomicNumber(:Si) + z2 = Int(z1) rbasis1 = get_rbasis(model1) rbasis2 = get_rbasis(model2) rcut = JuLIP.cutoff(model1) @@ -106,7 +108,7 @@ function check_pairbasis(model1, model2) err = maximum(abs, (P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1); dims=2) normalizederr = norm(err ./ (1:length(err)).^3, Inf) @show normalizederr - println_slim(@test normalizederr < 1e-4) + println_slim(@test normalizederr < 1e-3) end diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index 02f21831..e1152b9e 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -4,16 +4,20 @@ using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) ## -using Plots -using Random, Test, ACEbase, LinearAlgebra, Lux -using ACEbase.Testing: print_tf, println_slim +include(@__DIR__() * "/ace1_testutils.jl") + +@info( +""" +================================== +=== Testing ACE1 compatibility === +================================== +""") -using ACEpotentials -M = ACEpotentials.Models -ACE1compat = ACEpotentials.ACE1compat -rng = Random.MersenneTwister(1234) ## +# [1] +# a first test that was used to write the original +# ACE1 compat module and tests params = ( elements = [:Si,], order = 3, @@ -26,6 +30,26 @@ params = ( elements = [:Si,], Eref = [:Si => -1.234 ] ) +ACE1_TestUtils.check_compat(params) + ## +# [2] +# same as [1] but with auto cutoff, a different pair envelope, +# different transforms, and a different chemical element, +# different choice of degrees +# + +params = ( elements = [:Si,], + order = 3, + transform = (:agnesi, 2, 4), + totaldegree = [14, 12, 10], + pure = false, + pure2b = false, + pair_transform = (:agnesi, 1, 3), + pair_envelope = (:r, 3), + rcut = 5.0, + Eref = [:Si => -1.234 ] + ) ACE1_TestUtils.check_compat(params) + diff --git a/test/runtests.jl b/test/runtests.jl index dae9c9fd..5359a72d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,9 @@ using ACEpotentials, Test, LazyArtifacts # weird stuff @testset "Weird bugs" begin include("test_bugs.jl") end + # ACE1 compatibility + @testset "ACE1 Compat" begin include("ace1/test_ace1_compat.jl"); end + # outdated # @testset "Read data" begin include("outdated/test_data.jl") end # @testset "Basis" begin include("outdated/test_basis.jl") end From a9fb55e8c7e884c1a541e399688ff1306c29cc24 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 3 Aug 2024 15:11:02 -0700 Subject: [PATCH 102/112] ACE1 compat - two more tests --- src/ace1_compat.jl | 30 +----------------------------- src/models/ace_heuristics.jl | 2 +- src/models/radial_envelopes.jl | 25 +++++++++++++++++++++++++ test/ace1/ace1_testutils.jl | 3 ++- test/ace1/test_ace1_compat.jl | 17 +++++++++++++++-- 5 files changed, 44 insertions(+), 33 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index c46f1a95..8a6834d2 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -130,7 +130,7 @@ end function _get_rcut(kwargs, s1, s2; _rcut = kwargs[:rcut]) if _rcut isa Tuple if _rcut[1] == :bondlen # rcut = (:bondlen, rcut_factor) - return _rcut[2] * get_r0(s1, s2) + return _rcut[2] * _get_r0(kwargs, s1, s2) end elseif _rcut isa Number # rcut = explicit value return _rcut @@ -378,34 +378,6 @@ function ace1_model(; kwargs...) end -# ------------------------- - -import ACEpotentials - -""" -The pair basis radial envelope implemented in ACE1.jl -""" -struct ACE1_PolyEnvelope1sR{T} - rcut::T - r0::T - p::Int -end - - -ACE1_PolyEnvelope1sR(rcut, r0, p) = - ACE1_PolyEnvelope1sR(rcut, r0, p, Dict{String, Any}()) - -function ACEpotentials.Models.evaluate(env::ACE1_PolyEnvelope1sR, r::T, x::T) where T - p, r0, rcut = env.p, env.r0, env.rcut - if r > rcut; return zero(T); end - s = r/r0; scut = rcut/r0 - return s^(-p) - scut^(-p) + p * scut^(-p-1) * (s - scut) -end - -ACEpotentials.Models.evaluate_d(env::ACE1_PolyEnvelope1sR, r::T, x::T) where {T} = - (ForwardDiff.derivative(x -> evaluate(env, x), r), - zero(T),) - end diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index c6ff3be6..ada1a190 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -85,7 +85,7 @@ function ace_learnable_Rnlrzz(; envelopes = [ PolyEnvelope1sR(rin0cuts[iz, jz].rcut, envelopes[2]) for iz = 1:NZ, jz = 1:NZ ] elseif envelopes isa Tuple && envelopes[1] == :r_ace1 - envelopes = [ ACE1compat.ACE1_PolyEnvelope1sR(rin0cuts[iz, jz].rcut, rin0cuts[iz, jz].r0, envelopes[2]) + envelopes = [ ACE1_PolyEnvelope1sR(rin0cuts[iz, jz].rcut, rin0cuts[iz, jz].r0, envelopes[2]) for iz = 1:NZ, jz = 1:NZ ] else error("cannot read envelope : $envelopes") diff --git a/src/models/radial_envelopes.jl b/src/models/radial_envelopes.jl index 5f7fa61a..3174e2bd 100644 --- a/src/models/radial_envelopes.jl +++ b/src/models/radial_envelopes.jl @@ -25,6 +25,31 @@ evaluate_d(env::PolyEnvelope1sR, r::T, x::T) where {T} = (ForwardDiff.derivative(x -> evaluate(env, x), r), zero(T),) +# ---------------------------- + +""" +The pair basis radial envelope implemented in ACE1.jl +""" +struct ACE1_PolyEnvelope1sR{T} + rcut::T + r0::T + p::Int +end + + +ACE1_PolyEnvelope1sR(rcut, r0, p) = + ACE1_PolyEnvelope1sR(rcut, r0, p, Dict{String, Any}()) + +function evaluate(env::ACE1_PolyEnvelope1sR, r::T, x::T) where T + p, r0, rcut = env.p, env.r0, env.rcut + if r > rcut; return zero(T); end + s = r/r0; scut = rcut/r0 + return s^(-p) - scut^(-p) + p * scut^(-p-1) * (s - scut) +end + +evaluate_d(env::ACE1_PolyEnvelope1sR, r::T, x::T) where {T} = + (ForwardDiff.derivative(x -> evaluate(env, x), r), + zero(T),) # ---------------------------- diff --git a/test/ace1/ace1_testutils.jl b/test/ace1/ace1_testutils.jl index 0a9dabc7..da838377 100644 --- a/test/ace1/ace1_testutils.jl +++ b/test/ace1/ace1_testutils.jl @@ -131,6 +131,7 @@ function check_basis(model1, model2; Nenv = :auto) @show Nenv end + Random.seed!(12345) XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] @@ -142,7 +143,7 @@ function check_basis(model1, model2; Nenv = :auto) C = B2' \ B1' basiserr = norm(B1 - C' * B2, Inf) @show basiserr - println_slim(@test basiserr < 1e-3) + println_slim(@test basiserr < .3e-2) # # some more fine-grained checks for debugging # Nmb = length(spec1) diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index e1152b9e..c99f6eeb 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); ## @@ -47,9 +47,22 @@ params = ( elements = [:Si,], pure2b = false, pair_transform = (:agnesi, 1, 3), pair_envelope = (:r, 3), - rcut = 5.0, Eref = [:Si => -1.234 ] ) ACE1_TestUtils.check_compat(params) +## +# [3] +# A minimal example with as many defaults as possible +# + +params = ( elements = [:Si,], + order = 2, + totaldegree = 10, + pure = false, + pure2b = false, + ) + +ACE1_TestUtils.check_compat(params) + From 2f1d00c18435029102f2a56f8e867e21e07fc190 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sat, 3 Aug 2024 21:31:46 -0700 Subject: [PATCH 103/112] replace :linear with :onehot --- docs/src/newkernels/acebasis_analysis.jl | 4 +- docs/src/newkernels/linear.jl | 4 +- docs/src/newkernels/linear_nonlinear.jl | 4 +- docs/src/newkernels/match_bases.jl | 4 +- docs/src/newkernels/rbasis_analysis.jl | 4 +- src/ace1_compat.jl | 10 ++-- src/models/Rnl_learnable.jl | 2 +- src/models/ace_heuristics.jl | 3 +- test/ace1/ace1_testutils.jl | 11 ++-- test/ace1/test_ace1_compat.jl | 69 ++++++++++++++++++++++-- test/models/test_radialweights.jl | 4 +- 11 files changed, 92 insertions(+), 27 deletions(-) diff --git a/docs/src/newkernels/acebasis_analysis.jl b/docs/src/newkernels/acebasis_analysis.jl index d4856169..883758ed 100644 --- a/docs/src/newkernels/acebasis_analysis.jl +++ b/docs/src/newkernels/acebasis_analysis.jl @@ -53,8 +53,8 @@ model2 = M.ace_model(; elements = elements, max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = "linear", # how to initialize the pair potential parameters - init_Wradial = :linear, + init_Wpair = :onehot, # how to initialize the pair potential parameters + init_Wradial = :onehot, pair_transform = (:agnesi, 1, 3), pair_learnable = true, rin0cuts = rin0cuts, diff --git a/docs/src/newkernels/linear.jl b/docs/src/newkernels/linear.jl index a51b292b..282822d9 100644 --- a/docs/src/newkernels/linear.jl +++ b/docs/src/newkernels/linear.jl @@ -33,8 +33,8 @@ model2 = M.ace_model(; elements = elements, max_level = totaldegree+1, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = "linear", # how to initialize the pair potential parameters - init_Wradial = :linear, + init_Wpair = :onehot, # how to initialize the pair potential parameters + init_Wradial = :onehot, pair_transform = (:agnesi, 1, 3), pair_learnable = false, ) diff --git a/docs/src/newkernels/linear_nonlinear.jl b/docs/src/newkernels/linear_nonlinear.jl index aacf5680..a5160ed8 100644 --- a/docs/src/newkernels/linear_nonlinear.jl +++ b/docs/src/newkernels/linear_nonlinear.jl @@ -33,8 +33,8 @@ model2 = M.ace_model(; elements = elements, max_level = totaldegree+1, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = "linear", # how to initialize the pair potential parameters - init_Wradial = :linear, + init_Wpair = :onehot, # how to initialize the pair potential parameters + init_Wradial = :onehot, pair_transform = (:agnesi, 1, 3), pair_learnable = false, ) diff --git a/docs/src/newkernels/match_bases.jl b/docs/src/newkernels/match_bases.jl index 6c7d68b5..2f626b98 100644 --- a/docs/src/newkernels/match_bases.jl +++ b/docs/src/newkernels/match_bases.jl @@ -29,8 +29,8 @@ function matching_bases(; Z = :Si, order = 3, totaldegree = 10, max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = "linear", # how to initialize the pair potential parameters - init_Wradial = :linear, + init_Wpair = :onehot, # how to initialize the pair potential parameters + init_Wradial = :onehot, pair_transform = (:agnesi, 1, 3), pair_learnable = true, rin0cuts = rin0cuts, diff --git a/docs/src/newkernels/rbasis_analysis.jl b/docs/src/newkernels/rbasis_analysis.jl index 3a1dc505..e4850a36 100644 --- a/docs/src/newkernels/rbasis_analysis.jl +++ b/docs/src/newkernels/rbasis_analysis.jl @@ -53,8 +53,8 @@ model2 = M.ace_model(; elements = elements, max_level = totaldegree, # maximum level of the basis functions pair_maxn = totaldegree, # maximum number of basis functions for the pair potential init_WB = :zeros, # how to initialize the ACE basis parmeters - init_Wpair = "linear", # how to initialize the pair potential parameters - init_Wradial = :linear, + init_Wpair = :onehot, # how to initialize the pair potential parameters + init_Wradial = :onehot, pair_transform = (:agnesi, 1, 3), pair_learnable = true, rin0cuts = rin0cuts, diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 8a6834d2..20d76e2a 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -99,8 +99,9 @@ function _get_degrees(kwargs) end wL = kwargs[:wL] + NZ = length(_get_elements(kwargs)) - return Models.TotalDegree(1.0, 1/wL), maxlevels + return Models.TotalDegree(1.0*NZ, 1/wL), maxlevels end error("Cannot determine total degree of ACE basis from the arguments provided.") @@ -206,7 +207,8 @@ end function _get_Rnl_spec(kwargs, maxdeg = maximum(kwargs[:totaldegree]) ) wL = kwargs[:wL] - lvl = Models.TotalDegree(1.0, 1/wL) + NZ = length(_get_elements(kwargs)) + lvl = Models.TotalDegree(1.0*NZ, 1/wL) return Models.oneparticle_spec(lvl, maxdeg) end @@ -257,7 +259,7 @@ function _radial_basis(kwargs) rin0cuts = rin0cuts, transforms = trans_ace, polys = polys, - Winit = "linear") + Winit = :onehot) ps_Rn = Models.initialparameters(nothing, Rn_basis) Rn_spl = Models.splinify(Rn_basis, ps_Rn) @@ -303,7 +305,7 @@ function _pair_basis(kwargs) transforms = trans_pair, envelopes = envelope, polys = :legendre, - Winit = "linear" ) + Winit = :onehot ) ps_pair = Models.initialparameters(nothing, pair_basis) pair_spl = Models.splinify(pair_basis, ps_pair) return pair_spl diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 396583ce..07b37075 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -43,7 +43,7 @@ function initialparameters(rng::Union{AbstractRNG, Nothing}, Wnlq[:, :, i, j] .= glorot_normal(rng, Float64, len_nl, len_q) end - elseif basis.meta["Winit"] == "linear" + elseif basis.meta["Winit"] == :onehot set_I_weights!(basis, ps) elseif basis.meta["Winit"] == "zeros" diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index ada1a190..14be5b6a 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -72,7 +72,6 @@ function ace_learnable_Rnlrzz(; transforms = agnesi_transform.(rin0cuts, p, q) end - @show envelopes if envelopes == :poly2sx envelopes = PolyEnvelope2sX(-1.0, 1.0, 2, 2) elseif envelopes == :poly1sr @@ -124,7 +123,7 @@ function ace_model(; elements = nothing, pair_basis = :auto, pair_learnable = false, pair_transform = (:agnesi, 1, 4), - init_Wpair = "linear", + init_Wpair = :onehot, rng = Random.default_rng(), ) diff --git a/test/ace1/ace1_testutils.jl b/test/ace1/ace1_testutils.jl index da838377..e594b9b0 100644 --- a/test/ace1/ace1_testutils.jl +++ b/test/ace1/ace1_testutils.jl @@ -62,7 +62,7 @@ function check_rbasis(model1, model2) @info("check radial basis") @info(" error can be a bit larger since the jacobi basis used in ACE1 is constructed from a discrete measure") @info("The first test checks Rn vs Rn0") - z1 = AtomicNumber(:Si) + z1 = model1.basis.BB[2].pibasis.zlist.list[1] z2 = Int(z1) rbasis1 = get_rbasis(model1) rbasis2 = get_rbasis(model2) @@ -91,7 +91,7 @@ function check_pairbasis(model1, model2) pairbasis1 = model1.basis.BB[1] pairbasis2 = model2.pairbasis rcut = JuLIP.cutoff(model1) - z1 = AtomicNumber(:Si) + z1 = model1.basis.BB[2].pibasis.zlist.list[1] z2 = Int(z1) rr = range(0.001, rcut, length=200) P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z1, z1) for r in rr ]) @@ -116,7 +116,7 @@ function check_basis(model1, model2; Nenv = :auto) ps, st = Lux.setup(rng, model2) @info("Check the bases span the same space") - NZ = M._get_nz(model2); @assert NZ == 1 + NZ = M._get_nz(model2) if NZ == 1 @info(" NZ == 1 >>> check spec matches") _spec1 = ACE1.get_nl(model1.basis.BB[2]) @@ -137,6 +137,11 @@ function check_basis(model1, model2; Nenv = :auto) B1 = reduce(hcat, [ _evaluate(model1.basis, x...) for x in XX1]) B2 = reduce(hcat, [ M.evaluate_basis(model2, x..., ps, st) for x in XX2]) + + if size(B2, 1) < size(B1, 1) + error("ACE1 compat : ACE2 model must be at least as large as ACE1 model; aborting tests.") + end + @info("Compute linear transform between bases to show match") # We want full-rank C such that C * B2 = B1 # (note this allows B2 > B1) diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index c99f6eeb..72fda21a 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -36,10 +36,10 @@ ACE1_TestUtils.check_compat(params) # [2] # same as [1] but with auto cutoff, a different pair envelope, # different transforms, and a different chemical element, -# different choice of degrees +# different choice of degrees, different element # -params = ( elements = [:Si,], +params = ( elements = [:C,], order = 3, transform = (:agnesi, 2, 4), totaldegree = [14, 12, 10], @@ -47,17 +47,17 @@ params = ( elements = [:Si,], pure2b = false, pair_transform = (:agnesi, 1, 3), pair_envelope = (:r, 3), - Eref = [:Si => -1.234 ] + Eref = [:C => -1.234 ] ) ACE1_TestUtils.check_compat(params) ## # [3] -# A minimal example with as many defaults as possible +# A minimal example with as many defaults as possible # -params = ( elements = [:Si,], +params = ( elements = [:W,], order = 2, totaldegree = 10, pure = false, @@ -66,3 +66,62 @@ params = ( elements = [:Si,], ACE1_TestUtils.check_compat(params) +## +# [4] +# A first multi-species example + +params = ( elements = [:Al, :Ti,], + order = 3, + totaldegree = 8, + pure = false, + pure2b = false, + ) + +ACE1_TestUtils.check_compat(params) + +## + +using Random, Test, ACEbase, LinearAlgebra, Lux +using ACEbase.Testing: print_tf, println_slim +import ACE1, ACE1x, JuLIP + +using ACEpotentials +M = ACEpotentials.Models +ACE1compat = ACEpotentials.ACE1compat +rng = Random.MersenneTwister(1234) + +## + +params = ( elements = [:Al, :Ti,], + order = 3, + totaldegree = 7, + pure = false, + pure2b = false, + ) + +model1 = acemodel(; params...) +params2 = (; params..., totaldegree = params.totaldegree .+ 0.1) +model2 = ACE1compat.ace1_model(; params...) +ps, st = Lux.setup(rng, model2) + +lenB1 = length(model1.basis) +Nenv = 10 * lenB1 + +Random.seed!(12345) +XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] +XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] + +B1 = reduce(hcat, [ ACE1_TestUtils._evaluate(model1.basis, x...) for x in XX1]) +B2 = reduce(hcat, [ M.evaluate_basis(model2, x..., ps, st) for x in XX2]) + +@info("Compute linear transform between bases to show match") +# We want full-rank C such that C * B2 = B1 +# (note this allows B2 > B1) +C = B2' \ B1' +basiserr = norm(B1 - C' * B2, Inf) +@show basiserr +println_slim(@test basiserr < .3e-2) + + +# ACE1_TestUtils.check_basis(model1, model2) + diff --git a/test/models/test_radialweights.jl b/test/models/test_radialweights.jl index 7cb8bd0b..34d081b6 100644 --- a/test/models/test_radialweights.jl +++ b/test/models/test_radialweights.jl @@ -22,7 +22,7 @@ model = M.ace_model(; elements = elements, order = order, Ytype = :solid, maxq_fact = maxq_fact, init_WB = :zeros, init_Wpair = :zeros, - init_Wradial = :linear) + init_Wradial = :onehot) ps, st = LuxCore.setup(rng, model) @@ -42,7 +42,7 @@ model = M.ace_model(; elements = elements, order = order, Ytype = :solid, pair_maxn = 15, init_WB = :zeros, init_Wpair = :zeros, - init_Wradial = :linear) + init_Wradial = :onehot) ps, st = LuxCore.setup(rng, model) From 160effcf5bae14a2e58c5ea3b545135cac2fc14e Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 4 Aug 2024 07:33:15 -0700 Subject: [PATCH 104/112] ACE1-compat, multispecies example --- src/models/Rnl_learnable.jl | 20 ++++++++---- test/ace1/test_ace1_compat.jl | 60 ++++++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 07b37075..9f6e22a9 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -43,8 +43,8 @@ function initialparameters(rng::Union{AbstractRNG, Nothing}, Wnlq[:, :, i, j] .= glorot_normal(rng, Float64, len_nl, len_q) end - elseif basis.meta["Winit"] == :onehot - set_I_weights!(basis, ps) + elseif basis.meta["Winit"] == "onehot" + set_onehot_weights!(basis, ps) elseif basis.meta["Winit"] == "zeros" @warn("Setting inner basis weights to zero.") @@ -73,15 +73,21 @@ end """ Set the radial weights as they would be in a linear ACE model. """ -function set_I_weights!(rbasis::LearnableRnlrzzBasis, ps) +function set_onehot_weights!(rbasis::LearnableRnlrzzBasis, ps) # Rnl(r, Z1, Z2) = ∑_q W[(nl), q, Z1, Z2] * P_q(r) - # For linear models this becomes Rnl(r, Z1, Z2) = Pn(r) + # For linear models this becomes R(n'z')l(r, Z1, Z2) = Pn'(r) * δ_{z',Z2} + # here, Z1 is the center atom, Z2 the neighbour atom. NZ = _get_nz(rbasis) ps.Wnlq[:] .= 0 - for i = 1:NZ, j = 1:NZ + for iz1 = 1:NZ, iz2 = 1:NZ for (i_nl, nl) in enumerate(rbasis.spec) - if nl.n <= size(ps.Wnlq, 2) - ps.Wnlq[i_nl, nl.n, i, j] = 1 + # n | 1 2 3 4 5 6 7 8 ... + # n'z' | 1,1 1,2 1,3 2,1 2,2 2,3 3,1 3,2 ... + # => z' = mod1(n, NZ), n' = div(n-1, NZ) + 1 + z_ = mod1(nl.n, NZ) + n_ = div(nl.n - 1, NZ) + 1 + if z_ == iz2 && n_ <= size(ps.Wnlq, 2) + ps.Wnlq[i_nl, n_, iz1, iz2] = 1 end end end diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index 72fda21a..b32fb50c 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -68,7 +68,7 @@ ACE1_TestUtils.check_compat(params) ## # [4] -# A first multi-species example +# First multi-species examples params = ( elements = [:Al, :Ti,], order = 3, @@ -79,9 +79,23 @@ params = ( elements = [:Al, :Ti,], ACE1_TestUtils.check_compat(params) +## [5] +# second multi-species example with three elements +# and a few small changes to the basis + +params = ( elements = [:Al, :Ti, :C], + order = 2, + totaldegree = 6, + pure = false, + pure2b = false, + ) + +ACE1_TestUtils.check_compat(params) + + ## -using Random, Test, ACEbase, LinearAlgebra, Lux +using Random, Test, ACEbase, LinearAlgebra, Lux, Plots using ACEbase.Testing: print_tf, println_slim import ACE1, ACE1x, JuLIP @@ -92,15 +106,15 @@ rng = Random.MersenneTwister(1234) ## -params = ( elements = [:Al, :Ti,], - order = 3, - totaldegree = 7, +params = ( elements = [:Al, :Ti, :Cu], + order = 2, + totaldegree = 6, pure = false, pure2b = false, ) model1 = acemodel(; params...) -params2 = (; params..., totaldegree = params.totaldegree .+ 0.1) +params2 = (; params..., totaldegree = params.totaldegree .+ 1) model2 = ACE1compat.ace1_model(; params...) ps, st = Lux.setup(rng, model2) @@ -120,8 +134,38 @@ B2 = reduce(hcat, [ M.evaluate_basis(model2, x..., ps, st) for x in XX2]) C = B2' \ B1' basiserr = norm(B1 - C' * B2, Inf) @show basiserr -println_slim(@test basiserr < .3e-2) +# println_slim(@test basiserr < .3e-2) + +## + +rbasis1 = model1.basis.BB[2].pibasis.basis1p.J +rbasis2 = model2.rbasis +z11 = AtomicNumber(:Al) +z12 = Int(z1) +z21 = AtomicNumber(:Ti) +z22 = Int(z21) + +rr = range(0.001, 5.0, length=200) +R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z11, z21) for r in rr]) +R2 = reduce(hcat, [ rbasis2(r, z12, z22, NamedTuple(), NamedTuple()) for r in rr]) + +# alternating basis functions must be zero. +for n = 1:6 + @assert norm(R2[3*(n-1) + 1, :]) == 0 + @assert norm(R2[3*(n-1) + 3, :]) == 0 + @assert norm(R2[3*(n-1) + 2, :]*sqrt(2) - R1[n, :])/n^2 < 1e-3 +end + +for (i_nl, nl) in enumerate(rbasis2.spec) + @assert R2[i_nl, :] ≈ R2[nl.n, :] +end + +# plt = plot() +# for n = 1:6 +# plot!(R1[n, :], c = n, label = "R1,$n") +# plot!(R2[3*(n-1)+2, :]*sqrt(2), c = n, ls = :dash, label = "R2,$n") +# end +# plt -# ACE1_TestUtils.check_basis(model1, model2) From 662c1354232513c6e49a90cf215ba38c77e5bcf7 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Sun, 4 Aug 2024 08:47:26 -0700 Subject: [PATCH 105/112] failing 3-species test --- src/ace1_compat.jl | 9 ++++++--- src/models/ace_heuristics.jl | 8 ++++---- test/ace1/test_ace1_compat.jl | 17 ++++++++++++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 20d76e2a..51b2fafb 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -273,15 +273,18 @@ end function _pair_basis(kwargs) rbasis = kwargs[:pair_basis] elements = _get_elements(kwargs) + NZ = length(elements) rin0cuts = _ace1_rin0cuts(kwargs; rcutkey = :pair_rcut) if rbasis == :legendre # SPECIFICATION if kwargs[:pair_degree] == :totaldegree - maxn = maximum(kwargs[:totaldegree]) + maxq = maximum(kwargs[:totaldegree]) + maxn = maxq * NZ elseif kwargs[:pair_degree] isa Integer - maxn = kwargs[:pair_degree] + maxq = kwargs[:pair_degree] + maxn = maxq * NZ else error("Cannot determine `maxn` for pair basis from information provided.") end @@ -299,7 +302,7 @@ function _pair_basis(kwargs) end pair_basis = ace_learnable_Rnlrzz(; spec = pair_spec, - maxq = maxn, + maxq = maxq, elements = elements, rin0cuts = rin0cuts, transforms = trans_pair, diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 14be5b6a..6eed6fc0 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -46,8 +46,8 @@ function ace_learnable_Rnlrzz(; maxq = ceil(Int, actual_maxn * maxq_fact) end - if maxq < actual_maxn - @warn("maxq < actual_maxn; this results in linear dependence") + if maxq < actual_maxn / NZ + @warn("maxq < actual_maxn / NZ; likely linear dependence") end if polys isa Symbol @@ -90,8 +90,8 @@ function ace_learnable_Rnlrzz(; error("cannot read envelope : $envelopes") end - if actual_maxn > length(polys) - error("actual_maxn > length of polynomial basis") + if actual_maxn > length(polys) * NZ + @warn("actual_maxn/NZ > maxq; likely linear dependence") end return LearnableRnlrzzBasis(zlist, polys, transforms, envelopes, diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index b32fb50c..6aa7894d 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -115,7 +115,7 @@ params = ( elements = [:Al, :Ti, :Cu], model1 = acemodel(; params...) params2 = (; params..., totaldegree = params.totaldegree .+ 1) -model2 = ACE1compat.ace1_model(; params...) +model2 = ACE1compat.ace1_model(; params2...) ps, st = Lux.setup(rng, model2) lenB1 = length(model1.basis) @@ -144,9 +144,11 @@ z11 = AtomicNumber(:Al) z12 = Int(z1) z21 = AtomicNumber(:Ti) z22 = Int(z21) +z31 = AtomicNumber(:Cu) +z32 = Int(z31) -rr = range(0.001, 5.0, length=200) -R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z11, z21) for r in rr]) +rr = range(0.001, 6.0, length=200) +R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z21, z11) for r in rr]) R2 = reduce(hcat, [ rbasis2(r, z12, z22, NamedTuple(), NamedTuple()) for r in rr]) # alternating basis functions must be zero. @@ -167,5 +169,14 @@ end # end # plt +## + +pairbasis1 = model1.basis.BB[1] +pairbasis2 = model2.pairbasis +rr = range(1.67, 6.0, length=200) +P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z31, z21) for r in rr]) +P2 = reduce(hcat, [ pairbasis2(r, z22, z32, NamedTuple(), NamedTuple()) for r in rr]) +scal = Diagonal([(-1)^(n+1) for n = 1:6]/sqrt(2)) +norm(P1[25:30, :] - scal * P2[3:3:18, :], Inf) From e05fd039d1fef8fd40218bbab4ff8dcb926932e1 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Mon, 5 Aug 2024 10:14:39 -0700 Subject: [PATCH 106/112] final fix: bring variablecutoffs back --- src/ace1_compat.jl | 13 ++--- src/models/ace_heuristics.jl | 4 +- src/models/smoothness_priors.jl | 2 +- test/ace1/ace1_testutils.jl | 8 +-- test/ace1/test_ace1_compat.jl | 95 ++------------------------------- 5 files changed, 20 insertions(+), 102 deletions(-) diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 51b2fafb..5b2bd9c8 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -36,6 +36,8 @@ const _kw_defaults = Dict(:elements => nothing, :pair_envelope => (:r, 2), # :Eref => missing, + # + :variable_cutoffs => false, ) const _kw_aliases = Dict( :N => :order, @@ -63,10 +65,6 @@ function _clean_args(kwargs) dargs[:pair_rcut] = dargs[:rcut] end - if haskey(dargs, :variable_cutoffs) - @warn("variable_cutoffs argument is ignored") - end - if kwargs[:pure2b] || kwargs[:pure] error("ACE1compat current does not support `pure2b` or `pure` options.") end @@ -148,6 +146,9 @@ function _get_all_rcut(kwargs; _rcut = kwargs[:rcut]) elements = _get_elements(kwargs) rcut = Dict( [ (s1, s2) => _get_rcut(kwargs, s1, s2; _rcut = _rcut) for s1 in elements, s2 in elements]... ) + if !kwargs[:variable_cutoffs] + rcut = maximum(values(rcut)) + end return rcut end @@ -280,10 +281,10 @@ function _pair_basis(kwargs) # SPECIFICATION if kwargs[:pair_degree] == :totaldegree - maxq = maximum(kwargs[:totaldegree]) + maxq = ceil(Int, maximum(kwargs[:totaldegree])) maxn = maxq * NZ elseif kwargs[:pair_degree] isa Integer - maxq = kwargs[:pair_degree] + maxq = ceil(Int, kwargs[:pair_degree]) maxn = maxq * NZ else error("Cannot determine `maxn` for pair basis from information provided.") diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index 6eed6fc0..099c1b98 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -52,7 +52,7 @@ function ace_learnable_Rnlrzz(; if polys isa Symbol if polys == :legendre - polys = Polynomials4ML.legendre_basis(maxq) + polys = Polynomials4ML.legendre_basis(ceil(Int, maxq)) else error("unknown polynomial type : $polys") end @@ -60,7 +60,7 @@ function ace_learnable_Rnlrzz(; if polys[1] == :jacobi α = polys[2] β = polys[3] - polys = Polynomials4ML.jacobi_basis(maxq, α, β) + polys = Polynomials4ML.jacobi_basis(ceil(Int, maxq), α, β) else error("unknown polynomial type : $polys") end diff --git a/src/models/smoothness_priors.jl b/src/models/smoothness_priors.jl index e873c8c2..78bb605e 100644 --- a/src/models/smoothness_priors.jl +++ b/src/models/smoothness_priors.jl @@ -13,7 +13,7 @@ end TotalDegree() = TotalDegree(1.0, 2/3) -(l::TotalDegree)(b::NamedTuple) = b.n/l.wn + b.l/l.wl +(l::TotalDegree)(b::NamedTuple) = b.n / l.wn + b.l/l.wl (l::TotalDegree)(bb::AbstractVector{<: NamedTuple}) = sum(l(b) for b in bb) diff --git a/test/ace1/ace1_testutils.jl b/test/ace1/ace1_testutils.jl index e594b9b0..b40d8b7d 100644 --- a/test/ace1/ace1_testutils.jl +++ b/test/ace1/ace1_testutils.jl @@ -96,11 +96,11 @@ function check_pairbasis(model1, model2) rr = range(0.001, rcut, length=200) P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z1, z1) for r in rr ]) P2 = reduce(hcat, [ pairbasis2(r, z2, z2, NamedTuple(), NamedTuple()) for r in rr]) - println_slim(@test size(P1) == size(P2)) + println_slim(@test size(P1) <= size(P2)) nmax = size(P1, 1) scal_pair = [ sum(P1[n, 70:end]) / sum(P2[n, 70:end]) for n = 1:nmax ] - P2 = Diagonal(scal_pair) * P2 + P2 = Diagonal(scal_pair) * P2[1:nmax, :] scal_err = abs( -(extrema(abs.(scal_pair))...) ) @show scal_err println_slim(@test scal_err < 0.01) @@ -139,6 +139,8 @@ function check_basis(model1, model2; Nenv = :auto) B2 = reduce(hcat, [ M.evaluate_basis(model2, x..., ps, st) for x in XX2]) if size(B2, 1) < size(B1, 1) + @show size(B1, 1) + @show size(B2, 1) error("ACE1 compat : ACE2 model must be at least as large as ACE1 model; aborting tests.") end @@ -185,7 +187,7 @@ end function check_compat(params; deginc = 0.1) model1 = acemodel(; params...) params2 = (; params..., totaldegree = params.totaldegree .+ deginc) - model2 = ACE1compat.ace1_model(; params...) + model2 = ACE1compat.ace1_model(; params2...) NZ = length(params.elements) if NZ == 1 diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index 6aa7894d..9476ef84 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -70,7 +70,10 @@ ACE1_TestUtils.check_compat(params) # [4] # First multi-species examples -params = ( elements = [:Al, :Ti,], +# NB : Ti, Al is a bad example because default bondlengths +# are the same. This can avoid some non-trivial behaviour + +params = ( elements = [:Al, :Cu,], order = 3, totaldegree = 8, pure = false, @@ -85,98 +88,10 @@ ACE1_TestUtils.check_compat(params) params = ( elements = [:Al, :Ti, :C], order = 2, - totaldegree = 6, + totaldegree = 8, pure = false, pure2b = false, ) ACE1_TestUtils.check_compat(params) - -## - -using Random, Test, ACEbase, LinearAlgebra, Lux, Plots -using ACEbase.Testing: print_tf, println_slim -import ACE1, ACE1x, JuLIP - -using ACEpotentials -M = ACEpotentials.Models -ACE1compat = ACEpotentials.ACE1compat -rng = Random.MersenneTwister(1234) - -## - -params = ( elements = [:Al, :Ti, :Cu], - order = 2, - totaldegree = 6, - pure = false, - pure2b = false, - ) - -model1 = acemodel(; params...) -params2 = (; params..., totaldegree = params.totaldegree .+ 1) -model2 = ACE1compat.ace1_model(; params2...) -ps, st = Lux.setup(rng, model2) - -lenB1 = length(model1.basis) -Nenv = 10 * lenB1 - -Random.seed!(12345) -XX2 = [ M.rand_atenv(model2, rand(6:10)) for _=1:Nenv ] -XX1 = [ (x[1], AtomicNumber.(x[2]), AtomicNumber(x[3])) for x in XX2 ] - -B1 = reduce(hcat, [ ACE1_TestUtils._evaluate(model1.basis, x...) for x in XX1]) -B2 = reduce(hcat, [ M.evaluate_basis(model2, x..., ps, st) for x in XX2]) - -@info("Compute linear transform between bases to show match") -# We want full-rank C such that C * B2 = B1 -# (note this allows B2 > B1) -C = B2' \ B1' -basiserr = norm(B1 - C' * B2, Inf) -@show basiserr -# println_slim(@test basiserr < .3e-2) - -## - -rbasis1 = model1.basis.BB[2].pibasis.basis1p.J -rbasis2 = model2.rbasis -z11 = AtomicNumber(:Al) -z12 = Int(z1) -z21 = AtomicNumber(:Ti) -z22 = Int(z21) -z31 = AtomicNumber(:Cu) -z32 = Int(z31) - -rr = range(0.001, 6.0, length=200) -R1 = reduce(hcat, [ ACE1.evaluate(rbasis1, r, z21, z11) for r in rr]) -R2 = reduce(hcat, [ rbasis2(r, z12, z22, NamedTuple(), NamedTuple()) for r in rr]) - -# alternating basis functions must be zero. -for n = 1:6 - @assert norm(R2[3*(n-1) + 1, :]) == 0 - @assert norm(R2[3*(n-1) + 3, :]) == 0 - @assert norm(R2[3*(n-1) + 2, :]*sqrt(2) - R1[n, :])/n^2 < 1e-3 -end - -for (i_nl, nl) in enumerate(rbasis2.spec) - @assert R2[i_nl, :] ≈ R2[nl.n, :] -end - -# plt = plot() -# for n = 1:6 -# plot!(R1[n, :], c = n, label = "R1,$n") -# plot!(R2[3*(n-1)+2, :]*sqrt(2), c = n, ls = :dash, label = "R2,$n") -# end -# plt - -## - -pairbasis1 = model1.basis.BB[1] -pairbasis2 = model2.pairbasis - -rr = range(1.67, 6.0, length=200) -P1 = reduce(hcat, [ ACE1.evaluate(pairbasis1, r, z31, z21) for r in rr]) -P2 = reduce(hcat, [ pairbasis2(r, z22, z32, NamedTuple(), NamedTuple()) for r in rr]) - -scal = Diagonal([(-1)^(n+1) for n = 1:6]/sqrt(2)) -norm(P1[25:30, :] - scal * P2[3:3:18, :], Inf) From f6d42d3878b72ba4cc09532116244d2df710ddf7 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 6 Aug 2024 21:37:08 -0700 Subject: [PATCH 107/112] get tests to pass --- src/io.jl | 18 +++++++++-- src/models/Rnl_learnable.jl | 19 +++++++---- src/models/ace.jl | 6 ++-- test/ace1/test_ace1_compat.jl | 2 +- test/models/test_Rnl.jl | 2 +- test/models/test_ace.jl | 8 ++--- test/test_silicon.jl | 59 ++++++++++++++++++----------------- 7 files changed, 67 insertions(+), 47 deletions(-) diff --git a/src/io.jl b/src/io.jl index 6b19f952..affc8870 100644 --- a/src/io.jl +++ b/src/io.jl @@ -79,6 +79,7 @@ function save_potential(fname, potential; save_version_numbers=true, meta=nothin ) end if !isnothing(meta) + @show meta @assert isa(meta, Dict{String, <:Any}) "meta needs to be a Dict{String, Any}" data["meta"] = convert(Dict{String, Any}, meta) end @@ -88,9 +89,20 @@ end # used to extraction version numbers when saving function extract_version(name::AbstractString) - vals = Pkg.dependencies()|> values |> collect - hit = filter(x->x.name==name, vals) |> only - return hit.version + try + vals = Pkg.dependencies()|> values |> collect + hit = filter(x->x.name==name, vals) |> only + return hit.version + catch + try + if name == Pkg.project().name + return Pkg.project().version + end + catch + @error("Couldn't determine version of $name") + return v"0.0.0" + end + end end diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index 9f6e22a9..be0d04f8 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -24,8 +24,13 @@ end Base.length(basis::LearnableRnlrzzBasis) = length(basis.spec) -function initialparameters(rng::Union{AbstractRNG, Nothing}, - basis::LearnableRnlrzzBasis) +initialparameters(::Nothing, basis::LearnableRnlrzzBasis) = + _initialparameters(nothing, basis) + +initialparameters(rng::AbstractRNG, basis::LearnableRnlrzzBasis) = + _initialparameters(rng, basis) + +function _initialparameters(rng, basis::LearnableRnlrzzBasis) NZ = _get_nz(basis) len_nl = length(basis) len_q = length(basis.polys) @@ -37,21 +42,23 @@ function initialparameters(rng::Union{AbstractRNG, Nothing}, @warn("No key Winit found for radial basis, use glorot_normal to initialize.") basis.meta["Winit"] = "glorot_normal" end + + Winit = String(basis.meta["Winit"]) - if basis.meta["Winit"] == "glorot_normal" + if Winit == "glorot_normal" for i = 1:NZ, j = 1:NZ Wnlq[:, :, i, j] .= glorot_normal(rng, Float64, len_nl, len_q) end - elseif basis.meta["Winit"] == "onehot" + elseif Winit == "onehot" set_onehot_weights!(basis, ps) - elseif basis.meta["Winit"] == "zeros" + elseif Winit == "zero" @warn("Setting inner basis weights to zero.") Wnlq[:] .= 0 else - error("Unknown key Winit = $(basis.meta["Winit"]) to initialize radial basis weights.") + error("Unknown key Winit = $(Winit) to initialize radial basis weights.") end return ps diff --git a/src/models/ace.jl b/src/models/ace.jl index b8f1e52a..3deb15ce 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -216,9 +216,9 @@ function LuxCore.parameterlength(model::ACEModel) return NZ^2 * length(model.pairbasis) + NZ * length(model.tensor) end -function splinify(model::ACEModel, ps::NamedTuple) - rbasis_spl = splinify(model.rbasis, ps.rbasis) - pairbasis_spl = splinify(model.pairbasis, ps.pairbasis) +function splinify(model::ACEModel, ps::NamedTuple; kwargs...) + rbasis_spl = splinify(model.rbasis, ps.rbasis; kwargs...) + pairbasis_spl = splinify(model.pairbasis, ps.pairbasis; kwargs...) return ACEModel(model._i2z, rbasis_spl, model.ybasis, diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index 9476ef84..fc224cc5 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) # using TestEnv; TestEnv.activate(); ## diff --git a/test/models/test_Rnl.jl b/test/models/test_Rnl.jl index 0fb3d2ea..3acdcf62 100644 --- a/test/models/test_Rnl.jl +++ b/test/models/test_Rnl.jl @@ -82,7 +82,7 @@ for ntest = 1:30 Rnl = basis(r, Zi, Zj, ps, st) - for (nnodes, tol) in [(30, 1e-3), (100, 1e-5), (1000, 1e-8)] + for (nnodes, tol) in [(30, 3e-2), (100, 1e-5), (1000, 1e-8)] local basis_spl, ps_spl, st_spl, Rnl_spl basis_spl = M.splinify(basis, ps; nnodes = nnodes) diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index fbccb7a9..8a2bfc86 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -29,7 +29,7 @@ order = 3 msolid = M.ace_model(; elements = elements, order = order, Ytype = :solid, level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :glorot_normal) mspherical = M.ace_model(; elements = elements, order = order, Ytype = :spherical, - level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :zero) + level = level, max_level = max_level, maxl = 8, pair_maxn = 15, init_WB = :glorot_normal, init_Wpair = :glorot_normal) ps, st = LuxCore.setup(rng, msolid) for ntest = 1:30 @@ -46,7 +46,7 @@ println() for ybasis in [:spherical, :solid] # ybasis = :solid @info("=== Testing ybasis = $ybasis === ") - local ps, st, Nat + local ps, st, Nat, model model = M.ace_model(; elements = elements, order = order, Ytype = ybasis, level = level, max_level = max_level, maxl = 8, pair_maxn = 15, @@ -186,7 +186,7 @@ for ybasis in [:spherical, :solid] ## @info("check splinification") - lin_ace = M.splinify(model, ps) + lin_ace = M.splinify(model, ps; nnodes = 1000) ps_lin, st_lin = LuxCore.setup(rng, lin_ace) ps_lin.WB[:] .= ps.WB[:] ps_lin.Wpair[:] .= ps.Wpair[:] @@ -202,7 +202,7 @@ for ybasis in [:spherical, :solid] abs(Ei - Ei_lin) end mae /= len - print_tf(@test mae < 0.02) + print_tf(@test mae < 0.01) end println() diff --git a/test/test_silicon.jl b/test/test_silicon.jl index 6d2351e4..5a223467 100644 --- a/test/test_silicon.jl +++ b/test/test_silicon.jl @@ -1,12 +1,13 @@ using ACEpotentials using Distributed using LazyArtifacts -using PythonCall +# using PythonCall using Test ## ----- setup ----- @warn "test_silicon not fully converted yet." + model = acemodel(elements = [:Si], Eref = [:Si => -158.54496821], rcut = 5.5, @@ -86,35 +87,35 @@ end #test_rmse(results["errors"]["rmse"], rmse_rrqr, 1e-5) end -@testset "SKLEARN_BRR" begin - rmse_brr = Dict( - "isolated_atom" => Dict("E"=>0.0, "F"=>0.0), - "dia" => Dict("V"=>0.0333241, "E"=>0.0013034, "F"=>0.0255757), - "liq" => Dict("V"=>0.0347208, "E"=>0.0003974, "F"=>0.1574544), - "set" => Dict("V"=>0.0619434, "E"=>0.0023868, "F"=>0.1219008), - "bt" => Dict("V"=>0.0823042, "E"=>0.0032196, "F"=>0.0627417),) - acefit!(model, data; - data_keys..., - weights = weights, - solver = ACEfit.SKLEARN_BRR(tol = 1e-4)) - #test_rmse(results["errors"]["rmse"], rmse_brr, 1e-5) -end +# @testset "SKLEARN_BRR" begin +# rmse_brr = Dict( +# "isolated_atom" => Dict("E"=>0.0, "F"=>0.0), +# "dia" => Dict("V"=>0.0333241, "E"=>0.0013034, "F"=>0.0255757), +# "liq" => Dict("V"=>0.0347208, "E"=>0.0003974, "F"=>0.1574544), +# "set" => Dict("V"=>0.0619434, "E"=>0.0023868, "F"=>0.1219008), +# "bt" => Dict("V"=>0.0823042, "E"=>0.0032196, "F"=>0.0627417),) +# acefit!(model, data; +# data_keys..., +# weights = weights, +# solver = ACEfit.SKLEARN_BRR(tol = 1e-4)) +# #test_rmse(results["errors"]["rmse"], rmse_brr, 1e-5) +# end -@testset "SKLEARN_ARD" begin - rmse_ard = Dict( - "isolated_atom" => Dict("E"=>0.0, "F"=>0.0), - "dia" => Dict("V"=>0.1084975, "E"=>0.0070814, "F"=>0.0937790), - "liq" => Dict("V"=>0.0682268, "E"=>0.0090065, "F"=>0.3693146), - "set" => Dict("V"=>0.1839696, "E"=>0.0137778, "F"=>0.2883043), - "bt" => Dict("V"=>0.2413568, "E"=>0.0185958, "F"=>0.1507498),) - acefit!(model, data; - data_keys..., - weights = weights, - solver = ACEfit.SKLEARN_ARD( - tol = 2e-3, threshold_lambda = 5000, n_iter = 1000)) - @warn "The SKLEARN_ARD test tolerance is very loose." - #test_rmse(results["errors"]["rmse"], rmse_ard, 1e-2) -end +# @testset "SKLEARN_ARD" begin +# rmse_ard = Dict( +# "isolated_atom" => Dict("E"=>0.0, "F"=>0.0), +# "dia" => Dict("V"=>0.1084975, "E"=>0.0070814, "F"=>0.0937790), +# "liq" => Dict("V"=>0.0682268, "E"=>0.0090065, "F"=>0.3693146), +# "set" => Dict("V"=>0.1839696, "E"=>0.0137778, "F"=>0.2883043), +# "bt" => Dict("V"=>0.2413568, "E"=>0.0185958, "F"=>0.1507498),) +# acefit!(model, data; +# data_keys..., +# weights = weights, +# solver = ACEfit.SKLEARN_ARD( +# tol = 2e-3, threshold_lambda = 5000, n_iter = 1000)) +# @warn "The SKLEARN_ARD test tolerance is very loose." +# #test_rmse(results["errors"]["rmse"], rmse_ard, 1e-2) +# end @testset "BLR" begin rmse_blr = Dict( From 359f26fa402c83beb505971da0e30a87d82d39a4 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Tue, 6 Aug 2024 22:14:32 -0700 Subject: [PATCH 108/112] rough sketch of priors --- src/models/smoothness_priors.jl | 48 ++++++++++++++++++++++++++++++++- test/models/test_priors.jl | 33 +++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 test/models/test_priors.jl diff --git a/src/models/smoothness_priors.jl b/src/models/smoothness_priors.jl index 78bb605e..0e1513c2 100644 --- a/src/models/smoothness_priors.jl +++ b/src/models/smoothness_priors.jl @@ -1,4 +1,4 @@ - +using LinearAlgebra: Diagonal # -------------------------------------------------- # different notions of "level" / total degree. @@ -43,3 +43,49 @@ function oneparticle_spec(level::Union{TotalDegree, EuclideanDegree}, maxlevel) end # -------------------------------------------------- + + +# this should maybe be moved elsewhere, but for now it can live here. + +function _basis_length(model) + len_tensor = length(get_nnll_spec(model.tensor)) + len_pair = length(model.pairbasis.spec) + return (len_tensor + len_pair) * _get_nz(model) +end + +function _nnll_basis(model) + NTNL = typeof((n = 1, l = 0)) + TBB = Vector{NTNL} + + global_spec = Vector{TBB}(undef, _basis_length(model)) + + nnll_tensor = get_nnll_spec(model.tensor) + nn_pair = [ [b,] for b in model.pairbasis.spec] + + for iz = 1:_get_nz(model) + z = _i2z(model, iz) + global_spec[get_basis_inds(model, z)] = nnll_tensor + global_spec[get_pairbasis_inds(model, z)] = nn_pair + end + + return global_spec +end + + +function smoothness_prior(model, f) + nnll = _nnll_basis(model) + γ = zeros(length(nnll)) + for (i, bb) in enumerate(nnll) + γ[i] = f(bb) + end + return Diagonal(γ) +end + +algebraic_smoothness_prior(model; p = 4, wl = 3/2, wn = 1.0) = + smoothness_prior(model, bb -> sum((b.l/wl)^p + (b.n/wn)^p for b in bb)) + +exp_smoothness_prior(model; wl = 1.0, wn = 2/3) = + smoothness_prior(model, bb -> exp( sum(b.l / wl + b.n / wn for b in bb) )) + +gaussian_smoothness_prior(model; wl = 1/sqrt(2), wn = 1/sqrt(2)) = + smoothness_prior(model, bb -> exp( sum( (b.l/wl)^2 + (b.n/wn)^2 for b in bb) )) diff --git a/test/models/test_priors.jl b/test/models/test_priors.jl new file mode 100644 index 00000000..86c5d909 --- /dev/null +++ b/test/models/test_priors.jl @@ -0,0 +1,33 @@ + +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); + +## + +using Test, ACEbase, Random +using ACEbase.Testing: print_tf, println_slim +using Lux, LuxCore, StaticArrays, LinearAlgebra +rng = Random.MersenneTwister(1234) +Random.seed!(11) + +using ACEpotentials +M = ACEpotentials.Models + +## + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 + +model = M.ace_model(; elements = elements, order = order, Ytype = :spherical, + level = level, max_level = max_level, pair_maxn = max_level, init_WB = :glorot_normal, init_Wpair = :glorot_normal) +ps, st = LuxCore.setup(rng, model) + +## + +Γa = M.algebraic_smoothness_prior(model) +Γe = M.exp_smoothness_prior(model) +Γg = M.gaussian_smoothness_prior(model) + +[Γa.diag Γe.diag Γg.diag] \ No newline at end of file From 15d12b89ea283c9d039fe8a09b3d0a6676ecceb5 Mon Sep 17 00:00:00 2001 From: wcwitt Date: Fri, 9 Aug 2024 15:40:45 -0400 Subject: [PATCH 109/112] Update python_ase.md. --- docs/src/tutorials/python_ase.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/src/tutorials/python_ase.md b/docs/src/tutorials/python_ase.md index b14b4629..003aac88 100644 --- a/docs/src/tutorials/python_ase.md +++ b/docs/src/tutorials/python_ase.md @@ -16,4 +16,8 @@ ats.calc = calc print(ats.get_potential_energy()) ``` -See the `ase` [documentation](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html#module-ase.calculators) for more details. +See the `ase` [documentation](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html#module-ase.calculators) for more details. + +### Another option: ASE's LAMMPSlib calculator + +Alternatively, to avoid direct Julia-Python interaction, one can export to LAMMPS (see [LAMMPS](lammps.md)) and utilize ASE's [`LAMMPSlib` calculator](https://wiki.fysik.dtu.dk/ase/ase/calculators/lammpslib.html). From 7e26c15f5cfa74392b23b7fc7ad6b67883c77129 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 21 Aug 2024 16:43:05 -0700 Subject: [PATCH 110/112] smoothness priors compat with ACE1 --- src/models/Rnl_basis.jl | 5 +-- src/models/Rnl_learnable.jl | 8 ++-- src/models/smoothness_priors.jl | 39 ++++++++++++----- test/ace1/ace1_testutils.jl | 74 ++++++++++++++++++++++++++++++++- test/ace1/test_ace1_compat.jl | 51 +++++++++++++++++++++++ test/models/test_priors.jl | 5 ++- 6 files changed, 161 insertions(+), 21 deletions(-) diff --git a/src/models/Rnl_basis.jl b/src/models/Rnl_basis.jl index 12f1c0c9..719d5257 100644 --- a/src/models/Rnl_basis.jl +++ b/src/models/Rnl_basis.jl @@ -33,7 +33,6 @@ mutable struct LearnableRnlrzzBasis{NZ, TPOLY, TT, TENV, T} <: AbstractExplicitL transforms::SMatrix{NZ, NZ, TT} envelopes::SMatrix{NZ, NZ, TENV} # -------------- - # weights::Array{TW, 4} # learnable weights, `nothing` when using Lux rin0cuts::SMatrix{NZ, NZ, NT_RIN0CUTS{T}} # matrix of (rin, rout, rcut) spec::Vector{NT_NL_SPEC} # -------------- @@ -49,9 +48,9 @@ mutable struct SplineRnlrzzBasis{NZ, TT, TENV, LEN, T} <: AbstractExplicitLayer splines::SMatrix{NZ, NZ, SPL_OF_SVEC{LEN, T}} # -------------- rin0cuts::SMatrix{NZ, NZ, NT_RIN0CUTS{T}} # matrix of (rin, rout, rcut) - spec::Vector{NT_NL_SPEC} + spec::Vector{NT_NL_SPEC} # -------------- - meta::Dict{String, Any} + meta::Dict{String, Any} end diff --git a/src/models/Rnl_learnable.jl b/src/models/Rnl_learnable.jl index be0d04f8..119900c0 100644 --- a/src/models/Rnl_learnable.jl +++ b/src/models/Rnl_learnable.jl @@ -120,7 +120,7 @@ function evaluate!(Rnl, basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, r, x) Rnl[:] .= Wij * (P .* e) - return Rnl + return Rnl end function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) @@ -132,7 +132,7 @@ function evaluate(basis::LearnableRnlrzzBasis, r::Real, Zi, Zj, ps, st) P = Polynomials4ML.evaluate(basis.polys, x) env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, r, x) - return Wij * (P .* e) + return Wij * (P .* e) end @@ -244,7 +244,7 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, rs, zi, zjs, ps, st) @assert length(rs) == length(zjs) # evaluate the first one to get the types and size - Rnl_1, _ = evaluate(basis, rs[1], zi, zjs[1], ps, st) + Rnl_1 = evaluate(basis, rs[1], zi, zjs[1], ps, st) # ... and then allocate storage Rnl = zeros(eltype(Rnl_1), (length(rs), length(Rnl_1))) @@ -262,7 +262,7 @@ function pullback_evaluate_batched(Δ, basis::LearnableRnlrzzBasis, env_ij = basis.envelopes[iz, jz] e = evaluate(env_ij, rs[j], x) P = Polynomials4ML.evaluate(basis.polys, x) .* e - # TODO: the P shouuld be stored inside a closure in the + # TODO: the P should be stored inside a closure in the # forward pass and then resused. # TODO: ... and obviously this part here needs to be moved diff --git a/src/models/smoothness_priors.jl b/src/models/smoothness_priors.jl index 0e1513c2..eca9192b 100644 --- a/src/models/smoothness_priors.jl +++ b/src/models/smoothness_priors.jl @@ -28,18 +28,12 @@ EuclideanDegree() = EuclideanDegree(1.0, 2/3) (l::EuclideanDegree)(bb::AbstractVector{<: NamedTuple}) = sqrt( sum(l(b)^2 for b in bb) ) -# struct SparseBasisSelector -# order::Int -# maxlevels::AbstractVector{<: Number} -# level -# end - function oneparticle_spec(level::Union{TotalDegree, EuclideanDegree}, maxlevel) maxn1 = ceil(Int, maxlevel * level.wn) maxl1 = ceil(Int, maxlevel * level.wl) spec = [ (n = n, l = l) for n = 1:maxn1, l = 0:maxl1 if level((n = n, l = l)) <= maxlevel ] - return sort(spec; by = x -> (x.l, x.n)) + return sort(spec; by = x -> (x.l, x.n)) end # -------------------------------------------------- @@ -71,21 +65,44 @@ function _nnll_basis(model) return global_spec end - +function _coupling_scalings(model) + scal = ones(_basis_length(model)) + for iz = 1:_get_nz(model) + z = _i2z(model, iz) + mb_inds = get_basis_inds(model, z) + @assert length(mb_inds) == size(model.tensor.A2Bmap, 1) + for i = 1:length(mb_inds) + scal[mb_inds[i]] = sqrt(sum(abs2, model.tensor.A2Bmap[i,:])) + end + end + return scal +end + function smoothness_prior(model, f) nnll = _nnll_basis(model) γ = zeros(length(nnll)) for (i, bb) in enumerate(nnll) γ[i] = f(bb) end - return Diagonal(γ) + return Diagonal(γ) # .* _coupling_scalings(model)) end -algebraic_smoothness_prior(model; p = 4, wl = 3/2, wn = 1.0) = +algebraic_smoothness_prior(model; p = 4, wl = 2/3, wn = 1.0) = smoothness_prior(model, bb -> sum((b.l/wl)^p + (b.n/wn)^p for b in bb)) -exp_smoothness_prior(model; wl = 1.0, wn = 2/3) = +exp_smoothness_prior(model; wn = 1.0, wl = 2/3) = smoothness_prior(model, bb -> exp( sum(b.l / wl + b.n / wn for b in bb) )) gaussian_smoothness_prior(model; wl = 1/sqrt(2), wn = 1/sqrt(2)) = smoothness_prior(model, bb -> exp( sum( (b.l/wl)^2 + (b.n/wn)^2 for b in bb) )) + + +function algebraic_smoothness_prior_ace1(model; p = 4, wL = 3/2) + nnll = _nnll_basis(model) + γ = zeros(length(nnll)) + for (i, bb) in enumerate(nnll) + γ[i] = sum(b.n^p + wL * b.l^p * (1 + b.l/(p+1)) for b in bb) + end + scal = _coupling_scalings(model) + return Diagonal(γ .* scal .+ 1) +end \ No newline at end of file diff --git a/test/ace1/ace1_testutils.jl b/test/ace1/ace1_testutils.jl index b40d8b7d..3be03757 100644 --- a/test/ace1/ace1_testutils.jl +++ b/test/ace1/ace1_testutils.jl @@ -183,11 +183,18 @@ function check_basis(model1, model2; Nenv = :auto) nothing end - -function check_compat(params; deginc = 0.1) +function make_models(params; deginc = 0.1) model1 = acemodel(; params...) params2 = (; params..., totaldegree = params.totaldegree .+ deginc) model2 = ACE1compat.ace1_model(; params2...) + return model1, model2 +end + +function check_compat(params; deginc = 0.1) + model1, model2 = make_models(params, deginc = deginc) + # model1 = acemodel(; params...) + # params2 = (; params..., totaldegree = params.totaldegree .+ deginc) + # model2 = ACE1compat.ace1_model(; params2...) NZ = length(params.elements) if NZ == 1 @@ -201,4 +208,67 @@ function check_compat(params; deginc = 0.1) nothing end + +function compare_smoothness_prior(params, + priortype = :algebraic, + priorparams1 = (p = 2, wl = 1.5), + priorparams2 = (p = 2, wl = 2/3, wn = 1.0); + deginc = 0.1) + model1, model2 = make_models(params, deginc = deginc) + # model1 = ACE1x.acemodel(; params...) + # params2 = (; params..., totaldegree = params.totaldegree .+ deginc) + # model2 = ACE1compat.ace1_model(; params2...) + + if priortype == :algebraic + P1 = ACE1x.algebraic_smoothness_prior(model1.basis; priorparams1...) + P2 = M.algebraic_smoothness_prior_ace1(model2; priorparams2...) + elseif priortype == :exponential + P1 = ACE1x.exp_smoothness_prior(model1.basis; priorparams1...) + P2 = M.exp_smoothness_prior(model2; priorparams2...) + elseif priortype == :gaussian + P1 = ACE1x.gaussian_smoothness_prior(model1.basis; priorparams1...) + P2 = M.gaussian_smoothness_prior(model2; priorparams2...) + else + error("unknown priortype: $priortype") + end + + p1 = diag(P1) + p2 = diag(P2); l2 = length(p2) + + _spec1 = ACE1.get_nl(model1.basis.BB[2]) + spec1 = [ [ (n = b.n, l = b.l) for b in bb ] for bb in _spec1 ] + l1 = length(spec1) + p1mb = p1[end-l1+1:end] + + spec2 = M.get_nnll_spec(model2.tensor) + σ = [ findfirst(isequal(bb), spec2) for bb in spec1 ] + + l2 = length(spec2) + p2mb = p2[1:l2] + + ratios = Float64[] + numerr = 0 + for i = 1:length(σ) + if σ[i] == nothing; continue; end + bb = spec1[i] + _p1 = p1mb[i] + _p2 = p2mb[σ[i]] + push!(ratios, _p1/_p2) + if !(0.5 <= _p1/_p2 <= 2.0) + @error("""scaling mismatch: + $("$bb"[32:end]) + $(round(_p1, digits=1)) vs $(round(_p2, digits=1)) + i = $i, σ[i] = $(σ[i]) + """) + println() + numerr += 1 + end + end + @show numerr + @show extrema(ratios) + + @test 0.5 <= minimum(ratios) + @test maximum(ratios) <= 2.0 +end + end \ No newline at end of file diff --git a/test/ace1/test_ace1_compat.jl b/test/ace1/test_ace1_compat.jl index fc224cc5..6d18caed 100644 --- a/test/ace1/test_ace1_compat.jl +++ b/test/ace1/test_ace1_compat.jl @@ -95,3 +95,54 @@ params = ( elements = [:Al, :Ti, :C], ACE1_TestUtils.check_compat(params) + +## [6] +# Confirm that the smoothness prior is the same in ACE1x and ACEpotentials + +import ACE1, ACE1x, JuLIP, Random +using LinearAlgebra, ACEpotentials +M = ACEpotentials.Models +ACE1compat = ACEpotentials.ACE1compat +rng = Random.MersenneTwister(1234) + +## + +@info("Testing scaling of smoothness priors") + +params = ( elements = [:Si,], + order = 3, + transform = (:agnesi, 2, 2), + totaldegree = [16, 13, 11], + pure = false, + pure2b = false, + pair_envelope = (:r, 1), + rcut = 5.5, + Eref = [:Si => -1.234 ] + ) + +ACE1_TestUtils.check_compat(params) + +_p = (p = 1, wL = 1.5) +ACE1_TestUtils.compare_smoothness_prior(params, :algebraic, _p, _p) + +_p = (p = 2, wL = 1.5) +ACE1_TestUtils.compare_smoothness_prior(params, :algebraic, _p, _p) + +_p = (p = 4, wL = 1.5) +ACE1_TestUtils.compare_smoothness_prior(params, :algebraic, _p, _p) + +ACE1_TestUtils.compare_smoothness_prior( + params, :exponential, (al = 1.5, an = 1.0), (wl = 2/3, wn = 1.0) ) + +ACE1_TestUtils.compare_smoothness_prior( + params, :exponential, (al = 0.234, an = 0.321), + (wl = 1 / 0.234, wn = 1/0.321) ) + +ACE1_TestUtils.compare_smoothness_prior( + params, :gaussian, (σl = 2.0, σn = 2.0), + (wl = 1/sqrt(2), wn = 1/sqrt(2)) ) + +ACE1_TestUtils.compare_smoothness_prior( + params, :gaussian, (σl = 1.234, σn = 0.89), + (wl = 1/sqrt(1.234), wn = 1/sqrt(0.89)) ) + diff --git a/test/models/test_priors.jl b/test/models/test_priors.jl index 86c5d909..a559f747 100644 --- a/test/models/test_priors.jl +++ b/test/models/test_priors.jl @@ -30,4 +30,7 @@ ps, st = LuxCore.setup(rng, model) Γe = M.exp_smoothness_prior(model) Γg = M.gaussian_smoothness_prior(model) -[Γa.diag Γe.diag Γg.diag] \ No newline at end of file +[Γa.diag Γe.diag Γg.diag] + +## + From b4a02ad910f3fd087dff2e5346c95841d6edf581 Mon Sep 17 00:00:00 2001 From: Chuck Witt Date: Wed, 21 Aug 2024 21:39:21 -0400 Subject: [PATCH 111/112] First attempt at adding the long tutorial. --- docs/make.jl | 3 + docs/src/tutorials/index.md | 2 + docs/src/tutorials/long_tutorial.jl | 499 ++++++++++++++++++++++++++++ 3 files changed, 504 insertions(+) create mode 100644 docs/src/tutorials/long_tutorial.jl diff --git a/docs/make.jl b/docs/make.jl index 22af71fb..c25a8864 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -44,6 +44,9 @@ Literate.markdown(_tutorial_src * "/committee.jl", Literate.markdown(_tutorial_src * "/experimental.jl", _tutorial_out; documenter = true) +Literate.markdown(_tutorial_src * "/long_tutorial.jl", + _tutorial_out; documenter = true) + # ???? cf Jump.jl docs, they do also this: # postprocess = _link_example, # # Turn off the footer. We manually add a modified one. diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 22bd76d8..7ca617a0 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,6 +1,8 @@ # Tutorials Overview +NEW! There is now a [long tutorial](../literate_tutorials/long_tutorial.md). + ### Fitting potentials from Julia scripts These tutorials use the direct Julia interface provided by `ACEpotentials.jl` (interfacing with `ACE1.jl, ACE1x.jl, ACEfit.jl`). They are provided in [Literate.jl](https://github.com/fredrikekre/Literate.jl) format and can also be run as scripts if that is preferred. diff --git a/docs/src/tutorials/long_tutorial.jl b/docs/src/tutorials/long_tutorial.jl new file mode 100644 index 00000000..78029425 --- /dev/null +++ b/docs/src/tutorials/long_tutorial.jl @@ -0,0 +1,499 @@ +# ACEpotentials.jl Tutorial +# ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡ +# +# CECAM - Psi-k School on ML-IP (November 2023) + +# Introduction +# ============== +# +# The ACEpotentials.jl documentation +# (https://acesuit.github.io/ACEpotentials.jl/) contains a number of short, +# focused tutorials on key topics. This tutorial is longer and has a single +# narrative. Many Julia commands are introduced by example. + +# Installing ACEpotentials +# –––––––––––––––––––––––––– +# +# ACEpotentials requires Julia 1.9. For detailed installation instructions, +# see: +# https://acesuit.github.io/ACEpotentials.jl/dev/gettingstarted/installation/. +# +# Warning: The following installation will take several minutes. + +# add and load general packages used in this notebook. +using Pkg +Pkg.activate(".") +Pkg.add("LaTeXStrings") +Pkg.add("MultivariateStats") +Pkg.add("Plots") +Pkg.add("PrettyTables") +Pkg.add("Suppressor") +using LaTeXStrings, MultivariateStats, Plots, PrettyTables, Printf, Statistics, Suppressor + +# ACEpotentials installation (requires Julia 1.9) +using Pkg +Pkg.activate(".") +Pkg.Registry.add("General") # only needed when installing Julia for the first time +Pkg.Registry.add(RegistrySpec(url="https://github.com/ACEsuit/ACEregistry")) +Pkg.add("ACEpotentials") +using ACEpotentials + +# Now, let's check the status of the installed projects. + +using Pkg +Pkg.status() + +# Part 1: Basic dataset analysis +# ================================ +# +# ACEpotentials provides quick access to several example datasets, which can +# be useful for testing. The following command lists these datasets. (We +# expect to expand this list signifcantly; please feel free to suggest +# additions.) + +ACEpotentials.list_example_datasets() + +# We begin by loading the tiny silicon dataset. + +Si_tiny_dataset, _, _ = ACEpotentials.example_dataset("Si_tiny"); + +# These data were taken from a larger set published with: +# +# │ A. P. Bartók, J. Kermode, N. Bernstein, and G. Csányi, Machine +# │ Learning a General-Purpose Interatomic Potential for Silicon, +# │ Phys. Rev. X 8, 041048 (2018) +# +# To illustrate the procedure for loading extended xyz data from a file, we +# download the larger dataset and load it. + +download("https://www.dropbox.com/scl/fi/mzd7zcb1x1l4rw5eswxcd/gp_iter6_sparse9k.xml.xyz?rlkey=o4avtpkka6jnqn7qg375vg7z0&dl=0", + "Si_dataset.xyz"); + +Si_dataset = read_extxyz("Si_dataset.xyz"); + +# Next, we assess the dataset sizes. + +println("The tiny dataset has ", length(Si_tiny_dataset), " structures.") +println("The large dataset has ", length(Si_dataset), " structures.") + +# Next, we create arrays containing the config_type for each structure in the +# datasets. Afterwards, we count the configurations of each type. + +config_types_tiny = [at.data["config_type"].data for at in Si_tiny_dataset] +config_types = [at.data["config_type"].data for at in Si_dataset] + +function count_configs(config_types) + config_counts = [sum(config_types.==ct) for ct in unique(config_types)] + config_dict = Dict([ct=>cc for (ct,cc) in zip(unique(config_types), config_counts)]) +end; + +println("There are ", length(unique(config_types_tiny)), " unique config_types "* + "in the tiny dataset:") +display(count_configs(config_types_tiny)) + +println("There are ", length(unique(config_types)), " unique config_types "* + "in the full dataset:") +display(count_configs(config_types)) + +# Two basic distributions which indicate how well the data fills space are the +# radial and angular distribution functions. We begin with the radial +# distribution function, plotting using the histogram function in Plots.jl. +# For the RDF we add some vertical lines to indicate the distances and first, +# second neighbours and so forth to confirm that the peaks are in the right +# place. + +r_cut = 6.0 + +rdf_tiny = ACEpotentials.get_rdf(Si_tiny_dataset, r_cut; rescale = true) +plt_rdf_1 = histogram(rdf_tiny[(:Si, :Si)], bins=150, label = "rdf", + title="Si_tiny_dataset", titlefontsize=10, + xlabel = L"r[\AA]", ylabel = "RDF", yticks = [], + xlims=(1.5,6), size=(400,200), left_margin = 2Plots.mm) +vline!(rnn(:Si)*[1.0, 1.633, 1.915, 2.3, 2.5], label = "r1, r2, ...", lw=3) + +rdf = ACEpotentials.get_rdf(Si_dataset, r_cut; rescale = true); +plt_rdf_2 = histogram(rdf[(:Si, :Si)], bins=150, label = "rdf", + title="Si_dataset", titlefontsize=10, + xlabel = L"r[\AA]", ylabel = "RDF", yticks = [], + xlims=(1.5,6), size=(400,200), left_margin = 2Plots.mm) +vline!(rnn(:Si)*[1.0, 1.633, 1.915, 2.3, 2.5], label = "r1, r2, ...", lw=3) + +plot(plt_rdf_1, plt_rdf_2, layout=(2,1), size=(400,400)) + +# The larger dataset clearly has a better-converged radial distribution +# function. + +# For the angular distribution function, we use a cutoff just above the +# nearest-neighbour distance so we can clearly see the equilibrium +# bond-angles. In this case, the vertical line indicates the equilibrium bond +# angle. + +r_cut_adf = 1.25 * rnn(:Si) +eq_angle = 1.91 # radians +adf_tiny = ACEpotentials.get_adf(Si_tiny_dataset, r_cut_adf); +plt_adf_1 = histogram(adf_tiny, bins=50, label = "adf", yticks = [], c = 3, + title = "Si_tiny_dataset", titlefontsize = 10, + xlabel = L"\theta", ylabel = "ADF", + xlims = (0, π), size=(400,200), left_margin = 2Plots.mm) +vline!([ eq_angle,], label = "109.5˚", lw=3) + +adf = ACEpotentials.get_adf(Si_dataset, r_cut_adf); +plt_adf_2 = histogram(adf, bins=50, label = "adf", yticks = [], c = 3, + title = "Si_dataset", titlefontsize = 10, + xlabel = L"\theta", ylabel = "ADF", + xlims = (0, π), size=(400,200), left_margin = 2Plots.mm) +vline!([ eq_angle,], label = "109.5˚", lw=3) + +plot(plt_adf_1, plt_adf_2, layout=(2,1), size=(400,400)) + +# For later use, we define a function that extracts the energies stored in the +# silicon datasets. + +function extract_energies(dataset) + energies = [] + for atoms in dataset + for key in keys(atoms.data) + if lowercase(key) == "dft_energy" + push!(energies, atoms.data[key].data/length(atoms)) + end + end + end + return energies +end; + +Si_dataset_energies = extract_energies(Si_dataset) + +GC.gc() + +# Part 2: ACE descriptors +# ========================= +# +# An ACE basis specifies a vector of invariant features of atomic environments +# and can therefore be used as a general descriptor. +# +# Some important parameters include: +# +# • elements: list of chemical species, as symbols; +# +# • order: correlation/interaction order (body order - 1); +# +# • totaldegree: maximum total polynomial degree used for the basis; +# +# • rcut : cutoff radius (optional, defaults are provided). + +basis = ACE1x.ace_basis(elements = [:Si], + rcut = 5.5, + order = 3, # body-order - 1 + totaldegree = 8); + +# As an example, we compute an averaged structural descriptor for each +# configuration in the tiny dataset. + +descriptors = [] +for atoms in Si_tiny_dataset + struct_descriptor = sum(site_descriptors(basis, atoms)) / length(atoms) + push!(descriptors, struct_descriptor) +end + +# Next, we extract and plot the principal components of the structural +# descriptors. Note the segregation by configuration type. + +descriptors = hcat(descriptors...) # convert to matrix +M = fit(PCA, descriptors; maxoutdim=3, pratio=1) +descriptors_trans = transform(M, descriptors) +p = scatter( + descriptors_trans[1,:], descriptors_trans[2,:], descriptors_trans[3,:], + marker=:circle, linewidth=0, group=config_types_tiny, legend=:right) +plot!(p, xlabel="PC1", ylabel="PC2", zlabel="PC3", camera=(20,10)) + +# Finally, we repeat the procedure for the full dataset. Some clustering is +# apparent, although the results are a bit harder to interpret. + +descriptors = [] +for atoms in Si_dataset + struct_descriptor = sum(site_descriptors(basis, atoms)) / length(atoms) + push!(descriptors, struct_descriptor) +end + +descriptors = hcat(descriptors...) # convert to matrix +M = fit(PCA, descriptors; maxoutdim=3, pratio=1) +descriptors_trans = transform(M, descriptors) +p = scatter( + descriptors_trans[1,:], descriptors_trans[2,:], descriptors_trans[3,:], + marker=:circle, linewidth=0, group=config_types, legend=:right) +plot!(p, xlabel="PC1", ylabel="PC2", zlabel="PC3", camera=(10,10)) + +GC.gc() + +# Part 3: Basic model fitting +# ============================= + +# We begin by defining an (extremely simple) acemodel. The parameters have the +# same meaning as for ace_basis above, with an additional Eref providing a +# reference energy. + +model = acemodel(elements = [:Si,], + order = 3, + totaldegree = 8, + rcut = 5.0, + Eref = [:Si => -158.54496821]) +@show length(model.basis); + +# Next, we fit determine the model parameters using the tiny dataset and ridge +# regression via the QR solver. + +solver = ACEfit.QR(lambda=1e-1) +data_keys = (energy_key = "dft_energy", force_key = "dft_force", virial_key = "dft_virial") +acefit!(model, Si_tiny_dataset; + solver=solver, data_keys...); + +@info("Training Errors") +ACEpotentials.linear_errors(Si_tiny_dataset, model; data_keys...); + +@info("Test Error") +ACEpotentials.linear_errors(Si_dataset, model; data_keys...); + +# A model may be exported to JSON or LAMMPS formats with the following. + +export2json("model.json", model) +export2lammps("model.yace", model) + +# Part 4: Committee models +# ========================== +# +# ACEpotentials.jl can produce committee models using Bayesian linear +# regression. Such committees provide uncertainty estimates useful for active +# learning. +# +# Recall our two silicon datasets. We begin by training a (relatively small) +# model on the tiny version. +# +# Note the use of the BLR solver with a nonzero committee size. + +model = acemodel(elements = [:Si,], + Eref = [:Si => -158.54496821], + order = 3, + totaldegree = 12); + +acefit!(model, Si_tiny_dataset; + solver = ACEfit.BLR(committee_size=50, factorization=:svd), + energy_key = "dft_energy", force_key = "dft_force", + verbose = false); + +# Next we define a function which assesses model performance on the full +# silicon dataset. + +function assess_model(model, train_dataset) + + plot([-164,-158], [-164,-158]; lc=:black, label="") + + model_energies = [] + model_std = [] + for atoms in Si_dataset + ene, co_ene = ACE1.co_energy(model.potential, atoms) + push!(model_energies, ene/length(atoms)) + push!(model_std, std(co_ene/length(atoms))) + end + rmse = sqrt(sum((model_energies-Si_dataset_energies).^2)/length(Si_dataset)) + mae = sum(abs.(model_energies-Si_dataset_energies))/length(Si_dataset) + scatter!(Si_dataset_energies, model_energies; + label="full dataset", + title = @sprintf("Structures Used In Training: %i out of %i\n", length(train_dataset), length(Si_dataset)) * + @sprintf("RMSE (MAE) For Entire Dataset: %.0f (%.0f) meV/atom", 1000*rmse, 1000*mae), + titlefontsize = 8, + yerror = model_std, + xlabel="Energy [eV/atom]", xlims=(-164,-158), + ylabel="Model Energy [eV/atom]", ylims=(-164,-158), + aspect_ratio = :equal, color=1) + + model_energies = [energy(model.potential,atoms)/length(atoms) for atoms in train_dataset] + scatter!(extract_energies(train_dataset), model_energies; + label="training set", color=2) + +end; + +# Applying this function to our current model yields + +assess_model(model, Si_tiny_dataset) + +# Clearly there is room to improve: the model-derived RMSE is 280 meV/atom for +# the full dataset. Moreover, the error bars show the standard deviation of +# the energies predicted by the commmittee, which are quite high for some +# data. +# +# Next, we will define a function that augments the tiny dataset by adding +# structures for which the model is least confident. + +function augment(old_dataset, old_model; num=5) + + new_dataset = deepcopy(old_dataset) + new_model = deepcopy(old_model) + + model_std = [] + for atoms in Si_dataset + ene, co_ene = ACE1.co_energy(new_model.potential, atoms) + push!(model_std, std(co_ene/length(atoms))) + end + for atoms in Si_dataset[sortperm(model_std; rev=true)[1:num]] + push!(new_dataset, atoms) + end + @suppress acefit!(new_model, new_dataset; + solver = ACEfit.BLR(committee_size=50, factorization=:svd), + energy_key = "dft_energy", force_key = "dft_force", + verbose = false); + + return new_dataset, new_model +end; + +# The following applies this strategy, adding the five structures with the +# highest committee deviation. + +new_dataset, new_model = augment(Si_tiny_dataset, model; num=5); +assess_model(new_model, new_dataset) + +# Already, there is notable improvement. The overall errors have dropped, and +# the predictions for the worst-performing structures are much improved. +# +# Next, we perform four additional augmentation steps, adding twenty +# structures in total. + +for i in 1:4 + @show i + new_dataset, new_model = augment(new_dataset, new_model; num=5); +end +assess_model(new_model, new_dataset) + +# Remarkably, although we are using only a small fraction (~3%) of the full +# dataset, our model now performs reasonably well. +# +# Further iterations may improve on this result; however, a larger model is +# necessary to obtain extremely low errors. +# +# Important: While this dataset filtering can be useful, the connection with +# active learning is crucial. Recall that we did not use the reference +# energies when selecting structures, only the committee deviation. + +GC.gc() + +# Part 5: Multiple elements +# =========================== + +# We briefly demonstrate the syntax for multiple elements, using a TiAl +# dataset. + +tial_data, _, _ = ACEpotentials.example_dataset("TiAl_tutorial"); + +# The species-dependent RDFs are obtained as + +r_cut = 6.0 +rdf = ACEpotentials.get_rdf(tial_data, r_cut) +plt_TiTi = histogram(rdf[(:Ti, :Ti)], bins=100, xlabel = "", c = 1, + ylabel = "RDF - TiTi", label = "", yticks = [], xlims = (0, r_cut) ) +plt_TiAl = histogram(rdf[(:Ti, :Ti)], bins=100, xlabel = "", c = 2, + ylabel = "RDF - TiAl", label = "", yticks = [], xlims = (0, r_cut) ) +plt_AlAl = histogram(rdf[(:Al, :Al)], bins=100, xlabel = L"r [\AA]", c = 3, + ylabel = "RDF - AlAl", label = "", yticks = [], xlims = (0, r_cut), ) +plot(plt_TiTi, plt_TiAl, plt_AlAl, layout = (3,1), size = (500, 500), left_margin = 6Plots.mm) + +# An acemodel is defined as + +model = acemodel(elements = [:Ti, :Al], + order = 3, + totaldegree = 6, + rcut = 5.5, + Eref = [:Ti => -1586.0195, :Al => -105.5954]) +@show length(model.basis); + +# and it is fit in the same manner. + +acefit!(model, tial_data[1:5:end]); +ACEpotentials.linear_errors(tial_data[1:5:end], model); + +# Part 6: Recreate data from the ACEpotentials.jl paper +# ======================================================= +# +# The ACEpotentials paper (https://arxiv.org/abs/2309.03161) includes +# comparisons with results from +# +# │ Y. Zuo, C. Chen, X. Li, Z. Deng, Y. Chen, J. Behler, G. Csányi, A. +# │ V. Shapeev, A. P. Thompson, M. A. Wood, and S. P. Ong, Performance +# │ and cost assessment of machine learning interatomic potentials, J. +# │ Chem. Phys. A 124, 731 (2020). +# +# This section can be used to reproduce those data. + +### Choose elements to include +#elements = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] +elements = [:Ni, :Cu] + +### Choose a model size +totaldegree = [ 20, 16, 12 ] # small model: ~ 300 basis functions +#totaldegree = [ 25, 21, 17 ] # large model: ~ 1000 basis functions + +errors = Dict("E" => Dict(), "F" => Dict()) + +for element in elements + + # load the dataset + @info("---------- loading $(element) dataset ----------") + train, test, _ = ACEpotentials.example_dataset("Zuo20_$element") + # specify the model + model = acemodel(elements = [element], order = 3, totaldegree = totaldegree) + @info("$element model length: $(length(model.basis))") + # train the model + acefit!(model, train) + # compute and store errors + err = ACEpotentials.linear_errors(test, model) + errors["E"][element] = err["mae"]["set"]["E"] * 1000 + errors["F"][element] = err["mae"]["set"]["F"] + +end + +# Finally, create the tables. + +header = ([ "", "ACE", "GAP", "MTP"]) + +# create energy table +e_table_gap = Dict( + :Ni => 0.42, :Cu => 0.46, :Li => 0.49, + :Mo => 2.24, :Si => 2.91, :Ge => 2.06) +e_table_mtp = Dict( + :Ni => 0.48, :Cu => 0.41, :Li => 0.49, + :Mo => 2.83, :Si => 2.21, :Ge => 1.79) +e_table = hcat( + string.(elements), + [round(errors["E"][element], digits=3) for element in elements], + [e_table_gap[element] for element in elements], + [e_table_mtp[element] for element in elements]) +println("Energy Error") +pretty_table(e_table; header = header) + +# create force table +f_table_gap = Dict( + :Ni => 0.02, :Cu => 0.01, :Li => 0.01, + :Mo => 0.09, :Si => 0.07, :Ge => 0.05) +f_table_mtp = Dict( + :Ni => 0.01, :Cu => 0.01, :Li => 0.01, + :Mo => 0.09, :Si => 0.06, :Ge => 0.05) +f_table = hcat( + string.(elements), + [round(errors["F"][element], digits=3) for element in elements], + [f_table_gap[element] for element in elements], + [f_table_mtp[element] for element in elements]) +println("Force Error") +pretty_table(f_table; header = header) + +# Part 7: Next steps +# ==================== + +# • Review tutorials from ACEpotentials documentation: +# https://acesuit.github.io/ACEpotentials.jl/dev/tutorials/ +# +# • Parallel fitting: +# https://acesuit.github.io/ACEpotentials.jl/dev/gettingstarted/parallel-fitting/ +# +# • Use an ACEpotentials.jl potential with ASE: +# https://acesuit.github.io/ACEpotentials.jl/dev/tutorials/python_ase/ +# +# • Install LAMMPS with ACEpotentials patch: +# https://acesuit.github.io/ACEpotentials.jl/dev/tutorials/lammps/ From 85323f11e4e1cc1964e9306c11fc4fe98ea926f9 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 22 Aug 2024 22:06:30 -0700 Subject: [PATCH 112/112] fixing CI and docs --- .github/workflows/CI.yml | 2 +- Project.toml | 1 + docs/Project.toml | 2 +- docs/src/ACEpotentials/all_exported.md | 4 ++-- docs/src/tutorials/smoothness_priors.jl | 6 +++--- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 057fbad1..17335aa1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: matrix: julia-version: - '1.10' - - 'nightly' + # - 'nightly' python-version: - '3.8' os: diff --git a/Project.toml b/Project.toml index 703c004d..b5f59c71 100644 --- a/Project.toml +++ b/Project.toml @@ -60,6 +60,7 @@ Interpolations = "0.15" JuLIP = "0.13.9, 0.14.2" PrettyTables = "1.3, 2.0" Reexport = "1" +UltraFastACE = "0.0.6" StaticArrays = "1" YAML = "0.4" julia = "~1.10.0" diff --git a/docs/Project.toml b/docs/Project.toml index 90a94b29..465eb197 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -13,7 +13,7 @@ MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +UltraFastACE = "8bb720ee-daac-48fb-af73-8a282a9cbbd7" [compat] -ACEpotentials = "0.6" Literate = "2.13.2,3" diff --git a/docs/src/ACEpotentials/all_exported.md b/docs/src/ACEpotentials/all_exported.md index b95aaf4d..9e250d25 100644 --- a/docs/src/ACEpotentials/all_exported.md +++ b/docs/src/ACEpotentials/all_exported.md @@ -3,13 +3,13 @@ ### Exported ```@autodocs -Modules = [ACEpotentials] +Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat] Private = false ``` ### Not exported ```@autodocs -Modules = [ACEpotentials] +Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat] Public = false ``` diff --git a/docs/src/tutorials/smoothness_priors.jl b/docs/src/tutorials/smoothness_priors.jl index 6d2ecca5..eb3c2e3a 100644 --- a/docs/src/tutorials/smoothness_priors.jl +++ b/docs/src/tutorials/smoothness_priors.jl @@ -32,9 +32,9 @@ A, Y, W = ACEfit.assemble(data, model.basis); # In the following we demonstrate the usage of algebraic and gaussian priors. The choices for `σl, σn` made here may seem "magical", but there is a good justification and we plan to automate this in future releases. -Pa2 = algebraic_smoothness_prior(model.basis; p=2) -Pa4 = algebraic_smoothness_prior(model.basis; p=4) -Pg = gaussian_smoothness_prior( model.basis, σl = (2/r_nn)^2, σn = (0.5/r_nn)^2); +Pa2 = ACE1x.algebraic_smoothness_prior(model.basis; p=2) +Pa4 = ACE1x.algebraic_smoothness_prior(model.basis; p=4) +Pg = ACE1x.gaussian_smoothness_prior( model.basis, σl = (2/r_nn)^2, σn = (0.5/r_nn)^2); # Each of these object `Pa2, Pa4, Pg` are diagonal matrices. For each prior constructed above we now solve the regularised least squares problem. Note how design matrix need only be assembled once if we want to play with many different priors. Most of the time we would just use defaults however and then these steps are all taken care of behind the scenes.