diff --git a/src/metatensor/models/cli/conf/architecture/experimental.alchemical_model.yaml b/src/metatensor/models/cli/conf/architecture/experimental.alchemical_model.yaml index c0fe79da1..22a196c23 100644 --- a/src/metatensor/models/cli/conf/architecture/experimental.alchemical_model.yaml +++ b/src/metatensor/models/cli/conf/architecture/experimental.alchemical_model.yaml @@ -21,3 +21,4 @@ training: learning_rate: 0.001 log_interval: 10 checkpoint_interval: 25 + peratom_targets: [] diff --git a/src/metatensor/models/cli/conf/architecture/experimental.soap_bpnn.yaml b/src/metatensor/models/cli/conf/architecture/experimental.soap_bpnn.yaml index e43313e25..66873e749 100644 --- a/src/metatensor/models/cli/conf/architecture/experimental.soap_bpnn.yaml +++ b/src/metatensor/models/cli/conf/architecture/experimental.soap_bpnn.yaml @@ -25,3 +25,4 @@ training: learning_rate: 0.001 log_interval: 10 checkpoint_interval: 25 + peratom_targets: [] diff --git a/src/metatensor/models/cli/eval_model.py b/src/metatensor/models/cli/eval_model.py index 852c15675..5cdad8309 100644 --- a/src/metatensor/models/cli/eval_model.py +++ b/src/metatensor/models/cli/eval_model.py @@ -100,7 +100,7 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset]) aggregated_info: Dict[str, Tuple[float, int]] = {} for batch in dataloader: structures, targets = batch - _, info = compute_model_loss(loss_fn, model, structures, targets) + _, info = compute_model_loss(loss_fn, model, structures, targets, []) aggregated_info = update_aggregated_info(aggregated_info, info) finalized_info = finalize_aggregated_info(aggregated_info) diff --git a/src/metatensor/models/experimental/alchemical_model/train.py b/src/metatensor/models/experimental/alchemical_model/train.py index b9d0afdbc..d2eb1852e 100644 --- a/src/metatensor/models/experimental/alchemical_model/train.py +++ b/src/metatensor/models/experimental/alchemical_model/train.py @@ -209,7 +209,9 @@ def train( optimizer.zero_grad() structures, targets = batch assert len(structures[0].known_neighbors_lists()) > 0 - loss, info = compute_model_loss(loss_fn, model, structures, targets) + loss, info = compute_model_loss( + loss_fn, model, structures, targets, hypers_training["peratom_targets"] + ) train_loss += loss.item() loss.backward() optimizer.step() @@ -220,7 +222,9 @@ def train( for batch in validation_dataloader: structures, targets = batch # TODO: specify that the model is not training here to save some autograd - loss, info = compute_model_loss(loss_fn, model, structures, targets) + loss, info = compute_model_loss( + loss_fn, model, structures, targets, hypers_training["peratom_targets"] + ) validation_loss += loss.item() aggregated_validation_info = update_aggregated_info( aggregated_validation_info, info diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index 06ce1b755..b26d543db 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -189,7 +189,9 @@ def train( for batch in train_dataloader: optimizer.zero_grad() structures, targets = batch - loss, info = compute_model_loss(loss_fn, model, structures, targets) + loss, info = compute_model_loss( + loss_fn, model, structures, targets, hypers_training["peratom_targets"] + ) train_loss += loss.item() loss.backward() optimizer.step() @@ -200,7 +202,9 @@ def train( for batch in validation_dataloader: structures, targets = batch # TODO: specify that the model is not training here to save some autograd - loss, info = compute_model_loss(loss_fn, model, structures, targets) + loss, info = compute_model_loss( + loss_fn, model, structures, targets, hypers_training["peratom_targets"] + ) validation_loss += loss.item() aggregated_validation_info = update_aggregated_info( aggregated_validation_info, info diff --git a/src/metatensor/models/utils/compute_loss.py b/src/metatensor/models/utils/compute_loss.py index 08c601c68..737cc4563 100644 --- a/src/metatensor/models/utils/compute_loss.py +++ b/src/metatensor/models/utils/compute_loss.py @@ -29,6 +29,7 @@ def compute_model_loss( model: Union[torch.nn.Module, torch.jit._script.RecursiveScriptModule], systems: List[System], targets: Dict[str, TensorMap], + peratom_targets: List[str], ): """ Compute the loss of a model on a set of targets. @@ -174,8 +175,54 @@ def compute_model_loss( else: pass + # Perform averaging by numer of atoms + num_atoms = torch.tensor([len(s) for s in systems]).to(device=device) + num_atoms = torch.reshape(num_atoms, (-1, 1)) + + new_model_outputs = model_outputs.copy() + new_targets = targets.copy() + + for pa_target in peratom_targets: + + # Update predictions + cur_model_block = new_model_outputs[pa_target].block().copy() + new_model_block_values = cur_model_block.values / num_atoms + new_model_block = TensorBlock( + values=new_model_block_values, + samples=cur_model_block.samples, + components=cur_model_block.components, + properties=cur_model_block.properties, + ) + for param, gradient in cur_model_block.gradients(): + new_model_block.add_gradient(param, gradient) + + # Update targets + cur_target_block = new_targets[pa_target].block().copy() + new_target_block_values = cur_target_block.values / num_atoms + new_target_block = TensorBlock( + values=new_target_block_values, + samples=cur_target_block.samples, + components=cur_target_block.components, + properties=cur_target_block.properties, + ) + for param, gradient in cur_target_block.gradients(): + new_target_block.add_gradient(param, gradient) + + new_model_tensor_map = TensorMap( + keys=new_model_outputs[pa_target].keys, + blocks=[new_model_block], + ) + + new_target_tensor_map = TensorMap( + keys=new_targets[pa_target].keys, + blocks=[new_target_block], + ) + + new_model_outputs[pa_target] = new_model_tensor_map + new_targets[pa_target] = new_target_tensor_map + # Compute and return the loss and associated info: - return loss(model_outputs, targets) + return loss(new_model_outputs, new_targets) def _position_gradients_to_block(gradients_list): diff --git a/tests/utils/test_compute_loss.py b/tests/utils/test_compute_loss.py index 709bc6cb7..551d86698 100644 --- a/tests/utils/test_compute_loss.py +++ b/tests/utils/test_compute_loss.py @@ -96,9 +96,22 @@ def test_compute_model_loss(): ), } + peratom_targets = [] + + compute_model_loss( + loss_fn, + model, + structures, + targets, + peratom_targets, + ) + + peratom_targets = ["energy"] + compute_model_loss( loss_fn, model, structures, targets, + peratom_targets, )