Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.4.0"
version = "1.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
39 changes: 39 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
#####
##### constructors
#####

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end

#####
##### `vect`
#####

@non_differentiable Base.vect()

# Case of uniform type `T`: the data passes straight through,
# so no projection should be required.
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
return Base.vect(X...), vect_pullback
end

# Numbers and arrays are often promoted, to make a uniform vector.
# ProjectTo here reverses this
function rrule(
::typeof(Base.vect),
X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
) where {N}
projects = map(ProjectTo, X)
function vect_pullback(ȳ)
= ntuple(n -> projects[n](ȳ[n]), N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering whether projects[n] gets handled well. @code_warntype seems happy... and my attempts to make things easier to unroll all make it slower:

julia> @btime rrule(Base.vect, 1,2,3)[2]($(rand(3)))
  28.684 ns (1 allocation: 112 bytes)
(NoTangent(), 0.7437540971290453, 0.6835525631785602, 0.29678387383966687)

julia> @btime rrule(Base.vect, 1+im,2+im,3+im)[2]($(rand(3)))
  235.140 ns (6 allocations: 416 bytes)
(NoTangent(), 0.9212083670665245 + 0.0im, 0.989459216141123 + 0.0im, 0.8454719840778347 + 0.0im)

julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
  609.760 ns (6 allocations: 320 bytes)
(NoTangent(), 0.2914057312235363, 0.23309219863512798 + 0.0im, 0.08023319383991401)

julia> struct StaticGetter{i} end; @inline (::StaticGetter{i})(v) where {i} = v[i]; # from Zygote

julia> function rrule(
            ::typeof(Base.vect),
            X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
        ) where {N}
            valN = Val(N)
            projects = map(ProjectTo, X)
            function vect_pullback(ȳ)
                X̄ = ntuple(n -> StaticGetter{n}()(projects)(ȳ[n]), valN)
                return (NoTangent(), X̄...)
            end
            return Base.vect(X...), vect_pullback
        end
rrule (generic function with 723 methods)

julia> @btime rrule(Base.vect, 1, 2+3im, 4.0)[2]($(rand(3)))
  1.442 μs (11 allocations: 448 bytes)
(NoTangent(), 0.7692288886268137, 0.39993377443044065 + 0.0im, 0.6341039234276757)

return (NoTangent(), X̄...)
end
return Base.vect(X...), vect_pullback
end

#####
##### `reshape`
#####
Expand Down
37 changes: 37 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
@testset "constructors" begin

# We can't use test_rrule here (as it's currently implemented) because the elements of
# the array have arbitrary values. The only thing we can do is ensure that we're getting
# `ZeroTangent`s back, and that the forwards pass produces the correct thing still.
# Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202
@testset "undef" begin
val, pullback = rrule(Array{Float64}, undef, 5)
@test size(val) == (5, )
@test val isa Array{Float64, 1}
@test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent())
end
@testset "from existing array" begin
test_rrule(Array, randn(2, 5))
test_rrule(Array, Diagonal(randn(5)))
test_rrule(Matrix, Diagonal(randn(5)))
test_rrule(Matrix, transpose(randn(4)))
test_rrule(Array{ComplexF64}, randn(3))
end
end

@testset "vect" begin
test_rrule(Base.vect)
@testset "homogeneous type" begin
test_rrule(Base.vect, (5.0, ), (4.0, ))
test_rrule(Base.vect, 5.0, 4.0, 3.0)
test_rrule(Base.vect, randn(2, 2), randn(3, 3))
end
@testset "inhomogeneous type" begin
test_rrule(
Base.vect, 5.0, 3f0;
atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6",
) # tolerance due to Float32.
test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false)
end
end

@testset "reshape" begin
test_rrule(reshape, rand(4, 5), (2, 10))
test_rrule(reshape, rand(4, 5), 2, 10)
Expand Down