From 0781eb1f2edeb315002e3b2f05a3d28e3cdb74c3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 9 May 2023 18:52:14 -0700 Subject: [PATCH] Revert "actually, just follow @ipoletaev advice and remove autotuner for now" This reverts commit 2226ec8aeee03e9fbbf561e50fbf114b9677d3e9. --- lion_pytorch/lion_pytorch.py | 17 +++++++++----- lion_pytorch/triton.py | 43 +++++++++++++++++++----------------- setup.py | 2 +- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index f96fec6..e025f89 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Tuple, Optional, Callable import torch @@ -34,8 +33,7 @@ def __init__( lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, - use_triton: bool = False, - triton_block_size: int = 1024 + use_triton: bool = False ): assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) @@ -54,7 +52,7 @@ def __init__( if use_triton: from lion_pytorch.triton import update_fn as triton_update_fn - self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size) + self.update_fn = triton_update_fn @torch.no_grad() def step( @@ -67,6 +65,11 @@ def step( with torch.enable_grad(): loss = closure() + # address an issue with autotune and in-place updates with triton + # on the first .step call, simply do not update parameters in-place, if using triton + + update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict() + # update all parameters for group in self.param_groups: @@ -88,7 +91,11 @@ def step( lr, wd, beta1, - beta2 + beta2, + **update_kwargs ) + if not self.took_first_step: + self.took_first_step = True + return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index 0615dbc..ca59800 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -8,18 +8,12 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() -# helper functions - -def calc_num_warps(block_size): - num_warps = 4 - if block_size >= 2048: - num_warps = 8 - if block_size >= 4096: - num_warps = 16 - return num_warps - # triton cuda kernel +@triton.autotune(configs = [ + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements']) @triton.jit def update_fn_kernel( p_ptr, @@ -87,20 +81,25 @@ def update_fn( wd: float, beta1: float, beta2: float, - inplace: bool = True, - BLOCK_SIZE: int = 1024 + inplace: bool = True ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) - n_elements = p.numel() - block_size = triton.next_power_of_2(BLOCK_SIZE) - num_warps = calc_num_warps(block_size) - n_rows = triton.cdiv(n_elements, block_size) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + # address autotune and in-place update issue + + if not inplace: + orig_p = p + orig_exp_avg = exp_avg + + p = p.clone() + exp_avg = exp_avg.clone() # call triton cuda kernel - update_fn_kernel[(n_rows,)]( + update_fn_kernel[grid]( p, grad, exp_avg, @@ -108,7 +107,11 @@ def update_fn( wd, beta1, beta2, - n_elements, - num_warps = num_warps, - BLOCK_SIZE = BLOCK_SIZE + n_elements ) + + # update if not in-place call + + if not inplace: + orig_p.copy_(p) + orig_exp_avg.copy_(exp_avg) diff --git a/setup.py b/setup.py index 35eb081..1eace6d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.0.8', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',