-
-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update softmax implementation with dims argument #90
Conversation
Add tests? Otherwise lgtm |
Co-authored-by: David Widmann <[email protected]>
I tested it for an application locally. Will add some minimal tests for this by tomorrow. |
|
||
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),) | ||
@grad softmax(xs; dims=1) = softmax(data(xs); dims=dims), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs); dims=dims)),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a more efficient ∇softmax
which takes, as optional 3rd argument, the forward pass result, see here:
I think this will be in NNlib 0.6, so provided Tracker is happy to require that, it could use this (and the equivalent for logsoftmax) too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we tag NNlib with that update before making the changes in the dependent packages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Certainly we shouldn't tag an update to this package before that! Whether this should wait, or be done in two steps, is your call I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest version of NNlib broke softmax
with Tracker (see FluxML/NNlib.jl#251), so I would prefer if this fix could be made available as soon as possible (e.g., the bug breaks CI tests in multiple Turing packages). I just confirmed that this PR fixes the issue.
Lgtm thanks |
Is it possible to make a new release with these changes? I'd like to use the bugfix. |
No description provided.