diff --git a/Project.toml b/Project.toml index a5b16dc76..52fd73087 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.4.0" +version = "1.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index ffd97464c..928c0cf57 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -1,3 +1,42 @@ +##### +##### constructors +##### + +ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) + +function rrule(::Type{T}, x::AbstractArray) where {T<:Array} + project_x = ProjectTo(x) + Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return T(x), Array_pullback +end + +##### +##### `vect` +##### + +@non_differentiable Base.vect() + +# Case of uniform type `T`: the data passes straight through, +# so no projection should be required. +function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} + vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...) + return Base.vect(X...), vect_pullback +end + +# Numbers and arrays are often promoted, to make a uniform vector. +# ProjectTo here reverses this +function rrule( + ::typeof(Base.vect), + X::Vararg{Union{Number,AbstractArray{<:Number}}, N}, +) where {N} + projects = map(ProjectTo, X) + function vect_pullback(ȳ) + X̄ = ntuple(n -> projects[n](ȳ[n]), N) + return (NoTangent(), X̄...) + end + return Base.vect(X...), vect_pullback +end + ##### ##### `reshape` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index cb9a56940..8a0bf7d85 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,3 +1,40 @@ +@testset "constructors" begin + + # We can't use test_rrule here (as it's currently implemented) because the elements of + # the array have arbitrary values. The only thing we can do is ensure that we're getting + # `ZeroTangent`s back, and that the forwards pass produces the correct thing still. + # Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202 + @testset "undef" begin + val, pullback = rrule(Array{Float64}, undef, 5) + @test size(val) == (5, ) + @test val isa Array{Float64, 1} + @test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent()) + end + @testset "from existing array" begin + test_rrule(Array, randn(2, 5)) + test_rrule(Array, Diagonal(randn(5))) + test_rrule(Matrix, Diagonal(randn(5))) + test_rrule(Matrix, transpose(randn(4))) + test_rrule(Array{ComplexF64}, randn(3)) + end +end + +@testset "vect" begin + test_rrule(Base.vect) + @testset "homogeneous type" begin + test_rrule(Base.vect, (5.0, ), (4.0, )) + test_rrule(Base.vect, 5.0, 4.0, 3.0) + test_rrule(Base.vect, randn(2, 2), randn(3, 3)) + end + @testset "inhomogeneous type" begin + test_rrule( + Base.vect, 5.0, 3f0; + atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6", + ) # tolerance due to Float32. + test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) + end +end + @testset "reshape" begin test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10)