Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

happens-to-be-zero zero for tangent space of primal #476

Open
willtebbutt opened this issue Oct 2, 2021 · 7 comments
Open

happens-to-be-zero zero for tangent space of primal #476

willtebbutt opened this issue Oct 2, 2021 · 7 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Oct 2, 2021

Motivation

Somtimes it would be very convenient to have access to the element of the tangent space of a particular primal which happens to be zero for type-stability reasons. i.e. 0.0 and zeros(5, 4) rather than ZeroTangent.

The example I recently encountered where this would be helpful was reduce, specifically this line in Zygote's implementation of map. It's not possible to make this bit type-stable at the minute unless you know a priori whether or not the container you're mapping over is empty, so the pullback for map(function_with_fields, (5.0, 4.0)) infers, but the pullback map(function_with_fields, [5.0, 4.0]) does not. This is because the init kwarg is generally of a different type to the elements of Δf_and_args[1].

However, if we had access to a zero whose type doesn't change when we add cotangents to it, things ought to be type-stable.

Implementation

We know that the zero always exists, because the tangent space to a primal is a vector space, so there aren't any concerns regarding existence.

It's pretty clear what the right way to do this for composite types is via recursion (think rand_tangent, but zero rather than random), so we would just need to define it for primitives. It might get a little interesting here, because there are multiple possible tangent types for a Float64 primal (Float64, Float32, Float16, Int, etc) or a Vector primal (any AbstractVector of the same length with appropriate elements types), so possibly we would need additional information (such as the target tangent type) in order to do this.

Note that we do need to know the value of the primal for the same reasons that we need to have the primals hanging around in our projection functionality.

Anyway, I thought I'd bring this up because it's not something that we've thought much about before (ZeroTangent is often a really good option). It might be easier simply avoid situations like this most of the time (e.g. in the example I mentioned, using sized containers), but it's pretty annoying that the pullback for map isn't type-stable when mapping a closure over a Vector, because people to that a lot.

@oxinabox
Copy link
Member

oxinabox commented Oct 4, 2021

Would this be solved by JuliaLang/julia#38241 which would allow us to basically always use ZeroTangent and never have to use 0.0 for performance.
I can look into getting that resolved sooner rather than later if so.

@willtebbutt
Copy link
Member Author

I don't think so. I've been encountering this with Zygote types (see line linked above), so I'm using nothing rather than ZeroTangent.

@mcabbott
Copy link
Member

mcabbott commented Oct 4, 2021

So you want a function that's a lot like dx = zero(x), but tries to guarantee that [dx1, dx2, dx3] will be of uniform type, when some are zero and some nonzero.

there are multiple possible tangent types for a Float64 primal (Float64, Float32, Float16, Int, etc)

Projection should convert all of these to Float64. But not all numbers:

julia> p = ProjectTo(1)
ProjectTo{Float64}()

julia> p(2f3)  # projected to Float64
2000.0

julia> p(Dual(1,2))  # passes through
Dual{Nothing}(1,2)

or a Vector primal (any AbstractVector of the same length with

What would be easy to do is to always use Fill(zero(T), ...) so that e.g. [dx1, dx2, dx3] isa Vector{<:AbstractVector} which will at least make some dispatch happy:

julia> @which reduce(hcat, [Fill(2,3), fill(4,3)])
reduce(::typeof(hcat), A::AbstractVector{<:AbstractVecOrMat}) in Base at abstractarray.jl:1624

but still won't be stable.

JuliaLang/julia#38241 which would allow us to basically always use ZeroTangent and never have to use 0.0 for performance.

Even with this, you'd still miss BLAS, right?

@willtebbutt
Copy link
Member Author

So you want a function that's a lot like dx = zero(x), but tries to guarantee that [dx1, dx2, dx3] will be of uniform type, when some are zero and some nonzero.

Nearly. To be really specific, start with two things: a primal x, and a vectorts::Vector{T} of things in the tangent space of x. I want to sum ts. In general, to implement the summation you're going to need to write something like

reduce(+, ts; init=a_zero_tangent)

where a_zero_tangent is some representation of the zero element of the tangent space of x. What I want is a function zero_element(x, T) which picks the representation of a_zero_tangent such that the concrete type of the output of the sum is inferrable.

So if x is a Float64 and T == Float64 then I would want it to return 0.0.

Similarly, if x is a Vector{Float64} and T == Vector{Float64}, then I would want zeros(size(x)).

Structured array tangents are maybe more interesting. Suppose that x is a Vector{Float64} but T == Fill{Float64}, then I suspect you the optimal thing to do would be to make the zero a Fill{Float64}(0.0, size(x)) so that the sum can be performed efficiently. It would, of course, be totally valid to make the zero zeros(size(x)) again, it would just be sub-optimal, but would get you type -stability -- might have worse performance (locally) than the type-unstable version though.

If x is a Float64, and T == Real (perhaps because there's a mix of tangent precisions somehow), then I think you'd want to to make the initialisation a Float64 and somehow assert that the result must be a Float64 or something?

Projection should convert all of these to Float64. But not all numbers:

Oh, interesting, I didn't realise we were being that strict with projection. I guess I had been thinking that an Int is a perfectly fine tangent for a Float64 (no need to project) in the same sense that a Diagonal{Float64} is a perfectly good tangent representation for a Matrix{Float64}. A discussion for a different place perhaps...

@mcabbott
Copy link
Member

mcabbott commented Oct 4, 2021

being that strict

There's no maths in this, but accidental Float32 -> Float64 is such an easy performance bug to introduce, and we can kill it globally at last. But for Hessians etc. you may need dx::Dual so it can't be too strict. There's an argument that integers should be non-differentiable but making gradient(sin, 1) fail seems unfriendly, but nobody could think of a real use for integer gradients (beyond saving money on chalk).

ts::Vector{T}
such that the concrete type of [sum(ts)] is inferrable.

Maybe this is easier than what I had, in that you're given T. For numbers init=zero(T) is done automatically.

I'm a bit confused by that line in the map adjoint though. How do you get map over an empty array to have a nontrivial gradient? Are any such cases not errors for other reasons?

@willtebbutt
Copy link
Member Author

Maybe this is easier than what I had, in that you're given T. For numbers init=zero(T) is done automatically.

Exactly. Having both pieces of information is really helpful.

I'm a bit confused by that line in the map adjoint though. How do you get map over an empty array to have a nontrivial gradient? Are any such cases not errors for other reasons?

Here's an example of a programme that works:

julia> function foo(x)
           return x[1] + 3 * sum(map(sin, x[2:end]))
       end
foo (generic function with 1 method)

julia> Zygote.gradient(foo, randn(1))
([1.0],)

julia> Zygote.gradient(foo, randn(2))
([1.0, 2.9999914288727743],)

@oxinabox
Copy link
Member

oxinabox commented Oct 6, 2021

JuliaLang/julia#38241 which would allow us to basically always use ZeroTangent and never have to use 0.0 for performance.

Even with this, you'd still miss BLAS, right?

Right.
Could we use Tullio for that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants