Skip to content

Commit

Permalink
Merge pull request #185 from JuliaDiff/ox/fastmath
Browse files Browse the repository at this point in the history
Add rules for FastMath
  • Loading branch information
oxinabox authored May 7, 2020
2 parents 860c4d2 + 026b0c5 commit 2d8b265
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 183 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.5.2"
version = "0.5.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ if VERSION < v"1.3.0-DEV.142"
end

include("rulesets/Base/base.jl")
include("rulesets/Base/fastmath_able.jl")
include("rulesets/Base/array.jl")
include("rulesets/Base/mapreduce.jl")

Expand Down
186 changes: 64 additions & 122 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
@@ -1,139 +1,81 @@
@scalar_rule(one(x), Zero())
@scalar_rule(zero(x), Zero())
@scalar_rule(sign(x), Zero())

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(exp(x::Real), Ω)
@scalar_rule(exp10(x), Ω * log(oftype(x, 10)))
@scalar_rule(exp2(x), Ω * log(oftype(x, 2)))
@scalar_rule(expm1(x), exp(x))
@scalar_rule(log(x), inv(x))
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
@scalar_rule(log1p(x), inv(x + 1))
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))

@scalar_rule(cos(x), -sin(x))
@scalar_rule(cosd(x), -/ oftype(x, 180)) * sind(x))
@scalar_rule(cospi(x), -π * sinpi(x))
@scalar_rule(sin(x), cos(x))
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx)
@scalar_rule(sind(x), (π / oftype(x, 180)) * cosd(x))
@scalar_rule(sinpi(x), π * cospi(x))

@scalar_rule(acos(x), -inv(sqrt(1 - x^2)))
@scalar_rule(acot(x), -inv(1 + x^2))
@scalar_rule(acsc(x), -inv(x^2 * sqrt(1 - x^-2)))
@scalar_rule(acsc(x::Real), -inv(abs(x) * sqrt(x^2 - 1)))
@scalar_rule(asec(x), inv(x^2 * sqrt(1 - x^-2)))
@scalar_rule(asec(x::Real), inv(abs(x) * sqrt(x^2 - 1)))
@scalar_rule(asin(x), inv(sqrt(1 - x^2)))
@scalar_rule(atan(x), inv(1 + x^2))
@scalar_rule(atan(y, x), @setup(u = x^2 + y^2), (x / u, -y / u))

@scalar_rule(acosd(x), -oftype(x, 180) / π / sqrt(1 - x^2))
@scalar_rule(acotd(x), -oftype(x, 180) / π / (1 + x^2))
@scalar_rule(acscd(x), -oftype(x, 180) / π / x^2 / sqrt(1 - x^-2))
@scalar_rule(acscd(x::Real), -oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
@scalar_rule(asecd(x), oftype(x, 180) / π / x^2 / sqrt(1 - x^-2))
@scalar_rule(asecd(x::Real), oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
@scalar_rule(asind(x), oftype(x, 180) / π / sqrt(1 - x^2))
@scalar_rule(atand(x), oftype(x, 180) / π / (1 + x^2))

@scalar_rule(cosh(x), sinh(x))
@scalar_rule(coth(x), -(csch(x)^2))
@scalar_rule(sinh(x), cosh(x))
@scalar_rule(tanh(x), 1-Ω^2)
# See also fastmath_able.jl for where rules are defined simple base functions
# that also have FastMath versions.

# Can't multiply though sqrt in acosh because of negative complex case for x
@scalar_rule(acosh(x), inv(sqrt(x - 1) * sqrt(x + 1)))
@scalar_rule(acoth(x), inv(1 - x^2))
@scalar_rule(acsch(x), -inv(x^2 * sqrt(1 + x^-2)))
@scalar_rule(acsch(x::Real), -inv(abs(x) * sqrt(1 + x^2)))
@scalar_rule(asech(x), -inv(x * sqrt(1 - x^2)))
@scalar_rule(asinh(x), inv(sqrt(x^2 + 1)))
@scalar_rule(atanh(x), inv(1 - x^2))

@scalar_rule(deg2rad(x), π / oftype(x, 180))
@scalar_rule(rad2deg(x), oftype(x, 180) / π)

@scalar_rule(adjoint(x::Real), One())
@scalar_rule(conj(x::Real), One())
@scalar_rule(transpose(x), One())

@scalar_rule(+(x), One())
@scalar_rule(-(x), -1)
@scalar_rule(+(x, y), (One(), One()))
@scalar_rule(-(x, y), (One(), -1))
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))

