-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add hysteresis support for AMP gradient scale update (#1733)
* Add update_scale_hysteresis * Fix compile errors * Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715) * input grad checks out * adding clamp gamma * Both old and proposed implementation checks out * 2 tests not yet passed due to numerical issues * mem_eff works * fast-layer-norm done * Moving mem-eff to templates * Relax tolerance for memory efficient backward * Fix backward api of python * Distributed optimizer infrastructure for FP8 parameters (#1723) * Add distopt support for param syncs with non-floating-point dtypes Signed-off-by: Tim Moon <[email protected]> * Update apex/contrib/optimizers/distributed_fused_adam.py Co-authored-by: Masaki Kozuki <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]> * Add unit test * Fix comment in unit test * Remove unnecessary bits --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Jaemin Choi <[email protected]> Co-authored-by: Rui Wang <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]>
- Loading branch information
1 parent
2386a91
commit 6a77872
Showing
4 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/Exceptions.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, | ||
int* growth_tracker, | ||
int* hysteresis_tracker, | ||
const float* found_inf, | ||
double growth_factor, | ||
double backoff_factor, | ||
int growth_interval, | ||
int hysteresis) | ||
{ | ||
if (*found_inf > 0) { | ||
*hysteresis_tracker -= 1; | ||
|
||
// Only reset the growth tracker when hysteresis is larger than zero | ||
if (*hysteresis_tracker > 0) { | ||
*growth_tracker = 0; | ||
return; | ||
} | ||
} | ||
|
||
if (*found_inf) { | ||
*current_scale = (*current_scale)*backoff_factor; | ||
*growth_tracker = 0; | ||
} else { | ||
// Entering this branch means we just carried out a successful step, | ||
// so growth_tracker is incremented before comparing to growth_interval. | ||
auto successful = (*growth_tracker) + 1; | ||
if (successful == growth_interval) { | ||
auto new_scale = static_cast<float>((*current_scale)*growth_factor); | ||
// Do not grow the scale past fp32 bounds to inf. | ||
if (isfinite(new_scale)) { | ||
*current_scale = new_scale; | ||
} | ||
*growth_tracker = 0; | ||
} else { | ||
*growth_tracker = successful; | ||
} | ||
} | ||
|
||
// Reset the hysteresis tracker if no infs are found | ||
if (*found_inf <= 0) { | ||
*hysteresis_tracker = hysteresis; | ||
} | ||
} | ||
|
||
at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, | ||
at::Tensor growth_tracker, | ||
at::Tensor hysteresis_tracker, | ||
at::Tensor found_inf, | ||
const double growth_factor, | ||
const double backoff_factor, | ||
const int64_t growth_interval, | ||
const int hysteresis) | ||
{ | ||
update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
current_scale.mutable_data_ptr<float>(), | ||
growth_tracker.mutable_data_ptr<int>(), | ||
hysteresis_tracker.mutable_data_ptr<int>(), | ||
found_inf.const_data_ptr<float>(), | ||
growth_factor, | ||
backoff_factor, | ||
growth_interval, | ||
hysteresis); | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
|
||
return current_scale; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import unittest | ||
import random | ||
import math | ||
|
||
import torch | ||
|
||
try: | ||
import amp_C | ||
from amp_C import update_scale_hysteresis | ||
disabled = False | ||
except ImportError as err: | ||
print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err) | ||
disabled = True | ||
|
||
def isfinite(val): | ||
return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max)) | ||
|
||
class TestUpdateScaleHysteresis(unittest.TestCase): | ||
|
||
def setUp(self): | ||
pass | ||
|
||
def tearDown(self): | ||
pass | ||
|
||
def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor, | ||
growth_interval, hysteresis): | ||
scale_ref = float(init_scale) | ||
grow_tracker_ref = 0 | ||
hysteresis_tracker_ref = 0 | ||
|
||
scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda') | ||
growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda') | ||
hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda') | ||
|
||
# Infs appear for hysteresis-1 iterations, scale shouldn't change | ||
found_inf = torch.tensor([1], dtype=torch.float32, device='cuda') | ||
for i in range(hysteresis-1): | ||
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, | ||
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) | ||
self.assertTrue(scale.item() == init_scale) | ||
|
||
# No infs for growth_interval-1 iterations, scale shouldn't change | ||
found_inf.zero_() | ||
for i in range(growth_interval-1): | ||
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, | ||
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) | ||
self.assertTrue(scale.item() == init_scale) | ||
|
||
# Infs appear for more than hysteresis iterations, scale should be backed off | ||
found_inf.fill_(1) | ||
extra_iters = random.randint(0, 1000) | ||
scale_before = scale.detach().item() | ||
scale_ref = scale_before | ||
for i in range(hysteresis + extra_iters): | ||
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, | ||
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) | ||
for i in range(1 + extra_iters): | ||
# Scale is continuously backed off for each iteration with an inf | ||
scale_new = scale_ref * backoff_factor | ||
if isfinite(scale_new): | ||
scale_ref = scale_new | ||
else: | ||
scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero | ||
self.assertTrue(scale.item() == scale_ref) | ||
|
||
# No infs for more than growth_interval iterations, scale should be increased | ||
found_inf.fill_(0) | ||
extra_iters = random.randint(0, 1000) | ||
scale_before = scale.detach().item() | ||
scale_ref = scale_before | ||
for i in range(growth_interval + extra_iters): | ||
update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, | ||
found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) | ||
for i in range(1 + int(math.floor(extra_iters / growth_interval))): | ||
# Scale is grown every growth_interval iterations | ||
scale_new = scale_ref * growth_factor | ||
if isfinite(scale_new): | ||
scale_ref = scale_new | ||
self.assertTrue(scale.item() == scale_ref) | ||
|
||
|
||
@unittest.skipIf(disabled, "amp_C is unavailable") | ||
def test_fuzz(self): | ||
init_scale_list = [1, 1024, 65536] | ||
growth_factor_list = [1.0, 2.0, 4.0] | ||
backoff_factor_list = [0.5, 0.25] | ||
growth_interval_list = [10, 100] | ||
hysteresis_list = [10, 100] | ||
|
||
for init_scale in init_scale_list: | ||
for growth_factor in growth_factor_list: | ||
for backoff_factor in backoff_factor_list: | ||
for growth_interval in growth_interval_list: | ||
for hysteresis in hysteresis_list: | ||
self.update_scale_hysteresis_body(init_scale, growth_factor, | ||
backoff_factor, growth_interval, hysteresis) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |