diff --git a/src/super_gradients/training/utils/weight_averaging_utils.py b/src/super_gradients/training/utils/weight_averaging_utils.py index 4694d4a2e0..5deae56aea 100755 --- a/src/super_gradients/training/utils/weight_averaging_utils.py +++ b/src/super_gradients/training/utils/weight_averaging_utils.py @@ -49,7 +49,7 @@ def __init__( else: averaging_snapshots_dict["snapshots_metric"] = np.inf * np.ones(self.number_of_models_to_average) - torch.save(averaging_snapshots_dict, self.averaging_snapshots_file) + torch.save(averaging_snapshots_dict, self.averaging_snapshots_file) def update_snapshots_dict(self, model, validation_results_tuple): """