From f17c0f8cad268bbc35d18c46466ea64211efd219 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 24 Feb 2022 19:28:42 +0800 Subject: [PATCH] Simplify EMA to use Pytorch's update_parameters --- references/classification/utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 4afe9bf68f1..397603dfd76 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -166,17 +166,7 @@ def __init__(self, model, decay, device="cpu"): def ema_avg(avg_model_param, model_param, num_averaged): return decay * avg_model_param + (1 - decay) * model_param - super().__init__(model, device, ema_avg) - - def update_parameters(self, model): - for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()): - device = p_swa.device - p_model_ = p_model.detach().to(device) - if self.n_averaged == 0: - p_swa.detach().copy_(p_model_) - else: - p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device))) - self.n_averaged += 1 + super().__init__(model, device, ema_avg, use_buffers=True) def accuracy(output, target, topk=(1,)):