Skip to content

Commit

Permalink
full resnet50 precision(bf16+amp) (#253)
Browse files Browse the repository at this point in the history
* full resnet50

* add ieee754

* add ieee754
  • Loading branch information
shh2000 authored Oct 13, 2023
1 parent 410b73b commit 43d4f85
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class ToFloat16(object):

def __call__(self, tensor):
return tensor.to(dtype=torch.float16)
return tensor.to(dtype=torch.bfloat16)


def build_train_dataset(config):
Expand Down
6 changes: 4 additions & 2 deletions training/benchmarks/resnet50/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def train_one_epoch(self, train_dataloader, eval_dataloader):
device = self.device
epoch = self.training_state.epoch
scaler = self.scaler
criterion = torch.nn.CrossEntropyLoss()

print("Epoch " + str(epoch + 1))
if self.config.distributed:
train_dataloader.batch_sampler.sampler.set_epoch(epoch)
Expand All @@ -76,15 +78,14 @@ def train_one_epoch(self, train_dataloader, eval_dataloader):

batch = self.process_batch(batch, device)

dist_pytorch.barrier(self.config.vendor)
pure_start_time = time.time()
optimizer.zero_grad()

images, target = batch
if scaler is not None:
with torch.cuda.amp.autocast(enabled=True):
output = model(images)

criterion = torch.nn.CrossEntropyLoss()
loss = criterion(output, target)

scaler.scale(loss).backward()
Expand All @@ -102,6 +103,7 @@ def train_one_epoch(self, train_dataloader, eval_dataloader):
print("Train Step " + str(step) + "/" + str(len(data_loader)) +
", Loss : " + str(float(loss)))

dist_pytorch.barrier(self.config.vendor)
self.training_state.purecomputetime += time.time(
) - pure_start_time

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def model_to_fp16(model: nn.Module) -> nn.Module:
# To prevent OOM for model sizes that cannot fit in GPU memory in full precision
if config.fp16:
main_proc_print(" > use fp16...")
model.half()
model.to(torch.bfloat16)
return model


Expand Down
8 changes: 7 additions & 1 deletion training/nvidia/resnet50-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- 加速卡型号: NVIDIA_A100-SXM4-40GB
- CPU型号: AMD [email protected]
- 多机网络类型、带宽: InfiniBand,200Gb/s

- ##### 软件环境
- OS版本:Ubuntu 20.04
- OS kernel版本: 5.4.0-113-generic
Expand All @@ -16,6 +17,10 @@
- 训练框架版本:pytorch-1.8.0a0+52ea372
- 依赖软件版本:
- cuda: 11.4

- 数据格式

- 在NVIDIA DGX A100(40G)硬件上,16bit浮点数(fp16)可以使用IEEE 754 fp16或bf16格式实现。在resnet50测试样例中,bf16的性能、准确度更高。因此采用bf16格式作为fp16的实现,实行在16bit训练中

### 运行情况

Expand Down Expand Up @@ -45,5 +50,6 @@
| A100单机8卡(1x8) | fp32 | bs=256,lr=0.8 | 22653 | 5663 | 5866 | 6105 | 73.5% | 28.3/40.0 |
| A100单机单卡(1x1) | fp32 | bs=256,lr=0.8 | | 782 | 795 | 799 | | 27.6/40.0 |
| A100两机8卡(2x8) | fp32 | bs=256,lr=0.8 | | 10576 | 11085 | 11874 | | 27.9/40.0 |

| A100单机8卡(1x8) | amp | bs=512,lr=0.2 | 15312 | 7544 | 7901 | 9567 | 72.7% | 28.6/40.0 |
| A100单机8卡(1x8) | bf16 | bs=512,lr=0.2 | 14082 | 8203 | 8550 | 9818 | 64.0% | 28.6/40.0 |

0 comments on commit 43d4f85

Please sign in to comment.