diff --git a/Project.toml b/Project.toml index 632a6f587..9f4aaf106 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.17" +version = "0.7.18" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index 4d03491bf..1afa2466c 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -12,11 +12,27 @@ end @non_differentiable Random.randexp(::AbstractRNG) @non_differentiable Random.randstring(::AbstractRNG) -@non_differentiable rand(::Any) -@non_differentiable rand(::Any, ::Any) -@non_differentiable rand(::Any, ::Any, ::Any) -@non_differentiable rand(::Any, ::Any, ::Any, ::Any) -@non_differentiable rand(::Any, ::Any, ::Any, ::Any, ::Any) +@non_differentiable rand() +@non_differentiable rand(::AbstractRNG) +@non_differentiable rand(::AbstractRNG, ::Random.Sampler) +@non_differentiable rand(::AbstractRNG, ::Integer) +@non_differentiable rand(::AbstractRNG, ::Integer, ::Integer) +@non_differentiable rand(::AbstractRNG, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::AbstractRNG, ::Integer, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::AbstractRNG, ::Integer, ::Integer, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::Type{<:Real}) +@non_differentiable rand(::Type{<:Real}, ::Tuple) +@non_differentiable rand(::Type{<:Real}, ::Integer) +@non_differentiable rand(::Type{<:Real}, ::Integer, ::Integer) +@non_differentiable rand(::Type{<:Real}, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::Type{<:Real}, ::Integer, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::Type{<:Real}, ::Integer, ::Integer, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::Integer) +@non_differentiable rand(::Integer, ::Integer) +@non_differentiable rand(::Integer, ::Integer, ::Integer) +@non_differentiable rand(::Integer, ::Integer, ::Integer, ::Integer) +@non_differentiable rand(::Integer, ::Integer, ::Integer, ::Integer, ::Integer) + # There are many different 1-3 arg methods, but not varargs @non_differentiable rand!(::Any) diff --git a/test/rulesets/Random/random.jl b/test/rulesets/Random/random.jl index dec6c4ffe..1f5092450 100644 --- a/test/rulesets/Random/random.jl +++ b/test/rulesets/Random/random.jl @@ -1,3 +1,10 @@ +# Simple Distributions like object for testing purposes +struct NormalDistribution + μ + σ +end +Random.rand(d::NormalDistribution) = d.μ + d.σ*randn() + @testset "random" begin @testset "MersenneTwister" begin @testset "no args" begin @@ -19,4 +26,36 @@ @test all(map(x -> x isa Zero, pb(10))) end end + + @testset "rand" begin + non_differentiables = [ + ((), Float64), + ((MersenneTwister(123),), Float64), + ((MersenneTwister(123),2,2), Matrix{<:Float64}), + ((Float32,), Float32), + ((Float32,2,2), Matrix{<:Float32}), + ((Float32,(2,2)), Matrix{<:Float32}), + ((2,2), Matrix{<:Float64}), + ] + + for (args, xType) in non_differentiables + x, dΩ = frule((Zero(), randn(args...)), rand, args...) + @test x isa xType + @test dΩ isa DoesNotExist + + x, pb = rrule(rand, args...) + @test x isa xType + dself, dargs = Iterators.peel(pb(10.0)) + @test dself isa Zero + @test all(darg isa DoesNotExist for darg in dargs) + end + + # Make sure that we do *not* have these set as non_differentiable. as they are differentiable + @test nothing === frule( + (Zero(), Composite{NormalDistribution}(μ=0.5,σ=2.0)), + rand, + NormalDistribution(0.1,1.5), + ) + @test rrule(rand, NormalDistribution(0.1,1.5)) === nothing + end end