diff --git a/src/interface.jl b/src/interface.jl index 9b199e0a..2c0ae7c3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -127,7 +127,7 @@ end # N.B: This means the gradient is not available for zero-arg function, but such # a gradient would be guaranteed to be `()`, which is a bit of a useless thing function (::Type{∇})(f, x1, args...) - ∇(f)(x1, args...) + unthunk.(∇(f)(x1, args...)) end const gradient = ∇ @@ -159,7 +159,7 @@ function (f::PrimeDerivativeBack)(x) z = ∂⃖¹(lower_pd(f), x) y = getfield(z, 1) f☆ = getfield(z, 2) - return getfield(f☆(dx(y)), 2) + return unthunk(getfield(f☆(dx(y)), 2)) end # Forwards primal derivative diff --git a/test/runtests.jl b/test/runtests.jl index c7ae67d8..4b1832e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -108,7 +108,7 @@ let var"'" = Diffractor.PrimeDerivativeBack # Control flow cases @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) @test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0) - @test_broken (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] + @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] @test times_three_while'(1.0) == 3.0 pow5p(x) = (x->mypow(x, 5))'(x)