#log(complex(x)) is require so it give correct complex answer for x<0
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(complex(x))))

@scalar_rule(cbrt(x), inv(3 * Ω^2))
@scalar_rule(inv(x), -Ω^2)
@scalar_rule(sqrt(x), inv(2 * Ω))

@scalar_rule(cot(x), -(1 + Ω^2))
@scalar_rule(cotd(x), -/ oftype(x, 180)) * (1 + Ω^2))
@scalar_rule(csc(x), -Ω * cot(x))
@scalar_rule(cscd(x), -/ oftype(x, 180)) * Ω * cotd(x))
@scalar_rule(csch(x), -coth(x) * Ω)
@scalar_rule(sec(x), Ω * tan(x))
@scalar_rule(secd(x), (π / oftype(x, 180)) * Ω * tand(x))
@scalar_rule(sech(x), -tanh(x) * Ω)
@scalar_rule(tan(x), 1 + Ω^2)
@scalar_rule(tand(x), (π / oftype(x, 180)) * (1 + Ω^2))

@scalar_rule(angle(x::Real), Zero())
@scalar_rule(hypot(x::Real), sign(x))
@scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω))
@scalar_rule(imag(x::Real), Zero())

@scalar_rule(fma(x, y, z), (y, x, One()))
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
@scalar_rule(muladd(x, y, z), (y, x, One()))
@scalar_rule one(x) Zero()
@scalar_rule zero(x) Zero()
@scalar_rule adjoint(x::Real) One()
@scalar_rule transpose(x) One()
@scalar_rule imag(x::Real) Zero()
@scalar_rule hypot(x::Real) sign(x)


@scalar_rule fma(x, y, z) (y, x, One())
@scalar_rule muladd(x, y, z) (y, x, One())
@scalar_rule real(x::Real) One()
@scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist())
@scalar_rule(
mod(x, y),
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u))),
)
@scalar_rule(real(x::Real), One())
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))
@scalar_rule(
rem(x, y),
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))),
)

# product rule requires special care for arguments where `mul` is non-commutative

function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more
# accurate on machines with FMA instructions, since there are only two
# rounding operations, one in `muladd/fma` and the other in `*`.
∂xy = muladd.(Δx, y, x .* Δy)
return x * y, ∂xy
end
@scalar_rule deg2rad(x) π / oftype(x, 180)
@scalar_rule rad2deg(x) oftype(x, 180) / π

function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * y), @thunk(x * ΔΩ))
end
return x * y, times_pullback
end

function frule((_, ẏ), ::typeof(identity), x)
return x, ẏ
# Can't multiply though sqrt in acosh because of negative complex case for x
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x) inv(1 - x ^ 2)
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2)))
@scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2)))
@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2)))
@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1))
@scalar_rule atanh(x) inv(1 - x ^ 2)


@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2)
@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2)
@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1)
@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1)
@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2)
@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2)

@scalar_rule cot(x) -((1 + Ω ^ 2))
@scalar_rule coth(x) -(csch(x) ^ 2)
@scalar_rule cotd(x) -/ oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule csc(x) -Ω * cot(x)
@scalar_rule cscd(x) -/ oftype(x, 180)) * Ω * cotd(x)
@scalar_rule csch(x) -(coth(x)) * Ω
@scalar_rule sec(x) Ω * tan(x)
@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x)
@scalar_rule sech(x) -(tanh(x)) * Ω

