Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tri feature which is needed for force calculation #3

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DeepSolid/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def default() -> ml_collections.ConfigDict:
'hidden_dims': ((256, 32), (256, 32), (256, 32)),
'determinants': 8,
'after_determinants': 1,
'distance_type': 'nu',
},
'twist': (0.0, 0.0, 0.0), # Difine the twist of wavefunction,
# twists are given in terms of fractions of supercell reciprocal vectors
Expand Down
140 changes: 68 additions & 72 deletions DeepSolid/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,19 @@ def enforce_pbc(latvec, epos):


def init_solid_fermi_net_params(
key: jnp.ndarray,
data,
atoms: jnp.ndarray,
spins: Tuple[int, int],
envelope_type: str = 'full',
bias_orbitals: bool = False,
use_last_layer: bool = False,
eps: float = 0.01,
full_det: bool = True,
hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)),
determinants: int = 16,
after_determinants: Union[int, Tuple[int, ...]] = 1,
key: jnp.ndarray,
data,
atoms: jnp.ndarray,
spins: Tuple[int, int],
envelope_type: str = 'full',
bias_orbitals: bool = False,
use_last_layer: bool = False,
eps: float = 0.01,
full_det: bool = True,
hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)),
determinants: int = 16,
after_determinants: Union[int, Tuple[int, ...]] = 1,
distance_type='nu',
):
"""Initializes parameters for the Fermionic Neural Network.

Expand Down Expand Up @@ -107,7 +108,13 @@ def init_solid_fermi_net_params(
del data

natom = atoms.shape[0]
in_dims = (natom * 4, 4)
if distance_type == 'nu':
in_dims = (natom * 4, 4)
elif distance_type == 'tri':
in_dims = (natom * 7, 7)
else:
raise ValueError('Unrecognized distance function.')

active_spin_channels = [spin for spin in spins if spin > 0]
nchannels = len(active_spin_channels)
# The input to layer L of the one-electron stream is from
Expand Down Expand Up @@ -179,38 +186,6 @@ def init_solid_fermi_net_params(
return params


def construct_input_features(
x: jnp.ndarray,
atoms: jnp.ndarray,
ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Constructs inputs to Fermi Net from raw electron and atomic positions.
Args:
x: electron positions. Shape (nelectrons*ndim,).
atoms: atom positions. Shape (natoms, ndim).
ndim: dimension of system. Change only with caution.
Returns:
ae, ee, r_ae, r_ee tuple, where:
ae: atom-electron vector. Shape (nelectron, natom, 3).
ee: atom-electron vector. Shape (nelectron, nelectron, 3).
r_ae: atom-electron distance. Shape (nelectron, natom, 1).
r_ee: electron-electron distance. Shape (nelectron, nelectron, 1).
The diagonal terms in r_ee are masked out such that the gradients of these
terms are also zero.
"""

assert atoms.shape[1] == ndim
ae = jnp.reshape(x, [-1, 1, ndim]) - atoms[None, ...]
ee = jnp.reshape(x, [1, -1, ndim]) - jnp.reshape(x, [-1, 1, ndim])

r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True)
# Avoid computing the norm of zero, as is has undefined grad
n = ee.shape[0]
r_ee = (
jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n)))

return ae, ee, r_ae, r_ee[..., None]


def scaled_f(w):
"""
see Phys. Rev. B 94, 035157
Expand All @@ -229,23 +204,6 @@ def scaled_g(w):
return w * (1 - 3. / 2. * jnp.abs(w / jnp.pi) + 1. / 2. * jnp.abs(w / jnp.pi) ** 2)


def sin_relative_distance(xea, a, b):
'''

:param xea:
:param a:
:param b:
:return: periodic relative distance [ne, na, 6]
'''
w = jnp.einsum('...ijk,lk->...ijl', xea, b)
mod = (w + jnp.pi) // (2 * jnp.pi)
w = (w - mod * 2 * jnp.pi)
# w = jnp.mod(w + jnp.pi, 2 * jnp.pi) - jnp.pi
# r1 = jnp.einsum('...i,ij->...j', scaled_f(w), a)
r2 = jnp.einsum('...i,ij->...j', scaled_g(w), a)
return r2


def nu_distance(xea, a, b):
"""
see Phys. Rev. B 94, 035157
Expand All @@ -265,11 +223,36 @@ def nu_distance(xea, a, b):
sd = result ** 0.5
return sd, rel


def tri_distance(xea, a, b):
"""
see Phys. Rev. Lett. 130, 036401 (2023).
:param xea: relative distance between electrons and atoms
:param a: lattice vectors of primitive cell divided by 2\pi.
:param b: reciprocal vectors of primitive cell.
:return: periodic generalized relative and absolute distance of xea.
"""
w = jnp.einsum('...ijk,lk->...ijl', xea, b)
sg = jnp.sin(w)
cg = jnp.cos(w)
rel_sin = jnp.einsum('...i,ij->...j', sg, a)
rel_cos = jnp.einsum('...i,ij->...j', cg, a)
rel = jnp.concatenate([rel_sin, rel_cos], axis=-1)
metric = jnp.einsum('ij,kj->ik', a, a)
vector_sin = sg[..., :, None] * sg[..., None, :]
vector_cos = (1-cg[..., :, None]) * (1-cg[..., None, :])
vector = vector_cos + vector_sin
sd = jnp.einsum('...ij,ij->...', vector, metric) ** 0.5
return sd, rel


