From 974b754ffd65f7026ad8ea40c95efdb2b6d69230 Mon Sep 17 00:00:00 2001 From: yousufmo <108403694+yousufmo@users.noreply.github.com> Date: Tue, 19 Mar 2024 17:36:34 +0000 Subject: [PATCH] Fix in-place modification when autotuning triton Lion update --- lion_pytorch/triton.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index ab5b08e..40e4221 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -7,19 +7,12 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() -# clone param and exp_avg before autotuning takes place -# as those are updated in-place - -def clone_inplace_updated_params(nargs): - nargs['p_ptr'] = nargs['p_ptr'].clone() - nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() - # triton cuda kernel @triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params), -], key = ['n_elements']) + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) @triton.jit def update_fn_kernel( p_ptr,