diff --git a/training/benchmarks/efficientnet/pytorch/train/trainer_adapter.py b/training/benchmarks/efficientnet/pytorch/train/trainer_adapter.py index 7cefd0970..2bd8e0f07 100644 --- a/training/benchmarks/efficientnet/pytorch/train/trainer_adapter.py +++ b/training/benchmarks/efficientnet/pytorch/train/trainer_adapter.py @@ -88,17 +88,20 @@ def create_grad_scaler(args): def backward(args, step: int, epoch: int, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, scaler): - optimizer.zero_grad() if scaler is not None: scaler.scale(loss).backward() - if args.clip_grad_norm is not None: - # we should unscale the gradients of optimizer's assigned params if do gradient clipping - scaler.unscale_(optimizer) - nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) - scaler.step(optimizer) - scaler.update() + if step % args.gradient_accumulation_steps == 0: + if args.clip_grad_norm is not None: + # we should unscale the gradients of optimizer's assigned params if do gradient clipping + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) + scaler.step(optimizer) + optimizer.zero_grad() + scaler.update() else: loss.backward() - if args.clip_grad_norm is not None: - nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) - optimizer.step() + if step % args.gradient_accumulation_steps == 0: + if args.clip_grad_norm is not None: + nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) + optimizer.step() + optimizer.zero_grad() diff --git a/training/kunlunxin/efficientnet-pytorch/README.md b/training/kunlunxin/efficientnet-pytorch/README.md index 1e0a9e617..4a3af61f8 100644 --- a/training/kunlunxin/efficientnet-pytorch/README.md +++ b/training/kunlunxin/efficientnet-pytorch/README.md @@ -20,14 +20,17 @@ ### 运行情况 -| 训练资源 | 配置文件 | 运行时长(s) | 目标精度 | 收敛精度 | Steps数 | 性能(samples/s) | +| 训练资源 | 配置文件 | 运行时长(s) | 目标精度 | 收敛精度 | Steps数 | 性能 (samples/s)| | -------- | --------------- | ----------- | -------- | -------- | ------- | ---------------- | | 单机1卡 | config_R300x1x1 | | | | | | | 单机2卡 | config_R300x1x2 | | | | | | | 单机4卡 | config_R300x1x4 | | | | | | -| 单机8卡 | config_R300x1x8 | | | | | | +| 单机8卡 | config_R300x1x8 | | 82.672 | 72.666 | 868540 | | | 两机8卡 | config_R300x2x8 | | | | | | +### 收敛曲线 +![acc](acc.png) + ### 许可证 Apache 2.0 license。 diff --git a/training/kunlunxin/efficientnet-pytorch/acc.png b/training/kunlunxin/efficientnet-pytorch/acc.png new file mode 100644 index 000000000..8256f6909 Binary files /dev/null and b/training/kunlunxin/efficientnet-pytorch/acc.png differ diff --git a/training/kunlunxin/efficientnet-pytorch/config/config_R300x1x8.py b/training/kunlunxin/efficientnet-pytorch/config/config_R300x1x8.py index fbaea0e5d..1026a0695 100644 --- a/training/kunlunxin/efficientnet-pytorch/config/config_R300x1x8.py +++ b/training/kunlunxin/efficientnet-pytorch/config/config_R300x1x8.py @@ -1,4 +1,5 @@ from config_common import * train_batch_size = 64 -eval_batch_size = 64 +eval_batch_size = 128 +gradient_accumulation_steps = 2