Skip to content

Commit 2f9a3a5

Browse files
authored
Merge pull request #703 from ACEsuit/fix-extract-equivariant-features-with-num-layers-1
Fix-extract-equivariant-features-with-num-layers-1
2 parents 3cd9bb8 + a6a729a commit 2f9a3a5

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

mace/calculators/mace.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -400,24 +400,34 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
400400
atoms = self.atoms
401401
if self.model_type != "MACE":
402402
raise NotImplementedError("Only implemented for MACE models")
403+
num_interactions = int(self.models[0].num_interactions)
403404
if num_layers == -1:
404-
num_layers = int(self.models[0].num_interactions)
405+
num_layers = num_interactions
405406
batch = self._atoms_to_batch(atoms)
406407
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]
408+
409+
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
410+
l_max = irreps_out.lmax
411+
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
412+
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
413+
per_layer_features[-1] = (
414+
num_invariant_features # Equivariant features not created for the last layer
415+
)
416+
407417
if invariants_only:
408-
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
409-
l_max = irreps_out.lmax
410-
num_features = irreps_out.dim // (l_max + 1) ** 2
411418
descriptors = [
412419
extract_invariant(
413420
descriptor,
414421
num_layers=num_layers,
415-
num_features=num_features,
422+
num_features=num_invariant_features,
416423
l_max=l_max,
417424
)
418425
for descriptor in descriptors
419426
]
420-
descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors]
427+
to_keep = np.sum(per_layer_features[:num_layers])
428+
descriptors = [
429+
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
430+
]
421431

422432
if self.num_models == 1:
423433
return descriptors[0]

tests/test_calculator.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -481,24 +481,42 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
481481

482482
desc_invariant = calc.get_descriptors(at, invariants_only=True)
483483
desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True)
484-
desc_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1)
485-
desc_single_layer_rotated = calc.get_descriptors(
484+
desc_invariant_single_layer = calc.get_descriptors(
485+
at, invariants_only=True, num_layers=1
486+
)
487+
desc_invariant_single_layer_rotated = calc.get_descriptors(
486488
at_rotated, invariants_only=True, num_layers=1
487489
)
488490
desc = calc.get_descriptors(at, invariants_only=False)
491+
desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1)
489492
desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False)
493+
desc_rotated_single_layer = calc.get_descriptors(
494+
at_rotated, invariants_only=False, num_layers=1
495+
)
490496

491497
assert desc_invariant.shape[0] == 3
492498
assert desc_invariant.shape[1] == 32
493-
assert desc_single_layer.shape[0] == 3
494-
assert desc_single_layer.shape[1] == 16
499+
assert desc_invariant_single_layer.shape[0] == 3
500+
assert desc_invariant_single_layer.shape[1] == 16
495501
assert desc.shape[0] == 3
496502
assert desc.shape[1] == 80
503+
assert desc_single_layer.shape[0] == 3
504+
assert desc_single_layer.shape[1] == 16 * 4
505+
assert desc_rotated_single_layer.shape[0] == 3
506+
assert desc_rotated_single_layer.shape[1] == 16 * 4
497507

498508
np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6)
499-
np.testing.assert_allclose(desc_single_layer, desc_invariant[:, :16], atol=1e-6)
500509
np.testing.assert_allclose(
501-
desc_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
510+
desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6
511+
)
512+
np.testing.assert_allclose(
513+
desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
514+
)
515+
np.testing.assert_allclose(
516+
desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6
517+
)
518+
assert not np.allclose(
519+
desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6
502520
)
503521
assert not np.allclose(desc, desc_rotated, atol=1e-6)
504522

0 commit comments

Comments
 (0)