Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): se_e2_a #4217

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 100 additions & 7 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand Down Expand Up @@ -186,31 +190,33 @@
self.reinit_exclude(exclude_types)

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
for ii, embedding_idx in enumerate(
itertools.product(range(self.ntypes), repeat=self.embeddings.ndim)
itertools.product(range(self.ntypes), repeat=embeddings.ndim)
):
self.embeddings[embedding_idx] = EmbeddingNet(
embeddings[embedding_idx] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, ii),
)
self.embeddings = embeddings
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.nnei = np.sum(self.sel).item()
self.davg = np.zeros(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
)
self.dstd = np.ones(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
)
self.orig_sel = self.sel
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]

def __setitem__(self, key, value):
if key in ("avg", "data_avg", "davg"):
Expand Down Expand Up @@ -321,8 +327,9 @@
ss,
embedding_idx,
):
xp = array_api_compat.array_namespace(ss)
nf_times_nloc, nnei = ss.shape[0:2]
ss = ss.reshape(nf_times_nloc, nnei, 1)
ss = xp.reshape(ss, (nf_times_nloc, nnei, 1))
# (nf x nloc) x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
return gg
Expand Down Expand Up @@ -444,8 +451,8 @@
"env_mat": self.env_mat.serialize(),
"embeddings": self.embeddings.serialize(),
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
"davg": to_numpy_array(self.davg),
"dstd": to_numpy_array(self.dstd),
},
"type_map": self.type_map,
}
Expand Down Expand Up @@ -497,3 +504,89 @@
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist


class DescrptSeAArrayAPI(DescrptSeA):
def call(
self,
coord_ext,
atype_ext,
nlist,
mapping: Optional[np.ndarray] = None,
):
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping from extended to lcoal region. not used by this descriptor.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
if not self.type_one_side:
raise NotImplementedError(

Check warning on line 547 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L547

Added line #L547 was not covered by tests
"type_one_side == False is not supported in DescrptSeAArrayAPI"
)
del mapping
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
input_dtype = coord_ext.dtype
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = xp.asarray(self.sel_cumsum)

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, self.dstd.dtype)

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
(tt,) = embedding_idx
mm = exclude_mask[:, sec[tt] : sec[tt + 1]]
tr = rr[:, sec[tt] : sec[tt + 1], :]
tr = tr * xp.astype(mm[:, :, None], tr.dtype)
ss = tr[..., 0:1]
gg = self.cal_g(ss, embedding_idx)
# gr_tmp = xp.einsum("lni,lnj->lij", gg, tr)
gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
gr += gr_tmp
gr = xp.reshape(gr, (nf, nloc, ng, 4))
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
# nf x nloc x ng x ng1
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.astype(
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
)
return grrg, gr[..., 1:], None, None, ww
20 changes: 10 additions & 10 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,20 @@ def nlist_distinguish_types(
xp = array_api_compat.array_namespace(nlist, atype)
nf, nloc, _ = nlist.shape
ret_nlist = []
tmp_atype = xp.tile(atype[:, None], [1, nloc, 1])
tmp_atype = xp.tile(atype[:, None, :], (1, nloc, 1))
mask = nlist == -1
tnlist_0 = nlist.copy()
tnlist_0[mask] = 0
tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze()
tnlist = xp.where(mask, -1, tnlist)
snsel = tnlist.shape[2]
tnlist_0 = xp.where(mask, xp.zeros_like(nlist), nlist)
tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2)
tnlist = xp.where(mask, xp.full_like(tnlist, -1), tnlist)
for ii, ss in enumerate(sel):
pick_mask = (tnlist == ii).astype(xp.int32)
sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1)
pick_mask = xp.astype(tnlist == ii, xp.int32)
sorted_indices = xp.argsort(-pick_mask, stable=True, axis=-1)
pick_mask_sorted = -xp.sort(-pick_mask, axis=-1)
inlist = xp_take_along_axis(nlist, sorted_indices, axis=2)
inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist)
ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0])
inlist = xp.where(
~xp.astype(pick_mask_sorted, xp.bool), xp.full_like(inlist, -1), inlist
)
ret_nlist.append(inlist[..., :ss])
ret = xp.concat(ret_nlist, axis=-1)
return ret

Expand Down
33 changes: 33 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


@flax_module
class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
32 changes: 32 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
55 changes: 55 additions & 0 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -33,6 +35,17 @@
descrpt_se_a_args,
)

if INSTALLED_JAX:
from deepmd.jax.descriptor.se_e2_a import DescrptSeA as DescrptSeAJAX
else:
DescrptSeAJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_e2_a import (
DescrptSeA as DescrptSeAArrayAPIStrict,
)
else:
DescrptSeAArrayAPIStrict = None


@parameterized(
(True, False), # resnet_dt
Expand Down Expand Up @@ -98,9 +111,33 @@ def skip_tf(self) -> bool:
) = self.param
return env_protection != 0.0

@property
def skip_jax(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
env_protection,
) = self.param
return not type_one_side or not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
env_protection,
) = self.param
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeATF
dp_class = DescrptSeADP
pt_class = DescrptSeAPT
jax_class = DescrptSeAJAX
array_api_strict_class = DescrptSeAArrayAPIStrict
args = descrpt_se_a_args()

def setUp(self):
Expand Down Expand Up @@ -177,6 +214,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down