From 2671a69efd8a3b4ff1043f83685a53fad92ce2c4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 9 May 2023 18:52:21 -0700 Subject: [PATCH] Revert "address an issue with triton auto-tuner and in-place calls. make the assumption that after the first optimizer.step call, things are properly cached" This reverts commit 6ab873a380b47ebc5ea6f68ea588606daebb8b85. --- lion_pytorch/lion_pytorch.py | 15 +-------------- lion_pytorch/triton.py | 28 ++++------------------------ setup.py | 2 +- 3 files changed, 6 insertions(+), 39 deletions(-) diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index e025f89..0a6258a 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -47,8 +47,6 @@ def __init__( super().__init__(params, defaults) self.update_fn = update_fn - self.use_triton = use_triton - self.took_first_step = False if use_triton: from lion_pytorch.triton import update_fn as triton_update_fn @@ -65,13 +63,6 @@ def step( with torch.enable_grad(): loss = closure() - # address an issue with autotune and in-place updates with triton - # on the first .step call, simply do not update parameters in-place, if using triton - - update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict() - - # update all parameters - for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): @@ -91,11 +82,7 @@ def step( lr, wd, beta1, - beta2, - **update_kwargs + beta2 ) - if not self.took_first_step: - self.took_first_step = True - return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index ca59800..4cc35f0 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -1,5 +1,4 @@ import torch -from torch import Tensor try: import triton @@ -8,7 +7,6 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() -# triton cuda kernel @triton.autotune(configs = [ triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), @@ -74,31 +72,19 @@ def update_fn_kernel( tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) def update_fn( - p: Tensor, - grad: Tensor, - exp_avg: Tensor, + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, lr: float, wd: float, beta1: float, - beta2: float, - inplace: bool = True + beta2: float ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) n_elements = p.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - # address autotune and in-place update issue - - if not inplace: - orig_p = p - orig_exp_avg = exp_avg - - p = p.clone() - exp_avg = exp_avg.clone() - - # call triton cuda kernel - update_fn_kernel[grid]( p, grad, @@ -109,9 +95,3 @@ def update_fn( beta2, n_elements ) - - # update if not in-place call - - if not inplace: - orig_p.copy_(p) - orig_exp_avg.copy_(exp_avg) diff --git a/setup.py b/setup.py index 1eace6d..dab77ac 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.0.7', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',