-
Notifications
You must be signed in to change notification settings - Fork 530
/
Copy pathweight_averaging_utils.py
executable file
·125 lines (104 loc) · 5.58 KB
/
weight_averaging_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
import numpy as np
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
from super_gradients.training.utils.utils import move_state_dict_to_device, get_real_model
class ModelWeightAveraging:
"""
Utils class for managing the averaging of the best several snapshots into a single model.
A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when
training is completed. The snapshot file will only be deleted upon completing the training.
The snapshot dict will be managed on cpu.
"""
def __init__(
self,
ckpt_dir,
greater_is_better,
metric_to_watch="acc",
metric_idx=1,
load_checkpoint=False,
number_of_models_to_average=10,
):
"""
Init the ModelWeightAveraging
:param ckpt_dir: the directory where the checkpoints are saved
:param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
:param metric_idx:
:param load_checkpoint: whether to load pre-existing snapshot dict.
:param number_of_models_to_average: number of models to average
"""
self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
self.number_of_models_to_average = number_of_models_to_average
self.metric_to_watch = metric_to_watch
self.metric_idx = metric_idx
self.greater_is_better = greater_is_better
# if continuing training, copy previous snapshot dict if exist
if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)
else:
averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
# if metric to watch is acc, hold a zero array, if loss hold inf array
if self.greater_is_better:
averaging_snapshots_dict["snapshots_metric"] = -1 * np.inf * np.ones(self.number_of_models_to_average)
else:
averaging_snapshots_dict["snapshots_metric"] = np.inf * np.ones(self.number_of_models_to_average)
torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
def update_snapshots_dict(self, model, validation_results_tuple):
"""
Update the snapshot dict and returns the updated average model for saving
:param model: the latest model
:param validation_results_tuple: performance of the latest model
"""
averaging_snapshots_dict = self._get_averaging_snapshots_dict()
# IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE
require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_tuple)
if require_update:
# moving state dict to cpu
new_sd = get_real_model(model).state_dict()
new_sd = move_state_dict_to_device(new_sd, "cpu")
averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
averaging_snapshots_dict["snapshots_metric"][update_ind] = validation_results_tuple[self.metric_idx]
return averaging_snapshots_dict
def get_average_model(self, model, validation_results_tuple=None):
"""
Returns the averaged model
:param model: will be used to determine arch
:param validation_results_tuple: if provided, will update the average model before returning
:param target_device: if provided, return sd on target device
"""
# If validation tuple is provided, update the average model
if validation_results_tuple is not None:
averaging_snapshots_dict = self.update_snapshots_dict(model, validation_results_tuple)
else:
averaging_snapshots_dict = self._get_averaging_snapshots_dict()
torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
average_model_sd = averaging_snapshots_dict["snapshot0"]
for n_model in range(1, self.number_of_models_to_average):
if averaging_snapshots_dict["snapshot" + str(n_model)] is not None:
net_sd = averaging_snapshots_dict["snapshot" + str(n_model)]
# USING MOVING AVERAGE
for key in average_model_sd:
average_model_sd[key] = torch.true_divide(average_model_sd[key] * n_model + net_sd[key], (n_model + 1))
return average_model_sd
def cleanup(self):
"""
Delete snapshot file when reaching the last epoch
"""
os.remove(self.averaging_snapshots_file)
def _is_better(self, averaging_snapshots_dict, validation_results_tuple):
"""
Determines if the new model is better according to the specified metrics
:param averaging_snapshots_dict: snapshot dict
:param validation_results_tuple: latest model performance
"""
snapshot_metric_array = averaging_snapshots_dict["snapshots_metric"]
val = validation_results_tuple[self.metric_idx]
if self.greater_is_better:
update_ind = np.argmin(snapshot_metric_array)
else:
update_ind = np.argmax(snapshot_metric_array)
if (self.greater_is_better and val > snapshot_metric_array[update_ind]) or (not self.greater_is_better and val < snapshot_metric_array[update_ind]):
return True, update_ind
return False, None
def _get_averaging_snapshots_dict(self):
return torch.load(self.averaging_snapshots_file)