From 2783a90850d9d1ae6d2a9c5eaca0ec63eb84f404 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 08:51:31 +0300 Subject: [PATCH 01/20] format(".") --- docs/make.jl | 29 ++-- docs/src/assets/make_logo.jl | 48 +++--- src/accumulation.jl | 8 +- src/compat.jl | 4 +- src/deprecated.jl | 1 + src/ignore_derivatives.jl | 8 +- src/projection.jl | 85 ++++++----- src/rule_definition_tools.jl | 84 +++++++---- src/tangent_arithmetic.jl | 14 +- src/tangent_types/abstract_zero.jl | 8 +- src/tangent_types/notimplemented.jl | 10 +- src/tangent_types/tangent.jl | 113 ++++++++------- src/tangent_types/thunks.jl | 16 +- test/accumulation.jl | 24 +-- test/config.jl | 76 ++++++---- test/deprecated.jl | 1 + test/ignore_derivatives.jl | 8 +- test/projection.jl | 40 ++--- test/rule_definition_tools.jl | 175 +++++++++++----------- test/rules.jl | 75 ++++++---- test/tangent_types/abstract_zero.jl | 4 +- test/tangent_types/notimplemented.jl | 8 +- test/tangent_types/tangent.jl | 209 +++++++++++++-------------- test/tangent_types/thunks.jl | 2 +- 24 files changed, 567 insertions(+), 483 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1ef3a62a7..608422c25 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,20 +16,20 @@ DocMeta.setdocmeta!( @scalar_rule(sin(x), cos(x)) # frule and rrule doctest @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end + end, ) indigo = DocThemeIndigo.install(ChainRulesCore) makedocs( - modules=[ChainRulesCore], - format=Documenter.HTML( - prettyurls=false, - assets=[indigo], - mathengine=MathJax3( + modules = [ChainRulesCore], + format = Documenter.HTML( + prettyurls = false, + assets = [indigo], + mathengine = MathJax3( Dict( :tex => Dict( - "inlineMath" => [["\$","\$"], ["\\(","\\)"]], + "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], "tags" => "ams", # TODO: remove when using physics package "macros" => Dict( @@ -42,9 +42,9 @@ makedocs( ), ), ), - sitename="ChainRules", - authors="Jarrett Revels and other contributors", - pages=[ + sitename = "ChainRules", + authors = "Jarrett Revels and other contributors", + pages = [ "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", @@ -63,11 +63,8 @@ makedocs( ], "API" => "api.md", ], - strict=true, - checkdocs=:exports, + strict = true, + checkdocs = :exports, ) -deploydocs( - repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", - push_preview=true, -) +deploydocs(repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview = true) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 5bbfd36c1..3e7aeaa08 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -8,34 +8,34 @@ using Random const bridge_len = 50 -function chain(jiggle=0) - shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) - +function chain(jiggle = 0) + shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) + ### 1 shaky_rotate(0) sethue(Luxor.julia_red) link() m1 = getmatrix() - - + + ### 2 sethue(Luxor.julia_green) - translate(-50, 130); - shaky_rotate(π/3); + translate(-50, 130) + shaky_rotate(π / 3) link() m2 = getmatrix() - + setmatrix(m1) sethue(Luxor.julia_red) overlap(-1.3π) setmatrix(m2) - + ### 3 - shaky_rotate(-π/3); - translate(-120,80); + shaky_rotate(-π / 3) + translate(-120, 80) sethue(Luxor.julia_purple) link() - + setmatrix(m2) setcolor(Luxor.julia_green) overlap(-1.5π) @@ -45,24 +45,24 @@ end function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - - - rect(50,-3,40, bridge_len+6, :fill) - rect(-50-40,-3,40, bridge_len+6, :fill) - + + + rect(50, -3, 40, bridge_len + 6, :fill) + rect(-50 - 40, -3, 40, bridge_len + 6, :fill) + sethue("black") move(Point(-50, bridge_len)) - arc(Point(0,0), 50, π, 0, :stoke) + arc(Point(0, 0), 50, π, 0, :stoke) arc(Point(0, bridge_len), 50, 0, -π, :stroke) - + move(Point(-90, bridge_len)) - arc(Point(0,0), 90, π, 0, :stoke) + arc(Point(0, 0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) strokepath() end function overlap(ang_end) - sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) + sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill) sethue("black") arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) move(Point(90, bridge_len)) @@ -75,13 +75,13 @@ end function save_logo(filename) Random.seed!(16) - Drawing(450,450, filename) + Drawing(450, 450, filename) origin() - translate(50, -130); + translate(50, -130) chain(0.5) finish() preview() end save_logo("logo.svg") -save_logo("logo.png") \ No newline at end of file +save_logo("logo.png") diff --git a/src/accumulation.jl b/src/accumulation.jl index 4bcc5c33f..5fbc07fa8 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,7 +26,7 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) -function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N +function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} return if is_inplaceable_destination(x) x .+= y else @@ -75,8 +75,8 @@ end struct BadInplaceException <: Exception ithunk::InplaceableThunk - accumuland - returned_value + accumuland::Any + returned_value::Any end function Base.showerror(io::IO, err::BadInplaceException) @@ -88,7 +88,7 @@ function Base.showerror(io::IO, err::BadInplaceException) if err.accumuland == err.returned_value println( io, - "Which in this case happenned to be equal. But they are not the same object." + "Which in this case happenned to be equal. But they are not the same object.", ) end end diff --git a/src/compat.jl b/src/compat.jl index 8204b66d5..fa66b1d0f 100644 --- a/src/compat.jl +++ b/src/compat.jl @@ -5,7 +5,7 @@ end if VERSION < v"1.1" # Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4 # See: https://github.com/JuliaLang/julia/issues/34292 - function fieldtypes(::Type{T}) where T + function fieldtypes(::Type{T}) where {T} if @generated ntuple(i -> fieldtype(T, i), fieldcount(T)) else @@ -13,7 +13,7 @@ if VERSION < v"1.1" end end - function fieldnames(::Type{T}) where T + function fieldnames(::Type{T}) where {T} if @generated ntuple(i -> fieldname(T, i), fieldcount(T)) else diff --git a/src/deprecated.jl b/src/deprecated.jl index e69de29bb..8b1378917 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -0,0 +1 @@ + diff --git a/src/ignore_derivatives.jl b/src/ignore_derivatives.jl index c66d89d7e..18865f2c9 100644 --- a/src/ignore_derivatives.jl +++ b/src/ignore_derivatives.jl @@ -45,7 +45,9 @@ ignore_derivatives(x) = x Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`. """ macro ignore_derivatives(ex) - return :(ChainRulesCore.ignore_derivatives() do - $(esc(ex)) - end) + return :( + ChainRulesCore.ignore_derivatives() do + $(esc(ex)) + end + ) end diff --git a/src/projection.jl b/src/projection.jl index 4b07b2762..55f6e7bfd 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -131,7 +131,8 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}() +ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} = + ProjectTo{NoTangent}() # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -164,12 +165,14 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) end # In these cases we can just `convert` as we know we are dealing with plain and simple types -(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx) -(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity +(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) +(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = + convert(T, dx) (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = + convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -190,7 +193,7 @@ end # For arrays of numbers, just store one projector: function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) + return ProjectTo{AbstractArray}(; element = _eltype_projectto(T), axes = axes(x)) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -204,7 +207,7 @@ function ProjectTo(xs::AbstractArray) return ProjectTo{NoTangent}() # short-circuit if all elements project to zero else # Arrays of arrays come here, and will apply projectors individually: - return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs)) + return ProjectTo{AbstractArray}(; elements = elements, axes = axes(xs)) end end @@ -214,7 +217,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} dy = if axes(dx) == project.axes dx else - for d in 1:max(M, length(project.axes)) + for d = 1:max(M, length(project.axes)) if size(dx, d) != length(get(project.axes, d, 1)) throw(_projection_mismatch(project.axes, size(dx))) end @@ -244,9 +247,11 @@ end # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers if !(project.axes isa Tuple{}) - throw(DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - )) + throw( + DimensionMismatch( + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", + ), + ) end return fill(project.element(dx)) end @@ -254,7 +259,7 @@ end function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) return DimensionMismatch( - "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx" + "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx", ) end @@ -268,13 +273,13 @@ function ProjectTo(x::Ref) if sub isa ProjectTo{<:AbstractZero} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type=typeof(x), x=sub) + return ProjectTo{Ref}(; type = typeof(x), x = sub) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) +(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x = project.x(dx.x)) +(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x = project.x(dx[])) # Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) +(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x = project.x(dx)) ##### ##### `LinearAlgebra` @@ -283,7 +288,7 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Row vectors -ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x))) +ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent = ProjectTo(parent(x))) # Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec. # Transposed matrices are, like PermutedDimsArray, just a storage detail, # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number @@ -298,7 +303,8 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +ProjectTo(x::LinearAlgebra.TransposeAbsVec) = + ProjectTo{Transpose}(; parent = ProjectTo(parent(x))) function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -311,21 +317,22 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) end # Diagonal -ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) +ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag = ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) # Symmetric -for (SymHerm, chk, fun) in ( - (:Symmetric, :issymmetric, :transpose), - (:Hermitian, :ishermitian, :adjoint), - ) +for (SymHerm, chk, fun) in + ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) @eval begin function ProjectTo(x::$SymHerm) sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) + return ProjectTo{$SymHerm}(; + uplo = LinearAlgebra.sym_uplo(x.uplo), + parent = sub, + ) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) dy = project.parent(dx) @@ -338,9 +345,8 @@ for (SymHerm, chk, fun) in ( # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) sub = project.parent # this is going to be unhappy about the size - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) + sub_one = + ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) return Diagonal(sub_one(dx.diag)) end end @@ -349,13 +355,12 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) + ProjectTo(x::$UL) = ProjectTo{$UL}(; parent = ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) + sub_one = + ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) return Diagonal(sub_one(dx.diag)) end end @@ -392,7 +397,7 @@ end # another strategy is just to use the AbstractArray method function ProjectTo(x::Tridiagonal{T}) where {T<:Number} notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent=notparent) + return ProjectTo{Tridiagonal}(; notparent = notparent) end function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) dy = project.notparent(dx) @@ -411,7 +416,9 @@ using SparseArrays function ProjectTo(x::SparseVector{T}) where {T<:Number} return ProjectTo{SparseVector}(; - element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) + element = ProjectTo(zero(T)), + nzind = x.nzind, + axes = axes(x), ) end function (project::ProjectTo{SparseVector})(dx::AbstractArray) @@ -450,11 +457,11 @@ end function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} return ProjectTo{SparseMatrixCSC}(; - element=ProjectTo(zero(T)), - axes=axes(x), - rowval=rowvals(x), - nzranges=nzrange.(Ref(x), axes(x, 2)), - colptr=x.colptr, + element = ProjectTo(zero(T)), + axes = axes(x), + rowval = rowvals(x), + nzranges = nzrange.(Ref(x), axes(x, 2)), + colptr = x.colptr, ) end # You need not really store nzranges, you can get them from colptr -- TODO @@ -474,7 +481,7 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) for i in project.nzranges[col] row = project.rowval[i] val = dy[row, col] - nzval[k += 1] = project.element(val) + nzval[k+=1] = project.element(val) end end m, n = map(length, project.axes) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 911a32ddd..8a1e1cce4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -83,9 +83,8 @@ For examples, see ChainRules' `rulesets` directory. See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) - call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( - call, maybe_setup, partials - ) + call, setup_stmts, inputs, partials = + _normalize_scalarrules_macro_input(call, maybe_setup, partials) f = call.args[1] # Generate variables to store derivatives named dfi/dxj @@ -101,9 +100,11 @@ macro scalar_rule(call, maybe_setup, partials...) # Final return: building the expression to insert in the place of this macro code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw(ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))" - )) + throw( + ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))", + ), + ) end $(derivative_expr) @@ -175,7 +176,11 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote - function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + function ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), + ::Core.Typeof($f), + $(inputs...), + ) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) @@ -196,9 +201,8 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) end if n_outputs > 1 # For forward-mode we return a Tangent if output actually a tuple. - pushforward_returns = Expr( - :call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns... - ) + pushforward_returns = + Expr(:call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns...) else pushforward_returns = first(pushforward_returns) end @@ -210,7 +214,8 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = + ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -225,7 +230,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) Δs = _propagator_inputs(n_outputs) # Make a projector for each argument - projs, psetup = _make_projectors(call.args[2:end]) + projs, psetup = _make_projectors(call.args[2:end]) append!(setup_stmts, psetup) # 1 partial derivative per input @@ -248,7 +253,8 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = + ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -257,12 +263,12 @@ end # For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" -_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i = 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) projs = map(x -> Symbol(:proj_, x.args[1]), xs) - setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs) + setups = map((x, p) -> :($p = ProjectTo($x)), xs, projs) return projs, setups end @@ -275,7 +281,7 @@ Specify `_conj = true` to conjugate the partials. Projector `proj` is a function that will be applied at the end; for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity` """ -function propagation_expr(Δs, ∂s, _conj=false, proj=identity) +function propagation_expr(Δs, ∂s, _conj = false, proj = identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj @@ -288,9 +294,10 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) - end + summed_∂_mul_Δs = + foldl(Iterators.drop(zip(_∂s, Δs), 1); init = init_expr) do ex, (∂s_i, Δs_i) + :(muladd($∂s_i, $Δs_i, $ex)) + end return :($proj($summed_∂_mul_Δs)) end @@ -381,7 +388,10 @@ end function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... + :call, + call_expr.args[1], + Expr(:parameters, :($(kwargs)...)), + call_expr.args[2:end]..., ) end @@ -389,11 +399,18 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), - frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function (::Core.kwftype(typeof(ChainRulesCore.frule)))( + @nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end - function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function ChainRulesCore.frule( + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), + ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() return ($(esc(primal_invoke)), NoTangent()) @@ -408,7 +425,8 @@ function tuple_expression(primal_sig_parts) Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + length_expr = + :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) end end @@ -426,7 +444,11 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) + function (::Core.kwftype(typeof(rrule)))( + $(esc(kwargs))::Any, + ::typeof(rrule), + $(esc_primal_sig_parts...), + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) @@ -481,7 +503,7 @@ end "Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) - length(expr.args)===0 && error("Malformed method expression. $expr") + length(expr.args) === 0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore @@ -555,12 +577,13 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) + if primal_name isa Symbol || + Meta.isexpr(primal_name, :(.)) || + Meta.isexpr(primal_name, :curly) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name - # e.g. (::T)(x, y) + # e.g. (::T)(x, y) elseif Meta.isexpr(primal_name, :(::)) _primal_name = gensym(Symbol(:instance_, primal_name.args[end])) primal_name_sig = Expr(:(::), _primal_name, primal_name.args[end]) @@ -582,7 +605,8 @@ end function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) + Meta.isexpr(arg, :(...), 1) && + return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 9c1378aab..c2bad7a77 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -81,7 +81,7 @@ LinearAlgebra.dot(::ZeroTangent, ::NoTangent) = ZeroTangent() Base.muladd(::ZeroTangent, x, y) = y Base.muladd(x, ::ZeroTangent, y) = y -Base.muladd(x, y, ::ZeroTangent) = x*y +Base.muladd(x, y, ::ZeroTangent) = x * y Base.muladd(::ZeroTangent, ::ZeroTangent, y) = y Base.muladd(x, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() @@ -125,11 +125,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where P +function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P, typeof(data)}(data) + return Tangent{P,typeof(data)}(data) end -function Base.:+(a::P, d::Tangent{P}) where P +function Base.:+(a::P, d::Tangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -142,12 +142,12 @@ function Base.:+(a::P, d::Tangent{P}) where P end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where P = b + a +Base.:+(a::Tangent{P}, b::P) where {P} = b + a # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 differentials # Only of a differential and a scaling factor (generally `Real`) for T in (:Any,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent) + @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 216357e91..c86fc78ea 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -17,15 +17,15 @@ Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) -Base.Broadcast.broadcasted(::Type{T}) where T<:AbstractZero = T() +Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z -Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) -(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T) +Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) +(::Type{T})(xs::AbstractZero...) where {T<:Number} = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) @@ -33,7 +33,7 @@ Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) Base.getindex(z::AbstractZero, k) = z Base.view(z::AbstractZero, ind...) = z -Base.sum(z::AbstractZero; dims=:) = z +Base.sum(z::AbstractZero; dims = :) = z """ ZeroTangent() <: AbstractZero diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index a2044fbe1..7ceb315ea 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,9 +44,13 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw(NotImplementedException(@not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" -))) +Base.zero(::Type{<:NotImplemented}) = throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" + ) + ), +) Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index e4bbfb8c8..34e822ea8 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -21,42 +21,42 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P, T} <: AbstractTangent +struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T end -function Tangent{P}(; kwargs...) where P +function Tangent{P}(; kwargs...) where {P} backing = (; kwargs...) # construct as NamedTuple - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end -function Tangent{P}(args...) where P - return Tangent{P, typeof(args)}(args) +function Tangent{P}(args...) where {P} + return Tangent{P,typeof(args)}(args) end -function Tangent{P}() where P<:Tuple +function Tangent{P}() where {P<:Tuple} backing = () - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P, typeof(d)}(d) + return Tangent{P,typeof(d)}(d) end -function Base.:(==)(a::Tangent{P, T}, b::Tangent{P, T}) where {P, T} +function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P, T} +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P,T} all_fields = union(keys(backing(a)), keys(backing(b))) return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where P +function Base.show(io::IO, tangent::Tangent{P}) where {P} print(io, "Tangent{") show(io, P) print(io, "}") @@ -68,15 +68,15 @@ function Base.show(io::IO, tangent::Tangent{P}) where P end end -function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}} +function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end -function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict} +function Base.getindex(tangent::Tangent, idx) where {P,T<:AbstractDict} return unthunk(getindex(backing(tangent), idx)) end @@ -84,7 +84,7 @@ function Base.getproperty(tangent::Tangent, idx::Int) back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end @@ -99,26 +99,26 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) -Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T) +Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing) + Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) end -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state = 1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end -function Base.map(f, tangent::Tangent{P, <:Tuple}) where P +function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} vals::Tuple = map(f, backing(tangent)) - return Tangent{P, typeof(vals)}(vals) + return Tangent{P,typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L} +function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L, typeof(vals)}(vals) - return Tangent{P, typeof(named_vals)}(named_vals) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return Tangent{P,typeof(named_vals)}(named_vals) end -function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict} +function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end @@ -140,26 +140,28 @@ backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) # For generic structs -function backing(x::T)::NamedTuple where T +function backing(x::T)::NamedTuple where {T} # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. # see https://github.com/JuliaLang/julia/issues/34283 if @generated - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names, Tuple{$(types...)}}($vals)) + vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) else - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = ntuple(ii->getfield(x, ii), nfields) - return NamedTuple{names, Tuple{types...}}(vals) + vals = ntuple(ii -> getfield(x, ii), nfields) + return NamedTuple{names,Tuple{types...}}(vals) end end @@ -170,36 +172,38 @@ Return the canonical `Tangent` for the primal type `P`. The property names of the returned `Tangent` match the field names of the primal, and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. """ -function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L} +function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} nil = _zeroed_backing(P) combined = merge(nil, backing(tangent)) if length(combined) !== fieldcount(P) - throw(ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))" - )) + throw( + ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", + ), + ) end - return Tangent{P, typeof(combined)}(combined) + return Tangent{P,typeof(combined)}(combined) end # Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent +canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent # Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent +canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent # Tangents of unspecified primal types (indicated by specifying exactly `Any`) # all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:Tuple}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:AbstractDict}) where {L} = tangent """ _zeroed_backing(P) Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. """ -@generated function _zeroed_backing(::Type{P}) where P +@generated function _zeroed_backing(::Type{P}) where {P} nil_base = ntuple(fieldcount(P)) do i (fieldname(P, i), ZeroTangent()) end @@ -218,7 +222,7 @@ after an operation such as the addition of a primal to a tangent It should be overloaded, if `T` does not have a default constructor, or if `T` needs to maintain some invarients between its fields. """ -function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} +function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} # Tested and verified that that this avoids a ton of allocations if length(L) !== fieldcount(T) # if length is equal but names differ then we will catch that below anyway. @@ -233,12 +237,12 @@ function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} end end -construct(::Type{T}, fields::T) where T<:NamedTuple = fields -construct(::Type{T}, fields::T) where T<:Tuple = fields +construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields +construct(::Type{T}, fields::T) where {T<:Tuple} = fields elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} # Rule of Tangent addition: any fields not present are implict hard Zeros # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. @@ -281,7 +285,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} end field => value end - return (;vals...) + return (; vals...) end end @@ -297,15 +301,16 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color = :blue) println(io, "\nor overload") - printstyled(io, + printstyled( + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color=:blue + color = :blue, ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color = :blue) println(io, "\nOriginal Exception:") - printstyled(io, err.original; color=:yellow) + printstyled(io, err.original; color = :yellow) println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 16384d69e..c2b570902 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -56,18 +56,22 @@ LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a)) LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a)) LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a)) LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a)) -LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo) -LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo) +LinearAlgebra.Symmetric(a::AbstractThunk, uplo = :U) = Symmetric(unthunk(a), uplo) +LinearAlgebra.Hermitian(a::AbstractThunk, uplo = :U) = Hermitian(unthunk(a), uplo) function LinearAlgebra.diagm( - kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... + kv::Pair{<:Integer,<:AbstractThunk}, + kvs::Pair{<:Integer,<:AbstractThunk}..., ) return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...) end function LinearAlgebra.diagm( - m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... + m, + n, + kv::Pair{<:Integer,<:AbstractThunk}, + kvs::Pair{<:Integer,<:AbstractThunk}..., ) - return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) + return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a)) @@ -118,7 +122,7 @@ function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx) return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx) end -function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1) +function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn = 1) return throw(MutateThunkException()) end diff --git a/test/accumulation.jl b/test/accumulation.jl index 1b41fea55..a796b5289 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -27,7 +27,7 @@ end @testset "misc AbstractTangent subtypes" begin - @test 16 == add!!(12, @thunk(2*2)) + @test 16 == add!!(12, @thunk(2 * 2)) @test 16 == add!!(16, ZeroTangent()) @test 16 == add!!(16, NoTangent()) # Should this be an error? @@ -37,15 +37,15 @@ @testset "LHS Array (inplace)" begin @testset "RHS Array" begin A = [1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] end @testset "RHS StaticArray" begin - A = @SMatrix[1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + A = @SMatrix [1.0 2.0; 3.0 4.0] + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] @@ -53,7 +53,7 @@ @testset "RHS Diagonal" begin A = Diagonal([1.0, 2.0]) - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 -1.0; -1.0 1.0] @@ -79,17 +79,17 @@ @testset "Unhappy Path" begin # wrong length - @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(2, 2)) # wrong shape - @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(16)) # wrong type (adding scalar to array) @test_throws MethodError add!!(ones(4), 21.0) end end @testset "AbstractThunk $(typeof(thunk))" for thunk in ( - @thunk(-1.0*ones(2, 2)), - InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))), + @thunk(-1.0 * ones(2, 2)), + InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -111,12 +111,12 @@ @testset "not actually inplace but said it was" begin # thunk should never be used in this test ithunk = InplaceableThunk(@thunk(@assert false)) do x - 77*ones(2, 2) # not actually inplace (also wrong) + 77 * ones(2, 2) # not actually inplace (also wrong) end accumuland = ones(2, 2) @assert ChainRulesCore.debug_mode() == false # without debug being enabled should return the result, not error - @test 77*ones(2, 2) == add!!(accumuland, ithunk) + @test 77 * ones(2, 2) == add!!(accumuland, ithunk) ChainRulesCore.debug_mode() = true # enable debug mode # with debug being enabled should error @@ -127,7 +127,7 @@ @testset "showerror BadInplaceException" begin BadInplaceException = ChainRulesCore.BadInplaceException - ithunk = InplaceableThunk(x̄->nothing, @thunk(@assert false)) + ithunk = InplaceableThunk(x̄ -> nothing, @thunk(@assert false)) msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) @test occursin("22", msg) diff --git a/test/config.jl b/test/config.jl index 466baed9a..e6e2ab005 100644 --- a/test/config.jl +++ b/test/config.jl @@ -1,7 +1,7 @@ # Define a bunch of configs for testing purposes struct MostBoringConfig <: RuleConfig{Union{}} end -struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} forward_calls::Vector end MockForwardsConfig() = MockForwardsConfig([]) @@ -11,7 +11,7 @@ function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args. return f(args...; kws...), ȧrgs end -struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} reverse_calls::Vector end MockReverseConfig() = MockReverseConfig([]) @@ -23,7 +23,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. end -struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} forward_calls::Vector reverse_calls::Vector end @@ -47,18 +47,21 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, + MostBoringConfig, + MockForwardsConfig, + MockReverseConfig, + MockBothConfig, ) counting_id_count = Ref(0) function counting_id(x) - counting_id_count[]+=1 + counting_id_count[] += 1 return x end function ChainRulesCore.rrule(::typeof(counting_id), x) counting_id_pullback(x̄) = x̄ return counting_id(x), counting_id_pullback end - function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) + function ChainRulesCore.frule((dself, dx), ::typeof(counting_id), x) return counting_id(x), dx end @testset "rrule" begin @@ -76,21 +79,33 @@ end @testset "hitting forwards AD" begin do_thing_2(f, x) = f(x) function ChainRulesCore.frule( - config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x + config::RuleConfig{>:HasForwardsMode}, + (_, df, dx), + ::typeof(do_thing_2), + f, + x, ) return frule_via_ad(config, (df, dx), f, x) end @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) @test nothing === frule( - Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + Config(), + (NoTangent(), NoTangent(), 21.5), + do_thing_2, + identity, + 32.1, ) end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== frule( - bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + bconfig, + (NoTangent(), NoTangent(), 21.5), + do_thing_2, + identity, + 32.1, ) @test bconfig.forward_calls == [(identity, (32.1,))] end @@ -99,7 +114,10 @@ end @testset "hitting reverse AD" begin do_thing_3(f, x) = f(x) function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x + config::RuleConfig{>:HasReverseMode}, + ::typeof(do_thing_3), + f, + x, ) return (NoTangent(), rrule_via_ad(config, f, x)...) end @@ -110,7 +128,7 @@ end end @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) @test bconfig.reverse_calls == [(identity, (32.1,))] end @@ -130,14 +148,14 @@ end ẋ = one(x) y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) - pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ * ȳ return y, pullback_via_forwards_ad end function ChainRulesCore.rrule( - config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, + config::RuleConfig{>:Union{HasReverseMode,NoForwardsMode}}, ::typeof(do_thing_4), f, - x + x, ) y, f_pullback = rrule_via_ad(config, f, x) do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) @@ -147,43 +165,43 @@ end @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) @test bconfig.forward_calls == [(identity, (32.1,))] end - rconfig= MockReverseConfig() + rconfig = MockReverseConfig() @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) @test rconfig.reverse_calls == [(identity, (32.1,))] end @testset "RuleConfig broadcasts like a scaler" begin - @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} + @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} end @testset "fallbacks" begin - no_rule(x; kw="bye") = error() + no_rule(x; kw = "bye") = error() @test frule((1.0,), no_rule, 2.0) === nothing - @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule((1.0,), no_rule, 2.0; kw = "hello") === nothing @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing - @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw = "hello") === nothing @test rrule(no_rule, 2.0) === nothing - @test rrule(no_rule, 2.0; kw="hello") === nothing + @test rrule(no_rule, 2.0; kw = "hello") === nothing @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing - @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw = "hello") === nothing # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(; kw = "hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(sin; kw = "hello") @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(); kw = "hello") @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin; kw = "hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(; kw = "hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig();kw="hello") + @test_throws MethodError rrule(MostBoringConfig(); kw = "hello") end end diff --git a/test/deprecated.jl b/test/deprecated.jl index e69de29bb..8b1378917 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -0,0 +1 @@ + diff --git a/test/ignore_derivatives.jl b/test/ignore_derivatives.jl index 825287b9a..ad4fece9f 100644 --- a/test/ignore_derivatives.jl +++ b/test/ignore_derivatives.jl @@ -7,7 +7,7 @@ end @testset "function" begin f() = return 4.0 - y, ẏ = frule((1.0, ), ignore_derivatives, f) + y, ẏ = frule((1.0,), ignore_derivatives, f) @test y == f() @test ẏ == NoTangent() @@ -19,7 +19,7 @@ end @testset "argument" begin arg = 2.1 - y, ẏ = frule((1.0, ), ignore_derivatives, arg) + y, ẏ = frule((1.0,), ignore_derivatives, arg) @test y == arg @test ẏ == NoTangent() @@ -41,11 +41,11 @@ end @test pb(1.0) == (NoTangent(), NoTangent()) # when called - y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0)) + y, ẏ = frule((1.0,), ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test ẏ == NoTangent() - y, pb = rrule(ignore_derivatives, ()->mf(3.0)) + y, pb = rrule(ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test pb(1.0) == (NoTangent(), NoTangent()) end diff --git a/test/projection.jl b/test/projection.jl index ba61fb8da..ab418ef79 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,20 +24,21 @@ struct NoSuperType end # real / complex @test ProjectTo(1.0)(2.0 + 3im) === 2.0 @test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im - @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im - @test ProjectTo(2.0)(1+1im) === 1.0 - + @test ProjectTo(2.0 + 3.0im)(1 + 1im) === 1.0 + 1.0im + @test ProjectTo(2.0)(1 + 1im) === 1.0 + # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) - @test ProjectTo(1//2)(3//4) === 3//4 + @test ProjectTo(1 // 2)(3 // 4) === 3 // 4 @test ProjectTo(1.0f0)(1 / 2) === 0.5f0 @test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im @test ProjectTo(big(1.0))(2) === 2 @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re = 1, im = NoTangent())) === + 1.0f0 + 0.0f0im end @testset "Dual" begin # some weird Real subtype that we should basically leave alone @@ -46,13 +47,12 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} - @test ProjectTo(1.0 + 1im)( - Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)) - ) isa Complex{<:Dual} + @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa + Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent - @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent + @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value = 1.0)) isa Tangent end @testset "Base: arrays of numbers" begin @@ -99,10 +99,10 @@ struct NoSuperType end # arrays of other things @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} - @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} + @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. - @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) + @test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im) @test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent} # empty arrays @@ -172,7 +172,7 @@ struct NoSuperType end # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called - xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) + xs = adj(Any[Any[1, 2, 3], Any[4+im, 5-im, 6+im, 7-im]]) pvecvec3 = ProjectTo(xs) @test pvecvec3(xs)[1] == [1 2 3] @test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im]) @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()); + dx = Tangent{typeof(x)}(; step = 0.1, ref = NoTangent()) @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3+4im) + dy = Tangent{typeof(Ref(2.0))}(x = 3 + 4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end @@ -365,21 +365,21 @@ struct NoSuperType end # Each "@test 33 > ..." is zero on nightly, 32 on 1.5. pvec = ProjectTo(rand(10^3)) - @test 0 == @ballocated $pvec(dx) setup=(dx = rand(10^3)) # pass through - @test 90 > @ballocated $pvec(dx) setup=(dx = rand(10^3, 1)) # reshape + @test 0 == @ballocated $pvec(dx) setup = (dx = rand(10^3)) # pass through + @test 90 > @ballocated $pvec(dx) setup = (dx = rand(10^3, 1)) # reshape @test 33 > @ballocated ProjectTo(x)(dx) setup = (x = rand(10^3); dx = rand(10^3)) # including construction padj = ProjectTo(adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = transpose(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = adjoint(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = transpose(rand(10^3))) @test 33 > @ballocated ProjectTo(x')(dx') setup = (x = rand(10^3); dx = rand(10^3)) pdiag = ProjectTo(Diagonal(rand(10^3))) - @test 0 == @ballocated $pdiag(dx) setup=(dx = Diagonal(rand(10^3))) + @test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3))) psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) - @test_broken 0 == @ballocated $psymm(dx) setup=(dx = Symmetric(rand(10^3, 10^3))) # 64 + @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 0d6d98535..e99b66c2f 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -19,7 +19,7 @@ macro test_macro_throws(err_expr, expr) end end # Reuse `@test_throws` logic - if err!==nothing + if err !== nothing @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) @@ -29,21 +29,21 @@ end # struct need to be defined outside of tests for julia 1.0 compat struct NonDiffExample - x + x::Any end struct NonDiffCounterExample - x + x::Any end module NonDiffModuleExample - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] +nondiff_2_1(x, y) = fill(7.5, 100)[x+y] end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] + nondiff_2_1(x, y) = fill(7.5, 100)[x+y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((ZeroTangent(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, NoTangent()) res, pullback = rrule(nondiff_2_1, 3, 2) @@ -58,7 +58,7 @@ end res, pullback = rrule(nondiff_1_2, 3.1) @test res == (5.0, 3.0) @test isequal( - pullback(Tangent{Tuple{Float64, Float64}}(1.2, 3.2)), + pullback(Tangent{Tuple{Float64,Float64}}(1.2, 3.2)), (NoTangent(), NoTangent()), ) end @@ -81,7 +81,8 @@ end pointy_identity(x) = x @non_differentiable pointy_identity(::Vector{<:AbstractString}) - @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == (["2"], NoTangent()) + @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == + (["2"], NoTangent()) @test frule((ZeroTangent(), 1.2), pointy_identity, 2.0) == nothing res, pullback = rrule(pointy_identity, ["2"]) @@ -92,7 +93,7 @@ end end @testset "kwargs" begin - kw_demo(x; kw=2.0) = x + kw + kw_demo(x; kw = 2.0) = x + kw @non_differentiable kw_demo(::Any) @testset "not setting kw" begin @@ -106,13 +107,14 @@ end end @testset "setting kw" begin - @assert kw_demo(1.5; kw=3.0) == 4.5 + @assert kw_demo(1.5; kw = 3.0) == 4.5 - res, pullback = rrule(kw_demo, 1.5; kw=3.0) + res, pullback = rrule(kw_demo, 1.5; kw = 3.0) @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw = 3.0) == + (4.5, NoTangent()) end end @@ -121,7 +123,7 @@ end @test isequal( frule((ZeroTangent(), 1.2), NonDiffExample, 2.0), - (NonDiffExample(2.0), NoTangent()) + (NonDiffExample(2.0), NoTangent()), ) res, pullback = rrule(NonDiffExample, 2.0) @@ -151,7 +153,7 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), NoTangent()) @test frule((1, 1), fvarargs, 1, 2) == nothing - @test rrule(fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @@ -194,10 +196,10 @@ end end @testset "Functors" begin - (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] + (f::NonDiffExample)(y) = fill(7.5, 100)[f.x+y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x=1.2), 2.3), NonDiffExample(3), 2) == - (7.5, NoTangent()) + @test frule((Tangent{NonDiffExample}(x = 1.2), 2.3), NonDiffExample(3), 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent()) @@ -205,8 +207,12 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) - @test frule((ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2) == - (7.5, NoTangent()) + @test frule( + (ZeroTangent(), 1.2, 2.3), + NonDiffModuleExample.nondiff_2_1, + 3, + 2, + ) == (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) @@ -216,7 +222,7 @@ end # Where clauses are not supported. @test_macro_throws( ErrorException, - (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) + (@non_differentiable where_identity(::Vector{T}) where {T<:AbstractString}) ) end end @@ -224,32 +230,33 @@ end @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) - @scalar_rule(simo(x), 1f0, 2f0) + @scalar_rule(simo(x), 1.0f0, 2.0f0) y, simo_pb = rrule(simo, π) - @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) + @test simo_pb((10.0f0, 20.0f0)) == (NoTangent(), 50.0f0) - y, ẏ = frule((NoTangent(), 50f0), simo, π) + y, ẏ = frule((NoTangent(), 50.0f0), simo, π) @test y == (π, 2π) - @test ẏ == Tangent{typeof(y)}(50f0, 100f0) + @test ẏ == Tangent{typeof(y)}(50.0f0, 100.0f0) # make sure type is exactly as expected: - @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} + @test ẏ isa Tangent{Tuple{Irrational{:π},Float64},Tuple{Float32,Float32}} xs, Ω = (3,), (3, 6) - @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == + ((1.0f0,), (2.0f0,)) end @testset "@scalar_rule projection" begin - make_imaginary(x) = im*x + make_imaginary(x) = im * x @scalar_rule make_imaginary(x) im # note: the === will make sure that these are Float64, not ComplexF64 - @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im) + @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0 * im) @test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0) - @test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im) - @test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) + @test (NoTangent(), 1.0 + 0.0im) === rrule(make_imaginary, 2.0im)[2](1.0 * im) + @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end @testset "Regression tests against #276 and #265" begin @@ -257,16 +264,16 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 # Symptom of these problems is creation of global variables and type instability - num_globals_before = length(names(ChainRulesCore; all=true)) + num_globals_before = length(names(ChainRulesCore; all = true)) simo2(x) = (x, 2x) @scalar_rule(simo2(x), 1.0, 2.0) _, simo2_pb = rrule(simo2, 43.0) # make sure it infers: inferability implies type stability - @inferred simo2_pb(Tangent{Tuple{Float64, Float64}}(3.0, 6.0)) + @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) # Test no new globals were created - @test length(names(ChainRulesCore; all=true)) == num_globals_before + @test length(names(ChainRulesCore; all = true)) == num_globals_before # Example in #265 simo3(x) = sincos(x) @@ -279,60 +286,60 @@ end module IsolatedModuleForTestingScoping - # check that rules can be defined by macros without any additional imports - using ChainRulesCore: @scalar_rule, @non_differentiable - - # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved - const ChainRulesCore = nothing - - # this is - # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 - fixed(x) = :abc - @non_differentiable fixed(x) - - # check name collision between a primal input called `kwargs` and the actual keyword - # arguments - fixed_kwargs(x; kwargs...) = :abc - @non_differentiable fixed_kwargs(kwargs) - - my_id(x) = x - @scalar_rule(my_id(x), 1.0) - - module IsolatedSubmodule - # check that rules defined in isolated module without imports can be called - # without errors - using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output - using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id - using Test - - @testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((ZeroTangent(), randn()), f, randn()) - @test y === :abc - @test ẏ === NoTangent() - - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end +# check that rules can be defined by macros without any additional imports +using ChainRulesCore: @scalar_rule, @non_differentiable + +# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved +const ChainRulesCore = nothing + +# this is +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 +fixed(x) = :abc +@non_differentiable fixed(x) + +# check name collision between a primal input called `kwargs` and the actual keyword +# arguments +fixed_kwargs(x; kwargs...) = :abc +@non_differentiable fixed_kwargs(kwargs) + +my_id(x) = x +@scalar_rule(my_id(x), 1.0) + +module IsolatedSubmodule +# check that rules defined in isolated module without imports can be called +# without errors +using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output +using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id +using Test + +@testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((ZeroTangent(), randn()), f, randn()) + @test y === :abc + @test ẏ === NoTangent() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword = randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) +end - @testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ +@testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (NoTangent(), Δy) + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (NoTangent(), Δy) - @test derivatives_given_output(y, my_id, x) == ((1.0,),) - end - end + @test derivatives_given_output(y, my_id, x) == ((1.0,),) +end +end end diff --git a/test/rules.jl b/test/rules.jl index d43ca42d2..267b23005 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,8 +28,11 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Any, Any, Vararg}, - ::typeof(mixed_vararg), x, y, z..., + dargs::Tuple{Any,Any,Any,Vararg}, + ::typeof(mixed_vararg), + x, + y, + z..., ) Δx = dargs[2] Δy = dargs[3] @@ -39,16 +42,21 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any, Int, Float64}, - ::typeof(type_constraints), x::Int, y::Float64, + (_, Δx, Δy)::Tuple{Any,Int,Float64}, + ::typeof(type_constraints), + x::Int, + y::Float64, ) return type_constraints(x, y), Δx + Δy end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, - ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, + dargs::Tuple{Any,Float64,Real,Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), + x::Float64, + y::Real, + z::Vararg{Float64}, ) Δx = dargs[2] Δy = dargs[3] @@ -65,9 +73,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = ZeroTangent() @test frule((dself, 1), cool, 1) === nothing - @test frule((dself, 1), cool, 1; iscool=true) === nothing + @test frule((dself, 1), cool, 1; iscool = true) === nothing @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool=true) === nothing + @test rrule(cool, 1; iscool = true) === nothing # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @@ -76,8 +84,10 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}]) + only_methods = Set([ + Tuple{typeof(rrule),typeof(cool),Number}, + Tuple{typeof(rrule),typeof(cool),String}, + ]) @test cool_methods == only_methods frx, cool_pushforward = frule((dself, 1), cool, 1) @@ -98,21 +108,26 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) - @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == + (10.0, 10.0) @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing - @test(frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, - ) == (6.0, 6.0)) + @test( + frule( + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, + 3.0, + 2.0, + 1.0, + 0.0, + ) == (6.0, 6.0) + ) # violates type constraints, thus an frule should not be found. - @test frule( - (nothing, 3, 2.0, 1.0, 5.0), - mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, - ) == nothing + @test frule((nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0) == + nothing @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) @@ -153,27 +168,29 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) - @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where {T<:Float32} @opt_out( - ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float32} ) @testset "rrule" begin @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) - @test rrule(first_oa, 3f0, 4f0) === nothing + @test rrule(first_oa, 3.0f0, 4.0f0) === nothing @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} end) end @testset "frule" begin - @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) - @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing - - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m - m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 - end) + @test frule((NoTangent(), 1, 0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1, 0), first_oa, 3.0f0, 4.0f0) === nothing + + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end end end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7e0ec9398..fdbb92f55 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -9,7 +9,7 @@ @test view(NoTangent(), 1, 2) == NoTangent() @test sum(ZeroTangent()) == ZeroTangent() - @test sum(NoTangent(); dims=2) == NoTangent() + @test sum(NoTangent(); dims = 2) == NoTangent() end @testset "ZeroTangent" begin @@ -55,7 +55,7 @@ @test muladd(x, ZeroTangent(), ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), x, ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), ZeroTangent(), ZeroTangent()) === ZeroTangent() - + @test reim(z) === (ZeroTangent(), ZeroTangent()) @test real(z) === ZeroTangent() @test imag(z) === ZeroTangent() diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 2fd337979..2b7c6347e 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -1,10 +1,14 @@ @testset "NotImplemented" begin @testset "NotImplemented" begin ni = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" + @__MODULE__, + LineNumberNode(@__LINE__, @__FILE__), + "error", ) ni2 = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" + @__MODULE__, + LineNumberNode(@__LINE__, @__FILE__), + "error2", ) x = rand() thunk = @thunk(x^2) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 694e43b53..cc24d988e 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -1,6 +1,6 @@ # For testing Tangent struct Foo - x + x::Any y::Float64 end @@ -12,81 +12,81 @@ end # For testing Tangent: it is an invarient of the type that x2 = 2x # so simple addition can not be defined struct StructWithInvariant - x - x2 + x::Any + x2::Any StructWithInvariant(x) = new(x, 2x) end @testset "Tangent" begin @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{}, Tuple{}} + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} end @testset "==" begin - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(x=0.1, y=2.5) - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(y=2.5, x=0.1) - @test Tangent{Foo}(y=2.5, x=ZeroTangent()) == Tangent{Foo}(y=2.5) + @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(x = 0.1, y = 2.5) + @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(y = 2.5, x = 0.1) + @test Tangent{Foo}(y = 2.5, x = ZeroTangent()) == Tangent{Foo}(y = 2.5) - @test Tangent{Tuple{Float64,}}(2.0) == Tangent{Tuple{Float64,}}(2.0) + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2*1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(;y=2.0,) == Tangent{Foo}(;x=ZeroTangent(), y=Float32(2.0),) + @test Tangent{Foo}(; y = 2.0) == Tangent{Foo}(; x = ZeroTangent(), y = Float32(2.0)) end @testset "hash" begin - @test hash(Tangent{Foo}(x=0.1, y=2.5)) == hash(Tangent{Foo}(y=2.5, x=0.1)) - @test hash(Tangent{Foo}(y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(y=2.5)) + @test hash(Tangent{Foo}(x = 0.1, y = 2.5)) == hash(Tangent{Foo}(y = 2.5, x = 0.1)) + @test hash(Tangent{Foo}(y = 2.5, x = ZeroTangent())) == hash(Tangent{Foo}(y = 2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(x=2.5), :x) == true + @test keys(Tangent{Foo}(x = 2.5)) == (:x,) + @test propertynames(Tangent{Foo}(x = 2.5)) == (:x,) + @test haskey(Tangent{Foo}(x = 2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x=2.5), :y) == false + @test hasproperty(Tangent{Foo}(x = 2.5), :y) == false end - @test Tangent{Foo}(x=2.5).x == 2.5 - - @test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - - NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}} - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 - - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 + @test Tangent{Foo}(x = 2.5).x == 2.5 + + @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x=2.5)) == 1 - @test length(Tangent{Tuple{Float64,}}(2.0)) == 1 + @test length(Tangent{Foo}(x = 2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64,}}(2.0)) == Float64 + @test eltype(Tangent{Foo}(x = 2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64,}}(2.0)) == [2.0] + @test collect(Tangent{Foo}(x = 2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] # Test indexed_iterate ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function(tangent) + _unpack2tuple = function (tangent) a, b = tangent return (a, b) end @@ -96,33 +96,33 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(x = 2, y = 3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(y = 3.0)) end end @testset "reverse" begin - c = Tangent{Tuple{Int, Int, String}}(1, 2, "something") - cr = Tangent{Tuple{String, Int, Int}}("something", 2, 1) + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(;x=1.0, y=2.0)) + @test_throws MethodError reverse(Tangent{Foo}(; x = 1.0, y = 2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo, typeof(d)}(d) + cdict = Tangent{Foo,typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end @testset "unset properties" begin - @test Tangent{Foo}(; x=1.4).y === ZeroTangent() + @test Tangent{Foo}(; x = 1.4).y === ZeroTangent() end @testset "conj" begin - @test conj(Tangent{Foo}(x=2.0+3.0im)) == Tangent{Foo}(x=2.0-3.0im) + @test conj(Tangent{Foo}(x = 2.0 + 3.0im)) == Tangent{Foo}(x = 2.0 - 3.0im) @test ==( - conj(Tangent{Tuple{Float64,}}(2.0+3.0im)), - Tangent{Tuple{Float64,}}(2.0-3.0im) + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), + Tangent{Tuple{Float64}}(2.0 - 3.0im), ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -132,26 +132,20 @@ end @testset "canonicalize" begin # Testing iterate via collect - @test ==( - canonicalize(Tangent{Tuple{Float64,}}(2.0)), - Tangent{Tuple{Float64,}}(2.0) - ) + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - @test ==( - canonicalize(Tangent{Dict}(Dict(4 => 3))), - Tangent{Dict}(Dict(4 => 3)), - ) + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x=2.5, y=10)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10, x=2.5)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10)) == CFoo(x=ZeroTangent(), y=10) + @test canonicalize(CFoo(x = 2.5, y = 10)) == CFoo(x = 2.5, y = 10) + @test canonicalize(CFoo(y = 10, x = 2.5)) == CFoo(x = 2.5, y = 10) + @test canonicalize(CFoo(y = 10)) == CFoo(x = ZeroTangent(), y = 10) - @test_throws ArgumentError canonicalize(CFoo(q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(q = 99.0, x = 2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(;a=1, b=2) + c1 = Tangent{Any}(; a = 1, b = 2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -164,30 +158,28 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x=1.5) + CFoo(x=2.5) == CFoo(x=4.0) - @test CFoo(y=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=2.5) - @test CFoo(y=1.5, x=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=4.0) + @test CFoo(x = 1.5) + CFoo(x = 2.5) == CFoo(x = 4.0) + @test CFoo(y = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 2.5) + @test CFoo(y = 1.5, x = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 4.0) end @testset "Tuples" begin @test ==( typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{}, Tuple{}} + Tangent{Tuple{},Tuple{}}, ) @test ( - Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64, Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64, Float64}}(2.0, 3.0) + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) end @testset "NamedTuples" begin - nt1 = (;a=1.5, b=0.0) - nt2 = (;a=0.0, b=2.5) - nt_sum = (a=1.5, b=2.5) - @test ( - Tangent{typeof(nt1)}(; nt1...) + - Tangent{typeof(nt2)}(; nt2...) - ) == Tangent{typeof(nt_sum)}(; nt_sum...) + nt1 = (; a = 1.5, b = 0.0) + nt2 = (; a = 0.0, b = 2.5) + nt_sum = (a = 1.5, b = 2.5) + @test (Tangent{typeof(nt1)}(; nt1...) + Tangent{typeof(nt2)}(; nt2...)) == + Tangent{typeof(nt_sum)}(; nt_sum...) end @testset "Dicts" begin @@ -199,8 +191,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x=1.5) - b = CFoo(x=@not_implemented("")) + a = CFoo(x = 1.5) + b = CFoo(x = @not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -215,8 +207,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x=1.5) - b = Tangent{NamedTuple{(:x,)}}(x=@not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(x = 1.5) + b = Tangent{NamedTuple{(:x,)}}(x = @not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -235,35 +227,35 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 + @test Foo(3.5, 1.5) + Tangent{Foo}(x = 2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(x = 2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x = 0.5)) == 0 end @testset "Tuples" begin @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end @testset "NamedTuple" begin - ntx = (; a=1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + ntx = (; a = 1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a = 3.0) - nty = (; a=1.5, b=0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) + nty = (; a = 1.5, b = 0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a = 3.0, b = 1.0) end @testset "Dicts" begin d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 =>5.0)) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x=2.0, x2=6.0) + diff = Tangent{StructWithInvariant}(x = 2.0, x2 = 6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -280,7 +272,7 @@ end # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2/2)/2 + x = (nt.x + nt.x2 / 2) / 2 return StructWithInvariant(x) end @test value + diff == StructWithInvariant(12.5) @@ -288,7 +280,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y=1.5, x=2.5) + c = Tangent{Foo}(y = 1.5, x = 2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -310,14 +302,14 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y=1.5, x=2.5) - == Tangent{Foo}(y=3.0, x=5.0) - == Tangent{Foo}(y=1.5, x=2.5) * 2 + 2 * Tangent{Foo}(y = 1.5, x = 2.5) == + Tangent{Foo}(y = 3.0, x = 5.0) == + Tangent{Foo}(y = 1.5, x = 2.5) * 2 ) @test ( - 2 * Tangent{Tuple{Float64, Float64}}(2.0, 4.0) - == Tangent{Tuple{Float64, Float64}}(4.0, 8.0) - == Tangent{Tuple{Float64, Float64}}(2.0, 4.0) * 2 + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 ) d = Tangent{Dict}(Dict(4 => 3.0)) two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) @@ -325,7 +317,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x=1,)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(x = 1)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -342,8 +334,9 @@ end end @testset "Internals don't allocate a ton" begin - bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + bk = (; x = 1.0, y = 2.0) + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 # weaker version of the above (which should pass on all versions) @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 @@ -352,8 +345,8 @@ end end @testset "non-same-typed differential arithmetic" begin - nt = (; a=1, b=2.0) - c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1); + nt = (; a = 1, b = 2.0) + c = Tangent{typeof(nt)}(; a = NoTangent(), b = 0.1) + @test nt + c == (; a = 1, b = 2.1) end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 89461caa1..af4a747d1 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -141,7 +141,7 @@ # Check against accidential type piracy # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472 @test Base.which(diagm, Tuple{}()).module != ChainRulesCore - @test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore + @test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore end @test tril(a) == tril(t) @test tril(a, 1) == tril(t, 1) From 7e701648c26c1cfa845da5255ed33e55f2fa82c6 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 08:51:38 +0300 Subject: [PATCH 02/20] format workflow --- .github/workflows/format.yml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 .github/workflows/format.yml diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 000000000..06b8dbe46 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,26 @@ +name: Format suggestions + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + tool_name: JuliaFormatter + fail_on_error: true From 8ddb4acc97b32c833bc006aff4b8c1b5e0295f18 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 13:29:16 +0300 Subject: [PATCH 03/20] format with blue style --- .JuliaFormatter.toml | 1 + docs/make.jl | 24 ++--- docs/src/assets/make_logo.jl | 13 ++- src/accumulation.jl | 2 - src/config.jl | 1 - src/projection.jl | 79 ++++++++--------- src/rule_definition_tools.jl | 52 +++++------ src/tangent_types/abstract_zero.jl | 2 +- src/tangent_types/notimplemented.jl | 18 ++-- src/tangent_types/tangent.jl | 14 +-- src/tangent_types/thunks.jl | 15 ++-- test/config.jl | 52 ++++------- test/projection.jl | 22 ++--- test/rule_definition_tools.jl | 28 +++--- test/rules.jl | 35 +++----- test/tangent_types/abstract_zero.jl | 2 +- test/tangent_types/notimplemented.jl | 8 +- test/tangent_types/tangent.jl | 125 +++++++++++++-------------- 18 files changed, 218 insertions(+), 275 deletions(-) create mode 100644 .JuliaFormatter.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..323237bab --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" diff --git a/docs/make.jl b/docs/make.jl index 608422c25..42e39a4c0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,12 +21,12 @@ DocMeta.setdocmeta!( indigo = DocThemeIndigo.install(ChainRulesCore) -makedocs( - modules = [ChainRulesCore], - format = Documenter.HTML( - prettyurls = false, - assets = [indigo], - mathengine = MathJax3( +makedocs(; + modules=[ChainRulesCore], + format=Documenter.HTML(; + prettyurls=false, + assets=[indigo], + mathengine=MathJax3( Dict( :tex => Dict( "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], @@ -42,9 +42,9 @@ makedocs( ), ), ), - sitename = "ChainRules", - authors = "Jarrett Revels and other contributors", - pages = [ + sitename="ChainRules", + authors="Jarrett Revels and other contributors", + pages=[ "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", @@ -63,8 +63,8 @@ makedocs( ], "API" => "api.md", ], - strict = true, - checkdocs = :exports, + strict=true, + checkdocs=:exports, ) -deploydocs(repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview = true) +deploydocs(; repo="github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview=true) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 3e7aeaa08..c023c308f 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -8,7 +8,7 @@ using Random const bridge_len = 50 -function chain(jiggle = 0) +function chain(jiggle=0) shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) ### 1 @@ -17,7 +17,6 @@ function chain(jiggle = 0) link() m1 = getmatrix() - ### 2 sethue(Luxor.julia_green) translate(-50, 130) @@ -38,15 +37,13 @@ function chain(jiggle = 0) setmatrix(m2) setcolor(Luxor.julia_green) - overlap(-1.5π) + return overlap(-1.5π) end - function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - rect(50, -3, 40, bridge_len + 6, :fill) rect(-50 - 40, -3, 40, bridge_len + 6, :fill) @@ -58,7 +55,7 @@ function link() move(Point(-90, bridge_len)) arc(Point(0, 0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) - strokepath() + return strokepath() end function overlap(ang_end) @@ -68,7 +65,7 @@ function overlap(ang_end) move(Point(90, bridge_len)) arc(Point(0, bridge_len), 90, 0, ang_end, :stoke) - strokepath() + return strokepath() end # Actually draw it @@ -80,7 +77,7 @@ function save_logo(filename) translate(50, -130) chain(0.5) finish() - preview() + return preview() end save_logo("logo.svg") diff --git a/src/accumulation.jl b/src/accumulation.jl index 5fbc07fa8..538216f38 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -34,7 +34,6 @@ function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} end end - """ is_inplaceable_destination(x) -> Bool @@ -64,7 +63,6 @@ end is_inplaceable_destination(::LinearAlgebra.Hermitian) = false is_inplaceable_destination(::LinearAlgebra.Symmetric) = false - function debug_add!(accumuland, t::InplaceableThunk) returned_value = t.add!(accumuland) if returned_value !== accumuland diff --git a/src/config.jl b/src/config.jl index 347e05c51..04757e838 100644 --- a/src/config.jl +++ b/src/config.jl @@ -64,7 +64,6 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar """ struct NoForwardsMode <: ForwardsModeCapability end - """ frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) diff --git a/src/projection.jl b/src/projection.jl index 55f6e7bfd..2e1a9340e 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -32,7 +32,7 @@ ProjectTo{P}() where {P} = ProjectTo{P}(EMPTY_NT) const Type_kwfunc = Core.kwftype(Type).instance function (::typeof(Type_kwfunc))(kws::Any, ::Type{ProjectTo{P}}) where {P} - ProjectTo{P}(NamedTuple(kws)) + return ProjectTo{P}(NamedTuple(kws)) end Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) @@ -131,8 +131,9 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} = - ProjectTo{NoTangent}() +function ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} + return ProjectTo{NoTangent}() +end # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -168,11 +169,13 @@ end (::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = - convert(T, dx) +function (::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} + return convert(T, dx) +end (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = - convert(T, dx) +function (::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} + return convert(T, dx) +end (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -193,7 +196,7 @@ end # For arrays of numbers, just store one projector: function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element = _eltype_projectto(T), axes = axes(x)) + return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -207,7 +210,7 @@ function ProjectTo(xs::AbstractArray) return ProjectTo{NoTangent}() # short-circuit if all elements project to zero else # Arrays of arrays come here, and will apply projectors individually: - return ProjectTo{AbstractArray}(; elements = elements, axes = axes(xs)) + return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs)) end end @@ -217,7 +220,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} dy = if axes(dx) == project.axes dx else - for d = 1:max(M, length(project.axes)) + for d in 1:max(M, length(project.axes)) if size(dx, d) != length(get(project.axes, d, 1)) throw(_projection_mismatch(project.axes, size(dx))) end @@ -249,7 +252,7 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro if !(project.axes isa Tuple{}) throw( DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number" ), ) end @@ -259,7 +262,7 @@ end function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) return DimensionMismatch( - "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx", + "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx" ) end @@ -273,13 +276,13 @@ function ProjectTo(x::Ref) if sub isa ProjectTo{<:AbstractZero} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type = typeof(x), x = sub) + return ProjectTo{Ref}(; type=typeof(x), x=sub) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x = project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x = project.x(dx[])) +(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) +(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) # Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x = project.x(dx)) +(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) ##### ##### `LinearAlgebra` @@ -288,7 +291,7 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Row vectors -ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent = ProjectTo(parent(x))) +ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x))) # Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec. # Transposed matrices are, like PermutedDimsArray, just a storage detail, # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number @@ -303,8 +306,9 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = - ProjectTo{Transpose}(; parent = ProjectTo(parent(x))) +function ProjectTo(x::LinearAlgebra.TransposeAbsVec) + return ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +end function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -317,7 +321,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) end # Diagonal -ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag = ProjectTo(x.diag)) +ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) @@ -329,10 +333,7 @@ for (SymHerm, chk, fun) in sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; - uplo = LinearAlgebra.sym_uplo(x.uplo), - parent = sub, - ) + return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) dy = project.parent(dx) @@ -345,8 +346,9 @@ for (SymHerm, chk, fun) in # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) sub = project.parent # this is going to be unhappy about the size - sub_one = - ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) + sub_one = ProjectTo{project_type(sub)}(; + element=sub.element, axes=(sub.axes[1],) + ) return Diagonal(sub_one(dx.diag)) end end @@ -355,12 +357,13 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent = ProjectTo(parent(x))) + ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent - sub_one = - ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) + sub_one = ProjectTo{project_type(sub)}(; + element=sub.element, axes=(sub.axes[1],) + ) return Diagonal(sub_one(dx.diag)) end end @@ -397,7 +400,7 @@ end # another strategy is just to use the AbstractArray method function ProjectTo(x::Tridiagonal{T}) where {T<:Number} notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent = notparent) + return ProjectTo{Tridiagonal}(; notparent=notparent) end function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) dy = project.notparent(dx) @@ -416,9 +419,7 @@ using SparseArrays function ProjectTo(x::SparseVector{T}) where {T<:Number} return ProjectTo{SparseVector}(; - element = ProjectTo(zero(T)), - nzind = x.nzind, - axes = axes(x), + element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) ) end function (project::ProjectTo{SparseVector})(dx::AbstractArray) @@ -457,11 +458,11 @@ end function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} return ProjectTo{SparseMatrixCSC}(; - element = ProjectTo(zero(T)), - axes = axes(x), - rowval = rowvals(x), - nzranges = nzrange.(Ref(x), axes(x, 2)), - colptr = x.colptr, + element=ProjectTo(zero(T)), + axes=axes(x), + rowval=rowvals(x), + nzranges=nzrange.(Ref(x), axes(x, 2)), + colptr=x.colptr, ) end # You need not really store nzranges, you can get them from colptr -- TODO @@ -481,7 +482,7 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) for i in project.nzranges[col] row = project.rowval[i] val = dy[row, col] - nzval[k+=1] = project.element(val) + nzval[k += 1] = project.element(val) end end m, n = map(length, project.axes) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 8a1e1cce4..d7510c8d6 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -83,8 +83,9 @@ For examples, see ChainRules' `rulesets` directory. See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) - call, setup_stmts, inputs, partials = - _normalize_scalarrules_macro_input(call, maybe_setup, partials) + call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( + call, maybe_setup, partials + ) f = call.args[1] # Generate variables to store derivatives named dfi/dxj @@ -98,11 +99,11 @@ macro scalar_rule(call, maybe_setup, partials...) rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro - code = quote + return code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 throw( ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))", + "@scalar_rule cannot be used on closures/functors (such as $($f))" ), ) end @@ -113,7 +114,6 @@ macro scalar_rule(call, maybe_setup, partials...) end end - """ _normalize_scalarrules_macro_input(call, maybe_setup, partials) @@ -177,9 +177,7 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote function ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), - ::Core.Typeof($f), - $(inputs...), + $(esc(:Ω)), ::Core.Typeof($f), $(inputs...) ) $(__source__) $(setup_stmts...) @@ -201,8 +199,9 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) end if n_outputs > 1 # For forward-mode we return a Tangent if output actually a tuple. - pushforward_returns = - Expr(:call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns...) + pushforward_returns = Expr( + :call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns... + ) else pushforward_returns = first(pushforward_returns) end @@ -214,8 +213,9 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = - ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) return $(esc(:Ω)), $pushforward_returns end end @@ -253,8 +253,9 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = - ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) return $(esc(:Ω)), $pullback end end @@ -263,7 +264,7 @@ end # For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" -_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i = 1:n] +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) @@ -281,7 +282,7 @@ Specify `_conj = true` to conjugate the partials. Projector `proj` is a function that will be applied at the end; for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity` """ -function propagation_expr(Δs, ∂s, _conj = false, proj = identity) +function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj @@ -295,7 +296,7 @@ function propagation_expr(Δs, ∂s, _conj = false, proj = identity) # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) summed_∂_mul_Δs = - foldl(Iterators.drop(zip(_∂s, Δs), 1); init = init_expr) do ex, (∂s_i, Δs_i) + foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) :(muladd($∂s_i, $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) @@ -373,7 +374,7 @@ macro non_differentiable(sig_expr) primal_invoke = if !has_vararg :($(primal_name)($(unconstrained_args...))) else - normal_args = unconstrained_args[1:end-1] + normal_args = unconstrained_args[1:(end - 1)] var_arg = unconstrained_args[end] :($(primal_name)($(normal_args...), $(var_arg)...)) end @@ -408,8 +409,7 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::Any), - $(map(esc, primal_sig_parts)...), + @nospecialize(::Any), $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() @@ -445,9 +445,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl function (::Core.kwftype(typeof(rrule)))( - $(esc(kwargs))::Any, - ::typeof(rrule), - $(esc_primal_sig_parts...), + $(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...) ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end @@ -458,7 +456,6 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end end - ############################################################################################ # @opt_out @@ -524,8 +521,6 @@ function _no_rule_target_rewrite!(call_target::Symbol) end end - - ############################################################################################ # Helpers @@ -580,7 +575,6 @@ function _split_primal_name(primal_name) if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || Meta.isexpr(primal_name, :curly) - primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name # e.g. (::T)(x, y) @@ -598,7 +592,7 @@ _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. Meta.isexpr(arg, :(...), 1) && return _unconstrain(arg.args[1]) - error("malformed arguments: $arg") + return error("malformed arguments: $arg") end "turn both `a` and `::constraint` into `a::constraint` etc" @@ -607,6 +601,6 @@ function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) - error("malformed arguments: $arg") + return error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index c86fc78ea..5993d32b4 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -33,7 +33,7 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) Base.getindex(z::AbstractZero, k) = z Base.view(z::AbstractZero, ind...) = z -Base.sum(z::AbstractZero; dims = :) = z +Base.sum(z::AbstractZero; dims=:) = z """ ZeroTangent() <: AbstractZero diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index 7ceb315ea..a6b9cc5f9 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,13 +44,15 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw( - NotImplementedException( - @not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" - ) - ), -) +function Base.zero(::Type{<:NotImplemented}) + return throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" + ) + ), + ) +end Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) @@ -79,5 +81,5 @@ function Base.showerror(io::IO, e::NotImplementedException) if e.info !== nothing print(io, "\nInfo: ", e.info) end - return + return nothing end diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 34e822ea8..bb91e431e 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -102,10 +102,10 @@ Base.length(tangent::Tangent) = length(backing(tangent)) Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) + return Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) end -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state = 1) where {P} +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end @@ -301,16 +301,16 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color = :blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")"; color=:blue) println(io, "\nor overload") printstyled( io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color = :blue, + color=:blue, ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color = :blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) println(io, "\nOriginal Exception:") - printstyled(io, err.original; color = :yellow) - println(io) + printstyled(io, err.original; color=:yellow) + return println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index c2b570902..e065bea62 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -56,20 +56,16 @@ LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a)) LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a)) LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a)) LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a)) -LinearAlgebra.Symmetric(a::AbstractThunk, uplo = :U) = Symmetric(unthunk(a), uplo) -LinearAlgebra.Hermitian(a::AbstractThunk, uplo = :U) = Hermitian(unthunk(a), uplo) +LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo) +LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo) function LinearAlgebra.diagm( - kv::Pair{<:Integer,<:AbstractThunk}, - kvs::Pair{<:Integer,<:AbstractThunk}..., + kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...) end function LinearAlgebra.diagm( - m, - n, - kv::Pair{<:Integer,<:AbstractThunk}, - kvs::Pair{<:Integer,<:AbstractThunk}..., + m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end @@ -122,7 +118,7 @@ function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx) return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx) end -function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn = 1) +function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1) return throw(MutateThunkException()) end @@ -201,7 +197,6 @@ Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a) - """ InplaceableThunk(add!::Function, val::Thunk) diff --git a/test/config.jl b/test/config.jl index e6e2ab005..58d943252 100644 --- a/test/config.jl +++ b/test/config.jl @@ -22,7 +22,6 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. return f(args...; kws...), pullback_via_ad end - struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} forward_calls::Vector reverse_calls::Vector @@ -47,10 +46,7 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, - MockForwardsConfig, - MockReverseConfig, - MockBothConfig, + MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig ) counting_id_count = Ref(0) function counting_id(x) @@ -79,33 +75,21 @@ end @testset "hitting forwards AD" begin do_thing_2(f, x) = f(x) function ChainRulesCore.frule( - config::RuleConfig{>:HasForwardsMode}, - (_, df, dx), - ::typeof(do_thing_2), - f, - x, + config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x ) return frule_via_ad(config, (df, dx), f, x) end @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) @test nothing === frule( - Config(), - (NoTangent(), NoTangent(), 21.5), - do_thing_2, - identity, - 32.1, + Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) bconfig = Config() @test nothing !== frule( - bconfig, - (NoTangent(), NoTangent(), 21.5), - do_thing_2, - identity, - 32.1, + bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) @test bconfig.forward_calls == [(identity, (32.1,))] end @@ -114,15 +98,11 @@ end @testset "hitting reverse AD" begin do_thing_3(f, x) = f(x) function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, - ::typeof(do_thing_3), - f, - x, + config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x ) return (NoTangent(), rrule_via_ad(config, f, x)...) end - @testset "$Config" for Config in (MostBoringConfig, MockForwardsConfig) @test nothing === rrule(Config(), do_thing_3, identity, 32.1) end @@ -180,28 +160,28 @@ end end @testset "fallbacks" begin - no_rule(x; kw = "bye") = error() + no_rule(x; kw="bye") = error() @test frule((1.0,), no_rule, 2.0) === nothing - @test frule((1.0,), no_rule, 2.0; kw = "hello") === nothing + @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing - @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw = "hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing @test rrule(no_rule, 2.0) === nothing - @test rrule(no_rule, 2.0; kw = "hello") === nothing + @test rrule(no_rule, 2.0; kw="hello") === nothing @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing - @test rrule(MostBoringConfig(), no_rule, 2.0; kw = "hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(; kw = "hello") + @test_throws MethodError frule(; kw="hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin; kw = "hello") + @test_throws MethodError frule(sin; kw="hello") @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw = "hello") + @test_throws MethodError frule(MostBoringConfig(); kw="hello") @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw = "hello") + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(; kw = "hello") + @test_throws MethodError rrule(; kw="hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig(); kw = "hello") + @test_throws MethodError rrule(MostBoringConfig(); kw="hello") end end diff --git a/test/projection.jl b/test/projection.jl index ab418ef79..cbfdcf6da 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -30,14 +30,14 @@ struct NoSuperType end # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) - @test ProjectTo(1 // 2)(3 // 4) === 3 // 4 + @test ProjectTo(1//2)(3//4) === 3//4 @test ProjectTo(1.0f0)(1 / 2) === 0.5f0 @test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im @test ProjectTo(big(1.0))(2) === 2 @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re = 1, im = NoTangent())) === + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(; re=1, im=NoTangent())) === 1.0f0 + 0.0f0im end @@ -52,7 +52,7 @@ struct NoSuperType end @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent - @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value = 1.0)) isa Tangent + @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent end @testset "Base: arrays of numbers" begin @@ -102,7 +102,7 @@ struct NoSuperType end @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. - @test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im) + @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) @test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent} # empty arrays @@ -126,18 +126,18 @@ struct NoSuperType end @testset "Base: Ref" begin pref = ProjectTo(Ref(2.0)) @test pref(Ref(3 + im)).x === 3.0 - @test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0 + @test pref(Tangent{Base.RefValue}(; x=3 + im)).x === 3.0 @test pref(4).x === 4.0 # also re-wraps scalars @test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue} pref2 = ProjectTo(Ref{Any}(6 + 7im)) @test pref2(Ref(8)).x === 8.0 + 0.0im - @test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im + @test pref2(Tangent{Base.RefValue}(; x=8)).x === 8.0 + 0.0im prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents @test prefvec(Ref(1:3)).x isa Vector{ComplexF64} - @test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64} - @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5)) + @test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64} + @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5)) @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} @@ -172,7 +172,7 @@ struct NoSuperType end # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called - xs = adj(Any[Any[1, 2, 3], Any[4+im, 5-im, 6+im, 7-im]]) + xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) pvecvec3 = ProjectTo(xs) @test pvecvec3(xs)[1] == [1 2 3] @test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im]) @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step = 0.1, ref = NoTangent()) + dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()) @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3 + 4im) + dy = Tangent{typeof(Ref(2.0))}(; x=3 + 4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index e99b66c2f..f4be16218 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -37,13 +37,13 @@ struct NonDiffCounterExample end module NonDiffModuleExample -nondiff_2_1(x, y) = fill(7.5, 100)[x+y] +nondiff_2_1(x, y) = fill(7.5, 100)[x + y] end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin - nondiff_2_1(x, y) = fill(7.5, 100)[x+y] + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((ZeroTangent(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, NoTangent()) res, pullback = rrule(nondiff_2_1, 3, 2) @@ -93,7 +93,7 @@ end end @testset "kwargs" begin - kw_demo(x; kw = 2.0) = x + kw + kw_demo(x; kw=2.0) = x + kw @non_differentiable kw_demo(::Any) @testset "not setting kw" begin @@ -107,13 +107,13 @@ end end @testset "setting kw" begin - @assert kw_demo(1.5; kw = 3.0) == 4.5 + @assert kw_demo(1.5; kw=3.0) == 4.5 - res, pullback = rrule(kw_demo, 1.5; kw = 3.0) + res, pullback = rrule(kw_demo, 1.5; kw=3.0) @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw = 3.0) == + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) end end @@ -196,9 +196,9 @@ end end @testset "Functors" begin - (f::NonDiffExample)(y) = fill(7.5, 100)[f.x+y] + (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x = 1.2), 2.3), NonDiffExample(3), 2) == + @test frule((Tangent{NonDiffExample}(; x=1.2), 2.3), NonDiffExample(3), 2) == (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @@ -208,10 +208,7 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) @test frule( - (ZeroTangent(), 1.2, 2.3), - NonDiffModuleExample.nondiff_2_1, - 3, - 2, + (ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2 ) == (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @@ -264,7 +261,7 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 # Symptom of these problems is creation of global variables and type instability - num_globals_before = length(names(ChainRulesCore; all = true)) + num_globals_before = length(names(ChainRulesCore; all=true)) simo2(x) = (x, 2x) @scalar_rule(simo2(x), 1.0, 2.0) @@ -273,7 +270,7 @@ end @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) # Test no new globals were created - @test length(names(ChainRulesCore; all = true)) == num_globals_before + @test length(names(ChainRulesCore; all=true)) == num_globals_before # Example in #265 simo3(x) = sincos(x) @@ -284,7 +281,6 @@ end end end - module IsolatedModuleForTestingScoping # check that rules can be defined by macros without any additional imports using ChainRulesCore: @scalar_rule, @non_differentiable @@ -323,7 +319,7 @@ using Test @test f_pullback(randn()) === (NoTangent(), NoTangent()) end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword = randn()) + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) @test y === :abc @test f_pullback(randn()) === (NoTangent(), NoTangent()) end diff --git a/test/rules.jl b/test/rules.jl index 267b23005..54c10b160 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,11 +28,7 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any,Any,Any,Vararg}, - ::typeof(mixed_vararg), - x, - y, - z..., + dargs::Tuple{Any,Any,Any,Vararg}, ::typeof(mixed_vararg), x, y, z... ) Δx = dargs[2] Δy = dargs[3] @@ -42,10 +38,7 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any,Int,Float64}, - ::typeof(type_constraints), - x::Int, - y::Float64, + (_, Δx, Δy)::Tuple{Any,Int,Float64}, ::typeof(type_constraints), x::Int, y::Float64 ) return type_constraints(x, y), Δx + Δy end @@ -73,9 +66,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = ZeroTangent() @test frule((dself, 1), cool, 1) === nothing - @test frule((dself, 1), cool, 1; iscool = true) === nothing + @test frule((dself, 1), cool, 1; iscool=true) === nothing @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool = true) === nothing + @test rrule(cool, 1; iscool=true) === nothing # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @@ -85,8 +78,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) only_methods = Set([ - Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}, + Tuple{typeof(rrule),typeof(cool),Number}, Tuple{typeof(rrule),typeof(cool),String} ]) @test cool_methods == only_methods @@ -104,7 +96,6 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) rrx, nice_pullback = rrule(nice, 1) @test (NoTangent(), ZeroTangent()) === nice_pullback(1) - # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) @@ -116,12 +107,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test( frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, - 3.0, - 2.0, - 1.0, - 0.0, + (nothing, 3.0, 2.0, 1.0, 0.0), mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0 ) == (6.0, 6.0) ) @@ -164,7 +150,6 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test_skip ∂xr ≈ real(∂x) end - @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) @@ -177,9 +162,11 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) @test rrule(first_oa, 3.0f0, 4.0f0) === nothing - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} - end) + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_rrule)) do m + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end @testset "frule" begin diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index fdbb92f55..f8222d942 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -9,7 +9,7 @@ @test view(NoTangent(), 1, 2) == NoTangent() @test sum(ZeroTangent()) == ZeroTangent() - @test sum(NoTangent(); dims = 2) == NoTangent() + @test sum(NoTangent(); dims=2) == NoTangent() end @testset "ZeroTangent" begin diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 2b7c6347e..2fd337979 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -1,14 +1,10 @@ @testset "NotImplemented" begin @testset "NotImplemented" begin ni = ChainRulesCore.NotImplemented( - @__MODULE__, - LineNumberNode(@__LINE__, @__FILE__), - "error", + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" ) ni2 = ChainRulesCore.NotImplemented( - @__MODULE__, - LineNumberNode(@__LINE__, @__FILE__), - "error2", + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" ) x = rand() thunk = @thunk(x^2) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index cc24d988e..a9f022920 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -24,9 +24,9 @@ end end @testset "==" begin - @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(x = 0.1, y = 2.5) - @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(y = 2.5, x = 0.1) - @test Tangent{Foo}(y = 2.5, x = ZeroTangent()) == Tangent{Foo}(y = 2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) + @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) @@ -35,22 +35,22 @@ end @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(; y = 2.0) == Tangent{Foo}(; x = ZeroTangent(), y = Float32(2.0)) + @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) end @testset "hash" begin - @test hash(Tangent{Foo}(x = 0.1, y = 2.5)) == hash(Tangent{Foo}(y = 2.5, x = 0.1)) - @test hash(Tangent{Foo}(y = 2.5, x = ZeroTangent())) == hash(Tangent{Foo}(y = 2.5)) + @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) + @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x = 2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x = 2.5)) == (:x,) - @test haskey(Tangent{Foo}(x = 2.5), :x) == true + @test keys(Tangent{Foo}(; x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(; x=2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x = 2.5), :y) == false + @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false end - @test Tangent{Foo}(x = 2.5).x == 2.5 + @test Tangent{Foo}(; x=2.5).x == 2.5 @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) @@ -60,28 +60,28 @@ end @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x = 2.5)) == 1 + @test length(Tangent{Foo}(; x=2.5)) == 1 @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x = 2.5)) == Float64 + @test eltype(Tangent{Foo}(; x=2.5)) == Float64 @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x = 2.5)) == [2.5] + @test collect(Tangent{Foo}(; x=2.5)) == [2.5] @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] # Test indexed_iterate @@ -96,8 +96,8 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x = 2, y = 3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y = 3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) end end @@ -107,7 +107,7 @@ end @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x = 1.0, y = 2.0)) + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) cdict = Tangent{Foo,typeof(d)}(d) @@ -115,14 +115,13 @@ end end @testset "unset properties" begin - @test Tangent{Foo}(; x = 1.4).y === ZeroTangent() + @test Tangent{Foo}(; x=1.4).y === ZeroTangent() end @testset "conj" begin - @test conj(Tangent{Foo}(x = 2.0 + 3.0im)) == Tangent{Foo}(x = 2.0 - 3.0im) + @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), - Tangent{Tuple{Float64}}(2.0 - 3.0im), + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -138,14 +137,14 @@ end # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x = 2.5, y = 10)) == CFoo(x = 2.5, y = 10) - @test canonicalize(CFoo(y = 10, x = 2.5)) == CFoo(x = 2.5, y = 10) - @test canonicalize(CFoo(y = 10)) == CFoo(x = ZeroTangent(), y = 10) + @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(q = 99.0, x = 2.5)) + @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a = 1, b = 2) + c1 = Tangent{Any}(; a=1, b=2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -158,15 +157,14 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x = 1.5) + CFoo(x = 2.5) == CFoo(x = 4.0) - @test CFoo(y = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 2.5) - @test CFoo(y = 1.5, x = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 4.0) + @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) + @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) + @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) end @testset "Tuples" begin @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{},Tuple{}}, + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} ) @test ( Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + @@ -175,9 +173,9 @@ end end @testset "NamedTuples" begin - nt1 = (; a = 1.5, b = 0.0) - nt2 = (; a = 0.0, b = 2.5) - nt_sum = (a = 1.5, b = 2.5) + nt1 = (; a=1.5, b=0.0) + nt2 = (; a=0.0, b=2.5) + nt_sum = (a=1.5, b=2.5) @test (Tangent{typeof(nt1)}(; nt1...) + Tangent{typeof(nt2)}(; nt2...)) == Tangent{typeof(nt_sum)}(; nt_sum...) end @@ -191,8 +189,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x = 1.5) - b = CFoo(x = @not_implemented("")) + a = CFoo(; x=1.5) + b = CFoo(; x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -207,8 +205,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x = 1.5) - b = Tangent{NamedTuple{(:x,)}}(x = @not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(; x=1.5) + b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -227,9 +225,9 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x = 2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x = 2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x = 0.5)) == 0 + @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end @testset "Tuples" begin @@ -239,11 +237,11 @@ end end @testset "NamedTuple" begin - ntx = (; a = 1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a = 3.0) + ntx = (; a=1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) - nty = (; a = 1.5, b = 0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a = 3.0, b = 1.0) + nty = (; a=1.5, b=0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) end @testset "Dicts" begin @@ -255,7 +253,7 @@ end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x = 2.0, x2 = 6.0) + diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -268,7 +266,6 @@ end ChainRulesCore.debug_mode() = false # disable it again end - # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) @@ -280,7 +277,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y = 1.5, x = 2.5) + c = Tangent{Foo}(; y=1.5, x=2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -302,9 +299,9 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y = 1.5, x = 2.5) == - Tangent{Foo}(y = 3.0, x = 5.0) == - Tangent{Foo}(y = 1.5, x = 2.5) * 2 + 2 * Tangent{Foo}(; y=1.5, x=2.5) == + Tangent{Foo}(; y=3.0, x=5.0) == + Tangent{Foo}(; y=1.5, x=2.5) * 2 ) @test ( 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == @@ -317,7 +314,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x = 1)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -334,7 +331,7 @@ end end @testset "Internals don't allocate a ton" begin - bk = (; x = 1.0, y = 2.0) + bk = (; x=1.0, y=2.0) VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 @@ -345,8 +342,8 @@ end end @testset "non-same-typed differential arithmetic" begin - nt = (; a = 1, b = 2.0) - c = Tangent{typeof(nt)}(; a = NoTangent(), b = 0.1) - @test nt + c == (; a = 1, b = 2.1) + nt = (; a=1, b=2.0) + c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) + @test nt + c == (; a=1, b=2.1) end end From cd9376351609b3c7da47fad4b84576ead271c772 Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 7 Oct 2021 13:46:55 +0300 Subject: [PATCH 04/20] Update src/rule_definition_tools.jl --- src/rule_definition_tools.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d7510c8d6..fd32fbbbd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -577,8 +577,7 @@ function _split_primal_name(primal_name) Meta.isexpr(primal_name, :curly) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name - # e.g. (::T)(x, y) - elseif Meta.isexpr(primal_name, :(::)) + elseif Meta.isexpr(primal_name, :(::)) # e.g. (::T)(x, y) _primal_name = gensym(Symbol(:instance_, primal_name.args[end])) primal_name_sig = Expr(:(::), _primal_name, primal_name.args[end]) return primal_name_sig, _primal_name From 299b70bb87c7c3ba03753b8f4576f2112fd9f0a0 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 13:48:01 +0300 Subject: [PATCH 05/20] Revert "format with blue style" This reverts commit 8ddb4acc97b32c833bc006aff4b8c1b5e0295f18. --- .JuliaFormatter.toml | 1 - docs/make.jl | 24 ++--- docs/src/assets/make_logo.jl | 13 +-- src/accumulation.jl | 2 + src/config.jl | 1 + src/projection.jl | 79 +++++++++-------- src/rule_definition_tools.jl | 52 ++++++----- src/tangent_types/abstract_zero.jl | 2 +- src/tangent_types/notimplemented.jl | 18 ++-- src/tangent_types/tangent.jl | 14 +-- src/tangent_types/thunks.jl | 15 ++-- test/config.jl | 52 +++++++---- test/projection.jl | 22 ++--- test/rule_definition_tools.jl | 28 +++--- test/rules.jl | 35 +++++--- test/tangent_types/abstract_zero.jl | 2 +- test/tangent_types/notimplemented.jl | 8 +- test/tangent_types/tangent.jl | 125 ++++++++++++++------------- 18 files changed, 275 insertions(+), 218 deletions(-) delete mode 100644 .JuliaFormatter.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 323237bab..000000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1 +0,0 @@ -style = "blue" diff --git a/docs/make.jl b/docs/make.jl index 42e39a4c0..608422c25 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,12 +21,12 @@ DocMeta.setdocmeta!( indigo = DocThemeIndigo.install(ChainRulesCore) -makedocs(; - modules=[ChainRulesCore], - format=Documenter.HTML(; - prettyurls=false, - assets=[indigo], - mathengine=MathJax3( +makedocs( + modules = [ChainRulesCore], + format = Documenter.HTML( + prettyurls = false, + assets = [indigo], + mathengine = MathJax3( Dict( :tex => Dict( "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], @@ -42,9 +42,9 @@ makedocs(; ), ), ), - sitename="ChainRules", - authors="Jarrett Revels and other contributors", - pages=[ + sitename = "ChainRules", + authors = "Jarrett Revels and other contributors", + pages = [ "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", @@ -63,8 +63,8 @@ makedocs(; ], "API" => "api.md", ], - strict=true, - checkdocs=:exports, + strict = true, + checkdocs = :exports, ) -deploydocs(; repo="github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview=true) +deploydocs(repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview = true) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index c023c308f..3e7aeaa08 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -8,7 +8,7 @@ using Random const bridge_len = 50 -function chain(jiggle=0) +function chain(jiggle = 0) shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) ### 1 @@ -17,6 +17,7 @@ function chain(jiggle=0) link() m1 = getmatrix() + ### 2 sethue(Luxor.julia_green) translate(-50, 130) @@ -37,13 +38,15 @@ function chain(jiggle=0) setmatrix(m2) setcolor(Luxor.julia_green) - return overlap(-1.5π) + overlap(-1.5π) end + function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) + rect(50, -3, 40, bridge_len + 6, :fill) rect(-50 - 40, -3, 40, bridge_len + 6, :fill) @@ -55,7 +58,7 @@ function link() move(Point(-90, bridge_len)) arc(Point(0, 0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) - return strokepath() + strokepath() end function overlap(ang_end) @@ -65,7 +68,7 @@ function overlap(ang_end) move(Point(90, bridge_len)) arc(Point(0, bridge_len), 90, 0, ang_end, :stoke) - return strokepath() + strokepath() end # Actually draw it @@ -77,7 +80,7 @@ function save_logo(filename) translate(50, -130) chain(0.5) finish() - return preview() + preview() end save_logo("logo.svg") diff --git a/src/accumulation.jl b/src/accumulation.jl index 538216f38..5fbc07fa8 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -34,6 +34,7 @@ function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} end end + """ is_inplaceable_destination(x) -> Bool @@ -63,6 +64,7 @@ end is_inplaceable_destination(::LinearAlgebra.Hermitian) = false is_inplaceable_destination(::LinearAlgebra.Symmetric) = false + function debug_add!(accumuland, t::InplaceableThunk) returned_value = t.add!(accumuland) if returned_value !== accumuland diff --git a/src/config.jl b/src/config.jl index 04757e838..347e05c51 100644 --- a/src/config.jl +++ b/src/config.jl @@ -64,6 +64,7 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar """ struct NoForwardsMode <: ForwardsModeCapability end + """ frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) diff --git a/src/projection.jl b/src/projection.jl index 2e1a9340e..55f6e7bfd 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -32,7 +32,7 @@ ProjectTo{P}() where {P} = ProjectTo{P}(EMPTY_NT) const Type_kwfunc = Core.kwftype(Type).instance function (::typeof(Type_kwfunc))(kws::Any, ::Type{ProjectTo{P}}) where {P} - return ProjectTo{P}(NamedTuple(kws)) + ProjectTo{P}(NamedTuple(kws)) end Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) @@ -131,9 +131,8 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -function ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} - return ProjectTo{NoTangent}() -end +ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} = + ProjectTo{NoTangent}() # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -169,13 +168,11 @@ end (::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -function (::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} - return convert(T, dx) -end +(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = + convert(T, dx) (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -function (::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} - return convert(T, dx) -end +(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = + convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -196,7 +193,7 @@ end # For arrays of numbers, just store one projector: function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) + return ProjectTo{AbstractArray}(; element = _eltype_projectto(T), axes = axes(x)) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -210,7 +207,7 @@ function ProjectTo(xs::AbstractArray) return ProjectTo{NoTangent}() # short-circuit if all elements project to zero else # Arrays of arrays come here, and will apply projectors individually: - return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs)) + return ProjectTo{AbstractArray}(; elements = elements, axes = axes(xs)) end end @@ -220,7 +217,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} dy = if axes(dx) == project.axes dx else - for d in 1:max(M, length(project.axes)) + for d = 1:max(M, length(project.axes)) if size(dx, d) != length(get(project.axes, d, 1)) throw(_projection_mismatch(project.axes, size(dx))) end @@ -252,7 +249,7 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro if !(project.axes isa Tuple{}) throw( DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number" + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", ), ) end @@ -262,7 +259,7 @@ end function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) return DimensionMismatch( - "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx" + "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx", ) end @@ -276,13 +273,13 @@ function ProjectTo(x::Ref) if sub isa ProjectTo{<:AbstractZero} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type=typeof(x), x=sub) + return ProjectTo{Ref}(; type = typeof(x), x = sub) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) +(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x = project.x(dx.x)) +(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x = project.x(dx[])) # Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) +(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x = project.x(dx)) ##### ##### `LinearAlgebra` @@ -291,7 +288,7 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Row vectors -ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x))) +ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent = ProjectTo(parent(x))) # Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec. # Transposed matrices are, like PermutedDimsArray, just a storage detail, # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number @@ -306,9 +303,8 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -function ProjectTo(x::LinearAlgebra.TransposeAbsVec) - return ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) -end +ProjectTo(x::LinearAlgebra.TransposeAbsVec) = + ProjectTo{Transpose}(; parent = ProjectTo(parent(x))) function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -321,7 +317,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) end # Diagonal -ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) +ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag = ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) @@ -333,7 +329,10 @@ for (SymHerm, chk, fun) in sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) + return ProjectTo{$SymHerm}(; + uplo = LinearAlgebra.sym_uplo(x.uplo), + parent = sub, + ) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) dy = project.parent(dx) @@ -346,9 +345,8 @@ for (SymHerm, chk, fun) in # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) sub = project.parent # this is going to be unhappy about the size - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) + sub_one = + ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) return Diagonal(sub_one(dx.diag)) end end @@ -357,13 +355,12 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) + ProjectTo(x::$UL) = ProjectTo{$UL}(; parent = ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) + sub_one = + ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) return Diagonal(sub_one(dx.diag)) end end @@ -400,7 +397,7 @@ end # another strategy is just to use the AbstractArray method function ProjectTo(x::Tridiagonal{T}) where {T<:Number} notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent=notparent) + return ProjectTo{Tridiagonal}(; notparent = notparent) end function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) dy = project.notparent(dx) @@ -419,7 +416,9 @@ using SparseArrays function ProjectTo(x::SparseVector{T}) where {T<:Number} return ProjectTo{SparseVector}(; - element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) + element = ProjectTo(zero(T)), + nzind = x.nzind, + axes = axes(x), ) end function (project::ProjectTo{SparseVector})(dx::AbstractArray) @@ -458,11 +457,11 @@ end function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} return ProjectTo{SparseMatrixCSC}(; - element=ProjectTo(zero(T)), - axes=axes(x), - rowval=rowvals(x), - nzranges=nzrange.(Ref(x), axes(x, 2)), - colptr=x.colptr, + element = ProjectTo(zero(T)), + axes = axes(x), + rowval = rowvals(x), + nzranges = nzrange.(Ref(x), axes(x, 2)), + colptr = x.colptr, ) end # You need not really store nzranges, you can get them from colptr -- TODO @@ -482,7 +481,7 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) for i in project.nzranges[col] row = project.rowval[i] val = dy[row, col] - nzval[k += 1] = project.element(val) + nzval[k+=1] = project.element(val) end end m, n = map(length, project.axes) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d7510c8d6..8a1e1cce4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -83,9 +83,8 @@ For examples, see ChainRules' `rulesets` directory. See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) - call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( - call, maybe_setup, partials - ) + call, setup_stmts, inputs, partials = + _normalize_scalarrules_macro_input(call, maybe_setup, partials) f = call.args[1] # Generate variables to store derivatives named dfi/dxj @@ -99,11 +98,11 @@ macro scalar_rule(call, maybe_setup, partials...) rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro - return code = quote + code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 throw( ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))" + "@scalar_rule cannot be used on closures/functors (such as $($f))", ), ) end @@ -114,6 +113,7 @@ macro scalar_rule(call, maybe_setup, partials...) end end + """ _normalize_scalarrules_macro_input(call, maybe_setup, partials) @@ -177,7 +177,9 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote function ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), ::Core.Typeof($f), $(inputs...) + $(esc(:Ω)), + ::Core.Typeof($f), + $(inputs...), ) $(__source__) $(setup_stmts...) @@ -199,9 +201,8 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) end if n_outputs > 1 # For forward-mode we return a Tangent if output actually a tuple. - pushforward_returns = Expr( - :call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns... - ) + pushforward_returns = + Expr(:call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns...) else pushforward_returns = first(pushforward_returns) end @@ -213,9 +214,8 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), $f, $(inputs...) - ) + $(Expr(:tuple, partials...)) = + ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -253,9 +253,8 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), $f, $(inputs...) - ) + $(Expr(:tuple, partials...)) = + ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -264,7 +263,7 @@ end # For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" -_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i = 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) @@ -282,7 +281,7 @@ Specify `_conj = true` to conjugate the partials. Projector `proj` is a function that will be applied at the end; for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity` """ -function propagation_expr(Δs, ∂s, _conj=false, proj=identity) +function propagation_expr(Δs, ∂s, _conj = false, proj = identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj @@ -296,7 +295,7 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) summed_∂_mul_Δs = - foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) + foldl(Iterators.drop(zip(_∂s, Δs), 1); init = init_expr) do ex, (∂s_i, Δs_i) :(muladd($∂s_i, $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) @@ -374,7 +373,7 @@ macro non_differentiable(sig_expr) primal_invoke = if !has_vararg :($(primal_name)($(unconstrained_args...))) else - normal_args = unconstrained_args[1:(end - 1)] + normal_args = unconstrained_args[1:end-1] var_arg = unconstrained_args[end] :($(primal_name)($(normal_args...), $(var_arg)...)) end @@ -409,7 +408,8 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::Any), $(map(esc, primal_sig_parts)...) + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() @@ -445,7 +445,9 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl function (::Core.kwftype(typeof(rrule)))( - $(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...) + $(esc(kwargs))::Any, + ::typeof(rrule), + $(esc_primal_sig_parts...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end @@ -456,6 +458,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end end + ############################################################################################ # @opt_out @@ -521,6 +524,8 @@ function _no_rule_target_rewrite!(call_target::Symbol) end end + + ############################################################################################ # Helpers @@ -575,6 +580,7 @@ function _split_primal_name(primal_name) if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || Meta.isexpr(primal_name, :curly) + primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name # e.g. (::T)(x, y) @@ -592,7 +598,7 @@ _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. Meta.isexpr(arg, :(...), 1) && return _unconstrain(arg.args[1]) - return error("malformed arguments: $arg") + error("malformed arguments: $arg") end "turn both `a` and `::constraint` into `a::constraint` etc" @@ -601,6 +607,6 @@ function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) - return error("malformed arguments: $arg") + error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 5993d32b4..c86fc78ea 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -33,7 +33,7 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) Base.getindex(z::AbstractZero, k) = z Base.view(z::AbstractZero, ind...) = z -Base.sum(z::AbstractZero; dims=:) = z +Base.sum(z::AbstractZero; dims = :) = z """ ZeroTangent() <: AbstractZero diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index a6b9cc5f9..7ceb315ea 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,15 +44,13 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -function Base.zero(::Type{<:NotImplemented}) - return throw( - NotImplementedException( - @not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" - ) - ), - ) -end +Base.zero(::Type{<:NotImplemented}) = throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" + ) + ), +) Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) @@ -81,5 +79,5 @@ function Base.showerror(io::IO, e::NotImplementedException) if e.info !== nothing print(io, "\nInfo: ", e.info) end - return nothing + return end diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index bb91e431e..34e822ea8 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -102,10 +102,10 @@ Base.length(tangent::Tangent) = length(backing(tangent)) Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - return Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) + Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) end -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state = 1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end @@ -301,16 +301,16 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")"; color=:blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color = :blue) println(io, "\nor overload") printstyled( io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color=:blue, + color = :blue, ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color = :blue) println(io, "\nOriginal Exception:") - printstyled(io, err.original; color=:yellow) - return println(io) + printstyled(io, err.original; color = :yellow) + println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index e065bea62..c2b570902 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -56,16 +56,20 @@ LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a)) LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a)) LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a)) LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a)) -LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo) -LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo) +LinearAlgebra.Symmetric(a::AbstractThunk, uplo = :U) = Symmetric(unthunk(a), uplo) +LinearAlgebra.Hermitian(a::AbstractThunk, uplo = :U) = Hermitian(unthunk(a), uplo) function LinearAlgebra.diagm( - kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... + kv::Pair{<:Integer,<:AbstractThunk}, + kvs::Pair{<:Integer,<:AbstractThunk}..., ) return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...) end function LinearAlgebra.diagm( - m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... + m, + n, + kv::Pair{<:Integer,<:AbstractThunk}, + kvs::Pair{<:Integer,<:AbstractThunk}..., ) return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end @@ -118,7 +122,7 @@ function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx) return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx) end -function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1) +function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn = 1) return throw(MutateThunkException()) end @@ -197,6 +201,7 @@ Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a) + """ InplaceableThunk(add!::Function, val::Thunk) diff --git a/test/config.jl b/test/config.jl index 58d943252..e6e2ab005 100644 --- a/test/config.jl +++ b/test/config.jl @@ -22,6 +22,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. return f(args...; kws...), pullback_via_ad end + struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} forward_calls::Vector reverse_calls::Vector @@ -46,7 +47,10 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig + MostBoringConfig, + MockForwardsConfig, + MockReverseConfig, + MockBothConfig, ) counting_id_count = Ref(0) function counting_id(x) @@ -75,21 +79,33 @@ end @testset "hitting forwards AD" begin do_thing_2(f, x) = f(x) function ChainRulesCore.frule( - config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x + config::RuleConfig{>:HasForwardsMode}, + (_, df, dx), + ::typeof(do_thing_2), + f, + x, ) return frule_via_ad(config, (df, dx), f, x) end @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) @test nothing === frule( - Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + Config(), + (NoTangent(), NoTangent(), 21.5), + do_thing_2, + identity, + 32.1, ) end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) bconfig = Config() @test nothing !== frule( - bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + bconfig, + (NoTangent(), NoTangent(), 21.5), + do_thing_2, + identity, + 32.1, ) @test bconfig.forward_calls == [(identity, (32.1,))] end @@ -98,11 +114,15 @@ end @testset "hitting reverse AD" begin do_thing_3(f, x) = f(x) function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x + config::RuleConfig{>:HasReverseMode}, + ::typeof(do_thing_3), + f, + x, ) return (NoTangent(), rrule_via_ad(config, f, x)...) end + @testset "$Config" for Config in (MostBoringConfig, MockForwardsConfig) @test nothing === rrule(Config(), do_thing_3, identity, 32.1) end @@ -160,28 +180,28 @@ end end @testset "fallbacks" begin - no_rule(x; kw="bye") = error() + no_rule(x; kw = "bye") = error() @test frule((1.0,), no_rule, 2.0) === nothing - @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule((1.0,), no_rule, 2.0; kw = "hello") === nothing @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing - @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw = "hello") === nothing @test rrule(no_rule, 2.0) === nothing - @test rrule(no_rule, 2.0; kw="hello") === nothing + @test rrule(no_rule, 2.0; kw = "hello") === nothing @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing - @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw = "hello") === nothing # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(; kw="hello") + @test_throws MethodError frule(; kw = "hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin; kw="hello") + @test_throws MethodError frule(sin; kw = "hello") @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(); kw = "hello") @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin; kw = "hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(; kw="hello") + @test_throws MethodError rrule(; kw = "hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig(); kw="hello") + @test_throws MethodError rrule(MostBoringConfig(); kw = "hello") end end diff --git a/test/projection.jl b/test/projection.jl index cbfdcf6da..ab418ef79 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -30,14 +30,14 @@ struct NoSuperType end # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) - @test ProjectTo(1//2)(3//4) === 3//4 + @test ProjectTo(1 // 2)(3 // 4) === 3 // 4 @test ProjectTo(1.0f0)(1 / 2) === 0.5f0 @test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im @test ProjectTo(big(1.0))(2) === 2 @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(; re=1, im=NoTangent())) === + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re = 1, im = NoTangent())) === 1.0f0 + 0.0f0im end @@ -52,7 +52,7 @@ struct NoSuperType end @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent - @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent + @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value = 1.0)) isa Tangent end @testset "Base: arrays of numbers" begin @@ -102,7 +102,7 @@ struct NoSuperType end @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. - @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) + @test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im) @test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent} # empty arrays @@ -126,18 +126,18 @@ struct NoSuperType end @testset "Base: Ref" begin pref = ProjectTo(Ref(2.0)) @test pref(Ref(3 + im)).x === 3.0 - @test pref(Tangent{Base.RefValue}(; x=3 + im)).x === 3.0 + @test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0 @test pref(4).x === 4.0 # also re-wraps scalars @test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue} pref2 = ProjectTo(Ref{Any}(6 + 7im)) @test pref2(Ref(8)).x === 8.0 + 0.0im - @test pref2(Tangent{Base.RefValue}(; x=8)).x === 8.0 + 0.0im + @test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents @test prefvec(Ref(1:3)).x isa Vector{ComplexF64} - @test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64} - @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5)) + @test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64} + @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5)) @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} @@ -172,7 +172,7 @@ struct NoSuperType end # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called - xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) + xs = adj(Any[Any[1, 2, 3], Any[4+im, 5-im, 6+im, 7-im]]) pvecvec3 = ProjectTo(xs) @test pvecvec3(xs)[1] == [1 2 3] @test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im]) @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()) + dx = Tangent{typeof(x)}(; step = 0.1, ref = NoTangent()) @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(; x=3 + 4im) + dy = Tangent{typeof(Ref(2.0))}(x = 3 + 4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index f4be16218..e99b66c2f 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -37,13 +37,13 @@ struct NonDiffCounterExample end module NonDiffModuleExample -nondiff_2_1(x, y) = fill(7.5, 100)[x + y] +nondiff_2_1(x, y) = fill(7.5, 100)[x+y] end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] + nondiff_2_1(x, y) = fill(7.5, 100)[x+y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((ZeroTangent(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, NoTangent()) res, pullback = rrule(nondiff_2_1, 3, 2) @@ -93,7 +93,7 @@ end end @testset "kwargs" begin - kw_demo(x; kw=2.0) = x + kw + kw_demo(x; kw = 2.0) = x + kw @non_differentiable kw_demo(::Any) @testset "not setting kw" begin @@ -107,13 +107,13 @@ end end @testset "setting kw" begin - @assert kw_demo(1.5; kw=3.0) == 4.5 + @assert kw_demo(1.5; kw = 3.0) == 4.5 - res, pullback = rrule(kw_demo, 1.5; kw=3.0) + res, pullback = rrule(kw_demo, 1.5; kw = 3.0) @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw = 3.0) == (4.5, NoTangent()) end end @@ -196,9 +196,9 @@ end end @testset "Functors" begin - (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] + (f::NonDiffExample)(y) = fill(7.5, 100)[f.x+y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(; x=1.2), 2.3), NonDiffExample(3), 2) == + @test frule((Tangent{NonDiffExample}(x = 1.2), 2.3), NonDiffExample(3), 2) == (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @@ -208,7 +208,10 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) @test frule( - (ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2 + (ZeroTangent(), 1.2, 2.3), + NonDiffModuleExample.nondiff_2_1, + 3, + 2, ) == (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @@ -261,7 +264,7 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 # Symptom of these problems is creation of global variables and type instability - num_globals_before = length(names(ChainRulesCore; all=true)) + num_globals_before = length(names(ChainRulesCore; all = true)) simo2(x) = (x, 2x) @scalar_rule(simo2(x), 1.0, 2.0) @@ -270,7 +273,7 @@ end @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) # Test no new globals were created - @test length(names(ChainRulesCore; all=true)) == num_globals_before + @test length(names(ChainRulesCore; all = true)) == num_globals_before # Example in #265 simo3(x) = sincos(x) @@ -281,6 +284,7 @@ end end end + module IsolatedModuleForTestingScoping # check that rules can be defined by macros without any additional imports using ChainRulesCore: @scalar_rule, @non_differentiable @@ -319,7 +323,7 @@ using Test @test f_pullback(randn()) === (NoTangent(), NoTangent()) end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + y, f_pullback = rrule(fixed_kwargs, randn(); keyword = randn()) @test y === :abc @test f_pullback(randn()) === (NoTangent(), NoTangent()) end diff --git a/test/rules.jl b/test/rules.jl index 54c10b160..267b23005 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,7 +28,11 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any,Any,Any,Vararg}, ::typeof(mixed_vararg), x, y, z... + dargs::Tuple{Any,Any,Any,Vararg}, + ::typeof(mixed_vararg), + x, + y, + z..., ) Δx = dargs[2] Δy = dargs[3] @@ -38,7 +42,10 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any,Int,Float64}, ::typeof(type_constraints), x::Int, y::Float64 + (_, Δx, Δy)::Tuple{Any,Int,Float64}, + ::typeof(type_constraints), + x::Int, + y::Float64, ) return type_constraints(x, y), Δx + Δy end @@ -66,9 +73,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = ZeroTangent() @test frule((dself, 1), cool, 1) === nothing - @test frule((dself, 1), cool, 1; iscool=true) === nothing + @test frule((dself, 1), cool, 1; iscool = true) === nothing @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool=true) === nothing + @test rrule(cool, 1; iscool = true) === nothing # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @@ -78,7 +85,8 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) only_methods = Set([ - Tuple{typeof(rrule),typeof(cool),Number}, Tuple{typeof(rrule),typeof(cool),String} + Tuple{typeof(rrule),typeof(cool),Number}, + Tuple{typeof(rrule),typeof(cool),String}, ]) @test cool_methods == only_methods @@ -96,6 +104,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) rrx, nice_pullback = rrule(nice, 1) @test (NoTangent(), ZeroTangent()) === nice_pullback(1) + # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) @@ -107,7 +116,12 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test( frule( - (nothing, 3.0, 2.0, 1.0, 0.0), mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0 + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, + 3.0, + 2.0, + 1.0, + 0.0, ) == (6.0, 6.0) ) @@ -150,6 +164,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test_skip ∂xr ≈ real(∂x) end + @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) @@ -162,11 +177,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) @test rrule(first_oa, 3.0f0, 4.0f0) === nothing - @test !isempty( - Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} - end, - ) + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} + end) end @testset "frule" begin diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index f8222d942..fdbb92f55 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -9,7 +9,7 @@ @test view(NoTangent(), 1, 2) == NoTangent() @test sum(ZeroTangent()) == ZeroTangent() - @test sum(NoTangent(); dims=2) == NoTangent() + @test sum(NoTangent(); dims = 2) == NoTangent() end @testset "ZeroTangent" begin diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 2fd337979..2b7c6347e 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -1,10 +1,14 @@ @testset "NotImplemented" begin @testset "NotImplemented" begin ni = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" + @__MODULE__, + LineNumberNode(@__LINE__, @__FILE__), + "error", ) ni2 = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" + @__MODULE__, + LineNumberNode(@__LINE__, @__FILE__), + "error2", ) x = rand() thunk = @thunk(x^2) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index a9f022920..cc24d988e 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -24,9 +24,9 @@ end end @testset "==" begin - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) - @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) + @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(x = 0.1, y = 2.5) + @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(y = 2.5, x = 0.1) + @test Tangent{Foo}(y = 2.5, x = ZeroTangent()) == Tangent{Foo}(y = 2.5) @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) @@ -35,22 +35,22 @@ end @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) + @test Tangent{Foo}(; y = 2.0) == Tangent{Foo}(; x = ZeroTangent(), y = Float32(2.0)) end @testset "hash" begin - @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) - @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) + @test hash(Tangent{Foo}(x = 0.1, y = 2.5)) == hash(Tangent{Foo}(y = 2.5, x = 0.1)) + @test hash(Tangent{Foo}(y = 2.5, x = ZeroTangent())) == hash(Tangent{Foo}(y = 2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(; x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(; x=2.5), :x) == true + @test keys(Tangent{Foo}(x = 2.5)) == (:x,) + @test propertynames(Tangent{Foo}(x = 2.5)) == (:x,) + @test haskey(Tangent{Foo}(x = 2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false + @test hasproperty(Tangent{Foo}(x = 2.5), :y) == false end - @test Tangent{Foo}(; x=2.5).x == 2.5 + @test Tangent{Foo}(x = 2.5).x == 2.5 @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) @@ -60,28 +60,28 @@ end @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(; x=2.5)) == 1 + @test length(Tangent{Foo}(x = 2.5)) == 1 @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - @test eltype(Tangent{Foo}(; x=2.5)) == Float64 + @test eltype(Tangent{Foo}(x = 2.5)) == Float64 @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(; x=2.5)) == [2.5] + @test collect(Tangent{Foo}(x = 2.5)) == [2.5] @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] # Test indexed_iterate @@ -96,8 +96,8 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(x = 2, y = 3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(y = 3.0)) end end @@ -107,7 +107,7 @@ end @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) + @test_throws MethodError reverse(Tangent{Foo}(; x = 1.0, y = 2.0)) d = Dict(:x => 1, :y => 2.0) cdict = Tangent{Foo,typeof(d)}(d) @@ -115,13 +115,14 @@ end end @testset "unset properties" begin - @test Tangent{Foo}(; x=1.4).y === ZeroTangent() + @test Tangent{Foo}(; x = 1.4).y === ZeroTangent() end @testset "conj" begin - @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) + @test conj(Tangent{Foo}(x = 2.0 + 3.0im)) == Tangent{Foo}(x = 2.0 - 3.0im) @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), + Tangent{Tuple{Float64}}(2.0 - 3.0im), ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -137,14 +138,14 @@ end # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) + @test canonicalize(CFoo(x = 2.5, y = 10)) == CFoo(x = 2.5, y = 10) + @test canonicalize(CFoo(y = 10, x = 2.5)) == CFoo(x = 2.5, y = 10) + @test canonicalize(CFoo(y = 10)) == CFoo(x = ZeroTangent(), y = 10) - @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(q = 99.0, x = 2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a=1, b=2) + c1 = Tangent{Any}(; a = 1, b = 2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -157,14 +158,15 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) - @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) - @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) + @test CFoo(x = 1.5) + CFoo(x = 2.5) == CFoo(x = 4.0) + @test CFoo(y = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 2.5) + @test CFoo(y = 1.5, x = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 4.0) end @testset "Tuples" begin @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), + Tangent{Tuple{},Tuple{}}, ) @test ( Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + @@ -173,9 +175,9 @@ end end @testset "NamedTuples" begin - nt1 = (; a=1.5, b=0.0) - nt2 = (; a=0.0, b=2.5) - nt_sum = (a=1.5, b=2.5) + nt1 = (; a = 1.5, b = 0.0) + nt2 = (; a = 0.0, b = 2.5) + nt_sum = (a = 1.5, b = 2.5) @test (Tangent{typeof(nt1)}(; nt1...) + Tangent{typeof(nt2)}(; nt2...)) == Tangent{typeof(nt_sum)}(; nt_sum...) end @@ -189,8 +191,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(; x=1.5) - b = CFoo(; x=@not_implemented("")) + a = CFoo(x = 1.5) + b = CFoo(x = @not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -205,8 +207,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(; x=1.5) - b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(x = 1.5) + b = Tangent{NamedTuple{(:x,)}}(x = @not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -225,9 +227,9 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 + @test Foo(3.5, 1.5) + Tangent{Foo}(x = 2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(x = 2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x = 0.5)) == 0 end @testset "Tuples" begin @@ -237,11 +239,11 @@ end end @testset "NamedTuple" begin - ntx = (; a=1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + ntx = (; a = 1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a = 3.0) - nty = (; a=1.5, b=0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) + nty = (; a = 1.5, b = 0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a = 3.0, b = 1.0) end @testset "Dicts" begin @@ -253,7 +255,7 @@ end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) + diff = Tangent{StructWithInvariant}(x = 2.0, x2 = 6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -266,6 +268,7 @@ end ChainRulesCore.debug_mode() = false # disable it again end + # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) @@ -277,7 +280,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(; y=1.5, x=2.5) + c = Tangent{Foo}(y = 1.5, x = 2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -299,9 +302,9 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(; y=1.5, x=2.5) == - Tangent{Foo}(; y=3.0, x=5.0) == - Tangent{Foo}(; y=1.5, x=2.5) * 2 + 2 * Tangent{Foo}(y = 1.5, x = 2.5) == + Tangent{Foo}(y = 3.0, x = 5.0) == + Tangent{Foo}(y = 1.5, x = 2.5) * 2 ) @test ( 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == @@ -314,7 +317,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(x = 1)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -331,7 +334,7 @@ end end @testset "Internals don't allocate a ton" begin - bk = (; x=1.0, y=2.0) + bk = (; x = 1.0, y = 2.0) VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 @@ -342,8 +345,8 @@ end end @testset "non-same-typed differential arithmetic" begin - nt = (; a=1, b=2.0) - c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1) + nt = (; a = 1, b = 2.0) + c = Tangent{typeof(nt)}(; a = NoTangent(), b = 0.1) + @test nt + c == (; a = 1, b = 2.1) end end From a873a30a5405ec3d9c871aa59dc45eac9fdcc5fd Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 13:48:09 +0300 Subject: [PATCH 06/20] Revert "format(".")" This reverts commit 2783a90850d9d1ae6d2a9c5eaca0ec63eb84f404. --- docs/make.jl | 29 ++-- docs/src/assets/make_logo.jl | 48 +++--- src/accumulation.jl | 8 +- src/compat.jl | 4 +- src/deprecated.jl | 1 - src/ignore_derivatives.jl | 8 +- src/projection.jl | 85 +++++------ src/rule_definition_tools.jl | 84 ++++------- src/tangent_arithmetic.jl | 14 +- src/tangent_types/abstract_zero.jl | 8 +- src/tangent_types/notimplemented.jl | 10 +- src/tangent_types/tangent.jl | 113 +++++++-------- src/tangent_types/thunks.jl | 16 +- test/accumulation.jl | 24 +-- test/config.jl | 76 ++++------ test/deprecated.jl | 1 - test/ignore_derivatives.jl | 8 +- test/projection.jl | 40 ++--- test/rule_definition_tools.jl | 175 +++++++++++----------- test/rules.jl | 75 ++++------ test/tangent_types/abstract_zero.jl | 4 +- test/tangent_types/notimplemented.jl | 8 +- test/tangent_types/tangent.jl | 209 ++++++++++++++------------- test/tangent_types/thunks.jl | 2 +- 24 files changed, 483 insertions(+), 567 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 608422c25..1ef3a62a7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,20 +16,20 @@ DocMeta.setdocmeta!( @scalar_rule(sin(x), cos(x)) # frule and rrule doctest @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end, + end ) indigo = DocThemeIndigo.install(ChainRulesCore) makedocs( - modules = [ChainRulesCore], - format = Documenter.HTML( - prettyurls = false, - assets = [indigo], - mathengine = MathJax3( + modules=[ChainRulesCore], + format=Documenter.HTML( + prettyurls=false, + assets=[indigo], + mathengine=MathJax3( Dict( :tex => Dict( - "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], + "inlineMath" => [["\$","\$"], ["\\(","\\)"]], "tags" => "ams", # TODO: remove when using physics package "macros" => Dict( @@ -42,9 +42,9 @@ makedocs( ), ), ), - sitename = "ChainRules", - authors = "Jarrett Revels and other contributors", - pages = [ + sitename="ChainRules", + authors="Jarrett Revels and other contributors", + pages=[ "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", @@ -63,8 +63,11 @@ makedocs( ], "API" => "api.md", ], - strict = true, - checkdocs = :exports, + strict=true, + checkdocs=:exports, ) -deploydocs(repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview = true) +deploydocs( + repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", + push_preview=true, +) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 3e7aeaa08..5bbfd36c1 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -8,34 +8,34 @@ using Random const bridge_len = 50 -function chain(jiggle = 0) - shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) - +function chain(jiggle=0) + shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) + ### 1 shaky_rotate(0) sethue(Luxor.julia_red) link() m1 = getmatrix() - - + + ### 2 sethue(Luxor.julia_green) - translate(-50, 130) - shaky_rotate(π / 3) + translate(-50, 130); + shaky_rotate(π/3); link() m2 = getmatrix() - + setmatrix(m1) sethue(Luxor.julia_red) overlap(-1.3π) setmatrix(m2) - + ### 3 - shaky_rotate(-π / 3) - translate(-120, 80) + shaky_rotate(-π/3); + translate(-120,80); sethue(Luxor.julia_purple) link() - + setmatrix(m2) setcolor(Luxor.julia_green) overlap(-1.5π) @@ -45,24 +45,24 @@ end function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - - - rect(50, -3, 40, bridge_len + 6, :fill) - rect(-50 - 40, -3, 40, bridge_len + 6, :fill) - + + + rect(50,-3,40, bridge_len+6, :fill) + rect(-50-40,-3,40, bridge_len+6, :fill) + sethue("black") move(Point(-50, bridge_len)) - arc(Point(0, 0), 50, π, 0, :stoke) + arc(Point(0,0), 50, π, 0, :stoke) arc(Point(0, bridge_len), 50, 0, -π, :stroke) - + move(Point(-90, bridge_len)) - arc(Point(0, 0), 90, π, 0, :stoke) + arc(Point(0,0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) strokepath() end function overlap(ang_end) - sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill) + sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) sethue("black") arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) move(Point(90, bridge_len)) @@ -75,13 +75,13 @@ end function save_logo(filename) Random.seed!(16) - Drawing(450, 450, filename) + Drawing(450,450, filename) origin() - translate(50, -130) + translate(50, -130); chain(0.5) finish() preview() end save_logo("logo.svg") -save_logo("logo.png") +save_logo("logo.png") \ No newline at end of file diff --git a/src/accumulation.jl b/src/accumulation.jl index 5fbc07fa8..4bcc5c33f 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,7 +26,7 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) -function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} +function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N return if is_inplaceable_destination(x) x .+= y else @@ -75,8 +75,8 @@ end struct BadInplaceException <: Exception ithunk::InplaceableThunk - accumuland::Any - returned_value::Any + accumuland + returned_value end function Base.showerror(io::IO, err::BadInplaceException) @@ -88,7 +88,7 @@ function Base.showerror(io::IO, err::BadInplaceException) if err.accumuland == err.returned_value println( io, - "Which in this case happenned to be equal. But they are not the same object.", + "Which in this case happenned to be equal. But they are not the same object." ) end end diff --git a/src/compat.jl b/src/compat.jl index fa66b1d0f..8204b66d5 100644 --- a/src/compat.jl +++ b/src/compat.jl @@ -5,7 +5,7 @@ end if VERSION < v"1.1" # Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4 # See: https://github.com/JuliaLang/julia/issues/34292 - function fieldtypes(::Type{T}) where {T} + function fieldtypes(::Type{T}) where T if @generated ntuple(i -> fieldtype(T, i), fieldcount(T)) else @@ -13,7 +13,7 @@ if VERSION < v"1.1" end end - function fieldnames(::Type{T}) where {T} + function fieldnames(::Type{T}) where T if @generated ntuple(i -> fieldname(T, i), fieldcount(T)) else diff --git a/src/deprecated.jl b/src/deprecated.jl index 8b1378917..e69de29bb 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1 +0,0 @@ - diff --git a/src/ignore_derivatives.jl b/src/ignore_derivatives.jl index 18865f2c9..c66d89d7e 100644 --- a/src/ignore_derivatives.jl +++ b/src/ignore_derivatives.jl @@ -45,9 +45,7 @@ ignore_derivatives(x) = x Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`. """ macro ignore_derivatives(ex) - return :( - ChainRulesCore.ignore_derivatives() do - $(esc(ex)) - end - ) + return :(ChainRulesCore.ignore_derivatives() do + $(esc(ex)) + end) end diff --git a/src/projection.jl b/src/projection.jl index 55f6e7bfd..4b07b2762 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -131,8 +131,7 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} = - ProjectTo{NoTangent}() +ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}() # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -165,14 +164,12 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) end # In these cases we can just `convert` as we know we are dealing with plain and simple types -(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) -(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity +(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx) +(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = - convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = - convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -193,7 +190,7 @@ end # For arrays of numbers, just store one projector: function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element = _eltype_projectto(T), axes = axes(x)) + return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -207,7 +204,7 @@ function ProjectTo(xs::AbstractArray) return ProjectTo{NoTangent}() # short-circuit if all elements project to zero else # Arrays of arrays come here, and will apply projectors individually: - return ProjectTo{AbstractArray}(; elements = elements, axes = axes(xs)) + return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs)) end end @@ -217,7 +214,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} dy = if axes(dx) == project.axes dx else - for d = 1:max(M, length(project.axes)) + for d in 1:max(M, length(project.axes)) if size(dx, d) != length(get(project.axes, d, 1)) throw(_projection_mismatch(project.axes, size(dx))) end @@ -247,11 +244,9 @@ end # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers if !(project.axes isa Tuple{}) - throw( - DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - ), - ) + throw(DimensionMismatch( + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", + )) end return fill(project.element(dx)) end @@ -259,7 +254,7 @@ end function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) return DimensionMismatch( - "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx", + "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx" ) end @@ -273,13 +268,13 @@ function ProjectTo(x::Ref) if sub isa ProjectTo{<:AbstractZero} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type = typeof(x), x = sub) + return ProjectTo{Ref}(; type=typeof(x), x=sub) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x = project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x = project.x(dx[])) +(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) +(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) # Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x = project.x(dx)) +(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) ##### ##### `LinearAlgebra` @@ -288,7 +283,7 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Row vectors -ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent = ProjectTo(parent(x))) +ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x))) # Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec. # Transposed matrices are, like PermutedDimsArray, just a storage detail, # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number @@ -303,8 +298,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = - ProjectTo{Transpose}(; parent = ProjectTo(parent(x))) +ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -317,22 +311,21 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) end # Diagonal -ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag = ProjectTo(x.diag)) +ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) # Symmetric -for (SymHerm, chk, fun) in - ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) +for (SymHerm, chk, fun) in ( + (:Symmetric, :issymmetric, :transpose), + (:Hermitian, :ishermitian, :adjoint), + ) @eval begin function ProjectTo(x::$SymHerm) sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; - uplo = LinearAlgebra.sym_uplo(x.uplo), - parent = sub, - ) + return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) dy = project.parent(dx) @@ -345,8 +338,9 @@ for (SymHerm, chk, fun) in # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) sub = project.parent # this is going to be unhappy about the size - sub_one = - ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) + sub_one = ProjectTo{project_type(sub)}(; + element=sub.element, axes=(sub.axes[1],) + ) return Diagonal(sub_one(dx.diag)) end end @@ -355,12 +349,13 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent = ProjectTo(parent(x))) + ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent - sub_one = - ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) + sub_one = ProjectTo{project_type(sub)}(; + element=sub.element, axes=(sub.axes[1],) + ) return Diagonal(sub_one(dx.diag)) end end @@ -397,7 +392,7 @@ end # another strategy is just to use the AbstractArray method function ProjectTo(x::Tridiagonal{T}) where {T<:Number} notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent = notparent) + return ProjectTo{Tridiagonal}(; notparent=notparent) end function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) dy = project.notparent(dx) @@ -416,9 +411,7 @@ using SparseArrays function ProjectTo(x::SparseVector{T}) where {T<:Number} return ProjectTo{SparseVector}(; - element = ProjectTo(zero(T)), - nzind = x.nzind, - axes = axes(x), + element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) ) end function (project::ProjectTo{SparseVector})(dx::AbstractArray) @@ -457,11 +450,11 @@ end function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} return ProjectTo{SparseMatrixCSC}(; - element = ProjectTo(zero(T)), - axes = axes(x), - rowval = rowvals(x), - nzranges = nzrange.(Ref(x), axes(x, 2)), - colptr = x.colptr, + element=ProjectTo(zero(T)), + axes=axes(x), + rowval=rowvals(x), + nzranges=nzrange.(Ref(x), axes(x, 2)), + colptr=x.colptr, ) end # You need not really store nzranges, you can get them from colptr -- TODO @@ -481,7 +474,7 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) for i in project.nzranges[col] row = project.rowval[i] val = dy[row, col] - nzval[k+=1] = project.element(val) + nzval[k += 1] = project.element(val) end end m, n = map(length, project.axes) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 8a1e1cce4..911a32ddd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -83,8 +83,9 @@ For examples, see ChainRules' `rulesets` directory. See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) - call, setup_stmts, inputs, partials = - _normalize_scalarrules_macro_input(call, maybe_setup, partials) + call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( + call, maybe_setup, partials + ) f = call.args[1] # Generate variables to store derivatives named dfi/dxj @@ -100,11 +101,9 @@ macro scalar_rule(call, maybe_setup, partials...) # Final return: building the expression to insert in the place of this macro code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw( - ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))", - ), - ) + throw(ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))" + )) end $(derivative_expr) @@ -176,11 +175,7 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote - function ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), - ::Core.Typeof($f), - $(inputs...), - ) + function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) @@ -201,8 +196,9 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) end if n_outputs > 1 # For forward-mode we return a Tangent if output actually a tuple. - pushforward_returns = - Expr(:call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns...) + pushforward_returns = Expr( + :call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns... + ) else pushforward_returns = first(pushforward_returns) end @@ -214,8 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = - ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -230,7 +225,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) Δs = _propagator_inputs(n_outputs) # Make a projector for each argument - projs, psetup = _make_projectors(call.args[2:end]) + projs, psetup = _make_projectors(call.args[2:end]) append!(setup_stmts, psetup) # 1 partial derivative per input @@ -253,8 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = - ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -263,12 +257,12 @@ end # For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" -_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i = 1:n] +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) projs = map(x -> Symbol(:proj_, x.args[1]), xs) - setups = map((x, p) -> :($p = ProjectTo($x)), xs, projs) + setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs) return projs, setups end @@ -281,7 +275,7 @@ Specify `_conj = true` to conjugate the partials. Projector `proj` is a function that will be applied at the end; for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity` """ -function propagation_expr(Δs, ∂s, _conj = false, proj = identity) +function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj @@ -294,10 +288,9 @@ function propagation_expr(Δs, ∂s, _conj = false, proj = identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = - foldl(Iterators.drop(zip(_∂s, Δs), 1); init = init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) - end + summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) + :(muladd($∂s_i, $Δs_i, $ex)) + end return :($proj($summed_∂_mul_Δs)) end @@ -388,10 +381,7 @@ end function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, - call_expr.args[1], - Expr(:parameters, :($(kwargs)...)), - call_expr.args[2:end]..., + :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... ) end @@ -399,18 +389,11 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.frule)))( - @nospecialize($kwargs::Any), - frule::typeof(ChainRulesCore.frule), - @nospecialize(::Any), - $(map(esc, primal_sig_parts)...), - ) + function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end - function ChainRulesCore.frule( - @nospecialize(::Any), - $(map(esc, primal_sig_parts)...), - ) + function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() return ($(esc(primal_invoke)), NoTangent()) @@ -425,8 +408,7 @@ function tuple_expression(primal_sig_parts) Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = - :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) end end @@ -444,11 +426,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))( - $(esc(kwargs))::Any, - ::typeof(rrule), - $(esc_primal_sig_parts...), - ) + function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) @@ -503,7 +481,7 @@ end "Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) - length(expr.args) === 0 && error("Malformed method expression. $expr") + length(expr.args)===0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore @@ -577,13 +555,12 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || - Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) + if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || + Meta.isexpr(primal_name, :curly) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name - # e.g. (::T)(x, y) + # e.g. (::T)(x, y) elseif Meta.isexpr(primal_name, :(::)) _primal_name = gensym(Symbol(:instance_, primal_name.args[end])) primal_name_sig = Expr(:(::), _primal_name, primal_name.args[end]) @@ -605,8 +582,7 @@ end function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && - return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) + Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index c2bad7a77..9c1378aab 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -81,7 +81,7 @@ LinearAlgebra.dot(::ZeroTangent, ::NoTangent) = ZeroTangent() Base.muladd(::ZeroTangent, x, y) = y Base.muladd(x, ::ZeroTangent, y) = y -Base.muladd(x, y, ::ZeroTangent) = x * y +Base.muladd(x, y, ::ZeroTangent) = x*y Base.muladd(::ZeroTangent, ::ZeroTangent, y) = y Base.muladd(x, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() @@ -125,11 +125,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} +function Base.:+(a::Tangent{P}, b::Tangent{P}) where P data = elementwise_add(backing(a), backing(b)) - return Tangent{P,typeof(data)}(data) + return Tangent{P, typeof(data)}(data) end -function Base.:+(a::P, d::Tangent{P}) where {P} +function Base.:+(a::P, d::Tangent{P}) where P net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -142,12 +142,12 @@ function Base.:+(a::P, d::Tangent{P}) where {P} end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where {P} = b + a +Base.:+(a::Tangent{P}, b::P) where P = b + a # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 differentials # Only of a differential and a scaling factor (generally `Real`) for T in (:Any,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) + @eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent) + @eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index c86fc78ea..216357e91 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -17,15 +17,15 @@ Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) -Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() +Base.Broadcast.broadcasted(::Type{T}) where T<:AbstractZero = T() # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z -Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) -(::Type{T})(xs::AbstractZero...) where {T<:Number} = zero(T) +Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) +(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) @@ -33,7 +33,7 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) Base.getindex(z::AbstractZero, k) = z Base.view(z::AbstractZero, ind...) = z -Base.sum(z::AbstractZero; dims = :) = z +Base.sum(z::AbstractZero; dims=:) = z """ ZeroTangent() <: AbstractZero diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index 7ceb315ea..a2044fbe1 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,13 +44,9 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw( - NotImplementedException( - @not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" - ) - ), -) +Base.zero(::Type{<:NotImplemented}) = throw(NotImplementedException(@not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" +))) Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 34e822ea8..e4bbfb8c8 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -21,42 +21,42 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P,T} <: AbstractTangent +struct Tangent{P, T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T end -function Tangent{P}(; kwargs...) where {P} +function Tangent{P}(; kwargs...) where P backing = (; kwargs...) # construct as NamedTuple - return Tangent{P,typeof(backing)}(backing) + return Tangent{P, typeof(backing)}(backing) end -function Tangent{P}(args...) where {P} - return Tangent{P,typeof(args)}(args) +function Tangent{P}(args...) where P + return Tangent{P, typeof(args)}(args) end -function Tangent{P}() where {P<:Tuple} +function Tangent{P}() where P<:Tuple backing = () - return Tangent{P,typeof(backing)}(backing) + return Tangent{P, typeof(backing)}(backing) end function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P,typeof(d)}(d) + return Tangent{P, typeof(d)}(d) end -function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} +function Base.:(==)(a::Tangent{P, T}, b::Tangent{P, T}) where {P, T} return backing(a) == backing(b) end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P,T} +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P, T} all_fields = union(keys(backing(a)), keys(backing(b))) return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where {P} +function Base.show(io::IO, tangent::Tangent{P}) where P print(io, "Tangent{") show(io, P) print(io, "}") @@ -68,15 +68,15 @@ function Base.show(io::IO, tangent::Tangent{P}) where {P} end end -function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} +function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} +function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end -function Base.getindex(tangent::Tangent, idx) where {P,T<:AbstractDict} +function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict} return unthunk(getindex(backing(tangent), idx)) end @@ -84,7 +84,7 @@ function Base.getproperty(tangent::Tangent, idx::Int) back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} +function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end @@ -99,26 +99,26 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) -Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) +Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) + Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing) end -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state = 1) where {P} +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end -function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} +function Base.map(f, tangent::Tangent{P, <:Tuple}) where P vals::Tuple = map(f, backing(tangent)) - return Tangent{P,typeof(vals)}(vals) + return Tangent{P, typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} +function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L} vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L,typeof(vals)}(vals) - return Tangent{P,typeof(named_vals)}(named_vals) + named_vals = NamedTuple{L, typeof(vals)}(vals) + return Tangent{P, typeof(named_vals)}(named_vals) end -function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} +function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end @@ -140,28 +140,26 @@ backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) # For generic structs -function backing(x::T)::NamedTuple where {T} +function backing(x::T)::NamedTuple where T # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. # see https://github.com/JuliaLang/julia/issues/34283 if @generated - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) + vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names, Tuple{$(types...)}}($vals)) else - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = ntuple(ii -> getfield(x, ii), nfields) - return NamedTuple{names,Tuple{types...}}(vals) + vals = ntuple(ii->getfield(x, ii), nfields) + return NamedTuple{names, Tuple{types...}}(vals) end end @@ -172,38 +170,36 @@ Return the canonical `Tangent` for the primal type `P`. The property names of the returned `Tangent` match the field names of the primal, and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. """ -function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} +function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L} nil = _zeroed_backing(P) combined = merge(nil, backing(tangent)) if length(combined) !== fieldcount(P) - throw( - ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", - ), - ) + throw(ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))" + )) end - return Tangent{P,typeof(combined)}(combined) + return Tangent{P, typeof(combined)}(combined) end # Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent +canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent # Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent +canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent # Tangents of unspecified primal types (indicated by specifying exactly `Any`) # all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any,<:Tuple}) where {L} = tangent -canonicalize(tangent::Tangent{Any,<:AbstractDict}) where {L} = tangent +canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent +canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent """ _zeroed_backing(P) Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. """ -@generated function _zeroed_backing(::Type{P}) where {P} +@generated function _zeroed_backing(::Type{P}) where P nil_base = ntuple(fieldcount(P)) do i (fieldname(P, i), ZeroTangent()) end @@ -222,7 +218,7 @@ after an operation such as the addition of a primal to a tangent It should be overloaded, if `T` does not have a default constructor, or if `T` needs to maintain some invarients between its fields. """ -function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} +function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} # Tested and verified that that this avoids a ton of allocations if length(L) !== fieldcount(T) # if length is equal but names differ then we will catch that below anyway. @@ -237,12 +233,12 @@ function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} end end -construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields -construct(::Type{T}, fields::T) where {T<:Tuple} = fields +construct(::Type{T}, fields::T) where T<:NamedTuple = fields +construct(::Type{T}, fields::T) where T<:Tuple = fields elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} # Rule of Tangent addition: any fields not present are implict hard Zeros # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. @@ -285,7 +281,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} end field => value end - return (; vals...) + return (;vals...) end end @@ -301,16 +297,15 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color = :blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) println(io, "\nor overload") - printstyled( - io, + printstyled(io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color = :blue, + color=:blue ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color = :blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) println(io, "\nOriginal Exception:") - printstyled(io, err.original; color = :yellow) + printstyled(io, err.original; color=:yellow) println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index c2b570902..16384d69e 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -56,22 +56,18 @@ LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a)) LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a)) LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a)) LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a)) -LinearAlgebra.Symmetric(a::AbstractThunk, uplo = :U) = Symmetric(unthunk(a), uplo) -LinearAlgebra.Hermitian(a::AbstractThunk, uplo = :U) = Hermitian(unthunk(a), uplo) +LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo) +LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo) function LinearAlgebra.diagm( - kv::Pair{<:Integer,<:AbstractThunk}, - kvs::Pair{<:Integer,<:AbstractThunk}..., + kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...) end function LinearAlgebra.diagm( - m, - n, - kv::Pair{<:Integer,<:AbstractThunk}, - kvs::Pair{<:Integer,<:AbstractThunk}..., + m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) - return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) + return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a)) @@ -122,7 +118,7 @@ function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx) return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx) end -function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn = 1) +function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1) return throw(MutateThunkException()) end diff --git a/test/accumulation.jl b/test/accumulation.jl index a796b5289..1b41fea55 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -27,7 +27,7 @@ end @testset "misc AbstractTangent subtypes" begin - @test 16 == add!!(12, @thunk(2 * 2)) + @test 16 == add!!(12, @thunk(2*2)) @test 16 == add!!(16, ZeroTangent()) @test 16 == add!!(16, NoTangent()) # Should this be an error? @@ -37,15 +37,15 @@ @testset "LHS Array (inplace)" begin @testset "RHS Array" begin A = [1.0 2.0; 3.0 4.0] - accumuland = -1.0 * ones(2, 2) + accumuland = -1.0*ones(2,2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] end @testset "RHS StaticArray" begin - A = @SMatrix [1.0 2.0; 3.0 4.0] - accumuland = -1.0 * ones(2, 2) + A = @SMatrix[1.0 2.0; 3.0 4.0] + accumuland = -1.0*ones(2,2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] @@ -53,7 +53,7 @@ @testset "RHS Diagonal" begin A = Diagonal([1.0, 2.0]) - accumuland = -1.0 * ones(2, 2) + accumuland = -1.0*ones(2,2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 -1.0; -1.0 1.0] @@ -79,17 +79,17 @@ @testset "Unhappy Path" begin # wrong length - @test_throws DimensionMismatch add!!(ones(4, 4), ones(2, 2)) + @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) # wrong shape - @test_throws DimensionMismatch add!!(ones(4, 4), ones(16)) + @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) # wrong type (adding scalar to array) @test_throws MethodError add!!(ones(4), 21.0) end end @testset "AbstractThunk $(typeof(thunk))" for thunk in ( - @thunk(-1.0 * ones(2, 2)), - InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), + @thunk(-1.0*ones(2, 2)), + InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))), ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -111,12 +111,12 @@ @testset "not actually inplace but said it was" begin # thunk should never be used in this test ithunk = InplaceableThunk(@thunk(@assert false)) do x - 77 * ones(2, 2) # not actually inplace (also wrong) + 77*ones(2, 2) # not actually inplace (also wrong) end accumuland = ones(2, 2) @assert ChainRulesCore.debug_mode() == false # without debug being enabled should return the result, not error - @test 77 * ones(2, 2) == add!!(accumuland, ithunk) + @test 77*ones(2, 2) == add!!(accumuland, ithunk) ChainRulesCore.debug_mode() = true # enable debug mode # with debug being enabled should error @@ -127,7 +127,7 @@ @testset "showerror BadInplaceException" begin BadInplaceException = ChainRulesCore.BadInplaceException - ithunk = InplaceableThunk(x̄ -> nothing, @thunk(@assert false)) + ithunk = InplaceableThunk(x̄->nothing, @thunk(@assert false)) msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) @test occursin("22", msg) diff --git a/test/config.jl b/test/config.jl index e6e2ab005..466baed9a 100644 --- a/test/config.jl +++ b/test/config.jl @@ -1,7 +1,7 @@ # Define a bunch of configs for testing purposes struct MostBoringConfig <: RuleConfig{Union{}} end -struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} forward_calls::Vector end MockForwardsConfig() = MockForwardsConfig([]) @@ -11,7 +11,7 @@ function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args. return f(args...; kws...), ȧrgs end -struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} reverse_calls::Vector end MockReverseConfig() = MockReverseConfig([]) @@ -23,7 +23,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. end -struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} forward_calls::Vector reverse_calls::Vector end @@ -47,21 +47,18 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, - MockForwardsConfig, - MockReverseConfig, - MockBothConfig, + MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, ) counting_id_count = Ref(0) function counting_id(x) - counting_id_count[] += 1 + counting_id_count[]+=1 return x end function ChainRulesCore.rrule(::typeof(counting_id), x) counting_id_pullback(x̄) = x̄ return counting_id(x), counting_id_pullback end - function ChainRulesCore.frule((dself, dx), ::typeof(counting_id), x) + function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) return counting_id(x), dx end @testset "rrule" begin @@ -79,33 +76,21 @@ end @testset "hitting forwards AD" begin do_thing_2(f, x) = f(x) function ChainRulesCore.frule( - config::RuleConfig{>:HasForwardsMode}, - (_, df, dx), - ::typeof(do_thing_2), - f, - x, + config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x ) return frule_via_ad(config, (df, dx), f, x) end @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) @test nothing === frule( - Config(), - (NoTangent(), NoTangent(), 21.5), - do_thing_2, - identity, - 32.1, + Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig = Config() + bconfig= Config() @test nothing !== frule( - bconfig, - (NoTangent(), NoTangent(), 21.5), - do_thing_2, - identity, - 32.1, + bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) @test bconfig.forward_calls == [(identity, (32.1,))] end @@ -114,10 +99,7 @@ end @testset "hitting reverse AD" begin do_thing_3(f, x) = f(x) function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, - ::typeof(do_thing_3), - f, - x, + config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x ) return (NoTangent(), rrule_via_ad(config, f, x)...) end @@ -128,7 +110,7 @@ end end @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) - bconfig = Config() + bconfig= Config() @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) @test bconfig.reverse_calls == [(identity, (32.1,))] end @@ -148,14 +130,14 @@ end ẋ = one(x) y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) - pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ * ȳ + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ return y, pullback_via_forwards_ad end function ChainRulesCore.rrule( - config::RuleConfig{>:Union{HasReverseMode,NoForwardsMode}}, + config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, ::typeof(do_thing_4), f, - x, + x ) y, f_pullback = rrule_via_ad(config, f, x) do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) @@ -165,43 +147,43 @@ end @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig = Config() + bconfig= Config() @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) @test bconfig.forward_calls == [(identity, (32.1,))] end - rconfig = MockReverseConfig() + rconfig= MockReverseConfig() @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) @test rconfig.reverse_calls == [(identity, (32.1,))] end @testset "RuleConfig broadcasts like a scaler" begin - @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} + @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} end @testset "fallbacks" begin - no_rule(x; kw = "bye") = error() + no_rule(x; kw="bye") = error() @test frule((1.0,), no_rule, 2.0) === nothing - @test frule((1.0,), no_rule, 2.0; kw = "hello") === nothing + @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing - @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw = "hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing @test rrule(no_rule, 2.0) === nothing - @test rrule(no_rule, 2.0; kw = "hello") === nothing + @test rrule(no_rule, 2.0; kw="hello") === nothing @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing - @test rrule(MostBoringConfig(), no_rule, 2.0; kw = "hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(; kw = "hello") + @test_throws MethodError frule(;kw="hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin; kw = "hello") + @test_throws MethodError frule(sin;kw="hello") @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw = "hello") + @test_throws MethodError frule(MostBoringConfig(); kw="hello") @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw = "hello") + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(; kw = "hello") + @test_throws MethodError rrule(;kw="hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig(); kw = "hello") + @test_throws MethodError rrule(MostBoringConfig();kw="hello") end end diff --git a/test/deprecated.jl b/test/deprecated.jl index 8b1378917..e69de29bb 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -1 +0,0 @@ - diff --git a/test/ignore_derivatives.jl b/test/ignore_derivatives.jl index ad4fece9f..825287b9a 100644 --- a/test/ignore_derivatives.jl +++ b/test/ignore_derivatives.jl @@ -7,7 +7,7 @@ end @testset "function" begin f() = return 4.0 - y, ẏ = frule((1.0,), ignore_derivatives, f) + y, ẏ = frule((1.0, ), ignore_derivatives, f) @test y == f() @test ẏ == NoTangent() @@ -19,7 +19,7 @@ end @testset "argument" begin arg = 2.1 - y, ẏ = frule((1.0,), ignore_derivatives, arg) + y, ẏ = frule((1.0, ), ignore_derivatives, arg) @test y == arg @test ẏ == NoTangent() @@ -41,11 +41,11 @@ end @test pb(1.0) == (NoTangent(), NoTangent()) # when called - y, ẏ = frule((1.0,), ignore_derivatives, () -> mf(3.0)) + y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0)) @test y == mf(3.0) @test ẏ == NoTangent() - y, pb = rrule(ignore_derivatives, () -> mf(3.0)) + y, pb = rrule(ignore_derivatives, ()->mf(3.0)) @test y == mf(3.0) @test pb(1.0) == (NoTangent(), NoTangent()) end diff --git a/test/projection.jl b/test/projection.jl index ab418ef79..ba61fb8da 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,21 +24,20 @@ struct NoSuperType end # real / complex @test ProjectTo(1.0)(2.0 + 3im) === 2.0 @test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im - @test ProjectTo(2.0 + 3.0im)(1 + 1im) === 1.0 + 1.0im - @test ProjectTo(2.0)(1 + 1im) === 1.0 - + @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im + @test ProjectTo(2.0)(1+1im) === 1.0 + # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) - @test ProjectTo(1 // 2)(3 // 4) === 3 // 4 + @test ProjectTo(1//2)(3//4) === 3//4 @test ProjectTo(1.0f0)(1 / 2) === 0.5f0 @test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im @test ProjectTo(big(1.0))(2) === 2 @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re = 1, im = NoTangent())) === - 1.0f0 + 0.0f0im + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im end @testset "Dual" begin # some weird Real subtype that we should basically leave alone @@ -47,12 +46,13 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} - @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa - Complex{<:Dual} + @test ProjectTo(1.0 + 1im)( + Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)) + ) isa Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent - @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value = 1.0)) isa Tangent + @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent end @testset "Base: arrays of numbers" begin @@ -99,10 +99,10 @@ struct NoSuperType end # arrays of other things @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} - @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} + @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. - @test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im) + @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) @test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent} # empty arrays @@ -172,7 +172,7 @@ struct NoSuperType end # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called - xs = adj(Any[Any[1, 2, 3], Any[4+im, 5-im, 6+im, 7-im]]) + xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) pvecvec3 = ProjectTo(xs) @test pvecvec3(xs)[1] == [1 2 3] @test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im]) @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step = 0.1, ref = NoTangent()) + dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()); @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3 + 4im) + dy = Tangent{typeof(Ref(2.0))}(x = 3+4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end @@ -365,21 +365,21 @@ struct NoSuperType end # Each "@test 33 > ..." is zero on nightly, 32 on 1.5. pvec = ProjectTo(rand(10^3)) - @test 0 == @ballocated $pvec(dx) setup = (dx = rand(10^3)) # pass through - @test 90 > @ballocated $pvec(dx) setup = (dx = rand(10^3, 1)) # reshape + @test 0 == @ballocated $pvec(dx) setup=(dx = rand(10^3)) # pass through + @test 90 > @ballocated $pvec(dx) setup=(dx = rand(10^3, 1)) # reshape @test 33 > @ballocated ProjectTo(x)(dx) setup = (x = rand(10^3); dx = rand(10^3)) # including construction padj = ProjectTo(adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup = (dx = adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup = (dx = transpose(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup=(dx = adjoint(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup=(dx = transpose(rand(10^3))) @test 33 > @ballocated ProjectTo(x')(dx') setup = (x = rand(10^3); dx = rand(10^3)) pdiag = ProjectTo(Diagonal(rand(10^3))) - @test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3))) + @test 0 == @ballocated $pdiag(dx) setup=(dx = Diagonal(rand(10^3))) psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) - @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 + @test_broken 0 == @ballocated $psymm(dx) setup=(dx = Symmetric(rand(10^3, 10^3))) # 64 end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index e99b66c2f..0d6d98535 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -19,7 +19,7 @@ macro test_macro_throws(err_expr, expr) end end # Reuse `@test_throws` logic - if err !== nothing + if err!==nothing @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) @@ -29,21 +29,21 @@ end # struct need to be defined outside of tests for julia 1.0 compat struct NonDiffExample - x::Any + x end struct NonDiffCounterExample - x::Any + x end module NonDiffModuleExample -nondiff_2_1(x, y) = fill(7.5, 100)[x+y] + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin - nondiff_2_1(x, y) = fill(7.5, 100)[x+y] + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((ZeroTangent(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, NoTangent()) res, pullback = rrule(nondiff_2_1, 3, 2) @@ -58,7 +58,7 @@ end res, pullback = rrule(nondiff_1_2, 3.1) @test res == (5.0, 3.0) @test isequal( - pullback(Tangent{Tuple{Float64,Float64}}(1.2, 3.2)), + pullback(Tangent{Tuple{Float64, Float64}}(1.2, 3.2)), (NoTangent(), NoTangent()), ) end @@ -81,8 +81,7 @@ end pointy_identity(x) = x @non_differentiable pointy_identity(::Vector{<:AbstractString}) - @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == - (["2"], NoTangent()) + @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == (["2"], NoTangent()) @test frule((ZeroTangent(), 1.2), pointy_identity, 2.0) == nothing res, pullback = rrule(pointy_identity, ["2"]) @@ -93,7 +92,7 @@ end end @testset "kwargs" begin - kw_demo(x; kw = 2.0) = x + kw + kw_demo(x; kw=2.0) = x + kw @non_differentiable kw_demo(::Any) @testset "not setting kw" begin @@ -107,14 +106,13 @@ end end @testset "setting kw" begin - @assert kw_demo(1.5; kw = 3.0) == 4.5 + @assert kw_demo(1.5; kw=3.0) == 4.5 - res, pullback = rrule(kw_demo, 1.5; kw = 3.0) + res, pullback = rrule(kw_demo, 1.5; kw=3.0) @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw = 3.0) == - (4.5, NoTangent()) + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) end end @@ -123,7 +121,7 @@ end @test isequal( frule((ZeroTangent(), 1.2), NonDiffExample, 2.0), - (NonDiffExample(2.0), NoTangent()), + (NonDiffExample(2.0), NoTangent()) ) res, pullback = rrule(NonDiffExample, 2.0) @@ -153,7 +151,7 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), NoTangent()) @test frule((1, 1), fvarargs, 1, 2) == nothing - @test rrule(fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @@ -196,10 +194,10 @@ end end @testset "Functors" begin - (f::NonDiffExample)(y) = fill(7.5, 100)[f.x+y] + (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x = 1.2), 2.3), NonDiffExample(3), 2) == - (7.5, NoTangent()) + @test frule((Tangent{NonDiffExample}(x=1.2), 2.3), NonDiffExample(3), 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent()) @@ -207,12 +205,8 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) - @test frule( - (ZeroTangent(), 1.2, 2.3), - NonDiffModuleExample.nondiff_2_1, - 3, - 2, - ) == (7.5, NoTangent()) + @test frule((ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) @@ -222,7 +216,7 @@ end # Where clauses are not supported. @test_macro_throws( ErrorException, - (@non_differentiable where_identity(::Vector{T}) where {T<:AbstractString}) + (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) ) end end @@ -230,33 +224,32 @@ end @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) - @scalar_rule(simo(x), 1.0f0, 2.0f0) + @scalar_rule(simo(x), 1f0, 2f0) y, simo_pb = rrule(simo, π) - @test simo_pb((10.0f0, 20.0f0)) == (NoTangent(), 50.0f0) + @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) - y, ẏ = frule((NoTangent(), 50.0f0), simo, π) + y, ẏ = frule((NoTangent(), 50f0), simo, π) @test y == (π, 2π) - @test ẏ == Tangent{typeof(y)}(50.0f0, 100.0f0) + @test ẏ == Tangent{typeof(y)}(50f0, 100f0) # make sure type is exactly as expected: - @test ẏ isa Tangent{Tuple{Irrational{:π},Float64},Tuple{Float32,Float32}} + @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} xs, Ω = (3,), (3, 6) - @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == - ((1.0f0,), (2.0f0,)) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) end @testset "@scalar_rule projection" begin - make_imaginary(x) = im * x + make_imaginary(x) = im*x @scalar_rule make_imaginary(x) im # note: the === will make sure that these are Float64, not ComplexF64 - @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0 * im) + @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im) @test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0) - @test (NoTangent(), 1.0 + 0.0im) === rrule(make_imaginary, 2.0im)[2](1.0 * im) - @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) + @test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im) + @test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end @testset "Regression tests against #276 and #265" begin @@ -264,16 +257,16 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 # Symptom of these problems is creation of global variables and type instability - num_globals_before = length(names(ChainRulesCore; all = true)) + num_globals_before = length(names(ChainRulesCore; all=true)) simo2(x) = (x, 2x) @scalar_rule(simo2(x), 1.0, 2.0) _, simo2_pb = rrule(simo2, 43.0) # make sure it infers: inferability implies type stability - @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) + @inferred simo2_pb(Tangent{Tuple{Float64, Float64}}(3.0, 6.0)) # Test no new globals were created - @test length(names(ChainRulesCore; all = true)) == num_globals_before + @test length(names(ChainRulesCore; all=true)) == num_globals_before # Example in #265 simo3(x) = sincos(x) @@ -286,60 +279,60 @@ end module IsolatedModuleForTestingScoping -# check that rules can be defined by macros without any additional imports -using ChainRulesCore: @scalar_rule, @non_differentiable - -# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved -const ChainRulesCore = nothing - -# this is -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 -fixed(x) = :abc -@non_differentiable fixed(x) - -# check name collision between a primal input called `kwargs` and the actual keyword -# arguments -fixed_kwargs(x; kwargs...) = :abc -@non_differentiable fixed_kwargs(kwargs) - -my_id(x) = x -@scalar_rule(my_id(x), 1.0) - -module IsolatedSubmodule -# check that rules defined in isolated module without imports can be called -# without errors -using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output -using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id -using Test - -@testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((ZeroTangent(), randn()), f, randn()) - @test y === :abc - @test ẏ === NoTangent() - - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end + # check that rules can be defined by macros without any additional imports + using ChainRulesCore: @scalar_rule, @non_differentiable + + # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved + const ChainRulesCore = nothing + + # this is + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 + fixed(x) = :abc + @non_differentiable fixed(x) + + # check name collision between a primal input called `kwargs` and the actual keyword + # arguments + fixed_kwargs(x; kwargs...) = :abc + @non_differentiable fixed_kwargs(kwargs) + + my_id(x) = x + @scalar_rule(my_id(x), 1.0) + + module IsolatedSubmodule + # check that rules defined in isolated module without imports can be called + # without errors + using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output + using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + using Test + + @testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((ZeroTangent(), randn()), f, randn()) + @test y === :abc + @test ẏ === NoTangent() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword = randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) -end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end -@testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ + @testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (NoTangent(), Δy) + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (NoTangent(), Δy) - @test derivatives_given_output(y, my_id, x) == ((1.0,),) -end -end + @test derivatives_given_output(y, my_id, x) == ((1.0,),) + end + end end diff --git a/test/rules.jl b/test/rules.jl index 267b23005..d43ca42d2 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,11 +28,8 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any,Any,Any,Vararg}, - ::typeof(mixed_vararg), - x, - y, - z..., + dargs::Tuple{Any, Any, Any, Vararg}, + ::typeof(mixed_vararg), x, y, z..., ) Δx = dargs[2] Δy = dargs[3] @@ -42,21 +39,16 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any,Int,Float64}, - ::typeof(type_constraints), - x::Int, - y::Float64, + (_, Δx, Δy)::Tuple{Any, Int, Float64}, + ::typeof(type_constraints), x::Int, y::Float64, ) return type_constraints(x, y), Δx + Δy end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any,Float64,Real,Vararg{Float64}}, - ::typeof(mixed_vararg_type_constaint), - x::Float64, - y::Real, - z::Vararg{Float64}, + dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, ) Δx = dargs[2] Δy = dargs[3] @@ -73,9 +65,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = ZeroTangent() @test frule((dself, 1), cool, 1) === nothing - @test frule((dself, 1), cool, 1; iscool = true) === nothing + @test frule((dself, 1), cool, 1; iscool=true) === nothing @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool = true) === nothing + @test rrule(cool, 1; iscool=true) === nothing # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @@ -84,10 +76,8 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([ - Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}, - ]) + only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, + Tuple{typeof(rrule),typeof(cool),String}]) @test cool_methods == only_methods frx, cool_pushforward = frule((dself, 1), cool, 1) @@ -108,26 +98,21 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) - @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == - (10.0, 10.0) + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing - @test( - frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, - 3.0, - 2.0, - 1.0, - 0.0, - ) == (6.0, 6.0) - ) + @test(frule( + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, + ) == (6.0, 6.0)) # violates type constraints, thus an frule should not be found. - @test frule((nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0) == - nothing + @test frule( + (nothing, 3, 2.0, 1.0, 5.0), + mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, + ) == nothing @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) @@ -168,29 +153,27 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) - @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where {T<:Float32} + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 @opt_out( - ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float32} + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 ) @testset "rrule" begin @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) - @test rrule(first_oa, 3.0f0, 4.0f0) === nothing + @test rrule(first_oa, 3f0, 4f0) === nothing @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} + m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 end) end @testset "frule" begin - @test frule((NoTangent(), 1, 0), first_oa, 3.0, 4.0) == (3.0, 1) - @test frule((NoTangent(), 1, 0), first_oa, 3.0f0, 4.0f0) === nothing - - @test !isempty( - Iterators.filter(methods(ChainRulesCore.no_frule)) do m - m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float32} - end, - ) + @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing + + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 + end) end end end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index fdbb92f55..7e0ec9398 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -9,7 +9,7 @@ @test view(NoTangent(), 1, 2) == NoTangent() @test sum(ZeroTangent()) == ZeroTangent() - @test sum(NoTangent(); dims = 2) == NoTangent() + @test sum(NoTangent(); dims=2) == NoTangent() end @testset "ZeroTangent" begin @@ -55,7 +55,7 @@ @test muladd(x, ZeroTangent(), ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), x, ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), ZeroTangent(), ZeroTangent()) === ZeroTangent() - + @test reim(z) === (ZeroTangent(), ZeroTangent()) @test real(z) === ZeroTangent() @test imag(z) === ZeroTangent() diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 2b7c6347e..2fd337979 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -1,14 +1,10 @@ @testset "NotImplemented" begin @testset "NotImplemented" begin ni = ChainRulesCore.NotImplemented( - @__MODULE__, - LineNumberNode(@__LINE__, @__FILE__), - "error", + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" ) ni2 = ChainRulesCore.NotImplemented( - @__MODULE__, - LineNumberNode(@__LINE__, @__FILE__), - "error2", + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" ) x = rand() thunk = @thunk(x^2) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index cc24d988e..694e43b53 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -1,6 +1,6 @@ # For testing Tangent struct Foo - x::Any + x y::Float64 end @@ -12,81 +12,81 @@ end # For testing Tangent: it is an invarient of the type that x2 = 2x # so simple addition can not be defined struct StructWithInvariant - x::Any - x2::Any + x + x2 StructWithInvariant(x) = new(x, 2x) end @testset "Tangent" begin @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{}, Tuple{}} end @testset "==" begin - @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(x = 0.1, y = 2.5) - @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(y = 2.5, x = 0.1) - @test Tangent{Foo}(y = 2.5, x = ZeroTangent()) == Tangent{Foo}(y = 2.5) + @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(x=0.1, y=2.5) + @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(y=2.5, x=0.1) + @test Tangent{Foo}(y=2.5, x=ZeroTangent()) == Tangent{Foo}(y=2.5) - @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) + @test Tangent{Tuple{Float64,}}(2.0) == Tangent{Tuple{Float64,}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2*1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(; y = 2.0) == Tangent{Foo}(; x = ZeroTangent(), y = Float32(2.0)) + @test Tangent{Foo}(;y=2.0,) == Tangent{Foo}(;x=ZeroTangent(), y=Float32(2.0),) end @testset "hash" begin - @test hash(Tangent{Foo}(x = 0.1, y = 2.5)) == hash(Tangent{Foo}(y = 2.5, x = 0.1)) - @test hash(Tangent{Foo}(y = 2.5, x = ZeroTangent())) == hash(Tangent{Foo}(y = 2.5)) + @test hash(Tangent{Foo}(x=0.1, y=2.5)) == hash(Tangent{Foo}(y=2.5, x=0.1)) + @test hash(Tangent{Foo}(y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(y=2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x = 2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x = 2.5)) == (:x,) - @test haskey(Tangent{Foo}(x = 2.5), :x) == true + @test keys(Tangent{Foo}(x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(x=2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x = 2.5), :y) == false + @test hasproperty(Tangent{Foo}(x=2.5), :y) == false end - @test Tangent{Foo}(x = 2.5).x == 2.5 - - @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - - NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 - - @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 + @test Tangent{Foo}(x=2.5).x == 2.5 + + @test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 + + NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}} + @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 + @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() + @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() + @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 + + @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 + @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x = 2.5)) == 1 - @test length(Tangent{Tuple{Float64}}(2.0)) == 1 + @test length(Tangent{Foo}(x=2.5)) == 1 + @test length(Tangent{Tuple{Float64,}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x = 2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 + @test eltype(Tangent{Foo}(x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64,}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x = 2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] + @test collect(Tangent{Foo}(x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64,}}(2.0)) == [2.0] # Test indexed_iterate ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function (tangent) + _unpack2tuple = function(tangent) a, b = tangent return (a, b) end @@ -96,33 +96,33 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x = 2, y = 3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y = 3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(y=3.0)) end end @testset "reverse" begin - c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") - cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) + c = Tangent{Tuple{Int, Int, String}}(1, 2, "something") + cr = Tangent{Tuple{String, Int, Int}}("something", 2, 1) @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x = 1.0, y = 2.0)) + @test_throws MethodError reverse(Tangent{Foo}(;x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo,typeof(d)}(d) + cdict = Tangent{Foo, typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end @testset "unset properties" begin - @test Tangent{Foo}(; x = 1.4).y === ZeroTangent() + @test Tangent{Foo}(; x=1.4).y === ZeroTangent() end @testset "conj" begin - @test conj(Tangent{Foo}(x = 2.0 + 3.0im)) == Tangent{Foo}(x = 2.0 - 3.0im) + @test conj(Tangent{Foo}(x=2.0+3.0im)) == Tangent{Foo}(x=2.0-3.0im) @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), - Tangent{Tuple{Float64}}(2.0 - 3.0im), + conj(Tangent{Tuple{Float64,}}(2.0+3.0im)), + Tangent{Tuple{Float64,}}(2.0-3.0im) ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -132,20 +132,26 @@ end @testset "canonicalize" begin # Testing iterate via collect - @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) + @test ==( + canonicalize(Tangent{Tuple{Float64,}}(2.0)), + Tangent{Tuple{Float64,}}(2.0) + ) - @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) + @test ==( + canonicalize(Tangent{Dict}(Dict(4 => 3))), + Tangent{Dict}(Dict(4 => 3)), + ) # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x = 2.5, y = 10)) == CFoo(x = 2.5, y = 10) - @test canonicalize(CFoo(y = 10, x = 2.5)) == CFoo(x = 2.5, y = 10) - @test canonicalize(CFoo(y = 10)) == CFoo(x = ZeroTangent(), y = 10) + @test canonicalize(CFoo(x=2.5, y=10)) == CFoo(x=2.5, y=10) + @test canonicalize(CFoo(y=10, x=2.5)) == CFoo(x=2.5, y=10) + @test canonicalize(CFoo(y=10)) == CFoo(x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(q = 99.0, x = 2.5)) + @test_throws ArgumentError canonicalize(CFoo(q=99.0, x=2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a = 1, b = 2) + c1 = Tangent{Any}(;a=1, b=2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -158,28 +164,30 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x = 1.5) + CFoo(x = 2.5) == CFoo(x = 4.0) - @test CFoo(y = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 2.5) - @test CFoo(y = 1.5, x = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 4.0) + @test CFoo(x=1.5) + CFoo(x=2.5) == CFoo(x=4.0) + @test CFoo(y=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=2.5) + @test CFoo(y=1.5, x=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=4.0) end @testset "Tuples" begin @test ==( typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{},Tuple{}}, + Tangent{Tuple{}, Tuple{}} ) @test ( - Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64,Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64, Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64, Float64}}(2.0, 3.0) end @testset "NamedTuples" begin - nt1 = (; a = 1.5, b = 0.0) - nt2 = (; a = 0.0, b = 2.5) - nt_sum = (a = 1.5, b = 2.5) - @test (Tangent{typeof(nt1)}(; nt1...) + Tangent{typeof(nt2)}(; nt2...)) == - Tangent{typeof(nt_sum)}(; nt_sum...) + nt1 = (;a=1.5, b=0.0) + nt2 = (;a=0.0, b=2.5) + nt_sum = (a=1.5, b=2.5) + @test ( + Tangent{typeof(nt1)}(; nt1...) + + Tangent{typeof(nt2)}(; nt2...) + ) == Tangent{typeof(nt_sum)}(; nt_sum...) end @testset "Dicts" begin @@ -191,8 +199,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x = 1.5) - b = CFoo(x = @not_implemented("")) + a = CFoo(x=1.5) + b = CFoo(x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -207,8 +215,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x = 1.5) - b = Tangent{NamedTuple{(:x,)}}(x = @not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(x=1.5) + b = Tangent{NamedTuple{(:x,)}}(x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -227,35 +235,35 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x = 2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x = 2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x = 0.5)) == 0 + @test Foo(3.5, 1.5) + Tangent{Foo}(x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end @testset "Tuples" begin @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + @test ((1.0, 2.0) + Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end @testset "NamedTuple" begin - ntx = (; a = 1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a = 3.0) + ntx = (; a=1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) - nty = (; a = 1.5, b = 0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a = 3.0, b = 1.0) + nty = (; a=1.5, b=0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) end @testset "Dicts" begin d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 =>5.0)) @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x = 2.0, x2 = 6.0) + diff = Tangent{StructWithInvariant}(x=2.0, x2=6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -272,7 +280,7 @@ end # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2 / 2) / 2 + x = (nt.x + nt.x2/2)/2 return StructWithInvariant(x) end @test value + diff == StructWithInvariant(12.5) @@ -280,7 +288,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y = 1.5, x = 2.5) + c = Tangent{Foo}(y=1.5, x=2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -302,14 +310,14 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y = 1.5, x = 2.5) == - Tangent{Foo}(y = 3.0, x = 5.0) == - Tangent{Foo}(y = 1.5, x = 2.5) * 2 + 2 * Tangent{Foo}(y=1.5, x=2.5) + == Tangent{Foo}(y=3.0, x=5.0) + == Tangent{Foo}(y=1.5, x=2.5) * 2 ) @test ( - 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == - Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == - Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 + 2 * Tangent{Tuple{Float64, Float64}}(2.0, 4.0) + == Tangent{Tuple{Float64, Float64}}(4.0, 8.0) + == Tangent{Tuple{Float64, Float64}}(2.0, 4.0) * 2 ) d = Tangent{Dict}(Dict(4 => 3.0)) two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) @@ -317,7 +325,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x = 1)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(x=1,)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -334,9 +342,8 @@ end end @testset "Internals don't allocate a ton" begin - bk = (; x = 1.0, y = 2.0) - VERSION >= v"1.5" && - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + bk = (; x=1.0, y=2.0) + VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 # weaker version of the above (which should pass on all versions) @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 @@ -345,8 +352,8 @@ end end @testset "non-same-typed differential arithmetic" begin - nt = (; a = 1, b = 2.0) - c = Tangent{typeof(nt)}(; a = NoTangent(), b = 0.1) - @test nt + c == (; a = 1, b = 2.1) + nt = (; a=1, b=2.0) + c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) + @test nt + c == (; a=1, b=2.1); end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index af4a747d1..89461caa1 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -141,7 +141,7 @@ # Check against accidential type piracy # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472 @test Base.which(diagm, Tuple{}()).module != ChainRulesCore - @test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore + @test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore end @test tril(a) == tril(t) @test tril(a, 1) == tril(t, 1) From 8b54ef67146db7c1e085535d781063fef94ecbb0 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 13:50:00 +0300 Subject: [PATCH 07/20] format(".") with blue style from scratch --- docs/make.jl | 29 ++-- docs/src/assets/make_logo.jl | 48 +++--- src/accumulation.jl | 8 +- src/compat.jl | 4 +- src/deprecated.jl | 1 + src/ignore_derivatives.jl | 8 +- src/projection.jl | 85 ++++++----- src/rule_definition_tools.jl | 84 +++++++---- src/tangent_arithmetic.jl | 14 +- src/tangent_types/abstract_zero.jl | 8 +- src/tangent_types/notimplemented.jl | 10 +- src/tangent_types/tangent.jl | 113 ++++++++------- src/tangent_types/thunks.jl | 16 +- test/accumulation.jl | 24 +-- test/config.jl | 76 ++++++---- test/deprecated.jl | 1 + test/ignore_derivatives.jl | 8 +- test/projection.jl | 40 ++--- test/rule_definition_tools.jl | 175 +++++++++++----------- test/rules.jl | 75 ++++++---- test/tangent_types/abstract_zero.jl | 4 +- test/tangent_types/notimplemented.jl | 8 +- test/tangent_types/tangent.jl | 209 +++++++++++++-------------- test/tangent_types/thunks.jl | 2 +- 24 files changed, 567 insertions(+), 483 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1ef3a62a7..608422c25 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,20 +16,20 @@ DocMeta.setdocmeta!( @scalar_rule(sin(x), cos(x)) # frule and rrule doctest @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end + end, ) indigo = DocThemeIndigo.install(ChainRulesCore) makedocs( - modules=[ChainRulesCore], - format=Documenter.HTML( - prettyurls=false, - assets=[indigo], - mathengine=MathJax3( + modules = [ChainRulesCore], + format = Documenter.HTML( + prettyurls = false, + assets = [indigo], + mathengine = MathJax3( Dict( :tex => Dict( - "inlineMath" => [["\$","\$"], ["\\(","\\)"]], + "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], "tags" => "ams", # TODO: remove when using physics package "macros" => Dict( @@ -42,9 +42,9 @@ makedocs( ), ), ), - sitename="ChainRules", - authors="Jarrett Revels and other contributors", - pages=[ + sitename = "ChainRules", + authors = "Jarrett Revels and other contributors", + pages = [ "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", @@ -63,11 +63,8 @@ makedocs( ], "API" => "api.md", ], - strict=true, - checkdocs=:exports, + strict = true, + checkdocs = :exports, ) -deploydocs( - repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", - push_preview=true, -) +deploydocs(repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview = true) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 5bbfd36c1..3e7aeaa08 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -8,34 +8,34 @@ using Random const bridge_len = 50 -function chain(jiggle=0) - shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) - +function chain(jiggle = 0) + shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) + ### 1 shaky_rotate(0) sethue(Luxor.julia_red) link() m1 = getmatrix() - - + + ### 2 sethue(Luxor.julia_green) - translate(-50, 130); - shaky_rotate(π/3); + translate(-50, 130) + shaky_rotate(π / 3) link() m2 = getmatrix() - + setmatrix(m1) sethue(Luxor.julia_red) overlap(-1.3π) setmatrix(m2) - + ### 3 - shaky_rotate(-π/3); - translate(-120,80); + shaky_rotate(-π / 3) + translate(-120, 80) sethue(Luxor.julia_purple) link() - + setmatrix(m2) setcolor(Luxor.julia_green) overlap(-1.5π) @@ -45,24 +45,24 @@ end function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - - - rect(50,-3,40, bridge_len+6, :fill) - rect(-50-40,-3,40, bridge_len+6, :fill) - + + + rect(50, -3, 40, bridge_len + 6, :fill) + rect(-50 - 40, -3, 40, bridge_len + 6, :fill) + sethue("black") move(Point(-50, bridge_len)) - arc(Point(0,0), 50, π, 0, :stoke) + arc(Point(0, 0), 50, π, 0, :stoke) arc(Point(0, bridge_len), 50, 0, -π, :stroke) - + move(Point(-90, bridge_len)) - arc(Point(0,0), 90, π, 0, :stoke) + arc(Point(0, 0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) strokepath() end function overlap(ang_end) - sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) + sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill) sethue("black") arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) move(Point(90, bridge_len)) @@ -75,13 +75,13 @@ end function save_logo(filename) Random.seed!(16) - Drawing(450,450, filename) + Drawing(450, 450, filename) origin() - translate(50, -130); + translate(50, -130) chain(0.5) finish() preview() end save_logo("logo.svg") -save_logo("logo.png") \ No newline at end of file +save_logo("logo.png") diff --git a/src/accumulation.jl b/src/accumulation.jl index 4bcc5c33f..5fbc07fa8 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,7 +26,7 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) -function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N +function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} return if is_inplaceable_destination(x) x .+= y else @@ -75,8 +75,8 @@ end struct BadInplaceException <: Exception ithunk::InplaceableThunk - accumuland - returned_value + accumuland::Any + returned_value::Any end function Base.showerror(io::IO, err::BadInplaceException) @@ -88,7 +88,7 @@ function Base.showerror(io::IO, err::BadInplaceException) if err.accumuland == err.returned_value println( io, - "Which in this case happenned to be equal. But they are not the same object." + "Which in this case happenned to be equal. But they are not the same object.", ) end end diff --git a/src/compat.jl b/src/compat.jl index 8204b66d5..fa66b1d0f 100644 --- a/src/compat.jl +++ b/src/compat.jl @@ -5,7 +5,7 @@ end if VERSION < v"1.1" # Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4 # See: https://github.com/JuliaLang/julia/issues/34292 - function fieldtypes(::Type{T}) where T + function fieldtypes(::Type{T}) where {T} if @generated ntuple(i -> fieldtype(T, i), fieldcount(T)) else @@ -13,7 +13,7 @@ if VERSION < v"1.1" end end - function fieldnames(::Type{T}) where T + function fieldnames(::Type{T}) where {T} if @generated ntuple(i -> fieldname(T, i), fieldcount(T)) else diff --git a/src/deprecated.jl b/src/deprecated.jl index e69de29bb..8b1378917 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -0,0 +1 @@ + diff --git a/src/ignore_derivatives.jl b/src/ignore_derivatives.jl index c66d89d7e..18865f2c9 100644 --- a/src/ignore_derivatives.jl +++ b/src/ignore_derivatives.jl @@ -45,7 +45,9 @@ ignore_derivatives(x) = x Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`. """ macro ignore_derivatives(ex) - return :(ChainRulesCore.ignore_derivatives() do - $(esc(ex)) - end) + return :( + ChainRulesCore.ignore_derivatives() do + $(esc(ex)) + end + ) end diff --git a/src/projection.jl b/src/projection.jl index 4b07b2762..55f6e7bfd 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -131,7 +131,8 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}() +ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} = + ProjectTo{NoTangent}() # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -164,12 +165,14 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) end # In these cases we can just `convert` as we know we are dealing with plain and simple types -(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx) -(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity +(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) +(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = + convert(T, dx) (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = + convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -190,7 +193,7 @@ end # For arrays of numbers, just store one projector: function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) + return ProjectTo{AbstractArray}(; element = _eltype_projectto(T), axes = axes(x)) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -204,7 +207,7 @@ function ProjectTo(xs::AbstractArray) return ProjectTo{NoTangent}() # short-circuit if all elements project to zero else # Arrays of arrays come here, and will apply projectors individually: - return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs)) + return ProjectTo{AbstractArray}(; elements = elements, axes = axes(xs)) end end @@ -214,7 +217,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} dy = if axes(dx) == project.axes dx else - for d in 1:max(M, length(project.axes)) + for d = 1:max(M, length(project.axes)) if size(dx, d) != length(get(project.axes, d, 1)) throw(_projection_mismatch(project.axes, size(dx))) end @@ -244,9 +247,11 @@ end # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers if !(project.axes isa Tuple{}) - throw(DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - )) + throw( + DimensionMismatch( + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", + ), + ) end return fill(project.element(dx)) end @@ -254,7 +259,7 @@ end function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) return DimensionMismatch( - "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx" + "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx", ) end @@ -268,13 +273,13 @@ function ProjectTo(x::Ref) if sub isa ProjectTo{<:AbstractZero} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type=typeof(x), x=sub) + return ProjectTo{Ref}(; type = typeof(x), x = sub) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) +(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x = project.x(dx.x)) +(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x = project.x(dx[])) # Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) +(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x = project.x(dx)) ##### ##### `LinearAlgebra` @@ -283,7 +288,7 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Row vectors -ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x))) +ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent = ProjectTo(parent(x))) # Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec. # Transposed matrices are, like PermutedDimsArray, just a storage detail, # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number @@ -298,7 +303,8 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +ProjectTo(x::LinearAlgebra.TransposeAbsVec) = + ProjectTo{Transpose}(; parent = ProjectTo(parent(x))) function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -311,21 +317,22 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) end # Diagonal -ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) +ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag = ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) # Symmetric -for (SymHerm, chk, fun) in ( - (:Symmetric, :issymmetric, :transpose), - (:Hermitian, :ishermitian, :adjoint), - ) +for (SymHerm, chk, fun) in + ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) @eval begin function ProjectTo(x::$SymHerm) sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) + return ProjectTo{$SymHerm}(; + uplo = LinearAlgebra.sym_uplo(x.uplo), + parent = sub, + ) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) dy = project.parent(dx) @@ -338,9 +345,8 @@ for (SymHerm, chk, fun) in ( # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) sub = project.parent # this is going to be unhappy about the size - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) + sub_one = + ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) return Diagonal(sub_one(dx.diag)) end end @@ -349,13 +355,12 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) + ProjectTo(x::$UL) = ProjectTo{$UL}(; parent = ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) + sub_one = + ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) return Diagonal(sub_one(dx.diag)) end end @@ -392,7 +397,7 @@ end # another strategy is just to use the AbstractArray method function ProjectTo(x::Tridiagonal{T}) where {T<:Number} notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent=notparent) + return ProjectTo{Tridiagonal}(; notparent = notparent) end function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) dy = project.notparent(dx) @@ -411,7 +416,9 @@ using SparseArrays function ProjectTo(x::SparseVector{T}) where {T<:Number} return ProjectTo{SparseVector}(; - element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) + element = ProjectTo(zero(T)), + nzind = x.nzind, + axes = axes(x), ) end function (project::ProjectTo{SparseVector})(dx::AbstractArray) @@ -450,11 +457,11 @@ end function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} return ProjectTo{SparseMatrixCSC}(; - element=ProjectTo(zero(T)), - axes=axes(x), - rowval=rowvals(x), - nzranges=nzrange.(Ref(x), axes(x, 2)), - colptr=x.colptr, + element = ProjectTo(zero(T)), + axes = axes(x), + rowval = rowvals(x), + nzranges = nzrange.(Ref(x), axes(x, 2)), + colptr = x.colptr, ) end # You need not really store nzranges, you can get them from colptr -- TODO @@ -474,7 +481,7 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) for i in project.nzranges[col] row = project.rowval[i] val = dy[row, col] - nzval[k += 1] = project.element(val) + nzval[k+=1] = project.element(val) end end m, n = map(length, project.axes) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 911a32ddd..8a1e1cce4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -83,9 +83,8 @@ For examples, see ChainRules' `rulesets` directory. See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) - call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( - call, maybe_setup, partials - ) + call, setup_stmts, inputs, partials = + _normalize_scalarrules_macro_input(call, maybe_setup, partials) f = call.args[1] # Generate variables to store derivatives named dfi/dxj @@ -101,9 +100,11 @@ macro scalar_rule(call, maybe_setup, partials...) # Final return: building the expression to insert in the place of this macro code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw(ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))" - )) + throw( + ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))", + ), + ) end $(derivative_expr) @@ -175,7 +176,11 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote - function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + function ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), + ::Core.Typeof($f), + $(inputs...), + ) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) @@ -196,9 +201,8 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) end if n_outputs > 1 # For forward-mode we return a Tangent if output actually a tuple. - pushforward_returns = Expr( - :call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns... - ) + pushforward_returns = + Expr(:call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns...) else pushforward_returns = first(pushforward_returns) end @@ -210,7 +214,8 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = + ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -225,7 +230,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) Δs = _propagator_inputs(n_outputs) # Make a projector for each argument - projs, psetup = _make_projectors(call.args[2:end]) + projs, psetup = _make_projectors(call.args[2:end]) append!(setup_stmts, psetup) # 1 partial derivative per input @@ -248,7 +253,8 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = + ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -257,12 +263,12 @@ end # For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" -_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i = 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) projs = map(x -> Symbol(:proj_, x.args[1]), xs) - setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs) + setups = map((x, p) -> :($p = ProjectTo($x)), xs, projs) return projs, setups end @@ -275,7 +281,7 @@ Specify `_conj = true` to conjugate the partials. Projector `proj` is a function that will be applied at the end; for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity` """ -function propagation_expr(Δs, ∂s, _conj=false, proj=identity) +function propagation_expr(Δs, ∂s, _conj = false, proj = identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj @@ -288,9 +294,10 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) - end + summed_∂_mul_Δs = + foldl(Iterators.drop(zip(_∂s, Δs), 1); init = init_expr) do ex, (∂s_i, Δs_i) + :(muladd($∂s_i, $Δs_i, $ex)) + end return :($proj($summed_∂_mul_Δs)) end @@ -381,7 +388,10 @@ end function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... + :call, + call_expr.args[1], + Expr(:parameters, :($(kwargs)...)), + call_expr.args[2:end]..., ) end @@ -389,11 +399,18 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), - frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function (::Core.kwftype(typeof(ChainRulesCore.frule)))( + @nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end - function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function ChainRulesCore.frule( + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), + ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() return ($(esc(primal_invoke)), NoTangent()) @@ -408,7 +425,8 @@ function tuple_expression(primal_sig_parts) Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + length_expr = + :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) end end @@ -426,7 +444,11 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) + function (::Core.kwftype(typeof(rrule)))( + $(esc(kwargs))::Any, + ::typeof(rrule), + $(esc_primal_sig_parts...), + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) @@ -481,7 +503,7 @@ end "Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) - length(expr.args)===0 && error("Malformed method expression. $expr") + length(expr.args) === 0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore @@ -555,12 +577,13 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) + if primal_name isa Symbol || + Meta.isexpr(primal_name, :(.)) || + Meta.isexpr(primal_name, :curly) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name - # e.g. (::T)(x, y) + # e.g. (::T)(x, y) elseif Meta.isexpr(primal_name, :(::)) _primal_name = gensym(Symbol(:instance_, primal_name.args[end])) primal_name_sig = Expr(:(::), _primal_name, primal_name.args[end]) @@ -582,7 +605,8 @@ end function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) + Meta.isexpr(arg, :(...), 1) && + return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 9c1378aab..c2bad7a77 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -81,7 +81,7 @@ LinearAlgebra.dot(::ZeroTangent, ::NoTangent) = ZeroTangent() Base.muladd(::ZeroTangent, x, y) = y Base.muladd(x, ::ZeroTangent, y) = y -Base.muladd(x, y, ::ZeroTangent) = x*y +Base.muladd(x, y, ::ZeroTangent) = x * y Base.muladd(::ZeroTangent, ::ZeroTangent, y) = y Base.muladd(x, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() @@ -125,11 +125,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where P +function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P, typeof(data)}(data) + return Tangent{P,typeof(data)}(data) end -function Base.:+(a::P, d::Tangent{P}) where P +function Base.:+(a::P, d::Tangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -142,12 +142,12 @@ function Base.:+(a::P, d::Tangent{P}) where P end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where P = b + a +Base.:+(a::Tangent{P}, b::P) where {P} = b + a # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 differentials # Only of a differential and a scaling factor (generally `Real`) for T in (:Any,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent) + @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 216357e91..c86fc78ea 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -17,15 +17,15 @@ Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) -Base.Broadcast.broadcasted(::Type{T}) where T<:AbstractZero = T() +Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z -Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) -(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T) +Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) +(::Type{T})(xs::AbstractZero...) where {T<:Number} = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) @@ -33,7 +33,7 @@ Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) Base.getindex(z::AbstractZero, k) = z Base.view(z::AbstractZero, ind...) = z -Base.sum(z::AbstractZero; dims=:) = z +Base.sum(z::AbstractZero; dims = :) = z """ ZeroTangent() <: AbstractZero diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index a2044fbe1..7ceb315ea 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,9 +44,13 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw(NotImplementedException(@not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" -))) +Base.zero(::Type{<:NotImplemented}) = throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" + ) + ), +) Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index e4bbfb8c8..34e822ea8 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -21,42 +21,42 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P, T} <: AbstractTangent +struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T end -function Tangent{P}(; kwargs...) where P +function Tangent{P}(; kwargs...) where {P} backing = (; kwargs...) # construct as NamedTuple - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end -function Tangent{P}(args...) where P - return Tangent{P, typeof(args)}(args) +function Tangent{P}(args...) where {P} + return Tangent{P,typeof(args)}(args) end -function Tangent{P}() where P<:Tuple +function Tangent{P}() where {P<:Tuple} backing = () - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P, typeof(d)}(d) + return Tangent{P,typeof(d)}(d) end -function Base.:(==)(a::Tangent{P, T}, b::Tangent{P, T}) where {P, T} +function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P, T} +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P,T} all_fields = union(keys(backing(a)), keys(backing(b))) return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where P +function Base.show(io::IO, tangent::Tangent{P}) where {P} print(io, "Tangent{") show(io, P) print(io, "}") @@ -68,15 +68,15 @@ function Base.show(io::IO, tangent::Tangent{P}) where P end end -function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}} +function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end -function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict} +function Base.getindex(tangent::Tangent, idx) where {P,T<:AbstractDict} return unthunk(getindex(backing(tangent), idx)) end @@ -84,7 +84,7 @@ function Base.getproperty(tangent::Tangent, idx::Int) back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end @@ -99,26 +99,26 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) -Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T) +Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing) + Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) end -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state = 1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end -function Base.map(f, tangent::Tangent{P, <:Tuple}) where P +function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} vals::Tuple = map(f, backing(tangent)) - return Tangent{P, typeof(vals)}(vals) + return Tangent{P,typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L} +function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L, typeof(vals)}(vals) - return Tangent{P, typeof(named_vals)}(named_vals) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return Tangent{P,typeof(named_vals)}(named_vals) end -function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict} +function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end @@ -140,26 +140,28 @@ backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) # For generic structs -function backing(x::T)::NamedTuple where T +function backing(x::T)::NamedTuple where {T} # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. # see https://github.com/JuliaLang/julia/issues/34283 if @generated - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names, Tuple{$(types...)}}($vals)) + vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) else - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = ntuple(ii->getfield(x, ii), nfields) - return NamedTuple{names, Tuple{types...}}(vals) + vals = ntuple(ii -> getfield(x, ii), nfields) + return NamedTuple{names,Tuple{types...}}(vals) end end @@ -170,36 +172,38 @@ Return the canonical `Tangent` for the primal type `P`. The property names of the returned `Tangent` match the field names of the primal, and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. """ -function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L} +function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} nil = _zeroed_backing(P) combined = merge(nil, backing(tangent)) if length(combined) !== fieldcount(P) - throw(ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))" - )) + throw( + ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", + ), + ) end - return Tangent{P, typeof(combined)}(combined) + return Tangent{P,typeof(combined)}(combined) end # Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent +canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent # Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent +canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent # Tangents of unspecified primal types (indicated by specifying exactly `Any`) # all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:Tuple}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:AbstractDict}) where {L} = tangent """ _zeroed_backing(P) Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. """ -@generated function _zeroed_backing(::Type{P}) where P +@generated function _zeroed_backing(::Type{P}) where {P} nil_base = ntuple(fieldcount(P)) do i (fieldname(P, i), ZeroTangent()) end @@ -218,7 +222,7 @@ after an operation such as the addition of a primal to a tangent It should be overloaded, if `T` does not have a default constructor, or if `T` needs to maintain some invarients between its fields. """ -function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} +function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} # Tested and verified that that this avoids a ton of allocations if length(L) !== fieldcount(T) # if length is equal but names differ then we will catch that below anyway. @@ -233,12 +237,12 @@ function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} end end -construct(::Type{T}, fields::T) where T<:NamedTuple = fields -construct(::Type{T}, fields::T) where T<:Tuple = fields +construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields +construct(::Type{T}, fields::T) where {T<:Tuple} = fields elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} # Rule of Tangent addition: any fields not present are implict hard Zeros # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. @@ -281,7 +285,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} end field => value end - return (;vals...) + return (; vals...) end end @@ -297,15 +301,16 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color = :blue) println(io, "\nor overload") - printstyled(io, + printstyled( + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color=:blue + color = :blue, ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color = :blue) println(io, "\nOriginal Exception:") - printstyled(io, err.original; color=:yellow) + printstyled(io, err.original; color = :yellow) println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 16384d69e..c2b570902 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -56,18 +56,22 @@ LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a)) LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a)) LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a)) LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a)) -LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo) -LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo) +LinearAlgebra.Symmetric(a::AbstractThunk, uplo = :U) = Symmetric(unthunk(a), uplo) +LinearAlgebra.Hermitian(a::AbstractThunk, uplo = :U) = Hermitian(unthunk(a), uplo) function LinearAlgebra.diagm( - kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... + kv::Pair{<:Integer,<:AbstractThunk}, + kvs::Pair{<:Integer,<:AbstractThunk}..., ) return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...) end function LinearAlgebra.diagm( - m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... + m, + n, + kv::Pair{<:Integer,<:AbstractThunk}, + kvs::Pair{<:Integer,<:AbstractThunk}..., ) - return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) + return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a)) @@ -118,7 +122,7 @@ function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx) return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx) end -function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1) +function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn = 1) return throw(MutateThunkException()) end diff --git a/test/accumulation.jl b/test/accumulation.jl index 1b41fea55..a796b5289 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -27,7 +27,7 @@ end @testset "misc AbstractTangent subtypes" begin - @test 16 == add!!(12, @thunk(2*2)) + @test 16 == add!!(12, @thunk(2 * 2)) @test 16 == add!!(16, ZeroTangent()) @test 16 == add!!(16, NoTangent()) # Should this be an error? @@ -37,15 +37,15 @@ @testset "LHS Array (inplace)" begin @testset "RHS Array" begin A = [1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] end @testset "RHS StaticArray" begin - A = @SMatrix[1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + A = @SMatrix [1.0 2.0; 3.0 4.0] + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] @@ -53,7 +53,7 @@ @testset "RHS Diagonal" begin A = Diagonal([1.0, 2.0]) - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 -1.0; -1.0 1.0] @@ -79,17 +79,17 @@ @testset "Unhappy Path" begin # wrong length - @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(2, 2)) # wrong shape - @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(16)) # wrong type (adding scalar to array) @test_throws MethodError add!!(ones(4), 21.0) end end @testset "AbstractThunk $(typeof(thunk))" for thunk in ( - @thunk(-1.0*ones(2, 2)), - InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))), + @thunk(-1.0 * ones(2, 2)), + InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -111,12 +111,12 @@ @testset "not actually inplace but said it was" begin # thunk should never be used in this test ithunk = InplaceableThunk(@thunk(@assert false)) do x - 77*ones(2, 2) # not actually inplace (also wrong) + 77 * ones(2, 2) # not actually inplace (also wrong) end accumuland = ones(2, 2) @assert ChainRulesCore.debug_mode() == false # without debug being enabled should return the result, not error - @test 77*ones(2, 2) == add!!(accumuland, ithunk) + @test 77 * ones(2, 2) == add!!(accumuland, ithunk) ChainRulesCore.debug_mode() = true # enable debug mode # with debug being enabled should error @@ -127,7 +127,7 @@ @testset "showerror BadInplaceException" begin BadInplaceException = ChainRulesCore.BadInplaceException - ithunk = InplaceableThunk(x̄->nothing, @thunk(@assert false)) + ithunk = InplaceableThunk(x̄ -> nothing, @thunk(@assert false)) msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) @test occursin("22", msg) diff --git a/test/config.jl b/test/config.jl index 466baed9a..e6e2ab005 100644 --- a/test/config.jl +++ b/test/config.jl @@ -1,7 +1,7 @@ # Define a bunch of configs for testing purposes struct MostBoringConfig <: RuleConfig{Union{}} end -struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} forward_calls::Vector end MockForwardsConfig() = MockForwardsConfig([]) @@ -11,7 +11,7 @@ function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args. return f(args...; kws...), ȧrgs end -struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} reverse_calls::Vector end MockReverseConfig() = MockReverseConfig([]) @@ -23,7 +23,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. end -struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} forward_calls::Vector reverse_calls::Vector end @@ -47,18 +47,21 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, + MostBoringConfig, + MockForwardsConfig, + MockReverseConfig, + MockBothConfig, ) counting_id_count = Ref(0) function counting_id(x) - counting_id_count[]+=1 + counting_id_count[] += 1 return x end function ChainRulesCore.rrule(::typeof(counting_id), x) counting_id_pullback(x̄) = x̄ return counting_id(x), counting_id_pullback end - function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) + function ChainRulesCore.frule((dself, dx), ::typeof(counting_id), x) return counting_id(x), dx end @testset "rrule" begin @@ -76,21 +79,33 @@ end @testset "hitting forwards AD" begin do_thing_2(f, x) = f(x) function ChainRulesCore.frule( - config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x + config::RuleConfig{>:HasForwardsMode}, + (_, df, dx), + ::typeof(do_thing_2), + f, + x, ) return frule_via_ad(config, (df, dx), f, x) end @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) @test nothing === frule( - Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + Config(), + (NoTangent(), NoTangent(), 21.5), + do_thing_2, + identity, + 32.1, ) end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== frule( - bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + bconfig, + (NoTangent(), NoTangent(), 21.5), + do_thing_2, + identity, + 32.1, ) @test bconfig.forward_calls == [(identity, (32.1,))] end @@ -99,7 +114,10 @@ end @testset "hitting reverse AD" begin do_thing_3(f, x) = f(x) function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x + config::RuleConfig{>:HasReverseMode}, + ::typeof(do_thing_3), + f, + x, ) return (NoTangent(), rrule_via_ad(config, f, x)...) end @@ -110,7 +128,7 @@ end end @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) @test bconfig.reverse_calls == [(identity, (32.1,))] end @@ -130,14 +148,14 @@ end ẋ = one(x) y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) - pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ * ȳ return y, pullback_via_forwards_ad end function ChainRulesCore.rrule( - config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, + config::RuleConfig{>:Union{HasReverseMode,NoForwardsMode}}, ::typeof(do_thing_4), f, - x + x, ) y, f_pullback = rrule_via_ad(config, f, x) do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) @@ -147,43 +165,43 @@ end @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) @test bconfig.forward_calls == [(identity, (32.1,))] end - rconfig= MockReverseConfig() + rconfig = MockReverseConfig() @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) @test rconfig.reverse_calls == [(identity, (32.1,))] end @testset "RuleConfig broadcasts like a scaler" begin - @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} + @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} end @testset "fallbacks" begin - no_rule(x; kw="bye") = error() + no_rule(x; kw = "bye") = error() @test frule((1.0,), no_rule, 2.0) === nothing - @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule((1.0,), no_rule, 2.0; kw = "hello") === nothing @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing - @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw = "hello") === nothing @test rrule(no_rule, 2.0) === nothing - @test rrule(no_rule, 2.0; kw="hello") === nothing + @test rrule(no_rule, 2.0; kw = "hello") === nothing @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing - @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw = "hello") === nothing # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(; kw = "hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(sin; kw = "hello") @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(); kw = "hello") @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin; kw = "hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(; kw = "hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig();kw="hello") + @test_throws MethodError rrule(MostBoringConfig(); kw = "hello") end end diff --git a/test/deprecated.jl b/test/deprecated.jl index e69de29bb..8b1378917 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -0,0 +1 @@ + diff --git a/test/ignore_derivatives.jl b/test/ignore_derivatives.jl index 825287b9a..ad4fece9f 100644 --- a/test/ignore_derivatives.jl +++ b/test/ignore_derivatives.jl @@ -7,7 +7,7 @@ end @testset "function" begin f() = return 4.0 - y, ẏ = frule((1.0, ), ignore_derivatives, f) + y, ẏ = frule((1.0,), ignore_derivatives, f) @test y == f() @test ẏ == NoTangent() @@ -19,7 +19,7 @@ end @testset "argument" begin arg = 2.1 - y, ẏ = frule((1.0, ), ignore_derivatives, arg) + y, ẏ = frule((1.0,), ignore_derivatives, arg) @test y == arg @test ẏ == NoTangent() @@ -41,11 +41,11 @@ end @test pb(1.0) == (NoTangent(), NoTangent()) # when called - y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0)) + y, ẏ = frule((1.0,), ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test ẏ == NoTangent() - y, pb = rrule(ignore_derivatives, ()->mf(3.0)) + y, pb = rrule(ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test pb(1.0) == (NoTangent(), NoTangent()) end diff --git a/test/projection.jl b/test/projection.jl index ba61fb8da..ab418ef79 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,20 +24,21 @@ struct NoSuperType end # real / complex @test ProjectTo(1.0)(2.0 + 3im) === 2.0 @test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im - @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im - @test ProjectTo(2.0)(1+1im) === 1.0 - + @test ProjectTo(2.0 + 3.0im)(1 + 1im) === 1.0 + 1.0im + @test ProjectTo(2.0)(1 + 1im) === 1.0 + # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) - @test ProjectTo(1//2)(3//4) === 3//4 + @test ProjectTo(1 // 2)(3 // 4) === 3 // 4 @test ProjectTo(1.0f0)(1 / 2) === 0.5f0 @test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im @test ProjectTo(big(1.0))(2) === 2 @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re = 1, im = NoTangent())) === + 1.0f0 + 0.0f0im end @testset "Dual" begin # some weird Real subtype that we should basically leave alone @@ -46,13 +47,12 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} - @test ProjectTo(1.0 + 1im)( - Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)) - ) isa Complex{<:Dual} + @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa + Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent - @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent + @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value = 1.0)) isa Tangent end @testset "Base: arrays of numbers" begin @@ -99,10 +99,10 @@ struct NoSuperType end # arrays of other things @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} - @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} + @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. - @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) + @test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im) @test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent} # empty arrays @@ -172,7 +172,7 @@ struct NoSuperType end # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called - xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) + xs = adj(Any[Any[1, 2, 3], Any[4+im, 5-im, 6+im, 7-im]]) pvecvec3 = ProjectTo(xs) @test pvecvec3(xs)[1] == [1 2 3] @test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im]) @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()); + dx = Tangent{typeof(x)}(; step = 0.1, ref = NoTangent()) @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3+4im) + dy = Tangent{typeof(Ref(2.0))}(x = 3 + 4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end @@ -365,21 +365,21 @@ struct NoSuperType end # Each "@test 33 > ..." is zero on nightly, 32 on 1.5. pvec = ProjectTo(rand(10^3)) - @test 0 == @ballocated $pvec(dx) setup=(dx = rand(10^3)) # pass through - @test 90 > @ballocated $pvec(dx) setup=(dx = rand(10^3, 1)) # reshape + @test 0 == @ballocated $pvec(dx) setup = (dx = rand(10^3)) # pass through + @test 90 > @ballocated $pvec(dx) setup = (dx = rand(10^3, 1)) # reshape @test 33 > @ballocated ProjectTo(x)(dx) setup = (x = rand(10^3); dx = rand(10^3)) # including construction padj = ProjectTo(adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = transpose(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = adjoint(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = transpose(rand(10^3))) @test 33 > @ballocated ProjectTo(x')(dx') setup = (x = rand(10^3); dx = rand(10^3)) pdiag = ProjectTo(Diagonal(rand(10^3))) - @test 0 == @ballocated $pdiag(dx) setup=(dx = Diagonal(rand(10^3))) + @test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3))) psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) - @test_broken 0 == @ballocated $psymm(dx) setup=(dx = Symmetric(rand(10^3, 10^3))) # 64 + @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 0d6d98535..e99b66c2f 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -19,7 +19,7 @@ macro test_macro_throws(err_expr, expr) end end # Reuse `@test_throws` logic - if err!==nothing + if err !== nothing @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) @@ -29,21 +29,21 @@ end # struct need to be defined outside of tests for julia 1.0 compat struct NonDiffExample - x + x::Any end struct NonDiffCounterExample - x + x::Any end module NonDiffModuleExample - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] +nondiff_2_1(x, y) = fill(7.5, 100)[x+y] end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] + nondiff_2_1(x, y) = fill(7.5, 100)[x+y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((ZeroTangent(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, NoTangent()) res, pullback = rrule(nondiff_2_1, 3, 2) @@ -58,7 +58,7 @@ end res, pullback = rrule(nondiff_1_2, 3.1) @test res == (5.0, 3.0) @test isequal( - pullback(Tangent{Tuple{Float64, Float64}}(1.2, 3.2)), + pullback(Tangent{Tuple{Float64,Float64}}(1.2, 3.2)), (NoTangent(), NoTangent()), ) end @@ -81,7 +81,8 @@ end pointy_identity(x) = x @non_differentiable pointy_identity(::Vector{<:AbstractString}) - @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == (["2"], NoTangent()) + @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == + (["2"], NoTangent()) @test frule((ZeroTangent(), 1.2), pointy_identity, 2.0) == nothing res, pullback = rrule(pointy_identity, ["2"]) @@ -92,7 +93,7 @@ end end @testset "kwargs" begin - kw_demo(x; kw=2.0) = x + kw + kw_demo(x; kw = 2.0) = x + kw @non_differentiable kw_demo(::Any) @testset "not setting kw" begin @@ -106,13 +107,14 @@ end end @testset "setting kw" begin - @assert kw_demo(1.5; kw=3.0) == 4.5 + @assert kw_demo(1.5; kw = 3.0) == 4.5 - res, pullback = rrule(kw_demo, 1.5; kw=3.0) + res, pullback = rrule(kw_demo, 1.5; kw = 3.0) @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw = 3.0) == + (4.5, NoTangent()) end end @@ -121,7 +123,7 @@ end @test isequal( frule((ZeroTangent(), 1.2), NonDiffExample, 2.0), - (NonDiffExample(2.0), NoTangent()) + (NonDiffExample(2.0), NoTangent()), ) res, pullback = rrule(NonDiffExample, 2.0) @@ -151,7 +153,7 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), NoTangent()) @test frule((1, 1), fvarargs, 1, 2) == nothing - @test rrule(fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @@ -194,10 +196,10 @@ end end @testset "Functors" begin - (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] + (f::NonDiffExample)(y) = fill(7.5, 100)[f.x+y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x=1.2), 2.3), NonDiffExample(3), 2) == - (7.5, NoTangent()) + @test frule((Tangent{NonDiffExample}(x = 1.2), 2.3), NonDiffExample(3), 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent()) @@ -205,8 +207,12 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) - @test frule((ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2) == - (7.5, NoTangent()) + @test frule( + (ZeroTangent(), 1.2, 2.3), + NonDiffModuleExample.nondiff_2_1, + 3, + 2, + ) == (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) @@ -216,7 +222,7 @@ end # Where clauses are not supported. @test_macro_throws( ErrorException, - (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) + (@non_differentiable where_identity(::Vector{T}) where {T<:AbstractString}) ) end end @@ -224,32 +230,33 @@ end @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) - @scalar_rule(simo(x), 1f0, 2f0) + @scalar_rule(simo(x), 1.0f0, 2.0f0) y, simo_pb = rrule(simo, π) - @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) + @test simo_pb((10.0f0, 20.0f0)) == (NoTangent(), 50.0f0) - y, ẏ = frule((NoTangent(), 50f0), simo, π) + y, ẏ = frule((NoTangent(), 50.0f0), simo, π) @test y == (π, 2π) - @test ẏ == Tangent{typeof(y)}(50f0, 100f0) + @test ẏ == Tangent{typeof(y)}(50.0f0, 100.0f0) # make sure type is exactly as expected: - @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} + @test ẏ isa Tangent{Tuple{Irrational{:π},Float64},Tuple{Float32,Float32}} xs, Ω = (3,), (3, 6) - @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == + ((1.0f0,), (2.0f0,)) end @testset "@scalar_rule projection" begin - make_imaginary(x) = im*x + make_imaginary(x) = im * x @scalar_rule make_imaginary(x) im # note: the === will make sure that these are Float64, not ComplexF64 - @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im) + @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0 * im) @test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0) - @test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im) - @test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) + @test (NoTangent(), 1.0 + 0.0im) === rrule(make_imaginary, 2.0im)[2](1.0 * im) + @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end @testset "Regression tests against #276 and #265" begin @@ -257,16 +264,16 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 # Symptom of these problems is creation of global variables and type instability - num_globals_before = length(names(ChainRulesCore; all=true)) + num_globals_before = length(names(ChainRulesCore; all = true)) simo2(x) = (x, 2x) @scalar_rule(simo2(x), 1.0, 2.0) _, simo2_pb = rrule(simo2, 43.0) # make sure it infers: inferability implies type stability - @inferred simo2_pb(Tangent{Tuple{Float64, Float64}}(3.0, 6.0)) + @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) # Test no new globals were created - @test length(names(ChainRulesCore; all=true)) == num_globals_before + @test length(names(ChainRulesCore; all = true)) == num_globals_before # Example in #265 simo3(x) = sincos(x) @@ -279,60 +286,60 @@ end module IsolatedModuleForTestingScoping - # check that rules can be defined by macros without any additional imports - using ChainRulesCore: @scalar_rule, @non_differentiable - - # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved - const ChainRulesCore = nothing - - # this is - # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 - fixed(x) = :abc - @non_differentiable fixed(x) - - # check name collision between a primal input called `kwargs` and the actual keyword - # arguments - fixed_kwargs(x; kwargs...) = :abc - @non_differentiable fixed_kwargs(kwargs) - - my_id(x) = x - @scalar_rule(my_id(x), 1.0) - - module IsolatedSubmodule - # check that rules defined in isolated module without imports can be called - # without errors - using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output - using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id - using Test - - @testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((ZeroTangent(), randn()), f, randn()) - @test y === :abc - @test ẏ === NoTangent() - - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end +# check that rules can be defined by macros without any additional imports +using ChainRulesCore: @scalar_rule, @non_differentiable + +# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved +const ChainRulesCore = nothing + +# this is +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 +fixed(x) = :abc +@non_differentiable fixed(x) + +# check name collision between a primal input called `kwargs` and the actual keyword +# arguments +fixed_kwargs(x; kwargs...) = :abc +@non_differentiable fixed_kwargs(kwargs) + +my_id(x) = x +@scalar_rule(my_id(x), 1.0) + +module IsolatedSubmodule +# check that rules defined in isolated module without imports can be called +# without errors +using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output +using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id +using Test + +@testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((ZeroTangent(), randn()), f, randn()) + @test y === :abc + @test ẏ === NoTangent() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword = randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) +end - @testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ +@testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (NoTangent(), Δy) + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (NoTangent(), Δy) - @test derivatives_given_output(y, my_id, x) == ((1.0,),) - end - end + @test derivatives_given_output(y, my_id, x) == ((1.0,),) +end +end end diff --git a/test/rules.jl b/test/rules.jl index d43ca42d2..267b23005 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,8 +28,11 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Any, Any, Vararg}, - ::typeof(mixed_vararg), x, y, z..., + dargs::Tuple{Any,Any,Any,Vararg}, + ::typeof(mixed_vararg), + x, + y, + z..., ) Δx = dargs[2] Δy = dargs[3] @@ -39,16 +42,21 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any, Int, Float64}, - ::typeof(type_constraints), x::Int, y::Float64, + (_, Δx, Δy)::Tuple{Any,Int,Float64}, + ::typeof(type_constraints), + x::Int, + y::Float64, ) return type_constraints(x, y), Δx + Δy end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, - ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, + dargs::Tuple{Any,Float64,Real,Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), + x::Float64, + y::Real, + z::Vararg{Float64}, ) Δx = dargs[2] Δy = dargs[3] @@ -65,9 +73,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = ZeroTangent() @test frule((dself, 1), cool, 1) === nothing - @test frule((dself, 1), cool, 1; iscool=true) === nothing + @test frule((dself, 1), cool, 1; iscool = true) === nothing @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool=true) === nothing + @test rrule(cool, 1; iscool = true) === nothing # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @@ -76,8 +84,10 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}]) + only_methods = Set([ + Tuple{typeof(rrule),typeof(cool),Number}, + Tuple{typeof(rrule),typeof(cool),String}, + ]) @test cool_methods == only_methods frx, cool_pushforward = frule((dself, 1), cool, 1) @@ -98,21 +108,26 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) - @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == + (10.0, 10.0) @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing - @test(frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, - ) == (6.0, 6.0)) + @test( + frule( + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, + 3.0, + 2.0, + 1.0, + 0.0, + ) == (6.0, 6.0) + ) # violates type constraints, thus an frule should not be found. - @test frule( - (nothing, 3, 2.0, 1.0, 5.0), - mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, - ) == nothing + @test frule((nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0) == + nothing @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) @@ -153,27 +168,29 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) - @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where {T<:Float32} @opt_out( - ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float32} ) @testset "rrule" begin @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) - @test rrule(first_oa, 3f0, 4f0) === nothing + @test rrule(first_oa, 3.0f0, 4.0f0) === nothing @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} end) end @testset "frule" begin - @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) - @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing - - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m - m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 - end) + @test frule((NoTangent(), 1, 0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1, 0), first_oa, 3.0f0, 4.0f0) === nothing + + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end end end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7e0ec9398..fdbb92f55 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -9,7 +9,7 @@ @test view(NoTangent(), 1, 2) == NoTangent() @test sum(ZeroTangent()) == ZeroTangent() - @test sum(NoTangent(); dims=2) == NoTangent() + @test sum(NoTangent(); dims = 2) == NoTangent() end @testset "ZeroTangent" begin @@ -55,7 +55,7 @@ @test muladd(x, ZeroTangent(), ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), x, ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), ZeroTangent(), ZeroTangent()) === ZeroTangent() - + @test reim(z) === (ZeroTangent(), ZeroTangent()) @test real(z) === ZeroTangent() @test imag(z) === ZeroTangent() diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 2fd337979..2b7c6347e 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -1,10 +1,14 @@ @testset "NotImplemented" begin @testset "NotImplemented" begin ni = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" + @__MODULE__, + LineNumberNode(@__LINE__, @__FILE__), + "error", ) ni2 = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" + @__MODULE__, + LineNumberNode(@__LINE__, @__FILE__), + "error2", ) x = rand() thunk = @thunk(x^2) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 694e43b53..cc24d988e 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -1,6 +1,6 @@ # For testing Tangent struct Foo - x + x::Any y::Float64 end @@ -12,81 +12,81 @@ end # For testing Tangent: it is an invarient of the type that x2 = 2x # so simple addition can not be defined struct StructWithInvariant - x - x2 + x::Any + x2::Any StructWithInvariant(x) = new(x, 2x) end @testset "Tangent" begin @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{}, Tuple{}} + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} end @testset "==" begin - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(x=0.1, y=2.5) - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(y=2.5, x=0.1) - @test Tangent{Foo}(y=2.5, x=ZeroTangent()) == Tangent{Foo}(y=2.5) + @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(x = 0.1, y = 2.5) + @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(y = 2.5, x = 0.1) + @test Tangent{Foo}(y = 2.5, x = ZeroTangent()) == Tangent{Foo}(y = 2.5) - @test Tangent{Tuple{Float64,}}(2.0) == Tangent{Tuple{Float64,}}(2.0) + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2*1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(;y=2.0,) == Tangent{Foo}(;x=ZeroTangent(), y=Float32(2.0),) + @test Tangent{Foo}(; y = 2.0) == Tangent{Foo}(; x = ZeroTangent(), y = Float32(2.0)) end @testset "hash" begin - @test hash(Tangent{Foo}(x=0.1, y=2.5)) == hash(Tangent{Foo}(y=2.5, x=0.1)) - @test hash(Tangent{Foo}(y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(y=2.5)) + @test hash(Tangent{Foo}(x = 0.1, y = 2.5)) == hash(Tangent{Foo}(y = 2.5, x = 0.1)) + @test hash(Tangent{Foo}(y = 2.5, x = ZeroTangent())) == hash(Tangent{Foo}(y = 2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(x=2.5), :x) == true + @test keys(Tangent{Foo}(x = 2.5)) == (:x,) + @test propertynames(Tangent{Foo}(x = 2.5)) == (:x,) + @test haskey(Tangent{Foo}(x = 2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x=2.5), :y) == false + @test hasproperty(Tangent{Foo}(x = 2.5), :y) == false end - @test Tangent{Foo}(x=2.5).x == 2.5 - - @test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - - NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}} - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 - - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 + @test Tangent{Foo}(x = 2.5).x == 2.5 + + @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x=2.5)) == 1 - @test length(Tangent{Tuple{Float64,}}(2.0)) == 1 + @test length(Tangent{Foo}(x = 2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64,}}(2.0)) == Float64 + @test eltype(Tangent{Foo}(x = 2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64,}}(2.0)) == [2.0] + @test collect(Tangent{Foo}(x = 2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] # Test indexed_iterate ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function(tangent) + _unpack2tuple = function (tangent) a, b = tangent return (a, b) end @@ -96,33 +96,33 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(x = 2, y = 3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(y = 3.0)) end end @testset "reverse" begin - c = Tangent{Tuple{Int, Int, String}}(1, 2, "something") - cr = Tangent{Tuple{String, Int, Int}}("something", 2, 1) + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(;x=1.0, y=2.0)) + @test_throws MethodError reverse(Tangent{Foo}(; x = 1.0, y = 2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo, typeof(d)}(d) + cdict = Tangent{Foo,typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end @testset "unset properties" begin - @test Tangent{Foo}(; x=1.4).y === ZeroTangent() + @test Tangent{Foo}(; x = 1.4).y === ZeroTangent() end @testset "conj" begin - @test conj(Tangent{Foo}(x=2.0+3.0im)) == Tangent{Foo}(x=2.0-3.0im) + @test conj(Tangent{Foo}(x = 2.0 + 3.0im)) == Tangent{Foo}(x = 2.0 - 3.0im) @test ==( - conj(Tangent{Tuple{Float64,}}(2.0+3.0im)), - Tangent{Tuple{Float64,}}(2.0-3.0im) + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), + Tangent{Tuple{Float64}}(2.0 - 3.0im), ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -132,26 +132,20 @@ end @testset "canonicalize" begin # Testing iterate via collect - @test ==( - canonicalize(Tangent{Tuple{Float64,}}(2.0)), - Tangent{Tuple{Float64,}}(2.0) - ) + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - @test ==( - canonicalize(Tangent{Dict}(Dict(4 => 3))), - Tangent{Dict}(Dict(4 => 3)), - ) + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x=2.5, y=10)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10, x=2.5)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10)) == CFoo(x=ZeroTangent(), y=10) + @test canonicalize(CFoo(x = 2.5, y = 10)) == CFoo(x = 2.5, y = 10) + @test canonicalize(CFoo(y = 10, x = 2.5)) == CFoo(x = 2.5, y = 10) + @test canonicalize(CFoo(y = 10)) == CFoo(x = ZeroTangent(), y = 10) - @test_throws ArgumentError canonicalize(CFoo(q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(q = 99.0, x = 2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(;a=1, b=2) + c1 = Tangent{Any}(; a = 1, b = 2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -164,30 +158,28 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x=1.5) + CFoo(x=2.5) == CFoo(x=4.0) - @test CFoo(y=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=2.5) - @test CFoo(y=1.5, x=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=4.0) + @test CFoo(x = 1.5) + CFoo(x = 2.5) == CFoo(x = 4.0) + @test CFoo(y = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 2.5) + @test CFoo(y = 1.5, x = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 4.0) end @testset "Tuples" begin @test ==( typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{}, Tuple{}} + Tangent{Tuple{},Tuple{}}, ) @test ( - Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64, Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64, Float64}}(2.0, 3.0) + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) end @testset "NamedTuples" begin - nt1 = (;a=1.5, b=0.0) - nt2 = (;a=0.0, b=2.5) - nt_sum = (a=1.5, b=2.5) - @test ( - Tangent{typeof(nt1)}(; nt1...) + - Tangent{typeof(nt2)}(; nt2...) - ) == Tangent{typeof(nt_sum)}(; nt_sum...) + nt1 = (; a = 1.5, b = 0.0) + nt2 = (; a = 0.0, b = 2.5) + nt_sum = (a = 1.5, b = 2.5) + @test (Tangent{typeof(nt1)}(; nt1...) + Tangent{typeof(nt2)}(; nt2...)) == + Tangent{typeof(nt_sum)}(; nt_sum...) end @testset "Dicts" begin @@ -199,8 +191,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x=1.5) - b = CFoo(x=@not_implemented("")) + a = CFoo(x = 1.5) + b = CFoo(x = @not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -215,8 +207,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x=1.5) - b = Tangent{NamedTuple{(:x,)}}(x=@not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(x = 1.5) + b = Tangent{NamedTuple{(:x,)}}(x = @not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -235,35 +227,35 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 + @test Foo(3.5, 1.5) + Tangent{Foo}(x = 2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(x = 2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x = 0.5)) == 0 end @testset "Tuples" begin @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end @testset "NamedTuple" begin - ntx = (; a=1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + ntx = (; a = 1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a = 3.0) - nty = (; a=1.5, b=0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) + nty = (; a = 1.5, b = 0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a = 3.0, b = 1.0) end @testset "Dicts" begin d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 =>5.0)) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x=2.0, x2=6.0) + diff = Tangent{StructWithInvariant}(x = 2.0, x2 = 6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -280,7 +272,7 @@ end # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2/2)/2 + x = (nt.x + nt.x2 / 2) / 2 return StructWithInvariant(x) end @test value + diff == StructWithInvariant(12.5) @@ -288,7 +280,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y=1.5, x=2.5) + c = Tangent{Foo}(y = 1.5, x = 2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -310,14 +302,14 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y=1.5, x=2.5) - == Tangent{Foo}(y=3.0, x=5.0) - == Tangent{Foo}(y=1.5, x=2.5) * 2 + 2 * Tangent{Foo}(y = 1.5, x = 2.5) == + Tangent{Foo}(y = 3.0, x = 5.0) == + Tangent{Foo}(y = 1.5, x = 2.5) * 2 ) @test ( - 2 * Tangent{Tuple{Float64, Float64}}(2.0, 4.0) - == Tangent{Tuple{Float64, Float64}}(4.0, 8.0) - == Tangent{Tuple{Float64, Float64}}(2.0, 4.0) * 2 + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 ) d = Tangent{Dict}(Dict(4 => 3.0)) two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) @@ -325,7 +317,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x=1,)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(x = 1)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -342,8 +334,9 @@ end end @testset "Internals don't allocate a ton" begin - bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + bk = (; x = 1.0, y = 2.0) + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 # weaker version of the above (which should pass on all versions) @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 @@ -352,8 +345,8 @@ end end @testset "non-same-typed differential arithmetic" begin - nt = (; a=1, b=2.0) - c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1); + nt = (; a = 1, b = 2.0) + c = Tangent{typeof(nt)}(; a = NoTangent(), b = 0.1) + @test nt + c == (; a = 1, b = 2.1) end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 89461caa1..af4a747d1 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -141,7 +141,7 @@ # Check against accidential type piracy # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472 @test Base.which(diagm, Tuple{}()).module != ChainRulesCore - @test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore + @test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore end @test tril(a) == tril(t) @test tril(a, 1) == tril(t, 1) From f20585edacd804b084db67cbd6affddeb355fd9d Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 13:53:32 +0300 Subject: [PATCH 08/20] manual format cleanup --- src/rule_definition_tools.jl | 3 +-- test/tangent_types/tangent.jl | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 8a1e1cce4..67818fb71 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -583,8 +583,7 @@ function _split_primal_name(primal_name) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name - # e.g. (::T)(x, y) - elseif Meta.isexpr(primal_name, :(::)) + elseif Meta.isexpr(primal_name, :(::)) # e.g. (::T)(x, y) _primal_name = gensym(Symbol(:instance_, primal_name.args[end])) primal_name_sig = Expr(:(::), _primal_name, primal_name.args[end]) return primal_name_sig, _primal_name diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index cc24d988e..5aa16c9f1 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -175,11 +175,11 @@ end end @testset "NamedTuples" begin - nt1 = (; a = 1.5, b = 0.0) - nt2 = (; a = 0.0, b = 2.5) - nt_sum = (a = 1.5, b = 2.5) - @test (Tangent{typeof(nt1)}(; nt1...) + Tangent{typeof(nt2)}(; nt2...)) == - Tangent{typeof(nt_sum)}(; nt_sum...) + NTTangent(nt) = Tangent{typeof(nt)}(; nt...) + t1 = NTTangent((; a = 1.5, b = 0.0)) + t2 = NTTangent((; a = 0.0, b = 2.5)) + t_sum = NTTangent((a = 1.5, b = 2.5)) + @test t1 + t2 == t_sum end @testset "Dicts" begin From 2dd87fb6d34d7ddf4c22e8f4c11641ad178a35ca Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 7 Oct 2021 14:06:37 +0300 Subject: [PATCH 09/20] Update test/tangent_types/tangent.jl --- test/tangent_types/tangent.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 5aa16c9f1..f3ff1085a 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -175,10 +175,10 @@ end end @testset "NamedTuples" begin - NTTangent(nt) = Tangent{typeof(nt)}(; nt...) - t1 = NTTangent((; a = 1.5, b = 0.0)) - t2 = NTTangent((; a = 0.0, b = 2.5)) - t_sum = NTTangent((a = 1.5, b = 2.5)) + make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) + t1 = make_tangent((; a = 1.5, b = 0.0)) + t2 = make_tangent((; a = 0.0, b = 2.5)) + t_sum = make_tangent((a = 1.5, b = 2.5)) @test t1 + t2 == t_sum end From 8275a587ed27ff538940e2d4b483b31faaa230e2 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 19:24:19 +0300 Subject: [PATCH 10/20] Revert "format(".") with blue style from scratch" This reverts commit 8b54ef67146db7c1e085535d781063fef94ecbb0. --- docs/make.jl | 29 ++-- docs/src/assets/make_logo.jl | 48 +++---- src/accumulation.jl | 8 +- src/compat.jl | 4 +- src/deprecated.jl | 1 - src/ignore_derivatives.jl | 8 +- src/projection.jl | 85 ++++++------ src/rule_definition_tools.jl | 82 ++++------- src/tangent_arithmetic.jl | 14 +- src/tangent_types/abstract_zero.jl | 8 +- src/tangent_types/notimplemented.jl | 10 +- src/tangent_types/tangent.jl | 113 ++++++++------- src/tangent_types/thunks.jl | 16 +-- test/accumulation.jl | 24 ++-- test/config.jl | 76 ++++------- test/deprecated.jl | 1 - test/ignore_derivatives.jl | 8 +- test/projection.jl | 40 +++--- test/rule_definition_tools.jl | 175 ++++++++++++------------ test/rules.jl | 75 ++++------ test/tangent_types/abstract_zero.jl | 4 +- test/tangent_types/notimplemented.jl | 8 +- test/tangent_types/tangent.jl | 197 ++++++++++++++------------- test/tangent_types/thunks.jl | 2 +- 24 files changed, 475 insertions(+), 561 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 608422c25..1ef3a62a7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,20 +16,20 @@ DocMeta.setdocmeta!( @scalar_rule(sin(x), cos(x)) # frule and rrule doctest @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end, + end ) indigo = DocThemeIndigo.install(ChainRulesCore) makedocs( - modules = [ChainRulesCore], - format = Documenter.HTML( - prettyurls = false, - assets = [indigo], - mathengine = MathJax3( + modules=[ChainRulesCore], + format=Documenter.HTML( + prettyurls=false, + assets=[indigo], + mathengine=MathJax3( Dict( :tex => Dict( - "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], + "inlineMath" => [["\$","\$"], ["\\(","\\)"]], "tags" => "ams", # TODO: remove when using physics package "macros" => Dict( @@ -42,9 +42,9 @@ makedocs( ), ), ), - sitename = "ChainRules", - authors = "Jarrett Revels and other contributors", - pages = [ + sitename="ChainRules", + authors="Jarrett Revels and other contributors", + pages=[ "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", @@ -63,8 +63,11 @@ makedocs( ], "API" => "api.md", ], - strict = true, - checkdocs = :exports, + strict=true, + checkdocs=:exports, ) -deploydocs(repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview = true) +deploydocs( + repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", + push_preview=true, +) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 3e7aeaa08..5bbfd36c1 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -8,34 +8,34 @@ using Random const bridge_len = 50 -function chain(jiggle = 0) - shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) - +function chain(jiggle=0) + shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) + ### 1 shaky_rotate(0) sethue(Luxor.julia_red) link() m1 = getmatrix() - - + + ### 2 sethue(Luxor.julia_green) - translate(-50, 130) - shaky_rotate(π / 3) + translate(-50, 130); + shaky_rotate(π/3); link() m2 = getmatrix() - + setmatrix(m1) sethue(Luxor.julia_red) overlap(-1.3π) setmatrix(m2) - + ### 3 - shaky_rotate(-π / 3) - translate(-120, 80) + shaky_rotate(-π/3); + translate(-120,80); sethue(Luxor.julia_purple) link() - + setmatrix(m2) setcolor(Luxor.julia_green) overlap(-1.5π) @@ -45,24 +45,24 @@ end function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - - - rect(50, -3, 40, bridge_len + 6, :fill) - rect(-50 - 40, -3, 40, bridge_len + 6, :fill) - + + + rect(50,-3,40, bridge_len+6, :fill) + rect(-50-40,-3,40, bridge_len+6, :fill) + sethue("black") move(Point(-50, bridge_len)) - arc(Point(0, 0), 50, π, 0, :stoke) + arc(Point(0,0), 50, π, 0, :stoke) arc(Point(0, bridge_len), 50, 0, -π, :stroke) - + move(Point(-90, bridge_len)) - arc(Point(0, 0), 90, π, 0, :stoke) + arc(Point(0,0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) strokepath() end function overlap(ang_end) - sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill) + sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) sethue("black") arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) move(Point(90, bridge_len)) @@ -75,13 +75,13 @@ end function save_logo(filename) Random.seed!(16) - Drawing(450, 450, filename) + Drawing(450,450, filename) origin() - translate(50, -130) + translate(50, -130); chain(0.5) finish() preview() end save_logo("logo.svg") -save_logo("logo.png") +save_logo("logo.png") \ No newline at end of file diff --git a/src/accumulation.jl b/src/accumulation.jl index 5fbc07fa8..4bcc5c33f 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,7 +26,7 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) -function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} +function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N return if is_inplaceable_destination(x) x .+= y else @@ -75,8 +75,8 @@ end struct BadInplaceException <: Exception ithunk::InplaceableThunk - accumuland::Any - returned_value::Any + accumuland + returned_value end function Base.showerror(io::IO, err::BadInplaceException) @@ -88,7 +88,7 @@ function Base.showerror(io::IO, err::BadInplaceException) if err.accumuland == err.returned_value println( io, - "Which in this case happenned to be equal. But they are not the same object.", + "Which in this case happenned to be equal. But they are not the same object." ) end end diff --git a/src/compat.jl b/src/compat.jl index fa66b1d0f..8204b66d5 100644 --- a/src/compat.jl +++ b/src/compat.jl @@ -5,7 +5,7 @@ end if VERSION < v"1.1" # Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4 # See: https://github.com/JuliaLang/julia/issues/34292 - function fieldtypes(::Type{T}) where {T} + function fieldtypes(::Type{T}) where T if @generated ntuple(i -> fieldtype(T, i), fieldcount(T)) else @@ -13,7 +13,7 @@ if VERSION < v"1.1" end end - function fieldnames(::Type{T}) where {T} + function fieldnames(::Type{T}) where T if @generated ntuple(i -> fieldname(T, i), fieldcount(T)) else diff --git a/src/deprecated.jl b/src/deprecated.jl index 8b1378917..e69de29bb 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1 +0,0 @@ - diff --git a/src/ignore_derivatives.jl b/src/ignore_derivatives.jl index 18865f2c9..c66d89d7e 100644 --- a/src/ignore_derivatives.jl +++ b/src/ignore_derivatives.jl @@ -45,9 +45,7 @@ ignore_derivatives(x) = x Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`. """ macro ignore_derivatives(ex) - return :( - ChainRulesCore.ignore_derivatives() do - $(esc(ex)) - end - ) + return :(ChainRulesCore.ignore_derivatives() do + $(esc(ex)) + end) end diff --git a/src/projection.jl b/src/projection.jl index 55f6e7bfd..4b07b2762 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -131,8 +131,7 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} = - ProjectTo{NoTangent}() +ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}() # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -165,14 +164,12 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) end # In these cases we can just `convert` as we know we are dealing with plain and simple types -(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) -(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity +(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx) +(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = - convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = - convert(T, dx) +(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -193,7 +190,7 @@ end # For arrays of numbers, just store one projector: function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element = _eltype_projectto(T), axes = axes(x)) + return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -207,7 +204,7 @@ function ProjectTo(xs::AbstractArray) return ProjectTo{NoTangent}() # short-circuit if all elements project to zero else # Arrays of arrays come here, and will apply projectors individually: - return ProjectTo{AbstractArray}(; elements = elements, axes = axes(xs)) + return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs)) end end @@ -217,7 +214,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} dy = if axes(dx) == project.axes dx else - for d = 1:max(M, length(project.axes)) + for d in 1:max(M, length(project.axes)) if size(dx, d) != length(get(project.axes, d, 1)) throw(_projection_mismatch(project.axes, size(dx))) end @@ -247,11 +244,9 @@ end # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers if !(project.axes isa Tuple{}) - throw( - DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - ), - ) + throw(DimensionMismatch( + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", + )) end return fill(project.element(dx)) end @@ -259,7 +254,7 @@ end function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) return DimensionMismatch( - "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx", + "variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx" ) end @@ -273,13 +268,13 @@ function ProjectTo(x::Ref) if sub isa ProjectTo{<:AbstractZero} return ProjectTo{NoTangent}() else - return ProjectTo{Ref}(; type = typeof(x), x = sub) + return ProjectTo{Ref}(; type=typeof(x), x=sub) end end -(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x = project.x(dx.x)) -(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x = project.x(dx[])) +(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) +(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) # Since this works like a zero-array in broadcasting, it should also accept a number: -(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x = project.x(dx)) +(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) ##### ##### `LinearAlgebra` @@ -288,7 +283,7 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec # Row vectors -ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent = ProjectTo(parent(x))) +ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x))) # Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec. # Transposed matrices are, like PermutedDimsArray, just a storage detail, # but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number @@ -303,8 +298,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = - ProjectTo{Transpose}(; parent = ProjectTo(parent(x))) +ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -317,22 +311,21 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) end # Diagonal -ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag = ProjectTo(x.diag)) +ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) # Symmetric -for (SymHerm, chk, fun) in - ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) +for (SymHerm, chk, fun) in ( + (:Symmetric, :issymmetric, :transpose), + (:Hermitian, :ishermitian, :adjoint), + ) @eval begin function ProjectTo(x::$SymHerm) sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; - uplo = LinearAlgebra.sym_uplo(x.uplo), - parent = sub, - ) + return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) dy = project.parent(dx) @@ -345,8 +338,9 @@ for (SymHerm, chk, fun) in # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) sub = project.parent # this is going to be unhappy about the size - sub_one = - ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) + sub_one = ProjectTo{project_type(sub)}(; + element=sub.element, axes=(sub.axes[1],) + ) return Diagonal(sub_one(dx.diag)) end end @@ -355,12 +349,13 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent = ProjectTo(parent(x))) + ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent - sub_one = - ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],)) + sub_one = ProjectTo{project_type(sub)}(; + element=sub.element, axes=(sub.axes[1],) + ) return Diagonal(sub_one(dx.diag)) end end @@ -397,7 +392,7 @@ end # another strategy is just to use the AbstractArray method function ProjectTo(x::Tridiagonal{T}) where {T<:Number} notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent = notparent) + return ProjectTo{Tridiagonal}(; notparent=notparent) end function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) dy = project.notparent(dx) @@ -416,9 +411,7 @@ using SparseArrays function ProjectTo(x::SparseVector{T}) where {T<:Number} return ProjectTo{SparseVector}(; - element = ProjectTo(zero(T)), - nzind = x.nzind, - axes = axes(x), + element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) ) end function (project::ProjectTo{SparseVector})(dx::AbstractArray) @@ -457,11 +450,11 @@ end function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} return ProjectTo{SparseMatrixCSC}(; - element = ProjectTo(zero(T)), - axes = axes(x), - rowval = rowvals(x), - nzranges = nzrange.(Ref(x), axes(x, 2)), - colptr = x.colptr, + element=ProjectTo(zero(T)), + axes=axes(x), + rowval=rowvals(x), + nzranges=nzrange.(Ref(x), axes(x, 2)), + colptr=x.colptr, ) end # You need not really store nzranges, you can get them from colptr -- TODO @@ -481,7 +474,7 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) for i in project.nzranges[col] row = project.rowval[i] val = dy[row, col] - nzval[k+=1] = project.element(val) + nzval[k += 1] = project.element(val) end end m, n = map(length, project.axes) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 67818fb71..6f217a3ac 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -83,8 +83,9 @@ For examples, see ChainRules' `rulesets` directory. See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) - call, setup_stmts, inputs, partials = - _normalize_scalarrules_macro_input(call, maybe_setup, partials) + call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( + call, maybe_setup, partials + ) f = call.args[1] # Generate variables to store derivatives named dfi/dxj @@ -100,11 +101,9 @@ macro scalar_rule(call, maybe_setup, partials...) # Final return: building the expression to insert in the place of this macro code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw( - ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))", - ), - ) + throw(ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))" + )) end $(derivative_expr) @@ -176,11 +175,7 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote - function ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), - ::Core.Typeof($f), - $(inputs...), - ) + function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) @@ -201,8 +196,9 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) end if n_outputs > 1 # For forward-mode we return a Tangent if output actually a tuple. - pushforward_returns = - Expr(:call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns...) + pushforward_returns = Expr( + :call, :(Tangent{typeof($(esc(:Ω)))}), pushforward_returns... + ) else pushforward_returns = first(pushforward_returns) end @@ -214,8 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = - ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -230,7 +225,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) Δs = _propagator_inputs(n_outputs) # Make a projector for each argument - projs, psetup = _make_projectors(call.args[2:end]) + projs, psetup = _make_projectors(call.args[2:end]) append!(setup_stmts, psetup) # 1 partial derivative per input @@ -253,8 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = - ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -263,12 +257,12 @@ end # For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" -_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i = 1:n] +_propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) projs = map(x -> Symbol(:proj_, x.args[1]), xs) - setups = map((x, p) -> :($p = ProjectTo($x)), xs, projs) + setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs) return projs, setups end @@ -281,7 +275,7 @@ Specify `_conj = true` to conjugate the partials. Projector `proj` is a function that will be applied at the end; for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity` """ -function propagation_expr(Δs, ∂s, _conj = false, proj = identity) +function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj @@ -294,10 +288,9 @@ function propagation_expr(Δs, ∂s, _conj = false, proj = identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = - foldl(Iterators.drop(zip(_∂s, Δs), 1); init = init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) - end + summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) + :(muladd($∂s_i, $Δs_i, $ex)) + end return :($proj($summed_∂_mul_Δs)) end @@ -388,10 +381,7 @@ end function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, - call_expr.args[1], - Expr(:parameters, :($(kwargs)...)), - call_expr.args[2:end]..., + :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... ) end @@ -399,18 +389,11 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.frule)))( - @nospecialize($kwargs::Any), - frule::typeof(ChainRulesCore.frule), - @nospecialize(::Any), - $(map(esc, primal_sig_parts)...), - ) + function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end - function ChainRulesCore.frule( - @nospecialize(::Any), - $(map(esc, primal_sig_parts)...), - ) + function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() return ($(esc(primal_invoke)), NoTangent()) @@ -425,8 +408,7 @@ function tuple_expression(primal_sig_parts) Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = - :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) end end @@ -444,11 +426,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))( - $(esc(kwargs))::Any, - ::typeof(rrule), - $(esc_primal_sig_parts...), - ) + function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) @@ -503,7 +481,7 @@ end "Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) - length(expr.args) === 0 && error("Malformed method expression. $expr") + length(expr.args)===0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore @@ -577,9 +555,8 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || - Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) + if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || + Meta.isexpr(primal_name, :curly) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name @@ -604,8 +581,7 @@ end function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && - return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) + Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index c2bad7a77..9c1378aab 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -81,7 +81,7 @@ LinearAlgebra.dot(::ZeroTangent, ::NoTangent) = ZeroTangent() Base.muladd(::ZeroTangent, x, y) = y Base.muladd(x, ::ZeroTangent, y) = y -Base.muladd(x, y, ::ZeroTangent) = x * y +Base.muladd(x, y, ::ZeroTangent) = x*y Base.muladd(::ZeroTangent, ::ZeroTangent, y) = y Base.muladd(x, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() @@ -125,11 +125,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} +function Base.:+(a::Tangent{P}, b::Tangent{P}) where P data = elementwise_add(backing(a), backing(b)) - return Tangent{P,typeof(data)}(data) + return Tangent{P, typeof(data)}(data) end -function Base.:+(a::P, d::Tangent{P}) where {P} +function Base.:+(a::P, d::Tangent{P}) where P net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -142,12 +142,12 @@ function Base.:+(a::P, d::Tangent{P}) where {P} end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where {P} = b + a +Base.:+(a::Tangent{P}, b::P) where P = b + a # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 differentials # Only of a differential and a scaling factor (generally `Real`) for T in (:Any,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) + @eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent) + @eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index c86fc78ea..216357e91 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -17,15 +17,15 @@ Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) -Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() +Base.Broadcast.broadcasted(::Type{T}) where T<:AbstractZero = T() # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z -Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) -(::Type{T})(xs::AbstractZero...) where {T<:Number} = zero(T) +Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) +(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) @@ -33,7 +33,7 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) Base.getindex(z::AbstractZero, k) = z Base.view(z::AbstractZero, ind...) = z -Base.sum(z::AbstractZero; dims = :) = z +Base.sum(z::AbstractZero; dims=:) = z """ ZeroTangent() <: AbstractZero diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index 7ceb315ea..a2044fbe1 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,13 +44,9 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw( - NotImplementedException( - @not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" - ) - ), -) +Base.zero(::Type{<:NotImplemented}) = throw(NotImplementedException(@not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" +))) Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 34e822ea8..e4bbfb8c8 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -21,42 +21,42 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P,T} <: AbstractTangent +struct Tangent{P, T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T end -function Tangent{P}(; kwargs...) where {P} +function Tangent{P}(; kwargs...) where P backing = (; kwargs...) # construct as NamedTuple - return Tangent{P,typeof(backing)}(backing) + return Tangent{P, typeof(backing)}(backing) end -function Tangent{P}(args...) where {P} - return Tangent{P,typeof(args)}(args) +function Tangent{P}(args...) where P + return Tangent{P, typeof(args)}(args) end -function Tangent{P}() where {P<:Tuple} +function Tangent{P}() where P<:Tuple backing = () - return Tangent{P,typeof(backing)}(backing) + return Tangent{P, typeof(backing)}(backing) end function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P,typeof(d)}(d) + return Tangent{P, typeof(d)}(d) end -function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} +function Base.:(==)(a::Tangent{P, T}, b::Tangent{P, T}) where {P, T} return backing(a) == backing(b) end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P,T} +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P, T} all_fields = union(keys(backing(a)), keys(backing(b))) return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where {P} +function Base.show(io::IO, tangent::Tangent{P}) where P print(io, "Tangent{") show(io, P) print(io, "}") @@ -68,15 +68,15 @@ function Base.show(io::IO, tangent::Tangent{P}) where {P} end end -function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} +function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} +function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end -function Base.getindex(tangent::Tangent, idx) where {P,T<:AbstractDict} +function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict} return unthunk(getindex(backing(tangent), idx)) end @@ -84,7 +84,7 @@ function Base.getproperty(tangent::Tangent, idx::Int) back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} +function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end @@ -99,26 +99,26 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) -Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) +Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) + Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing) end -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state = 1) where {P} +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end -function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} +function Base.map(f, tangent::Tangent{P, <:Tuple}) where P vals::Tuple = map(f, backing(tangent)) - return Tangent{P,typeof(vals)}(vals) + return Tangent{P, typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} +function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L} vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L,typeof(vals)}(vals) - return Tangent{P,typeof(named_vals)}(named_vals) + named_vals = NamedTuple{L, typeof(vals)}(vals) + return Tangent{P, typeof(named_vals)}(named_vals) end -function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} +function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end @@ -140,28 +140,26 @@ backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) # For generic structs -function backing(x::T)::NamedTuple where {T} +function backing(x::T)::NamedTuple where T # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. # see https://github.com/JuliaLang/julia/issues/34283 if @generated - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) + vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names, Tuple{$(types...)}}($vals)) else - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = ntuple(ii -> getfield(x, ii), nfields) - return NamedTuple{names,Tuple{types...}}(vals) + vals = ntuple(ii->getfield(x, ii), nfields) + return NamedTuple{names, Tuple{types...}}(vals) end end @@ -172,38 +170,36 @@ Return the canonical `Tangent` for the primal type `P`. The property names of the returned `Tangent` match the field names of the primal, and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. """ -function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} +function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L} nil = _zeroed_backing(P) combined = merge(nil, backing(tangent)) if length(combined) !== fieldcount(P) - throw( - ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", - ), - ) + throw(ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))" + )) end - return Tangent{P,typeof(combined)}(combined) + return Tangent{P, typeof(combined)}(combined) end # Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent +canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent # Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent +canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent # Tangents of unspecified primal types (indicated by specifying exactly `Any`) # all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any,<:Tuple}) where {L} = tangent -canonicalize(tangent::Tangent{Any,<:AbstractDict}) where {L} = tangent +canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent +canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent """ _zeroed_backing(P) Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. """ -@generated function _zeroed_backing(::Type{P}) where {P} +@generated function _zeroed_backing(::Type{P}) where P nil_base = ntuple(fieldcount(P)) do i (fieldname(P, i), ZeroTangent()) end @@ -222,7 +218,7 @@ after an operation such as the addition of a primal to a tangent It should be overloaded, if `T` does not have a default constructor, or if `T` needs to maintain some invarients between its fields. """ -function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} +function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} # Tested and verified that that this avoids a ton of allocations if length(L) !== fieldcount(T) # if length is equal but names differ then we will catch that below anyway. @@ -237,12 +233,12 @@ function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} end end -construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields -construct(::Type{T}, fields::T) where {T<:Tuple} = fields +construct(::Type{T}, fields::T) where T<:NamedTuple = fields +construct(::Type{T}, fields::T) where T<:Tuple = fields elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} # Rule of Tangent addition: any fields not present are implict hard Zeros # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. @@ -285,7 +281,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} end field => value end - return (; vals...) + return (;vals...) end end @@ -301,16 +297,15 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color = :blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) println(io, "\nor overload") - printstyled( - io, + printstyled(io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color = :blue, + color=:blue ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color = :blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) println(io, "\nOriginal Exception:") - printstyled(io, err.original; color = :yellow) + printstyled(io, err.original; color=:yellow) println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index c2b570902..16384d69e 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -56,22 +56,18 @@ LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a)) LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a)) LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a)) LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a)) -LinearAlgebra.Symmetric(a::AbstractThunk, uplo = :U) = Symmetric(unthunk(a), uplo) -LinearAlgebra.Hermitian(a::AbstractThunk, uplo = :U) = Hermitian(unthunk(a), uplo) +LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo) +LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo) function LinearAlgebra.diagm( - kv::Pair{<:Integer,<:AbstractThunk}, - kvs::Pair{<:Integer,<:AbstractThunk}..., + kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) return diagm((k => unthunk(v) for (k, v) in (kv, kvs...))...) end function LinearAlgebra.diagm( - m, - n, - kv::Pair{<:Integer,<:AbstractThunk}, - kvs::Pair{<:Integer,<:AbstractThunk}..., + m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) - return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) + return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a)) @@ -122,7 +118,7 @@ function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx) return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx) end -function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn = 1) +function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1) return throw(MutateThunkException()) end diff --git a/test/accumulation.jl b/test/accumulation.jl index a796b5289..1b41fea55 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -27,7 +27,7 @@ end @testset "misc AbstractTangent subtypes" begin - @test 16 == add!!(12, @thunk(2 * 2)) + @test 16 == add!!(12, @thunk(2*2)) @test 16 == add!!(16, ZeroTangent()) @test 16 == add!!(16, NoTangent()) # Should this be an error? @@ -37,15 +37,15 @@ @testset "LHS Array (inplace)" begin @testset "RHS Array" begin A = [1.0 2.0; 3.0 4.0] - accumuland = -1.0 * ones(2, 2) + accumuland = -1.0*ones(2,2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] end @testset "RHS StaticArray" begin - A = @SMatrix [1.0 2.0; 3.0 4.0] - accumuland = -1.0 * ones(2, 2) + A = @SMatrix[1.0 2.0; 3.0 4.0] + accumuland = -1.0*ones(2,2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] @@ -53,7 +53,7 @@ @testset "RHS Diagonal" begin A = Diagonal([1.0, 2.0]) - accumuland = -1.0 * ones(2, 2) + accumuland = -1.0*ones(2,2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 -1.0; -1.0 1.0] @@ -79,17 +79,17 @@ @testset "Unhappy Path" begin # wrong length - @test_throws DimensionMismatch add!!(ones(4, 4), ones(2, 2)) + @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) # wrong shape - @test_throws DimensionMismatch add!!(ones(4, 4), ones(16)) + @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) # wrong type (adding scalar to array) @test_throws MethodError add!!(ones(4), 21.0) end end @testset "AbstractThunk $(typeof(thunk))" for thunk in ( - @thunk(-1.0 * ones(2, 2)), - InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), + @thunk(-1.0*ones(2, 2)), + InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))), ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -111,12 +111,12 @@ @testset "not actually inplace but said it was" begin # thunk should never be used in this test ithunk = InplaceableThunk(@thunk(@assert false)) do x - 77 * ones(2, 2) # not actually inplace (also wrong) + 77*ones(2, 2) # not actually inplace (also wrong) end accumuland = ones(2, 2) @assert ChainRulesCore.debug_mode() == false # without debug being enabled should return the result, not error - @test 77 * ones(2, 2) == add!!(accumuland, ithunk) + @test 77*ones(2, 2) == add!!(accumuland, ithunk) ChainRulesCore.debug_mode() = true # enable debug mode # with debug being enabled should error @@ -127,7 +127,7 @@ @testset "showerror BadInplaceException" begin BadInplaceException = ChainRulesCore.BadInplaceException - ithunk = InplaceableThunk(x̄ -> nothing, @thunk(@assert false)) + ithunk = InplaceableThunk(x̄->nothing, @thunk(@assert false)) msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) @test occursin("22", msg) diff --git a/test/config.jl b/test/config.jl index e6e2ab005..466baed9a 100644 --- a/test/config.jl +++ b/test/config.jl @@ -1,7 +1,7 @@ # Define a bunch of configs for testing purposes struct MostBoringConfig <: RuleConfig{Union{}} end -struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} forward_calls::Vector end MockForwardsConfig() = MockForwardsConfig([]) @@ -11,7 +11,7 @@ function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args. return f(args...; kws...), ȧrgs end -struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} reverse_calls::Vector end MockReverseConfig() = MockReverseConfig([]) @@ -23,7 +23,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. end -struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} forward_calls::Vector reverse_calls::Vector end @@ -47,21 +47,18 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, - MockForwardsConfig, - MockReverseConfig, - MockBothConfig, + MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, ) counting_id_count = Ref(0) function counting_id(x) - counting_id_count[] += 1 + counting_id_count[]+=1 return x end function ChainRulesCore.rrule(::typeof(counting_id), x) counting_id_pullback(x̄) = x̄ return counting_id(x), counting_id_pullback end - function ChainRulesCore.frule((dself, dx), ::typeof(counting_id), x) + function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) return counting_id(x), dx end @testset "rrule" begin @@ -79,33 +76,21 @@ end @testset "hitting forwards AD" begin do_thing_2(f, x) = f(x) function ChainRulesCore.frule( - config::RuleConfig{>:HasForwardsMode}, - (_, df, dx), - ::typeof(do_thing_2), - f, - x, + config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x ) return frule_via_ad(config, (df, dx), f, x) end @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) @test nothing === frule( - Config(), - (NoTangent(), NoTangent(), 21.5), - do_thing_2, - identity, - 32.1, + Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig = Config() + bconfig= Config() @test nothing !== frule( - bconfig, - (NoTangent(), NoTangent(), 21.5), - do_thing_2, - identity, - 32.1, + bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) @test bconfig.forward_calls == [(identity, (32.1,))] end @@ -114,10 +99,7 @@ end @testset "hitting reverse AD" begin do_thing_3(f, x) = f(x) function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, - ::typeof(do_thing_3), - f, - x, + config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x ) return (NoTangent(), rrule_via_ad(config, f, x)...) end @@ -128,7 +110,7 @@ end end @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) - bconfig = Config() + bconfig= Config() @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) @test bconfig.reverse_calls == [(identity, (32.1,))] end @@ -148,14 +130,14 @@ end ẋ = one(x) y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) - pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ * ȳ + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ return y, pullback_via_forwards_ad end function ChainRulesCore.rrule( - config::RuleConfig{>:Union{HasReverseMode,NoForwardsMode}}, + config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, ::typeof(do_thing_4), f, - x, + x ) y, f_pullback = rrule_via_ad(config, f, x) do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) @@ -165,43 +147,43 @@ end @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig = Config() + bconfig= Config() @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) @test bconfig.forward_calls == [(identity, (32.1,))] end - rconfig = MockReverseConfig() + rconfig= MockReverseConfig() @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) @test rconfig.reverse_calls == [(identity, (32.1,))] end @testset "RuleConfig broadcasts like a scaler" begin - @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} + @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} end @testset "fallbacks" begin - no_rule(x; kw = "bye") = error() + no_rule(x; kw="bye") = error() @test frule((1.0,), no_rule, 2.0) === nothing - @test frule((1.0,), no_rule, 2.0; kw = "hello") === nothing + @test frule((1.0,), no_rule, 2.0; kw="hello") === nothing @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0) === nothing - @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw = "hello") === nothing + @test frule(MostBoringConfig(), (1.0,), no_rule, 2.0; kw="hello") === nothing @test rrule(no_rule, 2.0) === nothing - @test rrule(no_rule, 2.0; kw = "hello") === nothing + @test rrule(no_rule, 2.0; kw="hello") === nothing @test rrule(MostBoringConfig(), no_rule, 2.0) === nothing - @test rrule(MostBoringConfig(), no_rule, 2.0; kw = "hello") === nothing + @test rrule(MostBoringConfig(), no_rule, 2.0; kw="hello") === nothing # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(; kw = "hello") + @test_throws MethodError frule(;kw="hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin; kw = "hello") + @test_throws MethodError frule(sin;kw="hello") @test_throws MethodError frule(MostBoringConfig()) - @test_throws MethodError frule(MostBoringConfig(); kw = "hello") + @test_throws MethodError frule(MostBoringConfig(); kw="hello") @test_throws MethodError frule(MostBoringConfig(), sin) - @test_throws MethodError frule(MostBoringConfig(), sin; kw = "hello") + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(; kw = "hello") + @test_throws MethodError rrule(;kw="hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig(); kw = "hello") + @test_throws MethodError rrule(MostBoringConfig();kw="hello") end end diff --git a/test/deprecated.jl b/test/deprecated.jl index 8b1378917..e69de29bb 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -1 +0,0 @@ - diff --git a/test/ignore_derivatives.jl b/test/ignore_derivatives.jl index ad4fece9f..825287b9a 100644 --- a/test/ignore_derivatives.jl +++ b/test/ignore_derivatives.jl @@ -7,7 +7,7 @@ end @testset "function" begin f() = return 4.0 - y, ẏ = frule((1.0,), ignore_derivatives, f) + y, ẏ = frule((1.0, ), ignore_derivatives, f) @test y == f() @test ẏ == NoTangent() @@ -19,7 +19,7 @@ end @testset "argument" begin arg = 2.1 - y, ẏ = frule((1.0,), ignore_derivatives, arg) + y, ẏ = frule((1.0, ), ignore_derivatives, arg) @test y == arg @test ẏ == NoTangent() @@ -41,11 +41,11 @@ end @test pb(1.0) == (NoTangent(), NoTangent()) # when called - y, ẏ = frule((1.0,), ignore_derivatives, () -> mf(3.0)) + y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0)) @test y == mf(3.0) @test ẏ == NoTangent() - y, pb = rrule(ignore_derivatives, () -> mf(3.0)) + y, pb = rrule(ignore_derivatives, ()->mf(3.0)) @test y == mf(3.0) @test pb(1.0) == (NoTangent(), NoTangent()) end diff --git a/test/projection.jl b/test/projection.jl index ab418ef79..ba61fb8da 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,21 +24,20 @@ struct NoSuperType end # real / complex @test ProjectTo(1.0)(2.0 + 3im) === 2.0 @test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im - @test ProjectTo(2.0 + 3.0im)(1 + 1im) === 1.0 + 1.0im - @test ProjectTo(2.0)(1 + 1im) === 1.0 - + @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im + @test ProjectTo(2.0)(1+1im) === 1.0 + # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) - @test ProjectTo(1 // 2)(3 // 4) === 3 // 4 + @test ProjectTo(1//2)(3//4) === 3//4 @test ProjectTo(1.0f0)(1 / 2) === 0.5f0 @test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im @test ProjectTo(big(1.0))(2) === 2 @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re = 1, im = NoTangent())) === - 1.0f0 + 0.0f0im + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im end @testset "Dual" begin # some weird Real subtype that we should basically leave alone @@ -47,12 +46,13 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} - @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa - Complex{<:Dual} + @test ProjectTo(1.0 + 1im)( + Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)) + ) isa Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent - @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value = 1.0)) isa Tangent + @test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent end @testset "Base: arrays of numbers" begin @@ -99,10 +99,10 @@ struct NoSuperType end # arrays of other things @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} - @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} + @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. - @test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im) + @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) @test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent} # empty arrays @@ -172,7 +172,7 @@ struct NoSuperType end # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called - xs = adj(Any[Any[1, 2, 3], Any[4+im, 5-im, 6+im, 7-im]]) + xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) pvecvec3 = ProjectTo(xs) @test pvecvec3(xs)[1] == [1 2 3] @test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im]) @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step = 0.1, ref = NoTangent()) + dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()); @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3 + 4im) + dy = Tangent{typeof(Ref(2.0))}(x = 3+4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end @@ -365,21 +365,21 @@ struct NoSuperType end # Each "@test 33 > ..." is zero on nightly, 32 on 1.5. pvec = ProjectTo(rand(10^3)) - @test 0 == @ballocated $pvec(dx) setup = (dx = rand(10^3)) # pass through - @test 90 > @ballocated $pvec(dx) setup = (dx = rand(10^3, 1)) # reshape + @test 0 == @ballocated $pvec(dx) setup=(dx = rand(10^3)) # pass through + @test 90 > @ballocated $pvec(dx) setup=(dx = rand(10^3, 1)) # reshape @test 33 > @ballocated ProjectTo(x)(dx) setup = (x = rand(10^3); dx = rand(10^3)) # including construction padj = ProjectTo(adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup = (dx = adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup = (dx = transpose(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup=(dx = adjoint(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup=(dx = transpose(rand(10^3))) @test 33 > @ballocated ProjectTo(x')(dx') setup = (x = rand(10^3); dx = rand(10^3)) pdiag = ProjectTo(Diagonal(rand(10^3))) - @test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3))) + @test 0 == @ballocated $pdiag(dx) setup=(dx = Diagonal(rand(10^3))) psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) - @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 + @test_broken 0 == @ballocated $psymm(dx) setup=(dx = Symmetric(rand(10^3, 10^3))) # 64 end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index e99b66c2f..0d6d98535 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -19,7 +19,7 @@ macro test_macro_throws(err_expr, expr) end end # Reuse `@test_throws` logic - if err !== nothing + if err!==nothing @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) @@ -29,21 +29,21 @@ end # struct need to be defined outside of tests for julia 1.0 compat struct NonDiffExample - x::Any + x end struct NonDiffCounterExample - x::Any + x end module NonDiffModuleExample -nondiff_2_1(x, y) = fill(7.5, 100)[x+y] + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin - nondiff_2_1(x, y) = fill(7.5, 100)[x+y] + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((ZeroTangent(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, NoTangent()) res, pullback = rrule(nondiff_2_1, 3, 2) @@ -58,7 +58,7 @@ end res, pullback = rrule(nondiff_1_2, 3.1) @test res == (5.0, 3.0) @test isequal( - pullback(Tangent{Tuple{Float64,Float64}}(1.2, 3.2)), + pullback(Tangent{Tuple{Float64, Float64}}(1.2, 3.2)), (NoTangent(), NoTangent()), ) end @@ -81,8 +81,7 @@ end pointy_identity(x) = x @non_differentiable pointy_identity(::Vector{<:AbstractString}) - @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == - (["2"], NoTangent()) + @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == (["2"], NoTangent()) @test frule((ZeroTangent(), 1.2), pointy_identity, 2.0) == nothing res, pullback = rrule(pointy_identity, ["2"]) @@ -93,7 +92,7 @@ end end @testset "kwargs" begin - kw_demo(x; kw = 2.0) = x + kw + kw_demo(x; kw=2.0) = x + kw @non_differentiable kw_demo(::Any) @testset "not setting kw" begin @@ -107,14 +106,13 @@ end end @testset "setting kw" begin - @assert kw_demo(1.5; kw = 3.0) == 4.5 + @assert kw_demo(1.5; kw=3.0) == 4.5 - res, pullback = rrule(kw_demo, 1.5; kw = 3.0) + res, pullback = rrule(kw_demo, 1.5; kw=3.0) @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw = 3.0) == - (4.5, NoTangent()) + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) end end @@ -123,7 +121,7 @@ end @test isequal( frule((ZeroTangent(), 1.2), NonDiffExample, 2.0), - (NonDiffExample(2.0), NoTangent()), + (NonDiffExample(2.0), NoTangent()) ) res, pullback = rrule(NonDiffExample, 2.0) @@ -153,7 +151,7 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), NoTangent()) @test frule((1, 1), fvarargs, 1, 2) == nothing - @test rrule(fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @@ -196,10 +194,10 @@ end end @testset "Functors" begin - (f::NonDiffExample)(y) = fill(7.5, 100)[f.x+y] + (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x = 1.2), 2.3), NonDiffExample(3), 2) == - (7.5, NoTangent()) + @test frule((Tangent{NonDiffExample}(x=1.2), 2.3), NonDiffExample(3), 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent()) @@ -207,12 +205,8 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) - @test frule( - (ZeroTangent(), 1.2, 2.3), - NonDiffModuleExample.nondiff_2_1, - 3, - 2, - ) == (7.5, NoTangent()) + @test frule((ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) @@ -222,7 +216,7 @@ end # Where clauses are not supported. @test_macro_throws( ErrorException, - (@non_differentiable where_identity(::Vector{T}) where {T<:AbstractString}) + (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) ) end end @@ -230,33 +224,32 @@ end @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) - @scalar_rule(simo(x), 1.0f0, 2.0f0) + @scalar_rule(simo(x), 1f0, 2f0) y, simo_pb = rrule(simo, π) - @test simo_pb((10.0f0, 20.0f0)) == (NoTangent(), 50.0f0) + @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) - y, ẏ = frule((NoTangent(), 50.0f0), simo, π) + y, ẏ = frule((NoTangent(), 50f0), simo, π) @test y == (π, 2π) - @test ẏ == Tangent{typeof(y)}(50.0f0, 100.0f0) + @test ẏ == Tangent{typeof(y)}(50f0, 100f0) # make sure type is exactly as expected: - @test ẏ isa Tangent{Tuple{Irrational{:π},Float64},Tuple{Float32,Float32}} + @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} xs, Ω = (3,), (3, 6) - @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == - ((1.0f0,), (2.0f0,)) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) end @testset "@scalar_rule projection" begin - make_imaginary(x) = im * x + make_imaginary(x) = im*x @scalar_rule make_imaginary(x) im # note: the === will make sure that these are Float64, not ComplexF64 - @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0 * im) + @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im) @test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0) - @test (NoTangent(), 1.0 + 0.0im) === rrule(make_imaginary, 2.0im)[2](1.0 * im) - @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) + @test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im) + @test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end @testset "Regression tests against #276 and #265" begin @@ -264,16 +257,16 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 # Symptom of these problems is creation of global variables and type instability - num_globals_before = length(names(ChainRulesCore; all = true)) + num_globals_before = length(names(ChainRulesCore; all=true)) simo2(x) = (x, 2x) @scalar_rule(simo2(x), 1.0, 2.0) _, simo2_pb = rrule(simo2, 43.0) # make sure it infers: inferability implies type stability - @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) + @inferred simo2_pb(Tangent{Tuple{Float64, Float64}}(3.0, 6.0)) # Test no new globals were created - @test length(names(ChainRulesCore; all = true)) == num_globals_before + @test length(names(ChainRulesCore; all=true)) == num_globals_before # Example in #265 simo3(x) = sincos(x) @@ -286,60 +279,60 @@ end module IsolatedModuleForTestingScoping -# check that rules can be defined by macros without any additional imports -using ChainRulesCore: @scalar_rule, @non_differentiable - -# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved -const ChainRulesCore = nothing - -# this is -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 -fixed(x) = :abc -@non_differentiable fixed(x) - -# check name collision between a primal input called `kwargs` and the actual keyword -# arguments -fixed_kwargs(x; kwargs...) = :abc -@non_differentiable fixed_kwargs(kwargs) - -my_id(x) = x -@scalar_rule(my_id(x), 1.0) - -module IsolatedSubmodule -# check that rules defined in isolated module without imports can be called -# without errors -using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output -using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id -using Test - -@testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((ZeroTangent(), randn()), f, randn()) - @test y === :abc - @test ẏ === NoTangent() - - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end + # check that rules can be defined by macros without any additional imports + using ChainRulesCore: @scalar_rule, @non_differentiable + + # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved + const ChainRulesCore = nothing + + # this is + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 + fixed(x) = :abc + @non_differentiable fixed(x) + + # check name collision between a primal input called `kwargs` and the actual keyword + # arguments + fixed_kwargs(x; kwargs...) = :abc + @non_differentiable fixed_kwargs(kwargs) + + my_id(x) = x + @scalar_rule(my_id(x), 1.0) + + module IsolatedSubmodule + # check that rules defined in isolated module without imports can be called + # without errors + using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output + using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + using Test + + @testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((ZeroTangent(), randn()), f, randn()) + @test y === :abc + @test ẏ === NoTangent() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword = randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) -end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end -@testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ + @testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (NoTangent(), Δy) + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (NoTangent(), Δy) - @test derivatives_given_output(y, my_id, x) == ((1.0,),) -end -end + @test derivatives_given_output(y, my_id, x) == ((1.0,),) + end + end end diff --git a/test/rules.jl b/test/rules.jl index 267b23005..d43ca42d2 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,11 +28,8 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any,Any,Any,Vararg}, - ::typeof(mixed_vararg), - x, - y, - z..., + dargs::Tuple{Any, Any, Any, Vararg}, + ::typeof(mixed_vararg), x, y, z..., ) Δx = dargs[2] Δy = dargs[3] @@ -42,21 +39,16 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any,Int,Float64}, - ::typeof(type_constraints), - x::Int, - y::Float64, + (_, Δx, Δy)::Tuple{Any, Int, Float64}, + ::typeof(type_constraints), x::Int, y::Float64, ) return type_constraints(x, y), Δx + Δy end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any,Float64,Real,Vararg{Float64}}, - ::typeof(mixed_vararg_type_constaint), - x::Float64, - y::Real, - z::Vararg{Float64}, + dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, ) Δx = dargs[2] Δy = dargs[3] @@ -73,9 +65,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = ZeroTangent() @test frule((dself, 1), cool, 1) === nothing - @test frule((dself, 1), cool, 1; iscool = true) === nothing + @test frule((dself, 1), cool, 1; iscool=true) === nothing @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool = true) === nothing + @test rrule(cool, 1; iscool=true) === nothing # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @@ -84,10 +76,8 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([ - Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}, - ]) + only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, + Tuple{typeof(rrule),typeof(cool),String}]) @test cool_methods == only_methods frx, cool_pushforward = frule((dself, 1), cool, 1) @@ -108,26 +98,21 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) - @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == - (10.0, 10.0) + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing - @test( - frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, - 3.0, - 2.0, - 1.0, - 0.0, - ) == (6.0, 6.0) - ) + @test(frule( + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, + ) == (6.0, 6.0)) # violates type constraints, thus an frule should not be found. - @test frule((nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0) == - nothing + @test frule( + (nothing, 3, 2.0, 1.0, 5.0), + mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, + ) == nothing @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) @@ -168,29 +153,27 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) - @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where {T<:Float32} + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 @opt_out( - ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float32} + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 ) @testset "rrule" begin @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) - @test rrule(first_oa, 3.0f0, 4.0f0) === nothing + @test rrule(first_oa, 3f0, 4f0) === nothing @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} + m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 end) end @testset "frule" begin - @test frule((NoTangent(), 1, 0), first_oa, 3.0, 4.0) == (3.0, 1) - @test frule((NoTangent(), 1, 0), first_oa, 3.0f0, 4.0f0) === nothing - - @test !isempty( - Iterators.filter(methods(ChainRulesCore.no_frule)) do m - m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float32} - end, - ) + @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing + + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 + end) end end end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index fdbb92f55..7e0ec9398 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -9,7 +9,7 @@ @test view(NoTangent(), 1, 2) == NoTangent() @test sum(ZeroTangent()) == ZeroTangent() - @test sum(NoTangent(); dims = 2) == NoTangent() + @test sum(NoTangent(); dims=2) == NoTangent() end @testset "ZeroTangent" begin @@ -55,7 +55,7 @@ @test muladd(x, ZeroTangent(), ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), x, ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), ZeroTangent(), ZeroTangent()) === ZeroTangent() - + @test reim(z) === (ZeroTangent(), ZeroTangent()) @test real(z) === ZeroTangent() @test imag(z) === ZeroTangent() diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 2b7c6347e..2fd337979 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -1,14 +1,10 @@ @testset "NotImplemented" begin @testset "NotImplemented" begin ni = ChainRulesCore.NotImplemented( - @__MODULE__, - LineNumberNode(@__LINE__, @__FILE__), - "error", + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" ) ni2 = ChainRulesCore.NotImplemented( - @__MODULE__, - LineNumberNode(@__LINE__, @__FILE__), - "error2", + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" ) x = rand() thunk = @thunk(x^2) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index f3ff1085a..e276b6a5a 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -1,6 +1,6 @@ # For testing Tangent struct Foo - x::Any + x y::Float64 end @@ -12,81 +12,81 @@ end # For testing Tangent: it is an invarient of the type that x2 = 2x # so simple addition can not be defined struct StructWithInvariant - x::Any - x2::Any + x + x2 StructWithInvariant(x) = new(x, 2x) end @testset "Tangent" begin @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{}, Tuple{}} end @testset "==" begin - @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(x = 0.1, y = 2.5) - @test Tangent{Foo}(x = 0.1, y = 2.5) == Tangent{Foo}(y = 2.5, x = 0.1) - @test Tangent{Foo}(y = 2.5, x = ZeroTangent()) == Tangent{Foo}(y = 2.5) + @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(x=0.1, y=2.5) + @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(y=2.5, x=0.1) + @test Tangent{Foo}(y=2.5, x=ZeroTangent()) == Tangent{Foo}(y=2.5) - @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) + @test Tangent{Tuple{Float64,}}(2.0) == Tangent{Tuple{Float64,}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2*1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(; y = 2.0) == Tangent{Foo}(; x = ZeroTangent(), y = Float32(2.0)) + @test Tangent{Foo}(;y=2.0,) == Tangent{Foo}(;x=ZeroTangent(), y=Float32(2.0),) end @testset "hash" begin - @test hash(Tangent{Foo}(x = 0.1, y = 2.5)) == hash(Tangent{Foo}(y = 2.5, x = 0.1)) - @test hash(Tangent{Foo}(y = 2.5, x = ZeroTangent())) == hash(Tangent{Foo}(y = 2.5)) + @test hash(Tangent{Foo}(x=0.1, y=2.5)) == hash(Tangent{Foo}(y=2.5, x=0.1)) + @test hash(Tangent{Foo}(y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(y=2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x = 2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x = 2.5)) == (:x,) - @test haskey(Tangent{Foo}(x = 2.5), :x) == true + @test keys(Tangent{Foo}(x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(x=2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x = 2.5), :y) == false + @test hasproperty(Tangent{Foo}(x=2.5), :y) == false end - @test Tangent{Foo}(x = 2.5).x == 2.5 - - @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - - NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 - - @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(a = (@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b = (@thunk 2.0^2)), 2) == 4.0 + @test Tangent{Foo}(x=2.5).x == 2.5 + + @test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 + + NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}} + @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 + @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() + @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() + @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 + + @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 + @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x = 2.5)) == 1 - @test length(Tangent{Tuple{Float64}}(2.0)) == 1 + @test length(Tangent{Foo}(x=2.5)) == 1 + @test length(Tangent{Tuple{Float64,}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x = 2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 + @test eltype(Tangent{Foo}(x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64,}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x = 2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] + @test collect(Tangent{Foo}(x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64,}}(2.0)) == [2.0] # Test indexed_iterate ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function (tangent) + _unpack2tuple = function(tangent) a, b = tangent return (a, b) end @@ -96,33 +96,33 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x = 2, y = 3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y = 3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(y=3.0)) end end @testset "reverse" begin - c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") - cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) + c = Tangent{Tuple{Int, Int, String}}(1, 2, "something") + cr = Tangent{Tuple{String, Int, Int}}("something", 2, 1) @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x = 1.0, y = 2.0)) + @test_throws MethodError reverse(Tangent{Foo}(;x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo,typeof(d)}(d) + cdict = Tangent{Foo, typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end @testset "unset properties" begin - @test Tangent{Foo}(; x = 1.4).y === ZeroTangent() + @test Tangent{Foo}(; x=1.4).y === ZeroTangent() end @testset "conj" begin - @test conj(Tangent{Foo}(x = 2.0 + 3.0im)) == Tangent{Foo}(x = 2.0 - 3.0im) + @test conj(Tangent{Foo}(x=2.0+3.0im)) == Tangent{Foo}(x=2.0-3.0im) @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), - Tangent{Tuple{Float64}}(2.0 - 3.0im), + conj(Tangent{Tuple{Float64,}}(2.0+3.0im)), + Tangent{Tuple{Float64,}}(2.0-3.0im) ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -132,20 +132,26 @@ end @testset "canonicalize" begin # Testing iterate via collect - @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) + @test ==( + canonicalize(Tangent{Tuple{Float64,}}(2.0)), + Tangent{Tuple{Float64,}}(2.0) + ) - @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) + @test ==( + canonicalize(Tangent{Dict}(Dict(4 => 3))), + Tangent{Dict}(Dict(4 => 3)), + ) # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x = 2.5, y = 10)) == CFoo(x = 2.5, y = 10) - @test canonicalize(CFoo(y = 10, x = 2.5)) == CFoo(x = 2.5, y = 10) - @test canonicalize(CFoo(y = 10)) == CFoo(x = ZeroTangent(), y = 10) + @test canonicalize(CFoo(x=2.5, y=10)) == CFoo(x=2.5, y=10) + @test canonicalize(CFoo(y=10, x=2.5)) == CFoo(x=2.5, y=10) + @test canonicalize(CFoo(y=10)) == CFoo(x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(q = 99.0, x = 2.5)) + @test_throws ArgumentError canonicalize(CFoo(q=99.0, x=2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a = 1, b = 2) + c1 = Tangent{Any}(;a=1, b=2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -158,20 +164,20 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x = 1.5) + CFoo(x = 2.5) == CFoo(x = 4.0) - @test CFoo(y = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 2.5) - @test CFoo(y = 1.5, x = 1.5) + CFoo(x = 2.5) == CFoo(y = 1.5, x = 4.0) + @test CFoo(x=1.5) + CFoo(x=2.5) == CFoo(x=4.0) + @test CFoo(y=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=2.5) + @test CFoo(y=1.5, x=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=4.0) end @testset "Tuples" begin @test ==( typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{},Tuple{}}, + Tangent{Tuple{}, Tuple{}} ) @test ( - Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64,Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64, Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64, Float64}}(2.0, 3.0) end @testset "NamedTuples" begin @@ -191,8 +197,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x = 1.5) - b = CFoo(x = @not_implemented("")) + a = CFoo(x=1.5) + b = CFoo(x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -207,8 +213,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x = 1.5) - b = Tangent{NamedTuple{(:x,)}}(x = @not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(x=1.5) + b = Tangent{NamedTuple{(:x,)}}(x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -227,35 +233,35 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x = 2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x = 2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x = 0.5)) == 0 + @test Foo(3.5, 1.5) + Tangent{Foo}(x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end @testset "Tuples" begin @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + @test ((1.0, 2.0) + Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end @testset "NamedTuple" begin - ntx = (; a = 1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a = 3.0) + ntx = (; a=1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) - nty = (; a = 1.5, b = 0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a = 3.0, b = 1.0) + nty = (; a=1.5, b=0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) end @testset "Dicts" begin d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 =>5.0)) @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x = 2.0, x2 = 6.0) + diff = Tangent{StructWithInvariant}(x=2.0, x2=6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -272,7 +278,7 @@ end # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2 / 2) / 2 + x = (nt.x + nt.x2/2)/2 return StructWithInvariant(x) end @test value + diff == StructWithInvariant(12.5) @@ -280,7 +286,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y = 1.5, x = 2.5) + c = Tangent{Foo}(y=1.5, x=2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -302,14 +308,14 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y = 1.5, x = 2.5) == - Tangent{Foo}(y = 3.0, x = 5.0) == - Tangent{Foo}(y = 1.5, x = 2.5) * 2 + 2 * Tangent{Foo}(y=1.5, x=2.5) + == Tangent{Foo}(y=3.0, x=5.0) + == Tangent{Foo}(y=1.5, x=2.5) * 2 ) @test ( - 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == - Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == - Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 + 2 * Tangent{Tuple{Float64, Float64}}(2.0, 4.0) + == Tangent{Tuple{Float64, Float64}}(4.0, 8.0) + == Tangent{Tuple{Float64, Float64}}(2.0, 4.0) * 2 ) d = Tangent{Dict}(Dict(4 => 3.0)) two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) @@ -317,7 +323,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x = 1)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(x=1,)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -334,9 +340,8 @@ end end @testset "Internals don't allocate a ton" begin - bk = (; x = 1.0, y = 2.0) - VERSION >= v"1.5" && - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + bk = (; x=1.0, y=2.0) + VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 # weaker version of the above (which should pass on all versions) @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 @@ -345,8 +350,8 @@ end end @testset "non-same-typed differential arithmetic" begin - nt = (; a = 1, b = 2.0) - c = Tangent{typeof(nt)}(; a = NoTangent(), b = 0.1) - @test nt + c == (; a = 1, b = 2.1) + nt = (; a=1, b=2.0) + c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) + @test nt + c == (; a=1, b=2.1); end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index af4a747d1..89461caa1 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -141,7 +141,7 @@ # Check against accidential type piracy # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472 @test Base.which(diagm, Tuple{}()).module != ChainRulesCore - @test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore + @test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore end @test tril(a) == tril(t) @test tril(a, 1) == tril(t, 1) From 9473c65bf74b15202d674417545bb7982086ff8f Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 19:25:45 +0300 Subject: [PATCH 11/20] add .JuliaFormatter.toml with style = blue --- .JuliaFormatter.toml | 1 + 1 file changed, 1 insertion(+) create mode 100644 .JuliaFormatter.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..323237bab --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" From 904312bd1a6ebc222785dcd2811ebccf02e7d3ea Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Oct 2021 19:26:03 +0300 Subject: [PATCH 12/20] format(".") - hopefully actually with BlueStyle this time around!!! --- docs/make.jl | 13 +- docs/src/assets/make_logo.jl | 53 ++++---- src/accumulation.jl | 6 +- src/compat.jl | 4 +- src/config.jl | 1 - src/deprecated.jl | 1 + src/ignore_derivatives.jl | 8 +- src/projection.jl | 36 +++--- src/rule_definition_tools.jl | 78 +++++++----- src/tangent_arithmetic.jl | 14 +-- src/tangent_types/abstract_zero.jl | 6 +- src/tangent_types/notimplemented.jl | 14 ++- src/tangent_types/tangent.jl | 109 +++++++++-------- src/tangent_types/thunks.jl | 3 +- test/accumulation.jl | 24 ++-- test/config.jl | 38 +++--- test/deprecated.jl | 1 + test/ignore_derivatives.jl | 8 +- test/projection.jl | 40 +++---- test/rule_definition_tools.jl | 155 ++++++++++++------------ test/rules.jl | 66 +++++----- test/tangent_types/abstract_zero.jl | 2 +- test/tangent_types/tangent.jl | 180 +++++++++++++--------------- test/tangent_types/thunks.jl | 2 +- 24 files changed, 445 insertions(+), 417 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1ef3a62a7..42e39a4c0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,20 +16,20 @@ DocMeta.setdocmeta!( @scalar_rule(sin(x), cos(x)) # frule and rrule doctest @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end + end, ) indigo = DocThemeIndigo.install(ChainRulesCore) -makedocs( +makedocs(; modules=[ChainRulesCore], - format=Documenter.HTML( + format=Documenter.HTML(; prettyurls=false, assets=[indigo], mathengine=MathJax3( Dict( :tex => Dict( - "inlineMath" => [["\$","\$"], ["\\(","\\)"]], + "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], "tags" => "ams", # TODO: remove when using physics package "macros" => Dict( @@ -67,7 +67,4 @@ makedocs( checkdocs=:exports, ) -deploydocs( - repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", - push_preview=true, -) +deploydocs(; repo="github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview=true) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 5bbfd36c1..c023c308f 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -9,79 +9,76 @@ using Random const bridge_len = 50 function chain(jiggle=0) - shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) - + shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) + ### 1 shaky_rotate(0) sethue(Luxor.julia_red) link() m1 = getmatrix() - - + ### 2 sethue(Luxor.julia_green) - translate(-50, 130); - shaky_rotate(π/3); + translate(-50, 130) + shaky_rotate(π / 3) link() m2 = getmatrix() - + setmatrix(m1) sethue(Luxor.julia_red) overlap(-1.3π) setmatrix(m2) - + ### 3 - shaky_rotate(-π/3); - translate(-120,80); + shaky_rotate(-π / 3) + translate(-120, 80) sethue(Luxor.julia_purple) link() - + setmatrix(m2) setcolor(Luxor.julia_green) - overlap(-1.5π) + return overlap(-1.5π) end - function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - - - rect(50,-3,40, bridge_len+6, :fill) - rect(-50-40,-3,40, bridge_len+6, :fill) - + + rect(50, -3, 40, bridge_len + 6, :fill) + rect(-50 - 40, -3, 40, bridge_len + 6, :fill) + sethue("black") move(Point(-50, bridge_len)) - arc(Point(0,0), 50, π, 0, :stoke) + arc(Point(0, 0), 50, π, 0, :stoke) arc(Point(0, bridge_len), 50, 0, -π, :stroke) - + move(Point(-90, bridge_len)) - arc(Point(0,0), 90, π, 0, :stoke) + arc(Point(0, 0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) - strokepath() + return strokepath() end function overlap(ang_end) - sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) + sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill) sethue("black") arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) move(Point(90, bridge_len)) arc(Point(0, bridge_len), 90, 0, ang_end, :stoke) - strokepath() + return strokepath() end # Actually draw it function save_logo(filename) Random.seed!(16) - Drawing(450,450, filename) + Drawing(450, 450, filename) origin() - translate(50, -130); + translate(50, -130) chain(0.5) finish() - preview() + return preview() end save_logo("logo.svg") -save_logo("logo.png") \ No newline at end of file +save_logo("logo.png") diff --git a/src/accumulation.jl b/src/accumulation.jl index 4bcc5c33f..c9a38956a 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,7 +26,7 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) -function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N +function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} return if is_inplaceable_destination(x) x .+= y else @@ -34,7 +34,6 @@ function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N end end - """ is_inplaceable_destination(x) -> Bool @@ -64,7 +63,6 @@ end is_inplaceable_destination(::LinearAlgebra.Hermitian) = false is_inplaceable_destination(::LinearAlgebra.Symmetric) = false - function debug_add!(accumuland, t::InplaceableThunk) returned_value = t.add!(accumuland) if returned_value !== accumuland @@ -88,7 +86,7 @@ function Base.showerror(io::IO, err::BadInplaceException) if err.accumuland == err.returned_value println( io, - "Which in this case happenned to be equal. But they are not the same object." + "Which in this case happenned to be equal. But they are not the same object.", ) end end diff --git a/src/compat.jl b/src/compat.jl index 8204b66d5..fa66b1d0f 100644 --- a/src/compat.jl +++ b/src/compat.jl @@ -5,7 +5,7 @@ end if VERSION < v"1.1" # Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4 # See: https://github.com/JuliaLang/julia/issues/34292 - function fieldtypes(::Type{T}) where T + function fieldtypes(::Type{T}) where {T} if @generated ntuple(i -> fieldtype(T, i), fieldcount(T)) else @@ -13,7 +13,7 @@ if VERSION < v"1.1" end end - function fieldnames(::Type{T}) where T + function fieldnames(::Type{T}) where {T} if @generated ntuple(i -> fieldname(T, i), fieldcount(T)) else diff --git a/src/config.jl b/src/config.jl index 347e05c51..04757e838 100644 --- a/src/config.jl +++ b/src/config.jl @@ -64,7 +64,6 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar """ struct NoForwardsMode <: ForwardsModeCapability end - """ frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) diff --git a/src/deprecated.jl b/src/deprecated.jl index e69de29bb..8b1378917 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -0,0 +1 @@ + diff --git a/src/ignore_derivatives.jl b/src/ignore_derivatives.jl index c66d89d7e..18865f2c9 100644 --- a/src/ignore_derivatives.jl +++ b/src/ignore_derivatives.jl @@ -45,7 +45,9 @@ ignore_derivatives(x) = x Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`. """ macro ignore_derivatives(ex) - return :(ChainRulesCore.ignore_derivatives() do - $(esc(ex)) - end) + return :( + ChainRulesCore.ignore_derivatives() do + $(esc(ex)) + end + ) end diff --git a/src/projection.jl b/src/projection.jl index 4b07b2762..2e1a9340e 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -32,7 +32,7 @@ ProjectTo{P}() where {P} = ProjectTo{P}(EMPTY_NT) const Type_kwfunc = Core.kwftype(Type).instance function (::typeof(Type_kwfunc))(kws::Any, ::Type{ProjectTo{P}}) where {P} - ProjectTo{P}(NamedTuple(kws)) + return ProjectTo{P}(NamedTuple(kws)) end Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) @@ -131,7 +131,9 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}() +function ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} + return ProjectTo{NoTangent}() +end # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -164,12 +166,16 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) end # In these cases we can just `convert` as we know we are dealing with plain and simple types -(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx) -(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity +(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) +(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +function (::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} + return convert(T, dx) +end (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +function (::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} + return convert(T, dx) +end (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -244,9 +250,11 @@ end # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers if !(project.axes isa Tuple{}) - throw(DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - )) + throw( + DimensionMismatch( + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number" + ), + ) end return fill(project.element(dx)) end @@ -298,7 +306,9 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +function ProjectTo(x::LinearAlgebra.TransposeAbsVec) + return ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +end function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -316,10 +326,8 @@ ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) # Symmetric -for (SymHerm, chk, fun) in ( - (:Symmetric, :issymmetric, :transpose), - (:Hermitian, :ishermitian, :adjoint), - ) +for (SymHerm, chk, fun) in + ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) @eval begin function ProjectTo(x::$SymHerm) sub = ProjectTo(parent(x)) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 6f217a3ac..fd32fbbbd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -99,11 +99,13 @@ macro scalar_rule(call, maybe_setup, partials...) rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro - code = quote + return code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw(ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))" - )) + throw( + ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))" + ), + ) end $(derivative_expr) @@ -112,7 +114,6 @@ macro scalar_rule(call, maybe_setup, partials...) end end - """ _normalize_scalarrules_macro_input(call, maybe_setup, partials) @@ -175,7 +176,9 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote - function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + function ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), ::Core.Typeof($f), $(inputs...) + ) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) @@ -210,7 +213,9 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) return $(esc(:Ω)), $pushforward_returns end end @@ -225,7 +230,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) Δs = _propagator_inputs(n_outputs) # Make a projector for each argument - projs, psetup = _make_projectors(call.args[2:end]) + projs, psetup = _make_projectors(call.args[2:end]) append!(setup_stmts, psetup) # 1 partial derivative per input @@ -248,7 +253,9 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) return $(esc(:Ω)), $pullback end end @@ -262,7 +269,7 @@ _propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) projs = map(x -> Symbol(:proj_, x.args[1]), xs) - setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs) + setups = map((x, p) -> :($p = ProjectTo($x)), xs, projs) return projs, setups end @@ -288,9 +295,10 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) - end + summed_∂_mul_Δs = + foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) + :(muladd($∂s_i, $Δs_i, $ex)) + end return :($proj($summed_∂_mul_Δs)) end @@ -366,7 +374,7 @@ macro non_differentiable(sig_expr) primal_invoke = if !has_vararg :($(primal_name)($(unconstrained_args...))) else - normal_args = unconstrained_args[1:end-1] + normal_args = unconstrained_args[1:(end - 1)] var_arg = unconstrained_args[end] :($(primal_name)($(normal_args...), $(var_arg)...)) end @@ -381,7 +389,10 @@ end function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... + :call, + call_expr.args[1], + Expr(:parameters, :($(kwargs)...)), + call_expr.args[2:end]..., ) end @@ -389,11 +400,17 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), - frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function (::Core.kwftype(typeof(ChainRulesCore.frule)))( + @nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end - function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function ChainRulesCore.frule( + @nospecialize(::Any), $(map(esc, primal_sig_parts)...) + ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() return ($(esc(primal_invoke)), NoTangent()) @@ -408,7 +425,8 @@ function tuple_expression(primal_sig_parts) Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + length_expr = + :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) end end @@ -426,7 +444,9 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) + function (::Core.kwftype(typeof(rrule)))( + $(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...) + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) @@ -436,7 +456,6 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end end - ############################################################################################ # @opt_out @@ -481,7 +500,7 @@ end "Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) - length(expr.args)===0 && error("Malformed method expression. $expr") + length(expr.args) === 0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore @@ -502,8 +521,6 @@ function _no_rule_target_rewrite!(call_target::Symbol) end end - - ############################################################################################ # Helpers @@ -555,9 +572,9 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) - + if primal_name isa Symbol || + Meta.isexpr(primal_name, :(.)) || + Meta.isexpr(primal_name, :curly) primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name elseif Meta.isexpr(primal_name, :(::)) # e.g. (::T)(x, y) @@ -574,14 +591,15 @@ _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. Meta.isexpr(arg, :(...), 1) && return _unconstrain(arg.args[1]) - error("malformed arguments: $arg") + return error("malformed arguments: $arg") end "turn both `a` and `::constraint` into `a::constraint` etc" function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) - error("malformed arguments: $arg") + Meta.isexpr(arg, :(...), 1) && + return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) + return error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 9c1378aab..c2bad7a77 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -81,7 +81,7 @@ LinearAlgebra.dot(::ZeroTangent, ::NoTangent) = ZeroTangent() Base.muladd(::ZeroTangent, x, y) = y Base.muladd(x, ::ZeroTangent, y) = y -Base.muladd(x, y, ::ZeroTangent) = x*y +Base.muladd(x, y, ::ZeroTangent) = x * y Base.muladd(::ZeroTangent, ::ZeroTangent, y) = y Base.muladd(x, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() @@ -125,11 +125,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where P +function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P, typeof(data)}(data) + return Tangent{P,typeof(data)}(data) end -function Base.:+(a::P, d::Tangent{P}) where P +function Base.:+(a::P, d::Tangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -142,12 +142,12 @@ function Base.:+(a::P, d::Tangent{P}) where P end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where P = b + a +Base.:+(a::Tangent{P}, b::P) where {P} = b + a # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 differentials # Only of a differential and a scaling factor (generally `Real`) for T in (:Any,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent) + @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 216357e91..5993d32b4 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -17,15 +17,15 @@ Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) -Base.Broadcast.broadcasted(::Type{T}) where T<:AbstractZero = T() +Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z -Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) -(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T) +Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) +(::Type{T})(xs::AbstractZero...) where {T<:Number} = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index a2044fbe1..a6b9cc5f9 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,9 +44,15 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw(NotImplementedException(@not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" -))) +function Base.zero(::Type{<:NotImplemented}) + return throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" + ) + ), + ) +end Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) @@ -75,5 +81,5 @@ function Base.showerror(io::IO, e::NotImplementedException) if e.info !== nothing print(io, "\nInfo: ", e.info) end - return + return nothing end diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index e4bbfb8c8..bb91e431e 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -21,42 +21,42 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P, T} <: AbstractTangent +struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T end -function Tangent{P}(; kwargs...) where P +function Tangent{P}(; kwargs...) where {P} backing = (; kwargs...) # construct as NamedTuple - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end -function Tangent{P}(args...) where P - return Tangent{P, typeof(args)}(args) +function Tangent{P}(args...) where {P} + return Tangent{P,typeof(args)}(args) end -function Tangent{P}() where P<:Tuple +function Tangent{P}() where {P<:Tuple} backing = () - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P, typeof(d)}(d) + return Tangent{P,typeof(d)}(d) end -function Base.:(==)(a::Tangent{P, T}, b::Tangent{P, T}) where {P, T} +function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P, T} +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P,T} all_fields = union(keys(backing(a)), keys(backing(b))) return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where P +function Base.show(io::IO, tangent::Tangent{P}) where {P} print(io, "Tangent{") show(io, P) print(io, "}") @@ -68,15 +68,15 @@ function Base.show(io::IO, tangent::Tangent{P}) where P end end -function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}} +function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end -function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict} +function Base.getindex(tangent::Tangent, idx) where {P,T<:AbstractDict} return unthunk(getindex(backing(tangent), idx)) end @@ -84,7 +84,7 @@ function Base.getproperty(tangent::Tangent, idx::Int) back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end @@ -99,26 +99,26 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) -Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T) +Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing) + return Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) end function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end -function Base.map(f, tangent::Tangent{P, <:Tuple}) where P +function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} vals::Tuple = map(f, backing(tangent)) - return Tangent{P, typeof(vals)}(vals) + return Tangent{P,typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L} +function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L, typeof(vals)}(vals) - return Tangent{P, typeof(named_vals)}(named_vals) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return Tangent{P,typeof(named_vals)}(named_vals) end -function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict} +function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end @@ -140,26 +140,28 @@ backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) # For generic structs -function backing(x::T)::NamedTuple where T +function backing(x::T)::NamedTuple where {T} # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. # see https://github.com/JuliaLang/julia/issues/34283 if @generated - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names, Tuple{$(types...)}}($vals)) + vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) else - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = ntuple(ii->getfield(x, ii), nfields) - return NamedTuple{names, Tuple{types...}}(vals) + vals = ntuple(ii -> getfield(x, ii), nfields) + return NamedTuple{names,Tuple{types...}}(vals) end end @@ -170,36 +172,38 @@ Return the canonical `Tangent` for the primal type `P`. The property names of the returned `Tangent` match the field names of the primal, and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. """ -function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L} +function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} nil = _zeroed_backing(P) combined = merge(nil, backing(tangent)) if length(combined) !== fieldcount(P) - throw(ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))" - )) + throw( + ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", + ), + ) end - return Tangent{P, typeof(combined)}(combined) + return Tangent{P,typeof(combined)}(combined) end # Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent +canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent # Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent +canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent # Tangents of unspecified primal types (indicated by specifying exactly `Any`) # all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:Tuple}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:AbstractDict}) where {L} = tangent """ _zeroed_backing(P) Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. """ -@generated function _zeroed_backing(::Type{P}) where P +@generated function _zeroed_backing(::Type{P}) where {P} nil_base = ntuple(fieldcount(P)) do i (fieldname(P, i), ZeroTangent()) end @@ -218,7 +222,7 @@ after an operation such as the addition of a primal to a tangent It should be overloaded, if `T` does not have a default constructor, or if `T` needs to maintain some invarients between its fields. """ -function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} +function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} # Tested and verified that that this avoids a ton of allocations if length(L) !== fieldcount(T) # if length is equal but names differ then we will catch that below anyway. @@ -233,12 +237,12 @@ function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} end end -construct(::Type{T}, fields::T) where T<:NamedTuple = fields -construct(::Type{T}, fields::T) where T<:Tuple = fields +construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields +construct(::Type{T}, fields::T) where {T<:Tuple} = fields elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} # Rule of Tangent addition: any fields not present are implict hard Zeros # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. @@ -281,7 +285,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} end field => value end - return (;vals...) + return (; vals...) end end @@ -297,15 +301,16 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")"; color=:blue) println(io, "\nor overload") - printstyled(io, + printstyled( + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color=:blue + color=:blue, ) println(io, "\nor overload") printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) println(io, "\nOriginal Exception:") printstyled(io, err.original; color=:yellow) - println(io) + return println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 16384d69e..e065bea62 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -67,7 +67,7 @@ end function LinearAlgebra.diagm( m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) - return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) + return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a)) @@ -197,7 +197,6 @@ Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a) - """ InplaceableThunk(add!::Function, val::Thunk) diff --git a/test/accumulation.jl b/test/accumulation.jl index 1b41fea55..a796b5289 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -27,7 +27,7 @@ end @testset "misc AbstractTangent subtypes" begin - @test 16 == add!!(12, @thunk(2*2)) + @test 16 == add!!(12, @thunk(2 * 2)) @test 16 == add!!(16, ZeroTangent()) @test 16 == add!!(16, NoTangent()) # Should this be an error? @@ -37,15 +37,15 @@ @testset "LHS Array (inplace)" begin @testset "RHS Array" begin A = [1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] end @testset "RHS StaticArray" begin - A = @SMatrix[1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + A = @SMatrix [1.0 2.0; 3.0 4.0] + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] @@ -53,7 +53,7 @@ @testset "RHS Diagonal" begin A = Diagonal([1.0, 2.0]) - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 -1.0; -1.0 1.0] @@ -79,17 +79,17 @@ @testset "Unhappy Path" begin # wrong length - @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(2, 2)) # wrong shape - @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(16)) # wrong type (adding scalar to array) @test_throws MethodError add!!(ones(4), 21.0) end end @testset "AbstractThunk $(typeof(thunk))" for thunk in ( - @thunk(-1.0*ones(2, 2)), - InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))), + @thunk(-1.0 * ones(2, 2)), + InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -111,12 +111,12 @@ @testset "not actually inplace but said it was" begin # thunk should never be used in this test ithunk = InplaceableThunk(@thunk(@assert false)) do x - 77*ones(2, 2) # not actually inplace (also wrong) + 77 * ones(2, 2) # not actually inplace (also wrong) end accumuland = ones(2, 2) @assert ChainRulesCore.debug_mode() == false # without debug being enabled should return the result, not error - @test 77*ones(2, 2) == add!!(accumuland, ithunk) + @test 77 * ones(2, 2) == add!!(accumuland, ithunk) ChainRulesCore.debug_mode() = true # enable debug mode # with debug being enabled should error @@ -127,7 +127,7 @@ @testset "showerror BadInplaceException" begin BadInplaceException = ChainRulesCore.BadInplaceException - ithunk = InplaceableThunk(x̄->nothing, @thunk(@assert false)) + ithunk = InplaceableThunk(x̄ -> nothing, @thunk(@assert false)) msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) @test occursin("22", msg) diff --git a/test/config.jl b/test/config.jl index 466baed9a..58d943252 100644 --- a/test/config.jl +++ b/test/config.jl @@ -1,7 +1,7 @@ # Define a bunch of configs for testing purposes struct MostBoringConfig <: RuleConfig{Union{}} end -struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} forward_calls::Vector end MockForwardsConfig() = MockForwardsConfig([]) @@ -11,7 +11,7 @@ function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args. return f(args...; kws...), ȧrgs end -struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} reverse_calls::Vector end MockReverseConfig() = MockReverseConfig([]) @@ -22,8 +22,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. return f(args...; kws...), pullback_via_ad end - -struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} forward_calls::Vector reverse_calls::Vector end @@ -47,18 +46,18 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, + MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig ) counting_id_count = Ref(0) function counting_id(x) - counting_id_count[]+=1 + counting_id_count[] += 1 return x end function ChainRulesCore.rrule(::typeof(counting_id), x) counting_id_pullback(x̄) = x̄ return counting_id(x), counting_id_pullback end - function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) + function ChainRulesCore.frule((dself, dx), ::typeof(counting_id), x) return counting_id(x), dx end @testset "rrule" begin @@ -88,7 +87,7 @@ end end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== frule( bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) @@ -104,13 +103,12 @@ end return (NoTangent(), rrule_via_ad(config, f, x)...) end - @testset "$Config" for Config in (MostBoringConfig, MockForwardsConfig) @test nothing === rrule(Config(), do_thing_3, identity, 32.1) end @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) @test bconfig.reverse_calls == [(identity, (32.1,))] end @@ -130,14 +128,14 @@ end ẋ = one(x) y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) - pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ * ȳ return y, pullback_via_forwards_ad end function ChainRulesCore.rrule( - config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, + config::RuleConfig{>:Union{HasReverseMode,NoForwardsMode}}, ::typeof(do_thing_4), f, - x + x, ) y, f_pullback = rrule_via_ad(config, f, x) do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) @@ -147,18 +145,18 @@ end @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) @test bconfig.forward_calls == [(identity, (32.1,))] end - rconfig= MockReverseConfig() + rconfig = MockReverseConfig() @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) @test rconfig.reverse_calls == [(identity, (32.1,))] end @testset "RuleConfig broadcasts like a scaler" begin - @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} + @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} end @testset "fallbacks" begin @@ -174,16 +172,16 @@ end # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(; kw="hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(sin; kw="hello") @test_throws MethodError frule(MostBoringConfig()) @test_throws MethodError frule(MostBoringConfig(); kw="hello") @test_throws MethodError frule(MostBoringConfig(), sin) @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(; kw="hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig();kw="hello") + @test_throws MethodError rrule(MostBoringConfig(); kw="hello") end end diff --git a/test/deprecated.jl b/test/deprecated.jl index e69de29bb..8b1378917 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -0,0 +1 @@ + diff --git a/test/ignore_derivatives.jl b/test/ignore_derivatives.jl index 825287b9a..ad4fece9f 100644 --- a/test/ignore_derivatives.jl +++ b/test/ignore_derivatives.jl @@ -7,7 +7,7 @@ end @testset "function" begin f() = return 4.0 - y, ẏ = frule((1.0, ), ignore_derivatives, f) + y, ẏ = frule((1.0,), ignore_derivatives, f) @test y == f() @test ẏ == NoTangent() @@ -19,7 +19,7 @@ end @testset "argument" begin arg = 2.1 - y, ẏ = frule((1.0, ), ignore_derivatives, arg) + y, ẏ = frule((1.0,), ignore_derivatives, arg) @test y == arg @test ẏ == NoTangent() @@ -41,11 +41,11 @@ end @test pb(1.0) == (NoTangent(), NoTangent()) # when called - y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0)) + y, ẏ = frule((1.0,), ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test ẏ == NoTangent() - y, pb = rrule(ignore_derivatives, ()->mf(3.0)) + y, pb = rrule(ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test pb(1.0) == (NoTangent(), NoTangent()) end diff --git a/test/projection.jl b/test/projection.jl index ba61fb8da..cbfdcf6da 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,9 +24,9 @@ struct NoSuperType end # real / complex @test ProjectTo(1.0)(2.0 + 3im) === 2.0 @test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im - @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im - @test ProjectTo(2.0)(1+1im) === 1.0 - + @test ProjectTo(2.0 + 3.0im)(1 + 1im) === 1.0 + 1.0im + @test ProjectTo(2.0)(1 + 1im) === 1.0 + # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) @@ -37,7 +37,8 @@ struct NoSuperType end @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(; re=1, im=NoTangent())) === + 1.0f0 + 0.0f0im end @testset "Dual" begin # some weird Real subtype that we should basically leave alone @@ -46,9 +47,8 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} - @test ProjectTo(1.0 + 1im)( - Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)) - ) isa Complex{<:Dual} + @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa + Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent @@ -99,7 +99,7 @@ struct NoSuperType end # arrays of other things @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} - @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} + @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) @@ -126,18 +126,18 @@ struct NoSuperType end @testset "Base: Ref" begin pref = ProjectTo(Ref(2.0)) @test pref(Ref(3 + im)).x === 3.0 - @test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0 + @test pref(Tangent{Base.RefValue}(; x=3 + im)).x === 3.0 @test pref(4).x === 4.0 # also re-wraps scalars @test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue} pref2 = ProjectTo(Ref{Any}(6 + 7im)) @test pref2(Ref(8)).x === 8.0 + 0.0im - @test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im + @test pref2(Tangent{Base.RefValue}(; x=8)).x === 8.0 + 0.0im prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents @test prefvec(Ref(1:3)).x isa Vector{ComplexF64} - @test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64} - @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5)) + @test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64} + @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5)) @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()); + dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()) @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3+4im) + dy = Tangent{typeof(Ref(2.0))}(; x=3 + 4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end @@ -365,21 +365,21 @@ struct NoSuperType end # Each "@test 33 > ..." is zero on nightly, 32 on 1.5. pvec = ProjectTo(rand(10^3)) - @test 0 == @ballocated $pvec(dx) setup=(dx = rand(10^3)) # pass through - @test 90 > @ballocated $pvec(dx) setup=(dx = rand(10^3, 1)) # reshape + @test 0 == @ballocated $pvec(dx) setup = (dx = rand(10^3)) # pass through + @test 90 > @ballocated $pvec(dx) setup = (dx = rand(10^3, 1)) # reshape @test 33 > @ballocated ProjectTo(x)(dx) setup = (x = rand(10^3); dx = rand(10^3)) # including construction padj = ProjectTo(adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = transpose(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = adjoint(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = transpose(rand(10^3))) @test 33 > @ballocated ProjectTo(x')(dx') setup = (x = rand(10^3); dx = rand(10^3)) pdiag = ProjectTo(Diagonal(rand(10^3))) - @test 0 == @ballocated $pdiag(dx) setup=(dx = Diagonal(rand(10^3))) + @test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3))) psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) - @test_broken 0 == @ballocated $psymm(dx) setup=(dx = Symmetric(rand(10^3, 10^3))) # 64 + @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 0d6d98535..674cbb361 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -19,7 +19,7 @@ macro test_macro_throws(err_expr, expr) end end # Reuse `@test_throws` logic - if err!==nothing + if err !== nothing @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) @@ -37,7 +37,7 @@ struct NonDiffCounterExample end module NonDiffModuleExample - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] +nondiff_2_1(x, y) = fill(7.5, 100)[x + y] end @testset "rule_definition_tools.jl" begin @@ -58,7 +58,7 @@ end res, pullback = rrule(nondiff_1_2, 3.1) @test res == (5.0, 3.0) @test isequal( - pullback(Tangent{Tuple{Float64, Float64}}(1.2, 3.2)), + pullback(Tangent{Tuple{Float64,Float64}}(1.2, 3.2)), (NoTangent(), NoTangent()), ) end @@ -81,7 +81,8 @@ end pointy_identity(x) = x @non_differentiable pointy_identity(::Vector{<:AbstractString}) - @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == (["2"], NoTangent()) + @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == + (["2"], NoTangent()) @test frule((ZeroTangent(), 1.2), pointy_identity, 2.0) == nothing res, pullback = rrule(pointy_identity, ["2"]) @@ -112,7 +113,8 @@ end @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == + (4.5, NoTangent()) end end @@ -121,7 +123,7 @@ end @test isequal( frule((ZeroTangent(), 1.2), NonDiffExample, 2.0), - (NonDiffExample(2.0), NoTangent()) + (NonDiffExample(2.0), NoTangent()), ) res, pullback = rrule(NonDiffExample, 2.0) @@ -151,7 +153,7 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), NoTangent()) @test frule((1, 1), fvarargs, 1, 2) == nothing - @test rrule(fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @@ -196,8 +198,8 @@ end @testset "Functors" begin (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x=1.2), 2.3), NonDiffExample(3), 2) == - (7.5, NoTangent()) + @test frule((Tangent{NonDiffExample}(; x=1.2), 2.3), NonDiffExample(3), 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent()) @@ -205,8 +207,9 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) - @test frule((ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2) == - (7.5, NoTangent()) + @test frule( + (ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2 + ) == (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) @@ -216,7 +219,7 @@ end # Where clauses are not supported. @test_macro_throws( ErrorException, - (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) + (@non_differentiable where_identity(::Vector{T}) where {T<:AbstractString}) ) end end @@ -224,32 +227,33 @@ end @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) - @scalar_rule(simo(x), 1f0, 2f0) + @scalar_rule(simo(x), 1.0f0, 2.0f0) y, simo_pb = rrule(simo, π) - @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) + @test simo_pb((10.0f0, 20.0f0)) == (NoTangent(), 50.0f0) - y, ẏ = frule((NoTangent(), 50f0), simo, π) + y, ẏ = frule((NoTangent(), 50.0f0), simo, π) @test y == (π, 2π) - @test ẏ == Tangent{typeof(y)}(50f0, 100f0) + @test ẏ == Tangent{typeof(y)}(50.0f0, 100.0f0) # make sure type is exactly as expected: - @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} + @test ẏ isa Tangent{Tuple{Irrational{:π},Float64},Tuple{Float32,Float32}} xs, Ω = (3,), (3, 6) - @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == + ((1.0f0,), (2.0f0,)) end @testset "@scalar_rule projection" begin - make_imaginary(x) = im*x + make_imaginary(x) = im * x @scalar_rule make_imaginary(x) im # note: the === will make sure that these are Float64, not ComplexF64 - @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im) + @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0 * im) @test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0) - @test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im) - @test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) + @test (NoTangent(), 1.0 + 0.0im) === rrule(make_imaginary, 2.0im)[2](1.0 * im) + @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end @testset "Regression tests against #276 and #265" begin @@ -263,7 +267,7 @@ end @scalar_rule(simo2(x), 1.0, 2.0) _, simo2_pb = rrule(simo2, 43.0) # make sure it infers: inferability implies type stability - @inferred simo2_pb(Tangent{Tuple{Float64, Float64}}(3.0, 6.0)) + @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) # Test no new globals were created @test length(names(ChainRulesCore; all=true)) == num_globals_before @@ -277,62 +281,61 @@ end end end - module IsolatedModuleForTestingScoping - # check that rules can be defined by macros without any additional imports - using ChainRulesCore: @scalar_rule, @non_differentiable - - # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved - const ChainRulesCore = nothing - - # this is - # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 - fixed(x) = :abc - @non_differentiable fixed(x) - - # check name collision between a primal input called `kwargs` and the actual keyword - # arguments - fixed_kwargs(x; kwargs...) = :abc - @non_differentiable fixed_kwargs(kwargs) - - my_id(x) = x - @scalar_rule(my_id(x), 1.0) - - module IsolatedSubmodule - # check that rules defined in isolated module without imports can be called - # without errors - using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output - using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id - using Test - - @testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((ZeroTangent(), randn()), f, randn()) - @test y === :abc - @test ẏ === NoTangent() - - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end +# check that rules can be defined by macros without any additional imports +using ChainRulesCore: @scalar_rule, @non_differentiable + +# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved +const ChainRulesCore = nothing + +# this is +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 +fixed(x) = :abc +@non_differentiable fixed(x) + +# check name collision between a primal input called `kwargs` and the actual keyword +# arguments +fixed_kwargs(x; kwargs...) = :abc +@non_differentiable fixed_kwargs(kwargs) + +my_id(x) = x +@scalar_rule(my_id(x), 1.0) + +module IsolatedSubmodule +# check that rules defined in isolated module without imports can be called +# without errors +using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output +using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id +using Test + +@testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((ZeroTangent(), randn()), f, randn()) + @test y === :abc + @test ẏ === NoTangent() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) +end - @testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ +@testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (NoTangent(), Δy) + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (NoTangent(), Δy) - @test derivatives_given_output(y, my_id, x) == ((1.0,),) - end - end + @test derivatives_given_output(y, my_id, x) == ((1.0,),) +end +end end diff --git a/test/rules.jl b/test/rules.jl index d43ca42d2..54c10b160 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,8 +28,7 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Any, Any, Vararg}, - ::typeof(mixed_vararg), x, y, z..., + dargs::Tuple{Any,Any,Any,Vararg}, ::typeof(mixed_vararg), x, y, z... ) Δx = dargs[2] Δy = dargs[3] @@ -39,16 +38,18 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any, Int, Float64}, - ::typeof(type_constraints), x::Int, y::Float64, + (_, Δx, Δy)::Tuple{Any,Int,Float64}, ::typeof(type_constraints), x::Int, y::Float64 ) return type_constraints(x, y), Δx + Δy end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, - ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, + dargs::Tuple{Any,Float64,Real,Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), + x::Float64, + y::Real, + z::Vararg{Float64}, ) Δx = dargs[2] Δy = dargs[3] @@ -76,8 +77,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}]) + only_methods = Set([ + Tuple{typeof(rrule),typeof(cool),Number}, Tuple{typeof(rrule),typeof(cool),String} + ]) @test cool_methods == only_methods frx, cool_pushforward = frule((dself, 1), cool, 1) @@ -94,25 +96,24 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) rrx, nice_pullback = rrule(nice, 1) @test (NoTangent(), ZeroTangent()) === nice_pullback(1) - # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) - @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == + (10.0, 10.0) @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing - @test(frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, - ) == (6.0, 6.0)) + @test( + frule( + (nothing, 3.0, 2.0, 1.0, 0.0), mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0 + ) == (6.0, 6.0) + ) # violates type constraints, thus an frule should not be found. - @test frule( - (nothing, 3, 2.0, 1.0, 5.0), - mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, - ) == nothing + @test frule((nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0) == + nothing @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) @@ -149,31 +150,34 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test_skip ∂xr ≈ real(∂x) end - @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) - @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where {T<:Float32} @opt_out( - ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float32} ) @testset "rrule" begin @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) - @test rrule(first_oa, 3f0, 4f0) === nothing + @test rrule(first_oa, 3.0f0, 4.0f0) === nothing - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 - end) + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_rrule)) do m + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end @testset "frule" begin - @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) - @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing - - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m - m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 - end) + @test frule((NoTangent(), 1, 0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1, 0), first_oa, 3.0f0, 4.0f0) === nothing + + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end end end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7e0ec9398..f8222d942 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -55,7 +55,7 @@ @test muladd(x, ZeroTangent(), ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), x, ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), ZeroTangent(), ZeroTangent()) === ZeroTangent() - + @test reim(z) === (ZeroTangent(), ZeroTangent()) @test real(z) === ZeroTangent() @test imag(z) === ZeroTangent() diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index e276b6a5a..26e6a2422 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -20,73 +20,73 @@ end @testset "Tangent" begin @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{}, Tuple{}} + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} end @testset "==" begin - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(x=0.1, y=2.5) - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(y=2.5, x=0.1) - @test Tangent{Foo}(y=2.5, x=ZeroTangent()) == Tangent{Foo}(y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) + @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) - @test Tangent{Tuple{Float64,}}(2.0) == Tangent{Tuple{Float64,}}(2.0) + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2*1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(;y=2.0,) == Tangent{Foo}(;x=ZeroTangent(), y=Float32(2.0),) + @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) end @testset "hash" begin - @test hash(Tangent{Foo}(x=0.1, y=2.5)) == hash(Tangent{Foo}(y=2.5, x=0.1)) - @test hash(Tangent{Foo}(y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(y=2.5)) + @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) + @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(x=2.5), :x) == true + @test keys(Tangent{Foo}(; x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(; x=2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x=2.5), :y) == false + @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false end - @test Tangent{Foo}(x=2.5).x == 2.5 - - @test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - - NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}} - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 - - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 + @test Tangent{Foo}(; x=2.5).x == 2.5 + + @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x=2.5)) == 1 - @test length(Tangent{Tuple{Float64,}}(2.0)) == 1 + @test length(Tangent{Foo}(; x=2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64,}}(2.0)) == Float64 + @test eltype(Tangent{Foo}(; x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64,}}(2.0)) == [2.0] + @test collect(Tangent{Foo}(; x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] # Test indexed_iterate ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function(tangent) + _unpack2tuple = function (tangent) a, b = tangent return (a, b) end @@ -96,21 +96,21 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) end end @testset "reverse" begin - c = Tangent{Tuple{Int, Int, String}}(1, 2, "something") - cr = Tangent{Tuple{String, Int, Int}}("something", 2, 1) + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(;x=1.0, y=2.0)) + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo, typeof(d)}(d) + cdict = Tangent{Foo,typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end @@ -119,10 +119,9 @@ end end @testset "conj" begin - @test conj(Tangent{Foo}(x=2.0+3.0im)) == Tangent{Foo}(x=2.0-3.0im) + @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) @test ==( - conj(Tangent{Tuple{Float64,}}(2.0+3.0im)), - Tangent{Tuple{Float64,}}(2.0-3.0im) + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -132,26 +131,20 @@ end @testset "canonicalize" begin # Testing iterate via collect - @test ==( - canonicalize(Tangent{Tuple{Float64,}}(2.0)), - Tangent{Tuple{Float64,}}(2.0) - ) + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - @test ==( - canonicalize(Tangent{Dict}(Dict(4 => 3))), - Tangent{Dict}(Dict(4 => 3)), - ) + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x=2.5, y=10)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10, x=2.5)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10)) == CFoo(x=ZeroTangent(), y=10) + @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(;a=1, b=2) + c1 = Tangent{Any}(; a=1, b=2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -164,27 +157,26 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x=1.5) + CFoo(x=2.5) == CFoo(x=4.0) - @test CFoo(y=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=2.5) - @test CFoo(y=1.5, x=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=4.0) + @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) + @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) + @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) end @testset "Tuples" begin @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{}, Tuple{}} + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} ) @test ( - Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64, Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64, Float64}}(2.0, 3.0) + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) end @testset "NamedTuples" begin make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) - t1 = make_tangent((; a = 1.5, b = 0.0)) - t2 = make_tangent((; a = 0.0, b = 2.5)) - t_sum = make_tangent((a = 1.5, b = 2.5)) + t1 = make_tangent((; a=1.5, b=0.0)) + t2 = make_tangent((; a=0.0, b=2.5)) + t_sum = make_tangent((a=1.5, b=2.5)) @test t1 + t2 == t_sum end @@ -197,8 +189,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x=1.5) - b = CFoo(x=@not_implemented("")) + a = CFoo(; x=1.5) + b = CFoo(; x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -213,8 +205,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x=1.5) - b = Tangent{NamedTuple{(:x,)}}(x=@not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(; x=1.5) + b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -233,15 +225,15 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end @testset "Tuples" begin @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end @testset "NamedTuple" begin @@ -254,14 +246,14 @@ end @testset "Dicts" begin d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 =>5.0)) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x=2.0, x2=6.0) + diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -274,11 +266,10 @@ end ChainRulesCore.debug_mode() = false # disable it again end - # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2/2)/2 + x = (nt.x + nt.x2 / 2) / 2 return StructWithInvariant(x) end @test value + diff == StructWithInvariant(12.5) @@ -286,7 +277,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y=1.5, x=2.5) + c = Tangent{Foo}(; y=1.5, x=2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -308,14 +299,14 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y=1.5, x=2.5) - == Tangent{Foo}(y=3.0, x=5.0) - == Tangent{Foo}(y=1.5, x=2.5) * 2 + 2 * Tangent{Foo}(; y=1.5, x=2.5) == + Tangent{Foo}(; y=3.0, x=5.0) == + Tangent{Foo}(; y=1.5, x=2.5) * 2 ) @test ( - 2 * Tangent{Tuple{Float64, Float64}}(2.0, 4.0) - == Tangent{Tuple{Float64, Float64}}(4.0, 8.0) - == Tangent{Tuple{Float64, Float64}}(2.0, 4.0) * 2 + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 ) d = Tangent{Dict}(Dict(4 => 3.0)) two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) @@ -323,7 +314,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x=1,)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -341,7 +332,8 @@ end @testset "Internals don't allocate a ton" begin bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 # weaker version of the above (which should pass on all versions) @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 @@ -352,6 +344,6 @@ end @testset "non-same-typed differential arithmetic" begin nt = (; a=1, b=2.0) c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1); + @test nt + c == (; a=1, b=2.1) end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 89461caa1..af4a747d1 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -141,7 +141,7 @@ # Check against accidential type piracy # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472 @test Base.which(diagm, Tuple{}()).module != ChainRulesCore - @test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore + @test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore end @test tril(a) == tril(t) @test tril(a, 1) == tril(t, 1) From 04edaf45b213a93cd19fb17f343a37cc0a07fd73 Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 7 Oct 2021 19:34:58 +0300 Subject: [PATCH 13/20] Update .github/workflows/format.yml Co-authored-by: Lyndon White --- .github/workflows/format.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 06b8dbe46..f6f268c0e 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -24,3 +24,4 @@ jobs: with: tool_name: JuliaFormatter fail_on_error: true + filter_mode: added From b01cd28fb8282551f2e9dcc424233da416241277 Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 7 Oct 2021 23:51:38 +0300 Subject: [PATCH 14/20] Apply suggestions from code review Co-authored-by: Lyndon White --- src/projection.jl | 3 ++- src/rule_definition_tools.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 2e1a9340e..3e84a0fb9 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -131,7 +131,8 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -function ProjectTo{P}(::NamedTuple{T,<:Tuple{_PZ,Vararg{<:_PZ}}}) where {P,T} +const _PZ_Tuple = Tuple{_PZ,Vararg{<:_PZ} # 1 or more ProjectTo{<:AbstractZeros} +function ProjectTo{P}(::NamedTuple{T,<:_PZ_Tuple}) where {P,T} return ProjectTo{NoTangent}() end diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index fd32fbbbd..bef0b8bb1 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -99,7 +99,7 @@ macro scalar_rule(call, maybe_setup, partials...) rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro - return code = quote + return quote if !($f isa Type) && fieldcount(typeof($f)) > 0 throw( ArgumentError( From 8345e17e4107989132cdf7c0bd76fdfa509a42a6 Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 8 Oct 2021 08:31:22 +0300 Subject: [PATCH 15/20] fix suggestion --- src/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index 3e84a0fb9..a017493e4 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -131,7 +131,7 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -const _PZ_Tuple = Tuple{_PZ,Vararg{<:_PZ} # 1 or more ProjectTo{<:AbstractZeros} +const _PZ_Tuple = Tuple{_PZ,Vararg{<:_PZ}} # 1 or more ProjectTo{<:AbstractZeros} function ProjectTo{P}(::NamedTuple{T,<:_PZ_Tuple}) where {P,T} return ProjectTo{NoTangent}() end From f0d78e2c4d862c35758c94f4929b4f910e6152db Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 8 Oct 2021 08:35:40 +0300 Subject: [PATCH 16/20] manual formatting workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484 --- test/rule_definition_tools.jl | 107 +++++++++++++++++----------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 674cbb361..ec9549e7e 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -281,61 +281,64 @@ end end end +#! format: off +# workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484 module IsolatedModuleForTestingScoping -# check that rules can be defined by macros without any additional imports -using ChainRulesCore: @scalar_rule, @non_differentiable - -# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved -const ChainRulesCore = nothing - -# this is -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 -fixed(x) = :abc -@non_differentiable fixed(x) - -# check name collision between a primal input called `kwargs` and the actual keyword -# arguments -fixed_kwargs(x; kwargs...) = :abc -@non_differentiable fixed_kwargs(kwargs) - -my_id(x) = x -@scalar_rule(my_id(x), 1.0) - -module IsolatedSubmodule -# check that rules defined in isolated module without imports can be called -# without errors -using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output -using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id -using Test - -@testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((ZeroTangent(), randn()), f, randn()) - @test y === :abc - @test ẏ === NoTangent() - - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) - end + # check that rules can be defined by macros without any additional imports + using ChainRulesCore: @scalar_rule, @non_differentiable + + # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved + const ChainRulesCore = nothing + + # this is + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 + fixed(x) = :abc + @non_differentiable fixed(x) + + # check name collision between a primal input called `kwargs` and the actual keyword + # arguments + fixed_kwargs(x; kwargs...) = :abc + @non_differentiable fixed_kwargs(kwargs) + + my_id(x) = x + @scalar_rule(my_id(x), 1.0) + + module IsolatedSubmodule + # check that rules defined in isolated module without imports can be called + # without errors + using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output + using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + using Test + + @testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((ZeroTangent(), randn()), f, randn()) + @test y === :abc + @test ẏ === NoTangent() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) - @test y === :abc - @test f_pullback(randn()) === (NoTangent(), NoTangent()) -end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + @test y === :abc + @test f_pullback(randn()) === (NoTangent(), NoTangent()) + end -@testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ + @testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((ZeroTangent(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (NoTangent(), Δy) + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (NoTangent(), Δy) - @test derivatives_given_output(y, my_id, x) == ((1.0,),) -end -end + @test derivatives_given_output(y, my_id, x) == ((1.0,),) + end + end end +#! format: on From 0f858bb48e1d15bcdf28352503f9448353085338 Mon Sep 17 00:00:00 2001 From: st-- Date: Fri, 8 Oct 2021 08:36:16 +0300 Subject: [PATCH 17/20] Update src/rule_definition_tools.jl Co-authored-by: Lyndon White --- src/rule_definition_tools.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index bef0b8bb1..2d246b394 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -295,10 +295,10 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = - foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) - end + _∂s_Δs_tail = Iterators.drop(zip(_∂s, Δs), 1) + summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) + :(muladd($∂s_i, $Δs_i, $ex)) + end return :($proj($summed_∂_mul_Δs)) end From 33946237868297a9f4eff448d5c550df689d158e Mon Sep 17 00:00:00 2001 From: st-- Date: Fri, 8 Oct 2021 13:14:37 +0300 Subject: [PATCH 18/20] Update src/rule_definition_tools.jl --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 2d246b394..c3ffe14f6 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -295,7 +295,7 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - _∂s_Δs_tail = Iterators.drop(zip(_∂s, Δs), 1) + _∂s_Δs_tail = Iterators.drop(zip(_∂s, Δs), 1) summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) :(muladd($∂s_i, $Δs_i, $ex)) end From ddec127545707352c7e15b175dcc037b16e10e13 Mon Sep 17 00:00:00 2001 From: st-- Date: Fri, 8 Oct 2021 13:21:54 +0300 Subject: [PATCH 19/20] Update src/rule_definition_tools.jl --- src/rule_definition_tools.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index c3ffe14f6..7f2fb3375 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -572,9 +572,10 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || - Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) + is_plain = primal_name isa Symbol + is_qualified = Meta.isexpr(primal_name, :(.)) + is_parameterized = Meta.isexpr(primal_name, :curly) + if is_plain || is_qualified || is_parameterized primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name elseif Meta.isexpr(primal_name, :(::)) # e.g. (::T)(x, y) From ffddfe4a8222310d4cf345113be99921e6d1981e Mon Sep 17 00:00:00 2001 From: st-- Date: Fri, 8 Oct 2021 13:22:28 +0300 Subject: [PATCH 20/20] Update src/rule_definition_tools.jl --- src/rule_definition_tools.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 7f2fb3375..56e02b02a 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -572,10 +572,10 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - is_plain = primal_name isa Symbol - is_qualified = Meta.isexpr(primal_name, :(.)) - is_parameterized = Meta.isexpr(primal_name, :curly) - if is_plain || is_qualified || is_parameterized + is_plain = primal_name isa Symbol + is_qualified = Meta.isexpr(primal_name, :(.)) + is_parameterized = Meta.isexpr(primal_name, :curly) + if is_plain || is_qualified || is_parameterized primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name elseif Meta.isexpr(primal_name, :(::)) # e.g. (::T)(x, y)