Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix for multigpu training using main branch #599

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ We are happy to accept pull requests under an [MIT license](https://choosealicen

If you use this code, please cite our papers:

```text
```bibtex
@inproceedings{Batatia2022mace,
title={{MACE}: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields},
author={Ilyes Batatia and David Peter Kovacs and Gregor N. C. Simm and Christoph Ortner and Gabor Csanyi},
Expand Down
6 changes: 6 additions & 0 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ def run(args: argparse.Namespace):
charges_key=args.charges_key,
)

if args.E0s is not None:
logging.info("Using E0s from command line argument")
E0s = ast.literal_eval(args.E0s)
assert isinstance(E0s, dict)
atomic_energies_dict = E0s

# Atomic number table
# yapf: disable
if args.atomic_numbers is None:
Expand Down
2 changes: 1 addition & 1 deletion mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def compute_average_E0s(
E0s = np.linalg.lstsq(A, B, rcond=None)[0]
atomic_energies_dict = {}
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = E0s[i]
atomic_energies_dict[z] = float(E0s[i])
except np.linalg.LinAlgError:
logging.error(
"Failed to compute E0s using least squares regression, using the same for all atoms"
Expand Down
10 changes: 8 additions & 2 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from huggingface_hub import PyTorchModelHubMixin

from mace.data import AtomicData
from mace.modules.radial import ZBLBasis
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff)
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True

sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
Expand Down Expand Up @@ -315,7 +316,12 @@ def forward(


@compile_mode("script")
class ScaleShiftMACE(MACE):
class ScaleShiftMACE(
MACE,
PyTorchModelHubMixin,
repo_url="https://github.com/ACEsuit/mace",
docs_url="https://mace-docs.readthedocs.io/en/latest/",
):
def __init__(
self,
atomic_inter_scale: float,
Expand Down
67 changes: 35 additions & 32 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]

@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""
Equation (8)
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""

p: torch.Tensor
Expand All @@ -125,36 +125,38 @@ def __init__(self, r_max: float, p=6):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# yapf: disable
return self.calculate_envelope(x, self.r_max, self.p)

@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
# yapf: enable

# noinspection PyUnresolvedReferences
return envelope * (x < self.r_max)
return envelope * (x < r_max)

def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"


@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""
Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6, trainable=False):
def __init__(self, p=6, trainable=False):
super().__init__()
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
Expand All @@ -170,7 +172,6 @@ def __init__(self, r_max: float, p=6, trainable=False):
dtype=torch.get_default_dtype(),
),
)
self.cutoff = PolynomialCutoff(r_max, p)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
Expand Down Expand Up @@ -208,12 +209,7 @@ def forward(
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2)
) * (x < r_max)
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
Expand All @@ -224,8 +220,8 @@ def __repr__(self):

@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""
Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""

def __init__(
Expand Down Expand Up @@ -265,21 +261,27 @@ def forward(
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p)))
) ** (-1)
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()

def __repr__(self):
return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})"
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)


@simplify_if_compile
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Soft Transform
"""
"""Soft Transform."""

def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False):
super().__init__()
Expand Down Expand Up @@ -312,9 +314,10 @@ def forward(
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4
r_over_r_0 = x / r_0
y = (
x
+ (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b))
+ (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b))
+ 1 / 2
)
return y
Expand Down
6 changes: 6 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:

def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--work_dir",
help="set directory for all files and folders",
type=str,
default=".",
)
parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
30 changes: 18 additions & 12 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def valid_err_log(
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_stress"] is not None
and eval_metrics.get("rmse_stress", None) is not None
):
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
Expand All @@ -74,7 +74,7 @@ def valid_err_log(
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_virials_per_atom"] is not None
and eval_metrics.get("rmse_virials_per_atom", None) is not None
):
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
Expand All @@ -83,18 +83,18 @@ def valid_err_log(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV",
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_stress_per_atom"] is not None
log_errors == "PerAtomMAEstress"
and eval_metrics.get("mae_stress", None) is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3"
f"{inintial_phrase}: loss={valid_loss:10.6e}, MAE_E_per_atom={error_e:10.6e} meV, MAE_F={error_f:10.6e} meV / A, MAE_stress={error_stress:10.6e} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_virials_per_atom"] is not None
and eval_metrics.get("mae_virials_per_atom", None) is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
Expand All @@ -111,6 +111,7 @@ def valid_err_log(
elif log_errors == "PerAtomMAE":
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A",
)
Expand Down Expand Up @@ -177,21 +178,29 @@ def train(
logging.info("Loss metrics on validation set")
epoch = start_epoch

if distributed:
torch.distributed.barrier()

model_to_evaluate = model if distributed_model is None else distributed_model
# log validation loss before _any_ training
valid_loss = 0.0
for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model,
model=model_to_evaluate,
loss_fn=loss_fn,
data_loader=valid_loader,
output_args=output_args,
device=device,
)
valid_err_log(
if rank == 0:
valid_err_log(
valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name
)
valid_loss = valid_loss_head # consider only the last head for the checkpoint

if distributed:
torch.distributed.barrier()

while epoch < max_num_epochs:
# LR scheduler and SWA update
if swa is None or epoch < swa.start:
Expand Down Expand Up @@ -266,11 +275,8 @@ def train(
wandb_log_dict[valid_loader_name] = {
"epoch": epoch,
"valid_loss": valid_loss_head,
"valid_rmse_e_per_atom": eval_metrics[
"rmse_e_per_atom"
],
"valid_rmse_f": eval_metrics["rmse_f"],
}
wandb_log_dict.update(eval_metrics)
valid_loss = (
valid_loss_head # consider only the last head for the checkpoint
)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires =
# for plotting:
matplotlib
pandas
huggingface-hub

[options.entry_points]
console_scripts =
Expand Down
Loading