Skip to content

Commit

Permalink
Broadcasting (#644)
Browse files Browse the repository at this point in the history
* broadcasting, adapted from Diffractor PR68

* many small upgrades

* fixup tuplecast

* re-organise split bc, add forward mode

* fix tests

* add Yota to downstream tests

* fix an ambiguity

* fix tests on 1.6

* testing

* improve unbroadcast

* change generic rule to use BroadcastStyle

* debug

* rename with unzip

* fix for 1.6

* test bugs

* version

* tidy unzipped

* add some GPU tests

* remove fallback unbroadcast method

* re-instate the error which breaks Revise
  • Loading branch information
mcabbott authored Aug 9, 2022
1 parent d53d8d8 commit 5818173
Show file tree
Hide file tree
Showing 13 changed files with 877 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
os: [ubuntu-latest]
package:
# - {user: dpsanders, repo: ReversePropagation.jl}
- {user: dfdx, repo: Yota.jl}
- {user: FluxML, repo: Zygote.jl}
# Diffractor needs to run on Julia nightly
# include:
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.42.0"
version = "1.43.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
ChainRulesCore = "1.15.3"
Expand All @@ -25,6 +26,7 @@ JLArrays = "0.1"
JuliaInterpreter = "0.8,0.9"
RealDot = "0.1"
StaticArrays = "1.2"
StructArrays = "0.6.11"
julia = "1.6"

[extras]
Expand Down
6 changes: 6 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
using ChainRulesCore
using Compat
using Distributed
using GPUArraysCore: AbstractGPUArrayStyle
using IrrationalConstants: logtwo, logten
using LinearAlgebra
using LinearAlgebra.BLAS
using Random
using RealDot: realdot
using SparseArrays
using Statistics
using StructArrays

# Basically everything this package does is overloading these, so we make an exception
# to the normal rule of only overload via `ChainRulesCore.rrule`.
Expand All @@ -22,6 +24,9 @@ using ChainRulesCore: derivatives_given_output
# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

# StructArrays
include("unzipped.jl")

include("rulesets/Core/core.jl")

include("rulesets/Base/utils.jl")
Expand All @@ -34,6 +39,7 @@ include("rulesets/Base/arraymath.jl")
include("rulesets/Base/indexing.jl")
include("rulesets/Base/sort.jl")
include("rulesets/Base/mapreduce.jl")
include("rulesets/Base/broadcast.jl")

include("rulesets/Distributed/nondiff.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
return (T(x, y), Complex_pullback)
end

@scalar_rule complex(x) true

# `hypot`

@scalar_rule hypot(x::Real) sign(x)
Expand Down
Loading

2 comments on commit 5818173

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65908

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.43.0 -m "<description of version>" 5818173b31fac9d358acda21f5751978a5dcb2e5
git push origin v1.43.0

Please sign in to comment.