From 3d1e555a52060ec67d7fd51d93890930b3039346 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 10 May 2023 07:43:10 -0700 Subject: [PATCH] attempt to fix autotune + inplace update issue --- lion_pytorch/triton.py | 12 ++++++++++-- setup.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index 4cc35f0..ab5b08e 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -7,10 +7,18 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() +# clone param and exp_avg before autotuning takes place +# as those are updated in-place + +def clone_inplace_updated_params(nargs): + nargs['p_ptr'] = nargs['p_ptr'].clone() + nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() + +# triton cuda kernel @triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params), ], key = ['n_elements']) @triton.jit def update_fn_kernel( diff --git a/setup.py b/setup.py index dab77ac..ee9ddb1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.1.2', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',