Commit 38834dc 1 parent 214ac8f commit 38834dc Copy full SHA for 38834dc
File tree 1 file changed +6
-3
lines changed
1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change
1
+ from copy import deepcopy
1
2
from typing import Any , Dict
2
3
3
4
import pytest
@@ -111,9 +112,11 @@ def test_bidirectional_conversion(
111
112
# model_e3nn_back = model_e3nn_back.to(device)
112
113
113
114
# Test forward pass equivalence
114
- out_e3nn = model_e3nn (batch , training = True )
115
- out_cueq = model_cueq (batch , training = True )
116
- out_e3nn_back = model_e3nn_back (batch , training = True )
115
+ out_e3nn = model_e3nn (deepcopy (batch ), training = True , compute_stress = True )
116
+ out_cueq = model_cueq (deepcopy (batch ), training = True , compute_stress = True )
117
+ out_e3nn_back = model_e3nn_back (
118
+ deepcopy (batch ), training = True , compute_stress = True
119
+ )
117
120
118
121
# Check outputs match for both conversions
119
122
torch .testing .assert_close (out_e3nn ["energy" ], out_cueq ["energy" ])
You can’t perform that action at this time.
0 commit comments