Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JuliaFormatter & workflow #482

Merged
merged 21 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "blue"
27 changes: 27 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Format suggestions

on:
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1
- run: |
julia -e 'using Pkg; Pkg.add("JuliaFormatter")'
julia -e 'using JuliaFormatter; format("."; verbose=true)'
- uses: reviewdog/action-suggester@v1
with:
tool_name: JuliaFormatter
fail_on_error: true
filter_mode: added
13 changes: 5 additions & 8 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ DocMeta.setdocmeta!(
@scalar_rule(sin(x), cos(x)) # frule and rrule doctest
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest
@scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest
end
end,
)

indigo = DocThemeIndigo.install(ChainRulesCore)

makedocs(
makedocs(;
modules=[ChainRulesCore],
format=Documenter.HTML(
format=Documenter.HTML(;
prettyurls=false,
assets=[indigo],
mathengine=MathJax3(
Dict(
:tex => Dict(
"inlineMath" => [["\$","\$"], ["\\(","\\)"]],
"inlineMath" => [["\$", "\$"], ["\\(", "\\)"]],
"tags" => "ams",
# TODO: remove when using physics package
"macros" => Dict(
Expand Down Expand Up @@ -67,7 +67,4 @@ makedocs(
checkdocs=:exports,
)

deploydocs(
repo = "github.com/JuliaDiff/ChainRulesCore.jl.git",
push_preview=true,
)
deploydocs(; repo="github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview=true)
53 changes: 25 additions & 28 deletions docs/src/assets/make_logo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,79 +9,76 @@ using Random
const bridge_len = 50

function chain(jiggle=0)
shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5))
shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5))

### 1
shaky_rotate(0)
sethue(Luxor.julia_red)
link()
m1 = getmatrix()



### 2
sethue(Luxor.julia_green)
translate(-50, 130);
shaky_rotate(π/3);
translate(-50, 130)
shaky_rotate(π / 3)
link()
m2 = getmatrix()

setmatrix(m1)
sethue(Luxor.julia_red)
overlap(-1.3π)
setmatrix(m2)

### 3
shaky_rotate(-π/3);
translate(-120,80);
shaky_rotate(-π / 3)
translate(-120, 80)
sethue(Luxor.julia_purple)
link()

setmatrix(m2)
setcolor(Luxor.julia_green)
overlap(-1.5π)
return overlap(-1.5π)
end


function link()
sector(50, 90, π, 0, :fill)
sector(Point(0, bridge_len), 50, 90, 0, -π, :fill)


rect(50,-3,40, bridge_len+6, :fill)
rect(-50-40,-3,40, bridge_len+6, :fill)


rect(50, -3, 40, bridge_len + 6, :fill)
rect(-50 - 40, -3, 40, bridge_len + 6, :fill)

sethue("black")
move(Point(-50, bridge_len))
arc(Point(0,0), 50, π, 0, :stoke)
arc(Point(0, 0), 50, π, 0, :stoke)
arc(Point(0, bridge_len), 50, 0, -π, :stroke)

move(Point(-90, bridge_len))
arc(Point(0,0), 90, π, 0, :stoke)
arc(Point(0, 0), 90, π, 0, :stoke)
arc(Point(0, bridge_len), 90, 0, -π, :stroke)
strokepath()
return strokepath()
end

function overlap(ang_end)
sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill)
sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill)
sethue("black")
arc(Point(0, bridge_len), 50, 0, ang_end, :stoke)
move(Point(90, bridge_len))
arc(Point(0, bridge_len), 90, 0, ang_end, :stoke)

strokepath()
return strokepath()
end

# Actually draw it

function save_logo(filename)
Random.seed!(16)
Drawing(450,450, filename)
Drawing(450, 450, filename)
origin()
translate(50, -130);
translate(50, -130)
chain(0.5)
finish()
preview()
return preview()
end

save_logo("logo.svg")
save_logo("logo.png")
save_logo("logo.png")
6 changes: 2 additions & 4 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ end

add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y))

function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N
function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N}
return if is_inplaceable_destination(x)
x .+= y
else
x + y
end
end