@scalar_rule acot(x) -(inv(1 + x ^ 2))
@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2)))
@scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1)))
@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2))
@scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1))

@scalar_rule cosd(x) -/ oftype(x, 180)) * sind(x)
@scalar_rule cospi(x) -π * sinpi(x)
@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x)
@scalar_rule sinpi(x) π * cospi(x)
@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2)

@scalar_rule x \ y (-((y / x) / x), inv(x))

function frule((_, ẏ), ::typeof(identity), x)
return (x, ẏ)
end

function rrule(::typeof(identity), x)
function identity_pullback()
return (NO_FIELDS, )
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
end
return x, identity_pullback
return (x, identity_pullback)
end

function rrule(::typeof(identity), x::Tuple)
Expand Down
90 changes: 90 additions & 0 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
let
# Include inside this quote any rules that should have FastMath versions
fastable_ast = quote
# Trig-Basics
@scalar_rule cos(x) -(sin(x))
@scalar_rule sin(x) cos(x)
@scalar_rule tan(x) 1 + Ω ^ 2


# Trig-Hyperbolic
@scalar_rule cosh(x) sinh(x)
@scalar_rule sinh(x) cosh(x)
@scalar_rule tanh(x) 1 - Ω ^ 2

# Trig- Inverses
@scalar_rule acos(x) -(inv(sqrt(1 - x ^ 2)))
@scalar_rule asin(x) inv(sqrt(1 - x ^ 2))
@scalar_rule atan(x) inv(1 + x ^ 2)

# Trig-Multivariate
@scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u)
@scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx

# exponents
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@scalar_rule inv(x) -^ 2)
@scalar_rule sqrt(x) inv(2Ω)
@scalar_rule exp(x::Real) Ω
@scalar_rule exp10(x) Ω * log(oftype(x, 10))
@scalar_rule exp2(x) Ω * log(oftype(x, 2))
@scalar_rule expm1(x) exp(x)
@scalar_rule log(x) inv(x)
@scalar_rule log10(x) inv(x) / log(oftype(x, 10))
@scalar_rule log1p(x) inv(x + 1)
@scalar_rule log2(x) inv(x) / log(oftype(x, 2))


# Unary complex functions
@scalar_rule abs(x::Real) sign(x)
@scalar_rule abs2(x) 2x
@scalar_rule angle(x::Real) Zero()
@scalar_rule conj(x::Real) One()

# Binary functions
@scalar_rule hypot(x::Real, y::Real) (x / Ω, y / Ω)
@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (inv(y), -((x / y) / y))
#log(complex(x)) is require so it give correct complex answer for x<0
@scalar_rule(x ^ y,
(ifelse(iszero(y), zero(Ω), y * x ^ (y - 1)), Ω * log(complex(x))),
)
@scalar_rule(
rem(x, y),
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))),
)
@scalar_rule max(x, y) @setup(gt = x > y) (gt, !gt)
@scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt)

# Unary functions
@scalar_rule +x One()
@scalar_rule -x -1


@scalar_rule sign(x) Zero()


# product rule requires special care for arguments where `mul` is non-commutative
function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more
# accurate on machines with FMA instructions, since there are only two
# rounding operations, one in `muladd/fma` and the other in `*`.
∂xy = muladd.(Δx, y, x .* Δy)
return x * y, ∂xy
end

function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * y), @thunk(x * ΔΩ))
end
return x * y, times_pullback
end
end

# Rewrite everything to use fast_math functions, including the type-constraints
eval(Base.FastMath.make_fastmath(fastable_ast))
eval(fastable_ast) # Get original definitions
# we do this second so it overwrites anything we included by mistake in the fastable
end
Loading

2 comments on commit 2d8b265

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/14351

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.3 -m "<description of version>" 2d8b2656c16a38df1a01d698a0bd729d4624ac76
git push origin v0.5.3

Please sign in to comment.