diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index ab5b08e..40e4221 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -7,19 +7,12 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() -# clone param and exp_avg before autotuning takes place -# as those are updated in-place - -def clone_inplace_updated_params(nargs): - nargs['p_ptr'] = nargs['p_ptr'].clone() - nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() - # triton cuda kernel @triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params), -], key = ['n_elements']) + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) @triton.jit def update_fn_kernel( p_ptr,