Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.4.0"
version = "1.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
37 changes: 37 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
#####
##### constructors
#####

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end

#####
##### `vect`
#####

@non_differentiable Base.vect()

# Don't worry about projection here. The data passes straight through, so if a cotangent has
# the wrong type for some reason, it must be the fault of another rule somewhere.
function rrule(::typeof(Base.vect), X...)
function vect_pullback(ȳ)
= ntuple(n -> ȳ[n], length(X))
return (NoTangent(), X̄...)
end
return Base.vect(X...), vect_pullback
end

# # Edge case: Numbers get promoted to other numbers, so we need to project.
# function rrule(::typeof(Base.vect), X::Number...)
# project
# function vect_pullback(ȳ)
# X̄ = ntuple(n -> )
# end
# return Base.vect(X...), vect_pullback
# end

#####
##### `reshape`
#####
Expand Down
26 changes: 26 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
@testset "constructors" begin

# We can't use test_rrule here (as it's currently implemented) because the elements of
# the array have arbitrary values. The only thing we can do is ensure that we're getting
# `ZeroTangent`s back, and that the forwards pass produces the correct thing still.
@testset "undef" begin
val, pullback = rrule(Array{Float64}, undef, 5)
@test size(val) == (5, )
@test val isa Array{Float64, 1}
@test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent())
end
@testset "from existing array" begin
test_rrule(Array, randn(2, 5))
test_rrule(Array, Diagonal(randn(5)))
test_rrule(Matrix, Diagonal(randn(5)))
test_rrule(Matrix, transpose(randn(4)))
test_rrule(Array{ComplexF64}, randn(3))
end
end

@testset "vect" begin
test_rrule(Base.vect)
test_rrule(Base.vect, 5.0, 4.0, 3.0)
test_rrule(Base.vect, randn(2, 2), randn(3, 3); check_inferred=false)
end

@testset "reshape" begin
test_rrule(reshape, rand(4, 5), (2, 10))
test_rrule(reshape, rand(4, 5), 2, 10)
Expand Down