diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 59c7ec63b..fcd1c9a2a 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1303,11 +1303,11 @@ def optimistic_gradient_descent( ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01 - Objective function: 1.30E+01 References: Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2 diff --git a/optax/_src/transform.py b/optax/_src/transform.py index fcb1f3859..c9d280772 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1183,6 +1183,11 @@ def update_fn(updates, state, params): return base.GradientTransformation(init_fn, update_fn) +class ScaleByOptimisticGradientState(NamedTuple): + is_initial_step: chex.Array + previous_gradient: base.Updates + + def scale_by_optimistic_gradient( alpha: float = 1.0, beta: float = 1.0 @@ -1201,15 +1206,29 @@ def scale_by_optimistic_gradient( """ def init_fn(params): - return TraceState(trace=otu.tree_zeros_like(params)) + return ScaleByOptimisticGradientState( + is_initial_step=jnp.array(True), + previous_gradient=otu.tree_zeros_like(params), + ) def update_fn(updates, state, params=None): del params - new_updates = jax.tree.map( - lambda grad_t, grad_tm1: (alpha + beta) * grad_t - beta * grad_tm1, - updates, state.trace) - return new_updates, TraceState(trace=updates) + def f(grad_t, grad_tm1): + # At the initial step, the previous gradient doesn't exist, so we use the + # current gradient instead. + # https://github.com/google-deepmind/optax/issues/1082 + grad_tm1 = jnp.where(state.is_initial_step, grad_t, grad_tm1) + return (alpha + beta) * grad_t - beta * grad_tm1 + + new_updates = jax.tree.map(f, updates, state.previous_gradient) + + new_state = ScaleByOptimisticGradientState( + is_initial_step=jnp.array(False), + previous_gradient=updates, + ) + + return new_updates, new_state return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index d8537df4a..35686caf9 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -145,25 +145,22 @@ def test_centralize(self, inputs, outputs): chex.assert_trees_all_close(centralized_inputs, outputs) def test_scale_by_optimistic_gradient(self): + opt = transform.scale_by_optimistic_gradient() - def f(params: jnp.ndarray) -> jnp.ndarray: - return params['x'] ** 2 + state = opt.init(10.0) - initial_params = { - 'x': jnp.array(2.0) - } + grad_0 = 2.0 + opt_grad_0, state = opt.update(grad_0, state) + chex.assert_trees_all_close(opt_grad_0, grad_0) + # initial step should yield 2 * grad_0 - grad_0 = grad_0 - og = transform.scale_by_optimistic_gradient() - og_state = og.init(initial_params) - # Provide some arbitrary previous gradient. - getattr(og_state, 'trace')['x'] = 1.5 + grad_1 = 3.0 + opt_grad_1, state = opt.update(grad_1, state) + chex.assert_trees_all_close(opt_grad_1, 2 * grad_1 - grad_0) - g = jax.grad(f)(initial_params) - og_true = 2 * g['x'] - getattr(og_state, 'trace')['x'] - og, _ = og.update(g, og_state) - - # Compare transformation output with manually computed optimistic gradient. - chex.assert_trees_all_close(og_true, og['x']) + grad_2 = 4.0 + opt_grad_2, state = opt.update(grad_2, state) + chex.assert_trees_all_close(opt_grad_2, 2 * grad_2 - grad_1) def test_scale_by_polyak_l1_norm(self, tol=1e-10): """Polyak step-size on L1 norm."""