diff --git a/Project.toml b/Project.toml index 64b63ffaf..8081d31a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.8.23" +version = "1.0.0-DEV" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -10,8 +10,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.10.12" -ChainRulesTestUtils = "0.7.9" +ChainRulesCore = "1" +ChainRulesTestUtils = "1" Compat = "3.31" FiniteDifferences = "0.12.8" StaticArrays = "1.2" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 34942cfa5..ffd97464c 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -81,10 +81,11 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability sizes = map(size, Xs) # this avoids closing over Xs + project_Xs = map(ProjectTo, Xs) function hcat_pullback(ȳ) dY = unthunk(ȳ) hi = Ref(0) # Ref avoids hi::Core.Box - dXs = map(sizes) do sizeX + dXs = map(project_Xs, sizes) do project, sizeX ndimsX = length(sizeX) lo = hi[] + 1 hi[] += get(sizeX, 2, 1) @@ -95,7 +96,7 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) d > ndimsX ? 1 : (:) end end - if ndimsX > 0 + dX = if ndimsX > 0 # Here InplaceableThunk breaks @inferred, removed for now # InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...])) dY[ind...] @@ -103,6 +104,7 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...) # This is a hack to perhaps avoid GPU scalar indexing sum(view(dY, ind...)) end + return project(dX) end return (NoTangent(), dXs...) end @@ -141,10 +143,11 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) Y = vcat(Xs...) ndimsY = Val(ndims(Y)) sizes = map(size, Xs) + project_Xs = map(ProjectTo, Xs) function vcat_pullback(ȳ) dY = unthunk(ȳ) hi = Ref(0) - dXs = map(sizes) do sizeX + dXs = map(project_Xs, sizes) do project, sizeX ndimsX = length(sizeX) lo = hi[] + 1 hi[] += get(sizeX, 1, 1) @@ -155,12 +158,13 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...) d > ndimsX ? 1 : (:) end end - if ndimsX > 0 + dX = if ndimsX > 0 # InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...)) dY[ind...] else sum(view(dY, ind...)) end + return project(dX) end return (NoTangent(), dXs...) end @@ -195,10 +199,11 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims) cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims) ndimsY = Val(ndims(Y)) sizes = map(size, Xs) + project_Xs = map(ProjectTo, Xs) function cat_pullback(ȳ) dY = unthunk(ȳ) prev = fill(0, _val(ndimsY)) # note that Y always has 1-based indexing, even if X isa OffsetArray - dXs = map(sizes) do sizeX + dXs = map(project_Xs, sizes) do project, sizeX ndimsX = length(sizeX) index = ntuple(ndimsY) do d if d in cdims @@ -210,12 +215,13 @@ function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims) for d in cdims prev[d] += get(sizeX, d, 1) end - if ndimsX > 0 + dX = if ndimsX > 0 # InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...)) dY[index...] else sum(view(dY, index...)) end + return project(dX) end return (NoTangent(), dXs...) end @@ -231,9 +237,10 @@ function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...) cols = size(Y,2) ndimsY = Val(ndims(Y)) sizes = map(size, values) + project_Vs = map(ProjectTo, values) function hvcat_pullback(dY) prev = fill(0, 2) - dXs = map(sizes) do sizeX + dXs = map(project_Vs, sizes) do project, sizeX ndimsX = length(sizeX) index = ntuple(ndimsY) do d if d in (1, 2) @@ -247,7 +254,7 @@ function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...) prev[2] = 0 prev[1] += get(sizeX, 1, 1) end - dY[index...] + project(dY[index...]) end return (NoTangent(), NoTangent(), dXs...) end diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 8de70a7d5..4aa77ef3d 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -24,19 +24,19 @@ function rrule( A::AbstractVecOrMat{<:CommutativeMulNumber}, B::AbstractVecOrMat{<:CommutativeMulNumber}, ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) function times_pullback(ȳ) Ȳ = unthunk(ȳ) - return ( - NoTangent(), - InplaceableThunk( - X̄ -> mul!(X̄, Ȳ, B', true, true), - @thunk(Ȳ * B'), - ), - InplaceableThunk( - X̄ -> mul!(X̄, A', Ȳ, true, true), - @thunk(A' * Ȳ), - ) + dA = InplaceableThunk( + X̄ -> mul!(X̄, Ȳ, B', true, true), + @thunk(project_A(Ȳ * B')), + ) + dB = InplaceableThunk( + X̄ -> mul!(X̄, A', Ȳ, true, true), + @thunk(project_B(A' * Ȳ)), ) + return NoTangent(), dA, dB end return A * B, times_pullback end @@ -46,6 +46,8 @@ function rrule( A::AbstractVector{<:CommutativeMulNumber}, B::AbstractMatrix{<:CommutativeMulNumber}, ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) function times_pullback(ȳ) Ȳ = unthunk(ȳ) @assert size(B, 1) === 1 # otherwise primal would have failed. @@ -53,11 +55,11 @@ function rrule( NoTangent(), InplaceableThunk( X̄ -> mul!(X̄, Ȳ, vec(B'), true, true), - @thunk(Ȳ * vec(B')), + @thunk(project_A(Ȳ * vec(B'))), ), InplaceableThunk( X̄ -> mul!(X̄, A', Ȳ, true, true), - @thunk(A' * Ȳ), + @thunk(project_B(A' * Ȳ)), ) ) end @@ -67,14 +69,16 @@ end function rrule( ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) function times_pullback(ȳ) Ȳ = unthunk(ȳ) return ( NoTangent(), - @thunk(dot(Ȳ, B)'), + @thunk(project_A(dot(Ȳ, B)')), InplaceableThunk( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), - @thunk(A' * Ȳ), + @thunk(project_B(A' * Ȳ)), ) ) end @@ -84,15 +88,17 @@ end function rrule( ::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) function times_pullback(ȳ) Ȳ = unthunk(ȳ) return ( NoTangent(), InplaceableThunk( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), - @thunk(A' * Ȳ), + @thunk(project_B(A' * Ȳ)), ), - @thunk(dot(Ȳ, B)'), + @thunk(project_A(dot(Ȳ, B)')), ) end return A * B, times_pullback @@ -109,27 +115,31 @@ function rrule( B::AbstractVecOrMat{<:CommutativeMulNumber}, z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}, ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + project_z = ProjectTo(z) + # The useful case, mul! fused with + function muladd_pullback_1(ȳ) Ȳ = unthunk(ȳ) matmul = ( InplaceableThunk( dA -> mul!(dA, Ȳ, B', true, true), - @thunk(Ȳ * B'), + @thunk(project_A(Ȳ * B')), ), InplaceableThunk( dB -> mul!(dB, A', Ȳ, true, true), - @thunk(A' * Ȳ), + @thunk(project_B(A' * Ȳ)), ) ) addon = if z isa Bool NoTangent() elseif z isa Number - @thunk(sum(Ȳ)) + @thunk(project_z(sum(Ȳ))) else InplaceableThunk( dz -> sum!(dz, Ȳ; init=false), - @thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)), + @thunk(project_z(sum!(similar(z, eltype(Ȳ)), Ȳ))), ) end (NoTangent(), matmul..., addon) @@ -143,18 +153,22 @@ function rrule( v::AbstractVector{<:CommutativeMulNumber}, z::CommutativeMulNumber, ) + project_ut = ProjectTo(ut) + project_v = ProjectTo(v) + project_z = ProjectTo(z) + # This case is dot(u,v)+z, but would also match signature above. function muladd_pullback_2(ȳ) dy = unthunk(ȳ) ut_thunk = InplaceableThunk( dut -> dut .+= v' .* dy, - @thunk(v' .* dy), + @thunk(project_ut(v' .* dy)), ) v_thunk = InplaceableThunk( dv -> dv .+= ut' .* dy, - @thunk(ut' .* dy), + @thunk(project_v(ut' .* dy)), ) - (NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : dy) + (NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : project_z(dy)) end return muladd(ut, v, z), muladd_pullback_2 end @@ -165,21 +179,25 @@ function rrule( vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber}, z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}, ) + project_u = ProjectTo(u) + project_vt = ProjectTo(vt) + project_z = ProjectTo(z) + # Outer product, just broadcasting function muladd_pullback_3(ȳ) Ȳ = unthunk(ȳ) proj = ( - @thunk(vec(sum(Ȳ .* conj.(vt), dims=2))), - @thunk(vec(sum(u .* conj.(Ȳ), dims=1))'), + @thunk(project_u(vec(sum(Ȳ .* conj.(vt), dims=2)))), + @thunk(project_vt(vec(sum(u .* conj.(Ȳ), dims=1))')), ) addon = if z isa Bool NoTangent() elseif z isa Number - @thunk(sum(Ȳ)) + @thunk(project_z(sum(Ȳ))) else InplaceableThunk( dz -> sum!(dz, Ȳ; init=false), - @thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)), + @thunk(project_z(sum!(similar(z, eltype(Ȳ)), Ȳ))), ) end (NoTangent(), proj..., addon) @@ -202,7 +220,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) ∂A = last(dA_pb(unthunk(dAᵀ))) - ∂B = last(dA_pb(unthunk(dBᵀ))) + ∂B = last(dB_pb(unthunk(dBᵀ))) (NoTangent(), ∂A, ∂B) end @@ -214,6 +232,9 @@ end ##### function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + Y = A \ B function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) @@ -222,9 +243,9 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Ā = -B̄ * Y' Ā = add!!(Ā, (B - A * Y) * B̄' / A') Ā = add!!(Ā, A' \ Y * (Ȳ' - B̄'A)) - Ā + project_A(Ā) end - ∂B = @thunk A' \ Ȳ + ∂B = @thunk project_B(A' \ Ȳ) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 032f9c4dc..4b4ab265b 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -53,9 +53,12 @@ function rrule(::Type{T}, x::Real) where {T<:Complex} return (T(x), Complex_pullback) end function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function Complex_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - return (NoTangent(), real(ΔΩ), imag(ΔΩ)) + return (NoTangent(), project_x(real(ΔΩ)), project_y(imag(ΔΩ))) end return (T(x, y), Complex_pullback) end diff --git a/src/rulesets/Base/evalpoly.jl b/src/rulesets/Base/evalpoly.jl index 419b82ae8..59f5873f4 100644 --- a/src/rulesets/Base/evalpoly.jl +++ b/src/rulesets/Base/evalpoly.jl @@ -13,6 +13,17 @@ if VERSION ≥ v"1.4" end function rrule(::typeof(evalpoly), x, p) + y, ys = _evalpoly_intermediates(x, p) + project_x = ProjectTo(x) + project_p = p isa Tuple ? identity : ProjectTo(p) + function evalpoly_pullback(Δy) + ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) + return NoTangent(), project_x(∂x), project_p(∂p) + end + return y, evalpoly_pullback + end + + function rrule(::typeof(evalpoly), x, p::Vector{<:Matrix}) # does not type infer with ProjectTo y, ys = _evalpoly_intermediates(x, p) function evalpoly_pullback(Δy) ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 31034ecdc..01fd2ff43 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -213,9 +213,11 @@ let end function rrule(::typeof(*), x::Number, y::Number) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function times_pullback(Ω̇) ΔΩ = unthunk(Ω̇) - return (NoTangent(), ΔΩ * y', x' * ΔΩ) + return (NoTangent(), project_x(ΔΩ * y'), project_y(x' * ΔΩ)) end return x * y, times_pullback end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 316ead4f2..c007687b3 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -39,7 +39,6 @@ function rrule( return y, covector_sum_pb end - function rrule( config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=: ) @@ -48,6 +47,8 @@ function rrule( pullbacks = last.(fx_and_pullbacks) + project = ProjectTo(xs) + function sum_pullback(ȳ) call(f, x) = f(x) # we need to broadcast this to handle dims kwarg f̄_and_x̄s = call.(pullbacks, ȳ) @@ -57,8 +58,8 @@ function rrule( else sum(first, f̄_and_x̄s) end - x̄s = map(last, f̄_and_x̄s) - return NoTangent(), f̄, x̄s + x̄s = map(unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks + return NoTangent(), f̄, project(x̄s) end return y, sum_pullback end @@ -118,6 +119,7 @@ end function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:CommutativeMulNumber} y = prod(x; dims=dims) + project_x = ProjectTo(x) # vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims)) function prod_pullback(ȳ) dy = unthunk(ȳ) @@ -132,15 +134,15 @@ function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:Commutativ dx .+= conj.(y ./ x) .* dy end, # Out-of-place versions - @thunk if dims === (:) + @thunk project_x(if dims === (:) ∇prod(x, dy, y) elseif any(iszero, x) # Then, and only then, will ./x lead to NaN vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims)) ∇prod_dims(vald, x, dy, y) # val(Int(dims)) is about 2x faster than Val(Tuple(dims)) else conj.(y ./ x) .* dy - end - ) + end) + ) return (NoTangent(), x_thunk) end return y, prod_pullback diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 897c0e11c..3cb362900 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -7,17 +7,13 @@ function frule((_, Δx, Δy), ::typeof(dot), x, y) end function rrule(::typeof(dot), x::AbstractArray, y::AbstractArray) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function dot_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - xthunk = InplaceableThunk( - dx -> dx .+= reshape(y, axes(x)) .* ΔΩ', - @thunk(reshape(y .* ΔΩ', axes(x))), - ) - ythunk = InplaceableThunk( - dy -> dy .+= reshape(x, axes(y)) .* ΔΩ, - @thunk(reshape(x .* ΔΩ, axes(y))), - ) - return (NoTangent(), xthunk, ythunk) + x̄ = @thunk(project_x(reshape(y .* ΔΩ', axes(x)))) + ȳ = @thunk(project_y(reshape(x .* ΔΩ, axes(y)))) + return (NoTangent(), x̄, ȳ) end return dot(x, y), dot_pullback end @@ -31,13 +27,16 @@ function frule((_, Δx, ΔA, Δy), ::typeof(dot), x::AbstractVector{<:Number}, A end function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) + project_x = ProjectTo(x) + project_A = ProjectTo(A) + project_y = ProjectTo(y) Ay = A * y z = adjoint(x) * Ay function dot_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - dx = @thunk conj(ΔΩ) .* Ay - dA = @thunk ΔΩ .* x .* adjoint(y) - dy = @thunk ΔΩ .* (adjoint(A) * x) + dx = @thunk project_x(conj(ΔΩ) .* Ay) + dA = @thunk project_A(ΔΩ .* x .* adjoint(y)) + dy = @thunk project_y(ΔΩ .* (adjoint(A) * x)) return (NoTangent(), dx, dA, dy) end dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) @@ -45,12 +44,15 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N end function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number}, y::AbstractVector{<:Number}) + project_x = ProjectTo(x) + project_A = ProjectTo(A) + project_y = ProjectTo(y) z = dot(x,A,y) function dot_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused - dA = @thunk Diagonal(ΔΩ .* x .* conj(y)) # calculate N not N^2 elements - dy = @thunk ΔΩ .* conj.(A.diag) .* x + dx = @thunk project_x(conj(ΔΩ) .* A.diag .* y) # A*y is this broadcast, can be fused + dA = @thunk project_A(Diagonal(ΔΩ .* x .* conj(y))) # calculate N not N^2 elements + dy = @thunk project_y(ΔΩ .* conj.(A.diag) .* x) return (NoTangent(), dx, dA, dy) end dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) @@ -67,10 +69,14 @@ end # TODO: support complex vectors function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) + project_a = ProjectTo(a) + project_b = ProjectTo(b) Ω = cross(a, b) function cross_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - return (NoTangent(), @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) + da = @thunk(project_a(cross(b, ΔΩ))) + db = @thunk(project_b(cross(ΔΩ, a))) + return (NoTangent(), da, db) end return Ω, cross_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index d3ba72703..fcb983b49 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -7,10 +7,12 @@ const SquareMatrix{T} = Union{Diagonal{T}, AbstractTriangular{T}} function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} Y = A / B + project_A = ProjectTo(A) + project_B = ProjectTo(B) function slash_pullback(ȳ) Ȳ = unthunk(ȳ) - ∂A = @thunk Ȳ / B' - ∂B = @thunk _unionall_wrapper(T)(-Y' * (Ȳ / B')) + ∂A = @thunk project_A(Ȳ / B') + ∂B = @thunk project_B(-Y' * (Ȳ / B')) return (NoTangent(), ∂A, ∂B) end return Y, slash_pullback @@ -18,10 +20,12 @@ end function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} Y = A \ B + project_A = ProjectTo(A) + project_B = ProjectTo(B) function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) - ∂A = @thunk _unionall_wrapper(T)(-(A' \ Ȳ) * Y') - ∂B = @thunk A' \ Ȳ + ∂A = @thunk project_A(-(A' \ Ȳ) * Y') + ∂B = @thunk project_B(A' \ Ȳ) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback @@ -33,7 +37,8 @@ end # these functions are defined outside the rrule because otherwise type inference breaks # see https://github.com/JuliaLang/julia/issues/40990 -_Diagonal_pullback(ȳ::AbstractMatrix) = return (NoTangent(), diag(ȳ)) +_Diagonal_pullback(ȳ::AbstractMatrix) = return (NoTangent(), diag(ȳ)) # should we emit a warning here? this shouldn't be called if project works right +_Diagonal_pullback(ȳ::Diagonal) = return (NoTangent(), diag(ȳ)) function _Diagonal_pullback(ȳ::Tangent) # TODO: Assert about the primal type in the Tangent, It should be Diagonal # infact it should be exactly the type of `Diagonal(d)` @@ -82,9 +87,13 @@ function _diagm_back(p, ȳ) end function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) + project_D = ProjectTo(D) + project_V = ProjectTo(V) function times_pullback(ȳ) Ȳ = unthunk(ȳ) - return (NoTangent(), @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) + dD = @thunk(project_D(Diagonal(Ȳ .* V))) + dV = @thunk(project_V(D * Ȳ)) + return (NoTangent(), dD, dV) end return D * V, times_pullback end @@ -95,10 +104,12 @@ end # these functions are defined outside the rrule because otherwise type inference breaks # see https://github.com/JuliaLang/julia/issues/40990 -Adjoint_mat_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) -Adjoint_mat_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) -Adjoint_mat_pullback(ȳ::AbstractThunk) = return Adjoint_mat_pullback(unthunk(ȳ)) +_Adjoint_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) +_Adjoint_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(adjoint(ȳ))) +_Adjoint_mat_pullback(ȳ::AbstractThunk, proj) = return _Adjoint_mat_pullback(unthunk(ȳ), proj) function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) + project_A = ProjectTo(A) + Adjoint_mat_pullback(ȳ) = _Adjoint_mat_pullback(ȳ, project_A) return Adjoint(A), Adjoint_mat_pullback end @@ -109,11 +120,13 @@ function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) return Adjoint(A), _Adjoint_vec_pullback end -_adjoint_mat_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) -_adjoint_mat_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) -_adjoint_mat_pullback(ȳ::AbstractThunk) = return _adjoint_mat_pullback(unthunk(ȳ)) +_adjoint_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) +_adjoint_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(adjoint(ȳ))) +_adjoint_mat_pullback(ȳ::AbstractThunk, proj) = return _adjoint_mat_pullback(unthunk(ȳ), proj) function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) - return adjoint(A), _adjoint_mat_pullback + project_A = ProjectTo(A) + adjoint_mat_pullback(ȳ) = _adjoint_mat_pullback(ȳ, project_A) + return adjoint(A), adjoint_mat_pullback end _adjoint_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) @@ -129,11 +142,13 @@ end # these functions are defined outside the rrule because otherwise type inference breaks # see https://github.com/JuliaLang/julia/issues/40990 -_Transpose_mat_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) -_Transpose_mat_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), Transpose(ȳ)) -_Transpose_mat_pullback(ȳ::AbstractThunk) = return _Transpose_mat_pullback(unthunk(ȳ)) +_Transpose_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) +_Transpose_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(Transpose(ȳ))) +_Transpose_mat_pullback(ȳ::AbstractThunk, proj) = return _Transpose_mat_pullback(unthunk(ȳ), proj) function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) - return Transpose(A), _Transpose_mat_pullback + project_A = ProjectTo(A) + Transpose_mat_pullback(ȳ) = _Transpose_mat_pullback(ȳ, project_A) + return Transpose(A), Transpose_mat_pullback end _Transpose_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) @@ -143,11 +158,13 @@ function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) return Transpose(A), _Transpose_vec_pullback end -_transpose_mat_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) -_transpose_mat_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), transpose(ȳ)) -_transpose_mat_pullback(ȳ::AbstractThunk) = return _transpose_mat_pullback(unthunk(ȳ)) +_transpose_mat_pullback(ȳ::Tangent, proj) = (NoTangent(), proj(ȳ.parent)) +_transpose_mat_pullback(ȳ::AbstractVecOrMat, proj) = (NoTangent(), proj(transpose(ȳ))) +_transpose_mat_pullback(ȳ::AbstractThunk, proj) = return _transpose_mat_pullback(unthunk(ȳ), proj) function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) - return transpose(A), _transpose_mat_pullback + project_A = ProjectTo(A) + transpose_mat_pullback(ȳ) = _transpose_mat_pullback(ȳ, project_A) + return transpose(A), transpose_mat_pullback end _transpose_vec_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index f223ac63d..573d4fcf4 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -8,8 +8,9 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) + project_A = ProjectTo(A) @inline function HermOrSym_pullback(ΔΩ) - return (NoTangent(), _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) + return (NoTangent(), project_A(_symherm_back(typeof(Ω), ΔΩ, uplo)), NoTangent()) end return Ω, HermOrSym_pullback end diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 9e10aacab..3d8ad923f 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -30,22 +30,6 @@ end _extract_imag(x) = complex(0, imag(x)) -""" - _unionall_wrapper(T::Type) -> UnionAll - -Return the most general `UnionAll` type union associated with the concrete type `T`. - -# Example -```julia -julia> _unionall_wrapper(typeof(Diagonal(1:3))) -Diagonal - -julia> _unionall_wrapper(typeof(Symmetric(randn(3, 3)))) -Symmetric -```` -""" -_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper - """ WithSomeZeros{T} diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index eb8e8cd0f..cb9a56940 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,5 +1,5 @@ @testset "reshape" begin - test_rrule(reshape, rand(4, 5), (2, 10) ⊢ NoTangent()) + test_rrule(reshape, rand(4, 5), (2, 10)) test_rrule(reshape, rand(4, 5), 2, 10) test_rrule(reshape, rand(4, 5), 2, :) end @@ -36,22 +36,29 @@ end test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3); check_inferred=VERSION>v"1.1") test_rrule(hcat, rand(), rand(1,2), rand(1,2,1); check_inferred=VERSION>v"1.1") test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2); check_inferred=VERSION>v"1.1") + + # mix types + test_rrule(hcat, rand(2, 2), rand(2, 2)'; check_inferred=VERSION>v"1.1") end @testset "reduce hcat" begin mats = [randn(3, 2), randn(3, 1), randn(3, 3)] - test_rrule(reduce, hcat ⊢ NoTangent(), mats) + test_rrule(reduce, hcat, mats) vecs = [rand(3) for _ in 1:4] - test_rrule(reduce, hcat ⊢ NoTangent(), vecs) + test_rrule(reduce, hcat, vecs) mix = AbstractVecOrMat[rand(4,2), rand(4)] # this is weird, but does hit the fast path - test_rrule(reduce, hcat ⊢ NoTangent(), mix) + test_rrule(reduce, hcat, mix) adjs = vec([randn(2, 4), randn(1, 4), randn(3, 4)]') # not a Vector - # test_rrule(reduce, hcat ⊢ NoTangent(), adjs ⊢ map(m -> rand(size(m)), adjs)) + # test_rrule(reduce, hcat, adjs ⊢ map(m -> rand(size(m)), adjs)) dy = 1 ./ reduce(hcat, adjs) @test rrule(reduce, hcat, adjs)[2](dy)[3] ≈ rrule(reduce, hcat, collect.(adjs))[2](dy)[3] + + # mix types + mats = [randn(2, 2), rand(2, 2)'] + test_rrule(reduce, hcat, mats; check_inferred=VERSION>v"1.1") end @testset "vcat" begin @@ -59,17 +66,20 @@ end test_rrule(vcat, rand(), rand(); check_inferred=VERSION>v"1.1") test_rrule(vcat, rand(), rand(3), rand(3,1,1); check_inferred=VERSION>v"1.1") test_rrule(vcat, rand(3,1,2), rand(4,1,2); check_inferred=VERSION>v"1.1") + + # mix types + test_rrule(vcat, rand(2, 2), rand(2, 2)'; check_inferred=VERSION>v"1.1") end @testset "reduce vcat" begin mats = [randn(2, 4), randn(1, 4), randn(3, 4)] - test_rrule(reduce, vcat ⊢ NoTangent(), mats) + test_rrule(reduce, vcat, mats) vecs = [rand(2), rand(3), rand(4)] - test_rrule(reduce, vcat ⊢ NoTangent(), vecs) + test_rrule(reduce, vcat, vecs) mix = AbstractVecOrMat[rand(4,1), rand(4)] - test_rrule(reduce, vcat ⊢ NoTangent(), mix) + test_rrule(reduce, vcat, mix) end @testset "cat" begin @@ -77,16 +87,21 @@ end test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),), check_inferred=VERSION>v"1.1") test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],), check_inferred=VERSION>v"1.1") test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any} + + test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,), check_inferred=VERSION>v"1.1") end @testset "hvcat" begin - test_rrule(hvcat, 2 ⊢ NoTangent(), rand(ComplexF64, 6)...; check_inferred=VERSION>v"1.1") - test_rrule(hvcat, (2, 1) ⊢ NoTangent(), rand(), rand(1,1), rand(2,2); check_inferred=VERSION>v"1.1") - test_rrule(hvcat, 1 ⊢ NoTangent(), rand(3)' ⊢ rand(1,3), transpose(rand(3)) ⊢ rand(1,3); check_inferred=VERSION>v"1.1") - test_rrule(hvcat, 1 ⊢ NoTangent(), rand(0,3), rand(2,3), rand(1,3,1); check_inferred=VERSION>v"1.1") + test_rrule(hvcat, 2, rand(ComplexF64, 6)...; check_inferred=VERSION>v"1.1") + test_rrule(hvcat, (2, 1), rand(), rand(1,1), rand(2,2); check_inferred=VERSION>v"1.1") + test_rrule(hvcat, 1, rand(3)' ⊢ rand(1,3), transpose(rand(3)) ⊢ rand(1,3); check_inferred=VERSION>v"1.1") + test_rrule(hvcat, 1, rand(0,3), rand(2,3), rand(1,3,1); check_inferred=VERSION>v"1.1") + + # mix types (adjoint and transpose) + test_rrule(hvcat, 1, rand(3)', transpose(rand(3)) ⊢ rand(1,3); check_inferred=VERSION>v"1.1") end @testset "fill" begin test_rrule(fill, 44.0, 4; check_inferred=false) - test_rrule(fill, 2.0, (3, 3, 3) ⊢ NoTangent()) + test_rrule(fill, 2.0, (3, 3, 3)) end diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index e4982855b..cad583f16 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -109,13 +109,13 @@ for n in 3:5, m in 3:5 A = randn(m, n) B = randn(m, n) - test_rrule(f, A, B) + test_rrule(f, A, B; check_inferred=false) # ChainRulesCore #407 end end @testset "Vector" begin x = randn(10) y = randn(10) - test_rrule(f, x, y) + test_rrule(f, x, y; check_inferred=false) # ChainRulesCore #407 end if f == (\) @testset "Matrix $f Vector" begin @@ -128,6 +128,10 @@ Y = randn(10, 4) test_rrule(f, x, Y; output_tangent=Transpose(rand(4))) end + else + A = rand(2, 4) + B = rand(4, 4) + test_rrule(f, A, B; check_inferred=false) # ChainRulesCore #407 end end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 167c477cb..872b82c17 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -69,8 +69,8 @@ test_scalar(Complex, randn()) test_scalar(Complex, randn(ComplexF64)) - test_frule(Complex, randn(), randn()) - test_rrule(Complex, randn(), randn()) + test_frule(Complex, randn(), randn(Float32)) + test_rrule(Complex, randn(), randn(Float32); rtol=1.0e-7, atol=1.0e-7) end @testset "*(x, y) (scalar)" begin @@ -110,19 +110,11 @@ @testset "identity" for T in (Float64, ComplexF64) test_frule(identity, randn(T)) test_frule(identity, randn(T, 4)) - test_frule( - identity, - #Tuple(randn(T, 3)) ⊢ Tangent{Tuple{T, T, T}}(randn(T, 3)...) - Tuple(randn(T, 3)) - ) + test_frule(identity, Tuple(randn(T, 3))) test_rrule(identity, randn(T)) test_rrule(identity, randn(T, 4)) - test_rrule( - identity, - Tuple(randn(T, 3)) ⊢ Tangent{Tuple{T, T, T}}(randn(T, 3)...); - output_tangent = Tangent{Tuple{T, T, T}}(randn(T, 3)...) - ) + test_rrule(identity, Tuple(randn(T, 3))) end @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im) diff --git a/test/rulesets/Base/evalpoly.jl b/test/rulesets/Base/evalpoly.jl index 5d6bb76d8..34e0a7f46 100644 --- a/test/rulesets/Base/evalpoly.jl +++ b/test/rulesets/Base/evalpoly.jl @@ -25,4 +25,7 @@ VERSION ≥ v"1.4" && @testset "evalpoly" begin test_rrule(evalpoly, x, p) test_rrule(evalpoly, x, Tuple(p)) end + + # mixed type inputs + test_rrule(evalpoly, rand(ComplexF64), rand(3)) end diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 4b78608ca..653fd3f4e 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -151,11 +151,11 @@ const FASTABLE_AST = quote # Issue #233 @test frule((ZeroTangent(), Δx, Δy), f, x, 2) isa Tuple{T, T} _, ∂x, ∂y = rrule(f, x, 2)[2](Δz) - @test (∂x, ∂y) isa Tuple{T, T} + @test (∂x, ∂y) isa Tuple{T, Float64} @test frule((ZeroTangent(), Δx, Δy), f, 2, y) isa Tuple{T, T} _, ∂x, ∂y = rrule(f, 2, y)[2](Δz) - @test (∂x, ∂y) isa Tuple{T, T} + @test (∂x, ∂y) isa Tuple{Float64, T} end end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 95e0c2d6a..0ae4de338 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -7,14 +7,14 @@ test_rrule(getindex, x, 2, 1) test_rrule(getindex, x, 2, 2) - test_rrule(getindex, x, CartesianIndex(2, 3) ⊢ NoTangent()) + test_rrule(getindex, x, CartesianIndex(2, 3)) end @testset "slice/index postions" begin test_rrule(getindex, x, 2:3) test_rrule(getindex, x, 3:-1:2) - test_rrule(getindex, x, [3,2] ⊢ NoTangent()) - test_rrule(getindex, x, [2,3] ⊢ NoTangent()) + test_rrule(getindex, x, [3,2]) + test_rrule(getindex, x, [2,3]) test_rrule(getindex, x, 1:2, 2:3) test_rrule(getindex, x, (:), 2:3) @@ -30,21 +30,21 @@ end @testset "masking" begin - test_rrule(getindex, x, trues(size(x)) ⊢ NoTangent()) - test_rrule(getindex, x, trues(length(x)) ⊢ NoTangent()) + test_rrule(getindex, x, trues(size(x))) + test_rrule(getindex, x, trues(length(x))) mask = falses(size(x)) mask[2,3] = true mask[1,2] = true - test_rrule(getindex, x, mask ⊢ NoTangent()) + test_rrule(getindex, x, mask) - test_rrule(getindex, x, [true, false] ⊢ NoTangent(), (:)) + test_rrule(getindex, x, [true, false], (:)) end @testset "By position with repeated elements" begin - test_rrule(getindex, x, [2, 2] ⊢ NoTangent()) - test_rrule(getindex, x, [2, 2, 2] ⊢ NoTangent()) - test_rrule(getindex, x, [2,2] ⊢ NoTangent(), [3,3] ⊢ NoTangent()) + test_rrule(getindex, x, [2, 2]) + test_rrule(getindex, x, [2, 2, 2]) + test_rrule(getindex, x, [2,2], [3,3]) end end end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 54828da0f..37b9c2e0b 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -16,7 +16,7 @@ @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64) x = randn(T, sizes[1:N]...) test_frule(sum, abs2, x; fkwargs=(;dims=dims)) - test_rrule(sum, abs2 ⊢ NoTangent(), x; fkwargs=(;dims=dims)) + test_rrule(sum, abs2, x; fkwargs=(;dims=dims)) end end end # sum abs2 @@ -26,7 +26,8 @@ test_rrule(sum, abs, [-4.0, 2.0, 2.0]) test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0]) - test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]) # array of arrays + # inference fails for array of arrays + test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) # dims kwarg test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) @@ -45,11 +46,9 @@ _, pb = rrule(ADviaRuleConfig(), sum, abs, @SVector[1.0, -3.0]) @test pb(1.0) isa Tuple{NoTangent, NoTangent, SVector{2, Float64}} - - # For structured sparse matrixes we screw it up, getting dense back - # see https://github.com/JuliaDiff/ChainRules.jl/issues/232 etc + # make sure we preserve type for Diagonal _, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0])) - @test_broken pb(1.0)[3] isa Diagonal + @test pb(1.0)[3] isa Diagonal end @testset "prod" begin @@ -70,16 +69,18 @@ if ndims(x) == 3 xp = PermutedDimsArray(x, (3,2,1)) # not a StridedArray - xpdot, xpbar = permutedims(rand(T, sz), (3,2,1)), permutedims(rand(T, sz), (3,2,1)) - test_rrule(prod, xp ⊢ xpbar; fkwargs=(dims=dims,), check_inferred=true) + test_rrule(prod, xp; fkwargs=(dims=dims,), check_inferred=true) end end @testset "structured wrappers" begin # Adjoint -- like PermutedDimsArray this may actually be used xa = adjoint(rand(T,4,4)) - test_rrule(prod, xa ⊢ rand(T,4,4)) - test_rrule(prod, xa ⊢ rand(T,4,4), fkwargs=(dims=2,)) + test_rrule(prod, xa) + test_rrule(prod, xa, fkwargs=(dims=2,)) + # use Adjoint for tangent + test_rrule(prod, xa ⊢ rand(T,4,4)') + test_rrule(prod, xa ⊢ rand(T,4,4)', fkwargs=(dims=2,)) @test unthunk(rrule(prod, adjoint(rand(T,3,3)))[2](1.0)[2]) isa Matrix @test unthunk(rrule(prod, adjoint(rand(T,3,3)), dims=1)[2](ones(1,3))[2]) isa Matrix @@ -92,17 +93,19 @@ hcat(1); rtol=T <: Complex ? 2eps() : 0.0, ) - @test unthunk(rrule(prod, Diagonal(ones(T,2)), dims=1)[2](ones(1,2))[2]) == [0 1; 1 0] + @test unthunk(rrule(prod, Diagonal(ones(T,2)), dims=1)[2](ones(1,2))[2]) == Diagonal([0.0 1; 1 0]) # Triangular -- almost equally stupud @test iszero(unthunk(rrule(prod, UpperTriangular(rand(T,3,3)))[2](1.0)[2])) - @test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == [0 0; 1 0] + @test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0]) # Symmetric -- at least this doesn't have zeros, still an unlikely combination + xs = Symmetric(rand(T,4,4)) - @test_skip test_rrule(prod, xs ⊢ rand(T,4,4)) - @test_skip test_rrule(prod, xs ⊢ rand(T,4,4), fkwargs=(dims=2,)) @test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4] + # TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475 + @test_skip test_rrule(prod, xs) + @test_skip test_rrule(prod, xs, fkwargs=(dims=2,)) end end @testset "Array{Float32}, no zero entries" begin diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index fe80e2e13..da6fbf0d4 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -26,6 +26,13 @@ test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) end + @testset "different types" begin + test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2)) + test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2)) + + # Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407 + test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false) + end end @testset "cross" begin @@ -33,6 +40,9 @@ test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) test_rrule(cross, randn(3), randn(3)) # No complex support for rrule(cross,... + + # mix types + test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7) end @testset "pinv" begin @testset "$T" for T in (Float64, ComplexF64) @@ -53,7 +63,7 @@ end @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) - test_frule(pinv, F(randn(T, 3)) ⊢ F(randn(T, 3))) + test_frule(pinv, F(randn(T, 3))) check_inferred = VERSION ≥ v"1.5" test_rrule(pinv, F(randn(T, 3)); check_inferred=check_inferred) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 291af25bd..982141568 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -390,7 +390,7 @@ end D = Diagonal(rand(5) .+ 0.1) C = cholesky(D) test_rrule( - cholesky, D ⊢ Diagonal(randn(5)), Val(false) ⊢ NoTangent(); + cholesky, D ⊢ Diagonal(randn(5)), Val(false); output_tangent=Tangent{typeof(C)}(factors=Diagonal(randn(5))) ) end diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 752c5c89e..07e80c2cf 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -49,7 +49,7 @@ PermutedDimsArray(x, (1,2,3)) end @test !(xp isa StridedArray) - test_rrule(fnorm, xp ⊢ rand(T, size(xp))) + test_rrule(fnorm, xp) end T == Float64 && ndims(x) == 1 && @testset "Integer input" begin x = [1,2,3] @@ -92,9 +92,9 @@ elseif x isa Array{T,3} PermutedDimsArray(x, (1,2,3)) end - @test !(xp isa StridedArray) - test_frule(norm, xp ⊢ rand(T, size(xp))) - test_rrule(norm, xp ⊢ rand(T, size(xp))) # rand_tangent does not work here because eltype(xp)==Int + @assert !(xp isa StridedArray) + test_frule(norm, xp) + test_rrule(norm, xp) end end @testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index fb4335975..f2c3bd7c9 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -2,19 +2,18 @@ @testset "/ and \\ on Square Matrixes" begin @testset "//, $T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) RHS = T(randn(T == Diagonal ? 10 : (10, 10))) - test_rrule(/, randn(5, 10), RHS) + test_rrule(/, randn(Float32, 5, 10), RHS; rtol = 1.0e-4, atol = 1.0e-4) end @testset "\\ $T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular) LHS = T(randn(T == Diagonal ? 10 : (10, 10))) - test_rrule(\, LHS, randn(10)) - test_rrule(\, LHS, randn(10, 10)) + test_rrule(\, LHS, randn(Float32, 10); rtol = 1.0e-4, atol = 1.0e-4) + test_rrule(\, LHS, randn(Float32, 10, 10); rtol = 1.0e-4, atol = 1.0e-4) end end @testset "Diagonal" begin N = 3 - test_rrule(Diagonal, randn(N); output_tangent=randn(N, N)) D = Diagonal(randn(N)) test_rrule(Diagonal, randn(N); output_tangent=D) # Concrete type instead of UnionAll @@ -33,7 +32,7 @@ end @testset "::Diagonal * ::AbstractVector" begin N = 3 - test_rrule(*, Diagonal(randn(N)), randn(N)) + test_rrule(*, Diagonal(randn(Float32, N)), randn(N); rtol = 1.0e-4, atol = 1.0e-4) end @testset "diag" begin N = 7 @@ -100,7 +99,7 @@ Ȳ_mat = randn(T, m, n) Ȳ_composite = Tangent{typeof(Y)}(parent=collect(f(Ȳ_mat))) - test_rrule(f, A; output_tangent=Ȳ_mat) + test_rrule(f, A; output_tangent=Ȳ_mat, check_inferred=false) # ChainRulesCore #407 _, pb = rrule(f, A) @test pb(Ȳ_mat) == pb(Ȳ_composite) @@ -120,14 +119,12 @@ @testset "$f(::Adjoint{$T, Vector{$T})" begin a = randn(T, n)' - ā = randn(T, n)' - test_rrule(f, a ⊢ ā; output_tangent=randn(T, n)) + test_rrule(f, a; output_tangent=randn(T, n)) end @testset "$f(::Transpose{$T, Vector{$T})" begin a = transpose(randn(T, n)) - ā = transpose(randn(T, n)) - test_rrule(f, a ⊢ ā; output_tangent=randn(T, n)) + test_rrule(f, a; output_tangent=randn(T, n)) end end @testset "$T" for T in (UpperTriangular, LowerTriangular) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 394c04577..a007d1272 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -22,7 +22,7 @@ # so we need to test this more carefully below check_inferred=false, ) - if check_inferred + if check_inferred && false # ChainRulesCore #407 @maybe_inferred (function (SymHerm, x, ΔΩ, ::Val) return rrule(SymHerm, x, uplo)[2](ΔΩ) end)(SymHerm, x, ΔΩ, Val(uplo)) @@ -36,7 +36,7 @@ check_inferred=false, output_tangent = ΔΩ, ) - if check_inferred + if check_inferred && false # ChainRulesCore #407 @maybe_inferred (function (SymHerm, x, ΔΩ, ::Val) return rrule(SymHerm, x, uplo)[2](ΔΩ) end)(SymHerm, x, ΔΩ, Val(uplo)) @@ -53,9 +53,9 @@ x = SymHerm(randn(T, 3, 3), uplo) test_rrule(f, x) - # intentionally specifying tangents here to test both Matrix and SymHerm tangents + # intentionally specifying tangents here to test both SymHerm (default) and Matrix + test_frule(f, x) test_frule(f, x ⊢ randn(T, 3, 3)) - test_frule(f, x ⊢ SymHerm(randn(T, 3, 3), uplo)) end # symmetric/hermitian eigendecomposition follows the sign convention @@ -477,7 +477,7 @@ frule((ZeroTangent(), ΔA), cos, A)[2], ) # not exact because evaluated in a more efficient way - @test ∂Y_ad ≈ ∂Y_ad2 + ChainRulesTestUtils.test_approx(∂Y_ad, ∂Y_ad2) end end diff --git a/test/runtests.jl b/test/runtests.jl index 0670230e4..5570a1bdb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,6 @@ using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm using Compat: hasproperty, only, cispi, eachcol using FiniteDifferences -using FiniteDifferences: rand_tangent using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot @@ -14,10 +13,6 @@ using StaticArrays using Statistics using Test -# Transitional feature, see -# https://juliadiff.org/ChainRulesTestUtils.jl/dev/api.html#ChainRulesTestUtils.enable_tangent_transform! -ChainRulesTestUtils.enable_tangent_transform!(Thunk) - Random.seed!(1) # Set seed that all testsets should reset to. function include_test(path)