Skip to content

Commit

Permalink
Revert "actually, just follow @ipoletaev advice and remove autotuner …
Browse files Browse the repository at this point in the history
…for now"

This reverts commit 2226ec8.
  • Loading branch information
lucidrains committed May 10, 2023
1 parent 2226ec8 commit 0781eb1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 26 deletions.
17 changes: 12 additions & 5 deletions lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import Tuple, Optional, Callable

import torch
Expand Down Expand Up @@ -34,8 +33,7 @@ def __init__(
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False,
triton_block_size: int = 1024
use_triton: bool = False
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
Expand All @@ -54,7 +52,7 @@ def __init__(

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size)
self.update_fn = triton_update_fn

@torch.no_grad()
def step(
Expand All @@ -67,6 +65,11 @@ def step(
with torch.enable_grad():
loss = closure()

# address an issue with autotune and in-place updates with triton
# on the first .step call, simply do not update parameters in-place, if using triton

update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict()

# update all parameters

for group in self.param_groups:
Expand All @@ -88,7 +91,11 @@ def step(
lr,
wd,
beta1,
beta2
beta2,
**update_kwargs
)

if not self.took_first_step:
self.took_first_step = True

return loss
43 changes: 23 additions & 20 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,12 @@
print('triton is not installed, please install by running `pip install triton -U --pre`')
exit()

# helper functions

def calc_num_warps(block_size):
num_warps = 4
if block_size >= 2048:
num_warps = 8
if block_size >= 4096:
num_warps = 16
return num_warps

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
], key = ['n_elements'])
@triton.jit
def update_fn_kernel(
p_ptr,
Expand Down Expand Up @@ -87,28 +81,37 @@ def update_fn(
wd: float,
beta1: float,
beta2: float,
inplace: bool = True,
BLOCK_SIZE: int = 1024
inplace: bool = True
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])

n_elements = p.numel()

block_size = triton.next_power_of_2(BLOCK_SIZE)
num_warps = calc_num_warps(block_size)
n_rows = triton.cdiv(n_elements, block_size)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

# address autotune and in-place update issue

if not inplace:
orig_p = p
orig_exp_avg = exp_avg

p = p.clone()
exp_avg = exp_avg.clone()

# call triton cuda kernel

update_fn_kernel[(n_rows,)](
update_fn_kernel[grid](
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2,
n_elements,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
n_elements
)

# update if not in-place call

if not inplace:
orig_p.copy_(p)
orig_exp_avg.copy_(exp_avg)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.0.8',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0781eb1

Please sign in to comment.