-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add atol
option to contrib.reduce_on_plateau()
#698
Add atol
option to contrib.reduce_on_plateau()
#698
Conversation
Thanks @stefanocortinovis ! Having a relative tolerance makes a lot of sense to me. The only downside I see is that it no longer follows the pytorch implementation (not that we have to), so I think that's fine. Any opinion on this @vz415 @vroulet since you contributed the reduce_on_plateau implementation and review respectively? |
Nice catch! This does handle negative losses correctly but I'd just suggest adding having an assert statement or some other safety measure to ensure users set |
Thanks for the comments! I added validation checks for |
I suggest we use the warnings module in the Python standard library (https://docs.python.org/3/library/warnings.html) since that seems to be what jax uses (see for instance https://github.com/google/jax/blob/f539187c053bf1819f05d0f8e9e66e45da2af17b/jax/_src/array.py#L456) |
Thanks for the suggestion @fabianp. However, I'm still not convinced raising warnings based on the value of In theory, raising such a warning would involve adding something like if loss < 0.0 and some_other_condition:
warnings.warn('some warning') to the Howeover, @jit
def make_step(params, transform_state):
updates = {"params": 1.0}
loss = jnp.asarray(-1.0)
updates, _ = transform.update(updates=updates, state=transform_state, loss=loss)
params = optax.apply_updates(params, updates)
return params, loss
params = {"params": 2.0}
transform = contrib.reduce_on_plateau()
transform_state = transform.init(params)
make_step(params, transform_state) Hence, if Am I missing something here? |
You're absolutely right @stefanocortinovis , I wasn't thinking about jitting. |
Thanks for the contribution @stefanocortinovis ! |
This PR swaps the
threshold
argument in the current implementation ofreduce_on_plateau()
with the two argumentsrtol
andatol
. The change is added toreduce_on_plateau.py
andreduce_on_plateau.ipynb
where needed to retain the current behaviour.The behaviour in the current implementation corresponds to the case
atol = 0
andrtol > 0
in the new one, and is kept as the default. The addition ofatol
allows to set an absolute tolerance for measuring a new best loss, or to mix relative and absolute tolerances by choosing a value greater than zero for bothrtol
andatol
.In order to implement this change, I've modified the inequality
has_improved = jnp.where((loss / state.best_loss - 1) < -threshold, 1, 0)
tohas_improved = jnp.where(loss < (1 - rtol) * state.best_loss - atol, 1, 0)
. Notice that moving the termstate.best_loss
from the denominator of the left-hand side in the first inequality to the right-hand side in the second one has the added benefit of correctly handing losses that can attain negative values (e.g. negative ELBOs with continuous distributions) whenrtol = 0
andatol > 0
.