-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlion_optax.py
114 lines (95 loc) · 4.21 KB
/
lion_optax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2023 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optax implementation of the Lion optimizer."""
from typing import Any, Callable, NamedTuple, Optional, Union
import chex
import jax
import jax.numpy as jnp
import optax
def _scale_by_learning_rate(
learning_rate: optax.ScalarOrSchedule, flip_sign=True):
m = -1 if flip_sign else 1
if callable(learning_rate):
return optax.scale_by_schedule(lambda count: m * learning_rate(count))
return optax.scale(m * learning_rate)
def lion(
learning_rate: optax.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.99,
mu_dtype: Optional[Any] = None,
weight_decay: float = 0.0,
mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Lion.
Args:
learning_rate: A fixed global scaling factor.
b1: Exponential decay rate to combine the gradient and the moment.
b2: Exponential decay rate to track the moment of past gradients.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
Returns:
The corresponding `GradientTransformation`.
"""
return optax.chain(
scale_by_lion(
b1=b1, b2=b2, mu_dtype=mu_dtype),
optax.add_decayed_weights(weight_decay, mask),
_scale_by_learning_rate(learning_rate),
)
def update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order`-th moment."""
return jax.tree_util.tree_map(
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
class ScaleByLionState(NamedTuple):
"""State for the Lion algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
def scale_by_lion(
b1: float = 0.9,
b2: float = 0.99,
mu_dtype: Optional[Any] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the Lion algorithm.
Args:
b1: rate for combining moment and the current grad.
b2: decay rate for the exponentially weighted average of grads.
mu_dtype: optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype is inferred from `params` and `updates`.
Returns:
A `GradientTransformation` object.
"""
def init_fn(params):
mu = jax.tree_util.tree_map( # moment
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu)
def update_fn(updates, state, params=None):
del params
mu = update_moment(updates, state.mu, b2, 1)
mu = jax.tree_map(lambda x: x.astype(mu_dtype), mu)
count_inc = optax.safe_int32_increment(state.count)
updates = jax.tree_util.tree_map(
lambda g, m: jnp.sign((1. - b1) * g + b1 * m), updates, state.mu)
return updates, ScaleByLionState(count=count_inc, mu=mu)
return optax.GradientTransformation(init_fn, update_fn)