Skip to content

Commit

Permalink
Merge pull request #1059 from enolan:fix-schedule-free-donate
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681476986
  • Loading branch information
OptaxDev committed Oct 2, 2024
2 parents 211760b + b60629d commit dd1daab
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 22 deletions.
28 changes: 14 additions & 14 deletions optax/contrib/_schedule_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def schedule_free_eval_params(state: base.OptState, params: base.Params):
raise ValueError(
'schedule_free_eval_params requires a ScheduleFreeState as input.'
)
return jax.tree.map(
lambda yi, zi: (yi - (1.0 - b1) * zi) / b1, params, z
)
return jax.tree.map(lambda yi, zi: (yi - (1.0 - b1) * zi) / b1, params, z)


def schedule_free(
Expand Down Expand Up @@ -145,6 +143,9 @@ def init_fn(params: base.Params) -> ScheduleFreeState:
z = otu.tree_cast(params, dtype=state_dtype)
else:
z = params
# It's imporant to copy the params here so that z is a distinct array and
# we can donate both z and the params to JITted functions.
z = jax.tree_util.tree_map(lambda t: t.copy(), z)
return ScheduleFreeState(
b1=jnp.asarray(b1, dtype=params_dtype),
weight_sum=jnp.zeros([], dtype=params_dtype),
Expand Down Expand Up @@ -206,9 +207,7 @@ def update_fn(
x,
z,
)
updates = jax.tree.map(
lambda npi, pi: npi - pi, new_params, params
)
updates = jax.tree.map(lambda npi, pi: npi - pi, new_params, params)

next_state = ScheduleFreeState(
b1=state.b1,
Expand Down Expand Up @@ -244,10 +243,10 @@ def schedule_free_sgd(
warmup_steps: positive integer, the length of the linear warmup.
b1: beta_1 parameter in the y update.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
weight decay is multiplied with the learning rate. This is consistent with
other frameworks such as PyTorch, but different from (Loshchilov et al,
2019) where the weight decay is only multiplied with the "schedule
multiplier", but not the base learning rate.
weight_lr_power: we downweight the weight of averaging using this. This is
especially helpful in early iterations during warmup.
state_dtype: dtype for z sequence in the schedule free method.
Expand Down Expand Up @@ -287,7 +286,8 @@ def schedule_free_sgd(
optimizer = alias.sgd(learning_rate)
if weight_decay is not None:
optimizer = combine.chain(
_adding.add_decayed_weights(weight_decay), optimizer)
_adding.add_decayed_weights(weight_decay), optimizer
)
return schedule_free(
optimizer,
learning_rate=learning_rate,
Expand Down Expand Up @@ -319,8 +319,8 @@ def schedule_free_adamw(
warmup_steps: positive integer, the length of the linear warmup.
b1: beta_1 parameter in the y update.
b2: Exponential decay rate to track the second moment of past gradients.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps: A small constant applied to denominator outside of the square root (as
in the Adam paper) to avoid dividing by zero when rescaling.
weight_decay: Strength of the weight decay regularization.
weight_lr_power: we downweight the weight of averaging using this. This is
especially helpful in early iterations during warmup.
Expand Down Expand Up @@ -364,7 +364,7 @@ def schedule_free_adamw(
decay=b2, eps=eps, eps_in_sqrt=False, bias_correction=True
),
_adding.add_decayed_weights(weight_decay),
transform.scale_by_learning_rate(learning_rate)
transform.scale_by_learning_rate(learning_rate),
)
return schedule_free(
optimizer,
Expand Down
34 changes: 26 additions & 8 deletions optax/contrib/_schedule_free_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `_schedule_free.py`."""
"""Tests for the schedule-free wrapper."""

import functools
from absl.testing import absltest
from absl.testing import parameterized
import chex
Expand Down Expand Up @@ -42,7 +43,7 @@ class ScheduleFreeTest(chex.TestCase):
def test_learning_rate_zero(self):
base_opt = alias.sgd(learning_rate=0.0, momentum=0.0)
opt = _schedule_free.schedule_free(base_opt, learning_rate=0.0)
initial_params = jnp.array([1., 2.])
initial_params = jnp.array([1.0, 2.0])
fun = lambda x: jnp.sum(x**2)

@jax.jit
Expand All @@ -64,7 +65,7 @@ def step(params, state):

def test_schedule_free_adamw(self):

initial_params = jnp.array([1., 2.])
initial_params = jnp.array([1.0, 2.0])
fun = lambda x: jnp.sum(x**2)

def step(params, state, opt):
Expand All @@ -85,18 +86,20 @@ def run(opt):
opt_shortcut = _schedule_free.schedule_free_adamw(
learning_rate=1.0,
b1=0.9,
weight_decay=1-4,
weight_decay=1e-4,
)
params_shortcut = run(opt_shortcut)

# Test with wrapper implementation
opt_wrapper = _schedule_free.schedule_free(
alias.adamw(learning_rate=1.0, b1=0.0, weight_decay=1-4),
alias.adamw(learning_rate=1.0, b1=0.0, weight_decay=1e-4),
learning_rate=1.0,
b1=0.9,
)
params_wrapper = run(opt_wrapper)
chex.assert_trees_all_close(params_shortcut, params_wrapper)
chex.assert_trees_all_close(
params_shortcut, params_wrapper, atol=1e-6, rtol=1e-6
)

def test_scalar_preservance(self):
# Test whether the scalar arrays of shape () are preserved through
Expand All @@ -110,6 +113,22 @@ def test_scalar_preservance(self):
chex.assert_equal_shape([params, eval_params])
chex.assert_trees_all_equal_dtypes(params, eval_params)

def test_buffer_donation(self):
# Check that you can donate the params and optimizer state when doing a JIT
# update.
opt = _schedule_free.schedule_free_sgd()
initial_params, _, get_updates = _setup_parabola(jnp.float32)

@functools.partial(jax.jit, donate_argnums=(0, 1))
def step(params, state):
updates = get_updates(params)
updates, state = opt.update(updates, state, params)
params = update.apply_updates(params, updates)
return params, state

state = opt.init(initial_params)
_, _ = step(initial_params, state)

@parameterized.product(
params_dtype=('bfloat16', 'float32', 'complex64', None),
state_dtype=('bfloat16', 'float32', 'complex64', None),
Expand All @@ -123,8 +142,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype):
params_dtype = jax.dtypes.canonicalize_dtype(params_dtype)
params = jnp.array([0.0, 0.0], dtype=params_dtype)
state_has_lower_dtype = (
jnp.promote_types(params_dtype, state_dtype)
== params_dtype
jnp.promote_types(params_dtype, state_dtype) == params_dtype
)
if state_dtype is None or state_has_lower_dtype:
state = opt.init(params)
Expand Down

0 comments on commit dd1daab

Please sign in to comment.