Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

full resnet50 precision(bf16+amp) #253

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class ToFloat16(object):

def __call__(self, tensor):
return tensor.to(dtype=torch.float16)
return tensor.to(dtype=torch.bfloat16)


def build_train_dataset(config):
Expand Down
6 changes: 4 additions & 2 deletions training/benchmarks/resnet50/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def train_one_epoch(self, train_dataloader, eval_dataloader):
device = self.device
epoch = self.training_state.epoch
scaler = self.scaler
criterion = torch.nn.CrossEntropyLoss()

print("Epoch " + str(epoch + 1))
if self.config.distributed:
train_dataloader.batch_sampler.sampler.set_epoch(epoch)
Expand All @@ -76,15 +78,14 @@ def train_one_epoch(self, train_dataloader, eval_dataloader):

batch = self.process_batch(batch, device)

dist_pytorch.barrier(self.config.vendor)
pure_start_time = time.time()
optimizer.zero_grad()

images, target = batch
if scaler is not None:
with torch.cuda.amp.autocast(enabled=True):
output = model(images)

criterion = torch.nn.CrossEntropyLoss()
loss = criterion(output, target)

scaler.scale(loss).backward()
Expand All @@ -102,6 +103,7 @@ def train_one_epoch(self, train_dataloader, eval_dataloader):
print("Train Step " + str(step) + "/" + str(len(data_loader)) +
", Loss : " + str(float(loss)))

dist_pytorch.barrier(self.config.vendor)
self.training_state.purecomputetime += time.time(
) - pure_start_time

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def model_to_fp16(model: nn.Module) -> nn.Module:
# To prevent OOM for model sizes that cannot fit in GPU memory in full precision
if config.fp16:
main_proc_print(" > use fp16...")
model.half()
model.to(torch.bfloat16)
return model


Expand Down
8 changes: 7 additions & 1 deletion training/nvidia/resnet50-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- 加速卡型号: NVIDIA_A100-SXM4-40GB
- CPU型号: AMD [email protected]
- 多机网络类型、带宽: InfiniBand,200Gb/s

- ##### 软件环境
- OS版本:Ubuntu 20.04
- OS kernel版本: 5.4.0-113-generic
Expand All @@ -16,6 +17,10 @@
- 训练框架版本:pytorch-1.8.0a0+52ea372
- 依赖软件版本:
- cuda: 11.4

- 数据格式

- 在NVIDIA DGX A100(40G)硬件上,16bit浮点数(fp16)可以使用IEEE 754 fp16或bf16格式实现。在resnet50测试样例中,bf16的性能、准确度更高。因此采用bf16格式作为fp16的实现,实行在16bit训练和amp训练中

### 运行情况

Expand Down Expand Up @@ -45,5 +50,6 @@
| A100单机8卡(1x8) | fp32 | bs=256,lr=0.8 | 22653 | 5663 | 5866 | 6105 | 73.5% | 28.3/40.0 |
| A100单机单卡(1x1) | fp32 | bs=256,lr=0.8 | | 782 | 795 | 799 | | 27.6/40.0 |
| A100两机8卡(2x8) | fp32 | bs=256,lr=0.8 | | 10576 | 11085 | 11874 | | 27.9/40.0 |

| A100单机8卡(1x8) | amp | bs=512,lr=0.2 | 15312 | 7544 | 7901 | 9567 | 72.7% | 28.6/40.0 |
| A100单机8卡(1x8) | bf16 | bs=512,lr=0.2 | 14082 | 8203 | 8550 | 9818 | 64.0% | 28.6/40.0 |