Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update softmax implementation with dims argument #90

Merged
merged 3 commits into from
Jan 3, 2021

Conversation

avik-pal
Copy link
Member

No description provided.

@DhairyaLGandhi
Copy link
Member

Add tests? Otherwise lgtm

Co-authored-by: David Widmann <[email protected]>
@avik-pal
Copy link
Member Author

Add tests? Otherwise lgtm

I tested it for an application locally. Will add some minimal tests for this by tomorrow.

@avik-pal avik-pal marked this pull request as draft December 23, 2020 12:28
@avik-pal avik-pal marked this pull request as ready for review January 3, 2021 14:25

@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
@grad softmax(xs; dims=1) = softmax(data(xs); dims=dims), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs); dims=dims)),)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a more efficient ∇softmax which takes, as optional 3rd argument, the forward pass result, see here:

https://github.com/FluxML/NNlib.jl/pull/250/files#diff-858614f9f176124bafc3047ab7803e7c512f25a87f5f56661de0d3973c43a9ccR70

I think this will be in NNlib 0.6, so provided Tracker is happy to require that, it could use this (and the equivalent for logsoftmax) too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we tag NNlib with that update before making the changes in the dependent packages?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly we shouldn't tag an update to this package before that! Whether this should wait, or be done in two steps, is your call I guess.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest version of NNlib broke softmax with Tracker (see FluxML/NNlib.jl#251), so I would prefer if this fix could be made available as soon as possible (e.g., the bug breaks CI tests in multiple Turing packages). I just confirmed that this PR fixes the issue.

@DhairyaLGandhi
Copy link
Member

Lgtm thanks

@DhairyaLGandhi DhairyaLGandhi merged commit 6a21660 into master Jan 3, 2021
@devmotion
Copy link
Contributor

Is it possible to make a new release with these changes? I'd like to use the bugfix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants