Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient returns nothing for sum(abs2, x) with a complex CuArray #1121

Closed
LexaLutyi opened this issue Nov 18, 2021 · 10 comments
Closed

gradient returns nothing for sum(abs2, x) with a complex CuArray #1121

LexaLutyi opened this issue Nov 18, 2021 · 10 comments
Labels
CUDA All things GPU

Comments

@LexaLutyi
Copy link

gradient returns nothing for CuArrays{ComplexF32}, but works fine with Arrays{ComplexF32}:

julia> a = CUDA.rand(ComplexF32, 2)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
  0.040781975f0 + 0.31357694f0im
 0.0050712824f0 + 0.464033f0im

julia> gradient(t -> sum(abs2, t), a)
(nothing,)

julia> a = rand(ComplexF32, 2)
2-element Vector{ComplexF32}:
 0.6999197f0 + 0.343145f0im
 0.4877541f0 + 0.47177994f0im

julia> gradient(t -> sum(abs2, t), a)
(ComplexF32[1.3998394f0 + 0.68629f0im, 0.9755082f0 + 0.9435599f0im],)

real and imag behave similarly to abs2

@DhairyaLGandhi
Copy link
Member

Definitely a bug! Could you check on older Zygote releases, say v0.6.3?

@LexaLutyi
Copy link
Author

Definitely a bug! Could you check on older Zygote releases, say v0.6.3?

Yes, on version v0.6.3 it works correctly

julia> a = CUDA.rand(ComplexF32, 2)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
 0.12610757f0 + 0.47263396f0im
  0.4873352f0 + 0.42900836f0im

julia> gradient(t -> sum(abs2, t), a)
(ComplexF32[0.25221515f0 + 0.9452679f0im, 0.9746704f0 + 0.8580167f0im],)

julia> b = Array(a)
2-element Vector{ComplexF32}:
 0.12610757f0 + 0.47263396f0im
  0.4873352f0 + 0.42900836f0im

julia> gradient(t -> sum(abs2, t), b)
(ComplexF32[0.25221515f0 + 0.9452679f0im, 0.9746704f0 + 0.8580167f0im],)

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Nov 18, 2021

I wonder if it's related to the recent projection related issues too.

@LexaLutyi
Copy link
Author

The problem occurs when using Flux, where Zygote is v0.6.30. If you install Zygote without Flux, then Zygote is v0.6.12 and everything works correctly.

@mcabbott
Copy link
Member

This used to go here:
https://github.com/FluxML/Zygote.jl/blob/v0.6.12/src/lib/array.jl#L299

After #990 and #1004 it goes here, which calls the adjoint for broadcasting:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L278-L283

And that won't work, because broadcasting doesn't handle complex CuArrays at all, it treats them as if they are non-differentiable:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L224-L226

You could fix sum(abs2, x) by making it avoid broadcasting (either copying the explicit rule, or sending it to ChainRules). Or by adding a rule for @adjoint broadcasted(::typeof(abs2), x::Numeric), which wouldn't be a bad idea anyway. It would be easily make broadcasting over complex CuArrays an error instead of a silent wrong answer, and should also be done anyway. Finally, it would also not be very hard to make broadcasting over complex CuArrays just work instead.

Xref #961. No relation to projection.

@mcabbott mcabbott changed the title gradient returns nothing for complex CuArray gradient returns nothing for sum(abs2, x) with a complex CuArray Nov 18, 2021
@DhairyaLGandhi
Copy link
Member

Good shout on the broadcasting changes. It's hard to say how some functions may silently break, the answer is probably to make sure the non broadcasted adjoint is hit in this case anyway. It would be tedious to have two adjoint rules to forward it manually, maybe we should take a look at what functions are broken due to the changes in the map adjoint.

@alexjaffray
Copy link

Bumping this: Is there any update on this issue other than to use older Zygote and Flux versions pre-broadcasting changes?

@omalled
Copy link

omalled commented Jan 19, 2022

This is causing problems for me as well. Zygote gradients of other things like f_exp(x) = sum(real(exp.(x))) also return nothing for complex CuArrays, similar to issue #961. That makes me hesitant to go down the custom adjoint rabbit hole. Is there hope for a fix that will cover these cases generally? Thanks for all your hard work - really love Zygote and CUDA.jl!

@mcabbott
Copy link
Member

mcabbott commented Jan 19, 2022

Broadcasting with complex numbers has never worked on the GPU, sadly, this hasn't changed. It should be an error but isn't. It could certainly be made to work, but someone has to do it. A few special cases could more easily be made to work, too; I guess .+ and .* probably already do. And the headline issue is the special case sum(abs2, x).

@CarloLucibello
Copy link
Member

closed in #1324

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CUDA All things GPU
Projects
None yet
Development

No branches or pull requests

6 participants