diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index 0a6258a..e025f89 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -47,6 +47,8 @@ def __init__( super().__init__(params, defaults) self.update_fn = update_fn + self.use_triton = use_triton + self.took_first_step = False if use_triton: from lion_pytorch.triton import update_fn as triton_update_fn @@ -63,6 +65,13 @@ def step( with torch.enable_grad(): loss = closure() + # address an issue with autotune and in-place updates with triton + # on the first .step call, simply do not update parameters in-place, if using triton + + update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict() + + # update all parameters + for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): @@ -82,7 +91,11 @@ def step( lr, wd, beta1, - beta2 + beta2, + **update_kwargs ) + if not self.took_first_step: + self.took_first_step = True + return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index 4cc35f0..ca59800 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor try: import triton @@ -7,6 +8,7 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() +# triton cuda kernel @triton.autotune(configs = [ triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), @@ -72,19 +74,31 @@ def update_fn_kernel( tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) def update_fn( - p: torch.Tensor, - grad: torch.Tensor, - exp_avg: torch.Tensor, + p: Tensor, + grad: Tensor, + exp_avg: Tensor, lr: float, wd: float, beta1: float, - beta2: float + beta2: float, + inplace: bool = True ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) n_elements = p.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + # address autotune and in-place update issue + + if not inplace: + orig_p = p + orig_exp_avg = exp_avg + + p = p.clone() + exp_avg = exp_avg.clone() + + # call triton cuda kernel + update_fn_kernel[grid]( p, grad, @@ -95,3 +109,9 @@ def update_fn( beta2, n_elements ) + + # update if not in-place call + + if not inplace: + orig_p.copy_(p) + orig_exp_avg.copy_(exp_avg) diff --git a/setup.py b/setup.py index dab77ac..1eace6d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.0.8', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',