Skip to content

Commit

Permalink
always unthunk results (#79)
Browse files Browse the repository at this point in the history
* always unthunk results

* add unthunk for gradient
  • Loading branch information
oscardssmith authored Jul 25, 2022
1 parent 82096ee commit a2ea087
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ end
# N.B: This means the gradient is not available for zero-arg function, but such
# a gradient would be guaranteed to be `()`, which is a bit of a useless thing
function (::Type{∇})(f, x1, args...)
(f)(x1, args...)
unthunk.((f)(x1, args...))
end

const gradient =
Expand Down Expand Up @@ -159,7 +159,7 @@ function (f::PrimeDerivativeBack)(x)
z = ∂⃖¹(lower_pd(f), x)
y = getfield(z, 1)
f☆ = getfield(z, 2)
return getfield(f☆(dx(y)), 2)
return unthunk(getfield(f☆(dx(y)), 2))
end

# Forwards primal derivative
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ let var"'" = Diffractor.PrimeDerivativeBack
# Control flow cases
@test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0)
@test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0)
@test_broken (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
@test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
@test times_three_while'(1.0) == 3.0

pow5p(x) = (x->mypow(x, 5))'(x)
Expand Down

0 comments on commit a2ea087

Please sign in to comment.