Skip to content

Commit 38834dc

Browse files
committed
add stress to test_cueq.py
1 parent 214ac8f commit 38834dc

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/test_cueq.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from typing import Any, Dict
23

34
import pytest
@@ -111,9 +112,11 @@ def test_bidirectional_conversion(
111112
# model_e3nn_back = model_e3nn_back.to(device)
112113

113114
# Test forward pass equivalence
114-
out_e3nn = model_e3nn(batch, training=True)
115-
out_cueq = model_cueq(batch, training=True)
116-
out_e3nn_back = model_e3nn_back(batch, training=True)
115+
out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True)
116+
out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True)
117+
out_e3nn_back = model_e3nn_back(
118+
deepcopy(batch), training=True, compute_stress=True
119+
)
117120

118121
# Check outputs match for both conversions
119122
torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"])

0 commit comments

Comments
 (0)