"""
is_inplaceable_destination(x) -> Bool

Expand Down Expand Up @@ -64,7 +63,6 @@ end
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false


function debug_add!(accumuland, t::InplaceableThunk)
returned_value = t.add!(accumuland)
if returned_value !== accumuland
Expand All @@ -88,7 +86,7 @@ function Base.showerror(io::IO, err::BadInplaceException)
if err.accumuland == err.returned_value
println(
io,
"Which in this case happenned to be equal. But they are not the same object."
"Which in this case happenned to be equal. But they are not the same object.",
)
end
end
4 changes: 2 additions & 2 deletions src/compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ end
if VERSION < v"1.1"
# Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4
# See: https://github.com/JuliaLang/julia/issues/34292
function fieldtypes(::Type{T}) where T
function fieldtypes(::Type{T}) where {T}
if @generated
ntuple(i -> fieldtype(T, i), fieldcount(T))
else
ntuple(i -> fieldtype(T, i), fieldcount(T))
end
end

function fieldnames(::Type{T}) where T
function fieldnames(::Type{T}) where {T}
if @generated
ntuple(i -> fieldname(T, i), fieldcount(T))
else
Expand Down
1 change: 0 additions & 1 deletion src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar
"""
struct NoForwardsMode <: ForwardsModeCapability end


"""
frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...)

Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

8 changes: 5 additions & 3 deletions src/ignore_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ ignore_derivatives(x) = x
Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`.
"""
macro ignore_derivatives(ex)
return :(ChainRulesCore.ignore_derivatives() do
$(esc(ex))
end)
return :(
ChainRulesCore.ignore_derivatives() do
$(esc(ex))
end
)
end
37 changes: 23 additions & 14 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ProjectTo{P}() where {P} = ProjectTo{P}(EMPTY_NT)

const Type_kwfunc = Core.kwftype(Type).instance
function (::typeof(Type_kwfunc))(kws::Any, ::Type{ProjectTo{P}}) where {P}
ProjectTo{P}(NamedTuple(kws))
return ProjectTo{P}(NamedTuple(kws))
end

Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name)
Expand Down Expand Up @@ -131,7 +131,10 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
# Also, any explicit construction with fields, where all fields project to zero, itself
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
const _PZ = ProjectTo{<:AbstractZero}
ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}()
const _PZ_Tuple = Tuple{_PZ,Vararg{<:_PZ}} # 1 or more ProjectTo{<:AbstractZeros}
function ProjectTo{P}(::NamedTuple{T,<:_PZ_Tuple}) where {P,T}
return ProjectTo{NoTangent}()
end

# Tangent
# We haven't entirely figured out when to convert Tangents to "natural" representations such as
Expand Down Expand Up @@ -164,12 +167,16 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
end

# In these cases we can just `convert` as we know we are dealing with plain and simple types
(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx)
(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity
(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx)
(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity
# simple Complex{<:AbstractFloat}} cases
(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
function (::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}}
return convert(T, dx)
end
(::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
function (::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}}
return convert(T, dx)
end
(::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)

# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
Expand Down Expand Up @@ -244,9 +251,11 @@ end
# although really Ref() is probably a better structure.
function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers
if !(project.axes isa Tuple{})
throw(DimensionMismatch(
"array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
))
throw(
DimensionMismatch(
"array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number"
),
)
end
return fill(project.element(dx))
end
Expand Down Expand Up @@ -298,7 +307,9 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
return adjoint(project.parent(dy))
end

ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
return ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
end
function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec)
return transpose(project.parent(transpose(dx)))
end
Expand All @@ -316,10 +327,8 @@ ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))

# Symmetric
for (SymHerm, chk, fun) in (
(:Symmetric, :issymmetric, :transpose),
(:Hermitian, :ishermitian, :adjoint),
)
for (SymHerm, chk, fun) in
((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint))
@eval begin
function ProjectTo(x::$SymHerm)
sub = ProjectTo(parent(x))
Expand Down
Loading