Skip to content

Commit

Permalink
copy tensors rather than modifying in place
Browse files Browse the repository at this point in the history
  • Loading branch information
SanggyuChong committed Feb 21, 2024
1 parent 9006df2 commit 94686f0
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions src/metatensor/models/utils/compute_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,6 @@ def compute_model_loss(
# Based on the keys of the targets, get the outputs of the model:
model_outputs = _get_model_outputs(model, systems, list(targets.keys()))

# modify model predictions and targets where averaging per atom is requested
if len(peratom_targets) > 0:

num_atoms = torch.tensor([len(s) for s in systems]).to(device=device)
num_atoms = torch.reshape(num_atoms, (-1, 1))

for pa_target in peratom_targets:
model_outputs[pa_target].block().values /= num_atoms
# to prevent averaging builds up over epochs
targets = targets.copy()
targets[pa_target].block().values /= num_atoms

for energy_target in energy_targets:
# If the energy target requires gradients, compute them:
target_requires_pos_gradients = (
Expand Down Expand Up @@ -187,8 +175,54 @@ def compute_model_loss(
else:
pass

# Perform averaging by numer of atoms
num_atoms = torch.tensor([len(s) for s in systems]).to(device=device)
num_atoms = torch.reshape(num_atoms, (-1, 1))

new_model_outputs = model_outputs.copy()
new_targets = targets.copy()

for pa_target in peratom_targets:

# Update predictions
cur_model_block = new_model_outputs[pa_target].block().copy()
new_model_block_values = cur_model_block.values / num_atoms
new_model_block = TensorBlock(
values=new_model_block_values,
samples=cur_model_block.samples,
components=cur_model_block.components,
properties=cur_model_block.properties,
)
for param, gradient in cur_model_block.gradients():
new_model_block.add_gradient(param, gradient)

# Update targets
cur_target_block = new_targets[pa_target].block().copy()
new_target_block_values = cur_target_block.values / num_atoms
new_target_block = TensorBlock(
values=new_target_block_values,
samples=cur_target_block.samples,
components=cur_target_block.components,
properties=cur_target_block.properties,
)
for param, gradient in cur_target_block.gradients():
new_target_block.add_gradient(param, gradient)

new_model_tensor_map = TensorMap(
keys=new_model_outputs[pa_target].keys,
blocks=[new_model_block],
)

new_target_tensor_map = TensorMap(
keys=new_targets[pa_target].keys,
blocks=[new_target_block],
)

new_model_outputs[pa_target] = new_model_tensor_map
new_targets[pa_target] = new_target_tensor_map

# Compute and return the loss and associated info:
return loss(model_outputs, targets)
return loss(new_model_outputs, new_targets)


def _position_gradients_to_block(gradients_list):
Expand Down

0 comments on commit 94686f0

Please sign in to comment.