Skip to content

Commit

Permalink
Fix initial step of scale_by_optimistic_gradient.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Oct 3, 2024
1 parent dd1daab commit af5badd
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
2 changes: 1 addition & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,11 +1303,11 @@ def optimistic_gradient_descent(
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01
Objective function: 1.30E+01
References:
Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2
Expand Down
29 changes: 24 additions & 5 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,11 @@ def update_fn(updates, state, params):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByOptimisticGradientState(NamedTuple):
is_initial_step: chex.Array
previous_gradient: base.Updates


def scale_by_optimistic_gradient(
alpha: float = 1.0,
beta: float = 1.0
Expand All @@ -1201,15 +1206,29 @@ def scale_by_optimistic_gradient(
"""

def init_fn(params):
return TraceState(trace=otu.tree_zeros_like(params))
return ScaleByOptimisticGradientState(
is_initial_step=jnp.array(True),
previous_gradient=otu.tree_zeros_like(params),
)

def update_fn(updates, state, params=None):
del params

new_updates = jax.tree.map(
lambda grad_t, grad_tm1: (alpha + beta) * grad_t - beta * grad_tm1,
updates, state.trace)
return new_updates, TraceState(trace=updates)
def f(grad_t, grad_tm1):
# At the initial step, the previous gradient doesn't exist, so we use the
# current gradient instead.
# https://github.com/google-deepmind/optax/issues/1082
grad_tm1 = jnp.where(state.is_initial_step, grad_t, grad_tm1)
return (alpha + beta) * grad_t - beta * grad_tm1

new_updates = jax.tree.map(f, updates, state.previous_gradient)

new_state = ScaleByOptimisticGradientState(
is_initial_step=jnp.array(False),
previous_gradient=updates,
)

return new_updates, new_state

return base.GradientTransformation(init_fn, update_fn)

Expand Down
27 changes: 12 additions & 15 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,22 @@ def test_centralize(self, inputs, outputs):
chex.assert_trees_all_close(centralized_inputs, outputs)

def test_scale_by_optimistic_gradient(self):
opt = transform.scale_by_optimistic_gradient()

def f(params: jnp.ndarray) -> jnp.ndarray:
return params['x'] ** 2
state = opt.init(10.0)

initial_params = {
'x': jnp.array(2.0)
}
grad_0 = 2.0
opt_grad_0, state = opt.update(grad_0, state)
chex.assert_trees_all_close(opt_grad_0, grad_0)
# initial step should yield 2 * grad_0 - grad_0 = grad_0

og = transform.scale_by_optimistic_gradient()
og_state = og.init(initial_params)
# Provide some arbitrary previous gradient.
getattr(og_state, 'trace')['x'] = 1.5
grad_1 = 3.0
opt_grad_1, state = opt.update(grad_1, state)
chex.assert_trees_all_close(opt_grad_1, 2 * grad_1 - grad_0)

g = jax.grad(f)(initial_params)
og_true = 2 * g['x'] - getattr(og_state, 'trace')['x']
og, _ = og.update(g, og_state)

# Compare transformation output with manually computed optimistic gradient.
chex.assert_trees_all_close(og_true, og['x'])
grad_2 = 4.0
opt_grad_2, state = opt.update(grad_2, state)
chex.assert_trees_all_close(opt_grad_2, 2 * grad_2 - grad_1)

def test_scale_by_polyak_l1_norm(self, tol=1e-10):
"""Polyak step-size on L1 norm."""
Expand Down

0 comments on commit af5badd

Please sign in to comment.