def construct_periodic_input_features(
x: jnp.ndarray,
atoms: jnp.ndarray,
simulation_cell=None,
ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
x: jnp.ndarray,
atoms: jnp.ndarray,
simulation_cell=None,
ndim: int = 3,
distance_type: str = 'nu',
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Constructs a periodic generalized inputs to Fermi Net from raw electron and atomic positions.
see Phys. Rev. B 94, 035157
Args:
Expand All @@ -285,14 +268,21 @@ def construct_periodic_input_features(
The diagonal terms in r_ee are masked out such that the gradients of these
terms are also zero.
"""
if distance_type == 'nu':
distance_func = nu_distance
elif distance_type == 'tri':
distance_func = tri_distance
else:
raise ValueError('Unrecognized distance function.')

primitive_cell = simulation_cell.original_cell
x = x.reshape(-1, ndim)
n = x.shape[0]
prim_x, _ = enforce_pbc(primitive_cell.a, x)

# prim_xea = minimal_imag.dist_i(atoms.ravel(), prim_x.ravel())
prim_xea = prim_x[..., None, :] - atoms
prim_periodic_sea, prim_periodic_xea = nu_distance(prim_xea,
prim_periodic_sea, prim_periodic_xea = distance_func(prim_xea,
primitive_cell.AV,
primitive_cell.BV)
prim_periodic_sea = prim_periodic_sea[..., None]
Expand All @@ -301,7 +291,7 @@ def construct_periodic_input_features(
# sim_xee = sim_minimal_imag.dist_matrix(sim_x.ravel())
sim_xee = sim_x[:, None, :] - sim_x[None, :, :]

sim_periodic_see, sim_periodic_xee = nu_distance(sim_xee + jnp.eye(n)[..., None],
sim_periodic_see, sim_periodic_xee = distance_func(sim_xee + jnp.eye(n)[..., None],
simulation_cell.AV,
simulation_cell.BV)
sim_periodic_see = sim_periodic_see * (1.0 - jnp.eye(n))
Expand Down Expand Up @@ -474,7 +464,8 @@ def solid_fermi_net_orbitals(params, x,
atoms=None,
spins=(None, None),
envelope_type=None,
full_det=False):
full_det=False,
distance_type='nu'):
"""Forward evaluation of the Solid Neural Network up to the orbitals.
Args:
params: A dictionary of parameters, containing fields:
Expand Down Expand Up @@ -506,9 +497,9 @@ def solid_fermi_net_orbitals(params, x,
envelope, depending on the envelope type.
"""

ae_, ee_, r_ae, r_ee = construct_periodic_input_features(x, atoms,
simulation_cell=simulation_cell,
)
ae_, ee_, r_ae, r_ee = construct_periodic_input_features(
x, atoms, simulation_cell=simulation_cell, distance_type=distance_type
)
ae = jnp.concatenate((r_ae, ae_), axis=2)
ae = jnp.reshape(ae, [jnp.shape(ae)[0], -1])
ee = jnp.concatenate((r_ee, ee_), axis=2)
Expand Down Expand Up @@ -576,6 +567,7 @@ def eval_func(params, x,
spins=(None, None),
envelope_type='full',
full_det=False,
distance_type='nu',
method_name='eval_slogdet'):
'''
generates the wavefunction of simulation cell.
Expand All @@ -597,6 +589,7 @@ def eval_func(params, x,
atoms=atoms,
spins=spins,
envelope_type=envelope_type,
distance_type=distance_type,
full_det=full_det)
if method_name == 'eval_slogdet':
_, result = logdet_matmul(orbitals)
Expand All @@ -623,6 +616,7 @@ def make_solid_fermi_net(
hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)),
determinants: int = 16,
after_determinants: Union[int, Tuple[int, ...]] = 1,
distance_type='nu',
method_name='eval_logdet',
):
'''
Expand Down Expand Up @@ -655,6 +649,7 @@ def make_solid_fermi_net(
hidden_dims=hidden_dims,
determinants=determinants,
after_determinants=after_determinants,
distance_type=distance_type,
)
network = functools.partial(
eval_func,
Expand All @@ -664,6 +659,7 @@ def make_solid_fermi_net(
spins=simulation_cell.nelec,
envelope_type=envelope_type,
full_det=full_det,
distance_type=distance_type,
method_name=method_name,
)
method.init = init
Expand Down
1 change: 0 additions & 1 deletion bin/deepsolid
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# modified from FermiNet:https://github.com/deepmind/ferminet

import sys
import os

from absl import app
from absl import flags
Expand Down