Skip to content

Commit

Permalink
Changed gradient of LSE (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
trappmartin authored and yebai committed Aug 9, 2019
1 parent 9f6f127 commit abbc8c0
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,9 @@ end
import StatsFuns: logsumexp
logsumexp(x::Tracker.TrackedArray) = Tracker.track(logsumexp, x)
Tracker.@grad function logsumexp(x::Tracker.TrackedArray)
lse = logsumexp(Tracker.data(x))
se = exp(lse)
lse = logsumexp(Tracker.data(x))
return lse,
Δ->.* exp.(x) ./ se,)
Δ->.* exp.(x .- lse),)
end

import StatsFuns: binomlogpdf
Expand Down

0 comments on commit abbc8c0

Please sign in to comment.