From 6a7787201b0f8f93198d6647325b425b84a64048 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 29 Sep 2023 20:26:44 -0700 Subject: [PATCH] Add hysteresis support for AMP gradient scale update (#1733) * Add update_scale_hysteresis * Fix compile errors * Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715) * input grad checks out * adding clamp gamma * Both old and proposed implementation checks out * 2 tests not yet passed due to numerical issues * mem_eff works * fast-layer-norm done * Moving mem-eff to templates * Relax tolerance for memory efficient backward * Fix backward api of python * Distributed optimizer infrastructure for FP8 parameters (#1723) * Add distopt support for param syncs with non-floating-point dtypes Signed-off-by: Tim Moon * Update apex/contrib/optimizers/distributed_fused_adam.py Co-authored-by: Masaki Kozuki --------- Signed-off-by: Tim Moon Co-authored-by: Masaki Kozuki * Add unit test * Fix comment in unit test * Remove unnecessary bits --------- Signed-off-by: Tim Moon Co-authored-by: Jaemin Choi Co-authored-by: Rui Wang Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Masaki Kozuki --- csrc/amp_C_frontend.cpp | 12 +++ csrc/update_scale_hysteresis.cu | 71 ++++++++++++ setup.py | 1 + .../run_amp/test_update_scale_hysteresis.py | 102 ++++++++++++++++++ 4 files changed, 186 insertions(+) create mode 100644 csrc/update_scale_hysteresis.cu create mode 100644 tests/L0/run_amp/test_update_scale_hysteresis.py diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp index 74d36487..d39be60c 100644 --- a/csrc/amp_C_frontend.cpp +++ b/csrc/amp_C_frontend.cpp @@ -178,6 +178,16 @@ void multi_tensor_lamb_mp_cuda( at::Tensor found_inf, at::Tensor inv_scale); +at::Tensor update_scale_hysteresis_cuda( + at::Tensor current_scale, + at::Tensor growth_tracker, + at::Tensor hysteresis_tracker, + at::Tensor found_inf, + const double growth_factor, + const double backoff_factor, + const int64_t growth_interval, + const int hysteresis); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_scale", &multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors"); @@ -211,4 +221,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Computes and apply update for LAMB optimizer"); m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, "Computes and apply update for LAMB optimizer"); + m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda, + "Updates scale while accounting for hysteresis"); } diff --git a/csrc/update_scale_hysteresis.cu b/csrc/update_scale_hysteresis.cu new file mode 100644 index 00000000..2405130a --- /dev/null +++ b/csrc/update_scale_hysteresis.cu @@ -0,0 +1,71 @@ +#include +#include +#include + +__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, + int* growth_tracker, + int* hysteresis_tracker, + const float* found_inf, + double growth_factor, + double backoff_factor, + int growth_interval, + int hysteresis) +{ + if (*found_inf > 0) { + *hysteresis_tracker -= 1; + + // Only reset the growth tracker when hysteresis is larger than zero + if (*hysteresis_tracker > 0) { + *growth_tracker = 0; + return; + } + } + + if (*found_inf) { + *current_scale = (*current_scale)*backoff_factor; + *growth_tracker = 0; + } else { + // Entering this branch means we just carried out a successful step, + // so growth_tracker is incremented before comparing to growth_interval. + auto successful = (*growth_tracker) + 1; + if (successful == growth_interval) { + auto new_scale = static_cast((*current_scale)*growth_factor); + // Do not grow the scale past fp32 bounds to inf. + if (isfinite(new_scale)) { + *current_scale = new_scale; + } + *growth_tracker = 0; + } else { + *growth_tracker = successful; + } + } + + // Reset the hysteresis tracker if no infs are found + if (*found_inf <= 0) { + *hysteresis_tracker = hysteresis; + } +} + +at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, + at::Tensor growth_tracker, + at::Tensor hysteresis_tracker, + at::Tensor found_inf, + const double growth_factor, + const double backoff_factor, + const int64_t growth_interval, + const int hysteresis) +{ + update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + current_scale.mutable_data_ptr(), + growth_tracker.mutable_data_ptr(), + hysteresis_tracker.mutable_data_ptr(), + found_inf.const_data_ptr(), + growth_factor, + backoff_factor, + growth_interval, + hysteresis); + + AT_CUDA_CHECK(cudaGetLastError()); + + return current_scale; +} diff --git a/setup.py b/setup.py index fd160230..329f8564 100644 --- a/setup.py +++ b/setup.py @@ -195,6 +195,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "csrc/multi_tensor_novograd.cu", "csrc/multi_tensor_lamb.cu", "csrc/multi_tensor_lamb_mp.cu", + "csrc/update_scale_hysteresis.cu", ], extra_compile_args={ "cxx": ["-O3"] + version_dependent_macros, diff --git a/tests/L0/run_amp/test_update_scale_hysteresis.py b/tests/L0/run_amp/test_update_scale_hysteresis.py new file mode 100644 index 00000000..6bb52400 --- /dev/null +++ b/tests/L0/run_amp/test_update_scale_hysteresis.py @@ -0,0 +1,102 @@ +import unittest +import random +import math + +import torch + +try: + import amp_C + from amp_C import update_scale_hysteresis + disabled = False +except ImportError as err: + print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err) + disabled = True + +def isfinite(val): + return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max)) + +class TestUpdateScaleHysteresis(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor, + growth_interval, hysteresis): + scale_ref = float(init_scale) + grow_tracker_ref = 0 + hysteresis_tracker_ref = 0 + + scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda') + growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda') + hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda') + + # Infs appear for hysteresis-1 iterations, scale shouldn't change + found_inf = torch.tensor([1], dtype=torch.float32, device='cuda') + for i in range(hysteresis-1): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + self.assertTrue(scale.item() == init_scale) + + # No infs for growth_interval-1 iterations, scale shouldn't change + found_inf.zero_() + for i in range(growth_interval-1): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + self.assertTrue(scale.item() == init_scale) + + # Infs appear for more than hysteresis iterations, scale should be backed off + found_inf.fill_(1) + extra_iters = random.randint(0, 1000) + scale_before = scale.detach().item() + scale_ref = scale_before + for i in range(hysteresis + extra_iters): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + for i in range(1 + extra_iters): + # Scale is continuously backed off for each iteration with an inf + scale_new = scale_ref * backoff_factor + if isfinite(scale_new): + scale_ref = scale_new + else: + scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero + self.assertTrue(scale.item() == scale_ref) + + # No infs for more than growth_interval iterations, scale should be increased + found_inf.fill_(0) + extra_iters = random.randint(0, 1000) + scale_before = scale.detach().item() + scale_ref = scale_before + for i in range(growth_interval + extra_iters): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + for i in range(1 + int(math.floor(extra_iters / growth_interval))): + # Scale is grown every growth_interval iterations + scale_new = scale_ref * growth_factor + if isfinite(scale_new): + scale_ref = scale_new + self.assertTrue(scale.item() == scale_ref) + + + @unittest.skipIf(disabled, "amp_C is unavailable") + def test_fuzz(self): + init_scale_list = [1, 1024, 65536] + growth_factor_list = [1.0, 2.0, 4.0] + backoff_factor_list = [0.5, 0.25] + growth_interval_list = [10, 100] + hysteresis_list = [10, 100] + + for init_scale in init_scale_list: + for growth_factor in growth_factor_list: + for backoff_factor in backoff_factor_list: + for growth_interval in growth_interval_list: + for hysteresis in hysteresis_list: + self.update_scale_hysteresis_body(init_scale, growth_factor, + backoff_factor, growth_interval, hysteresis) + + + +if __name__ == '__main__': + unittest.main()