Skip to content

Commit

Permalink
Add hysteresis support for AMP gradient scale update (#1733)
Browse files Browse the repository at this point in the history
* Add update_scale_hysteresis

* Fix compile errors

* Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715)

* input grad checks out

* adding clamp gamma

* Both old and proposed implementation checks out

* 2 tests not yet passed due to numerical issues

* mem_eff works

* fast-layer-norm done

* Moving mem-eff to templates

* Relax tolerance for memory efficient backward

* Fix backward api of python

* Distributed optimizer infrastructure for FP8 parameters (#1723)

* Add distopt support for param syncs with non-floating-point dtypes

Signed-off-by: Tim Moon <[email protected]>

* Update apex/contrib/optimizers/distributed_fused_adam.py

Co-authored-by: Masaki Kozuki <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>

* Add unit test

* Fix comment in unit test

* Remove unnecessary bits

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: Rui Wang <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>
  • Loading branch information
5 people authored Sep 30, 2023
1 parent 2386a91 commit 6a77872
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 0 deletions.
12 changes: 12 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ void multi_tensor_lamb_mp_cuda(
at::Tensor found_inf,
at::Tensor inv_scale);

at::Tensor update_scale_hysteresis_cuda(
at::Tensor current_scale,
at::Tensor growth_tracker,
at::Tensor hysteresis_tracker,
at::Tensor found_inf,
const double growth_factor,
const double backoff_factor,
const int64_t growth_interval,
const int hysteresis);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
Expand Down Expand Up @@ -211,4 +221,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
"Computes and apply update for LAMB optimizer");
m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda,
"Updates scale while accounting for hysteresis");
}
71 changes: 71 additions & 0 deletions csrc/update_scale_hysteresis.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <ATen/ATen.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAContext.h>

__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale,
int* growth_tracker,
int* hysteresis_tracker,
const float* found_inf,
double growth_factor,
double backoff_factor,
int growth_interval,
int hysteresis)
{
if (*found_inf > 0) {
*hysteresis_tracker -= 1;

// Only reset the growth tracker when hysteresis is larger than zero
if (*hysteresis_tracker > 0) {
*growth_tracker = 0;
return;
}
}

if (*found_inf) {
*current_scale = (*current_scale)*backoff_factor;
*growth_tracker = 0;
} else {
// Entering this branch means we just carried out a successful step,
// so growth_tracker is incremented before comparing to growth_interval.
auto successful = (*growth_tracker) + 1;
if (successful == growth_interval) {
auto new_scale = static_cast<float>((*current_scale)*growth_factor);
// Do not grow the scale past fp32 bounds to inf.
if (isfinite(new_scale)) {
*current_scale = new_scale;
}
*growth_tracker = 0;
} else {
*growth_tracker = successful;
}
}

// Reset the hysteresis tracker if no infs are found
if (*found_inf <= 0) {
*hysteresis_tracker = hysteresis;
}
}

at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale,
at::Tensor growth_tracker,
at::Tensor hysteresis_tracker,
at::Tensor found_inf,
const double growth_factor,
const double backoff_factor,
const int64_t growth_interval,
const int hysteresis)
{
update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
current_scale.mutable_data_ptr<float>(),
growth_tracker.mutable_data_ptr<int>(),
hysteresis_tracker.mutable_data_ptr<int>(),
found_inf.const_data_ptr<float>(),
growth_factor,
backoff_factor,
growth_interval,
hysteresis);

AT_CUDA_CHECK(cudaGetLastError());

return current_scale;
}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
"csrc/multi_tensor_novograd.cu",
"csrc/multi_tensor_lamb.cu",
"csrc/multi_tensor_lamb_mp.cu",
"csrc/update_scale_hysteresis.cu",
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
Expand Down
102 changes: 102 additions & 0 deletions tests/L0/run_amp/test_update_scale_hysteresis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
import random
import math

import torch

try:
import amp_C
from amp_C import update_scale_hysteresis
disabled = False
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err)
disabled = True

def isfinite(val):
return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max))

class TestUpdateScaleHysteresis(unittest.TestCase):

def setUp(self):
pass

def tearDown(self):
pass

def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor,
growth_interval, hysteresis):
scale_ref = float(init_scale)
grow_tracker_ref = 0
hysteresis_tracker_ref = 0

scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda')
growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda')
hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda')

# Infs appear for hysteresis-1 iterations, scale shouldn't change
found_inf = torch.tensor([1], dtype=torch.float32, device='cuda')
for i in range(hysteresis-1):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
self.assertTrue(scale.item() == init_scale)

# No infs for growth_interval-1 iterations, scale shouldn't change
found_inf.zero_()
for i in range(growth_interval-1):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
self.assertTrue(scale.item() == init_scale)

# Infs appear for more than hysteresis iterations, scale should be backed off
found_inf.fill_(1)
extra_iters = random.randint(0, 1000)
scale_before = scale.detach().item()
scale_ref = scale_before
for i in range(hysteresis + extra_iters):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
for i in range(1 + extra_iters):
# Scale is continuously backed off for each iteration with an inf
scale_new = scale_ref * backoff_factor
if isfinite(scale_new):
scale_ref = scale_new
else:
scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero
self.assertTrue(scale.item() == scale_ref)

# No infs for more than growth_interval iterations, scale should be increased
found_inf.fill_(0)
extra_iters = random.randint(0, 1000)
scale_before = scale.detach().item()
scale_ref = scale_before
for i in range(growth_interval + extra_iters):
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
for i in range(1 + int(math.floor(extra_iters / growth_interval))):
# Scale is grown every growth_interval iterations
scale_new = scale_ref * growth_factor
if isfinite(scale_new):
scale_ref = scale_new
self.assertTrue(scale.item() == scale_ref)


@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fuzz(self):
init_scale_list = [1, 1024, 65536]
growth_factor_list = [1.0, 2.0, 4.0]
backoff_factor_list = [0.5, 0.25]
growth_interval_list = [10, 100]
hysteresis_list = [10, 100]

for init_scale in init_scale_list:
for growth_factor in growth_factor_list:
for backoff_factor in backoff_factor_list:
for growth_interval in growth_interval_list:
for hysteresis in hysteresis_list:
self.update_scale_hysteresis_body(init_scale, growth_factor,
backoff_factor, growth_interval, hysteresis)



if __name__ == '__main__':
unittest.main()

0 comments on commit 6a77872

Please sign in to comment.