Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic SArray adjoint #65

Merged
merged 9 commits into from
Apr 16, 2021
Merged

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Apr 13, 2021

In principle, should resolve issues seen in #61 .

@andreaskoher could you confirm that it does so in your environment?

@andreaskoher
Copy link
Contributor

not quiet it seems. With the first definition of ApproxPeriodicKernel, I still get

Need an adjoint for constructor SVector{2, Float64}. Gradient is of type Vector{Float64}

The error references the line H = SVector{2, T}(1, 0) in to_sde(k::OscillatorKernel{T,I} ...)

With the second definition, I get

MethodError: no method matching (::TemporalGPs.var"#SArray_pullback#27"{Tuple{8, 8}})(::Nothing)

Here, the error references the line Fs = SMatrix{D,D,T}(F) in to_sde(k::ApproxPeriodicKernel{J,T} ...)

@willtebbutt
Copy link
Member Author

willtebbutt commented Apr 14, 2021

With the second definition, I get

Ah -- I should have caught that now with the latest commit.

not quiet it seems. With the first definition of ApproxPeriodicKernel, I still get

Thinking about this a bit more, it's very odd, and definitely bad for performance, that you're getting a Vector as the gradient w.r.t. an SVector. Could you provide a complete minimal working example for this? Seems like there's something else going on.

@andreaskoher
Copy link
Contributor

andreaskoher commented Apr 14, 2021

Thinking about this a bit more, it's very odd, and definitely bad for performance, that you're getting a Vector as the gradient w.r.t. an SVector

That's right and probably there is some performance problem with the first definition of ApproxPeriodicKernel. Since we will not go with this approach anyway, I would not try to dig into it too much. However, this simple example fails too:

T = Float64
x = RegularSpacing(0.,0.01,1000)
k = OscillatorKernel{T,1}(T(1))
k = Matern32Kernel()
f = to_sde( GP(k), SArrayStorage(Float64) )
y = rand(f(x))

function objective(θ)
    k = OscillatorKernel{T,1}(T(1))
    f = to_sde( GP(k), SArrayStorage(Float64))
    return -logpdf(f(x, θ), y)
end

Zygote.gradient(objective, .1)

Need an adjoint for constructor SVector{2, Float64}. Gradient is of type SVector{2, Float64}

Similarly, you would get an error with a Matern32Kernel if you comment out the adjoints for to_sde or stationary_distribution:

with

Zygote.@adjoint function stationary_distribution(k::Matern32Kernel, storage_type)
    return stationary_distribution(k, storage_type), Δ->(nothing, nothing)
end

commented out, I get

Need an adjoint for constructor SVector{2, Float64}. Gradient is of type SVector{2, Float64}

@willtebbutt
Copy link
Member Author

Ahhh I see. I've made some more changes that I think should fix the problems you've been encountering -- the @nograds that I've now removed were hiding some underlying probems with some @adjoints that should now be fixed. I've tested locally, but could you try again to make sure that I've not missed anything?

@andreaskoher
Copy link
Contributor

really nice! it all works now 👍
There is one more thing that is probably more related to Zygote or rather my ignorance of its internals: When I try the same example above with k = ApproxPeriodicKernel() from the second implementation in #61, I run into

no method matching iterate(::Nothing)

In those few cases when the Kernel does not crash, it references the line

q = [ j>0 ? 2*besseli(j, 1/k.l^2)*exp(-k.l^2) : besseli(0, 1/k.l^2)*exp(-k.l^2) for j in 0:J]

and this happens even when I simplify to q = [ j>0 ? 2. : 1. for j in 0:J]. Is there something wrong with using list comprehension with Zygote?

@willtebbutt
Copy link
Member Author

really nice! it all works now 👍

Great. I'll merge this once I've got the tests to pass.

Regarding the list comprehensions, it's probably a Zygote thing. Could we take that conversation back to the periodic kernel issue, and could you provide a full stack trace?

@willtebbutt willtebbutt merged commit 38b9b3c into master Apr 16, 2021
@willtebbutt willtebbutt deleted the wct/zygote-staticarrays-ctors branch April 16, 2021 12:27
@andreaskoher andreaskoher mentioned this pull request Apr 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants