Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
60f4efb
Add rules for evalpoly
sethaxen May 8, 2020
25af803
Rename pullback
sethaxen May 8, 2020
d0a95d8
Apply suggestions from code review
sethaxen May 8, 2020
b8f82cf
Make backx its own method for the future
sethaxen May 8, 2020
bbdb729
Update UniformScaling to-dos
sethaxen May 8, 2020
dce3b7a
Add matrix poly tests
sethaxen May 8, 2020
cde6dcb
Use correct indices
sethaxen May 9, 2020
9bb10a0
Reorganize
sethaxen May 9, 2020
77bdea8
Use generated functions with fallbacks
sethaxen May 9, 2020
b902b3e
Reimplement rules for matrices
sethaxen May 9, 2020
5e77b3e
Deactivate complex tests
sethaxen May 9, 2020
e45cb05
Add generated functions for tuple case
sethaxen May 10, 2020
0ecdcba
Add comment
sethaxen May 10, 2020
100b814
Rename defs to exs
sethaxen May 10, 2020
5075354
Refactor and test fallbacks
sethaxen May 10, 2020
87268c9
Simplify indexing
sethaxen May 10, 2020
2c3d1fc
Don't store output as an intermediate
sethaxen May 10, 2020
7d69336
Support scalar x with matrix pi
sethaxen May 10, 2020
89d4144
Make extensible for other ps
sethaxen May 10, 2020
9bc0e63
Move fallback tests under rrule tests
sethaxen May 10, 2020
6a6212a
Reorder args and remove unnecessary product
sethaxen May 10, 2020
852a2a0
Eliminate unneeded mul and reorganize
sethaxen May 10, 2020
b291696
Remove unnecessary product
sethaxen May 10, 2020
0748723
Fix length of ys and wrap lines
sethaxen May 10, 2020
d3f93bc
Place final ∂yi to in loop
sethaxen May 11, 2020
06acb23
Keep other rules consistent with vector
sethaxen May 11, 2020
3c2d54b
Merge branch 'master' into evalpoly
sethaxen Jun 30, 2020
758236a
Unify tests
sethaxen Jun 30, 2020
899c0cd
Increment version number
sethaxen Jun 30, 2020
24700eb
Try equality check outside of tuple
sethaxen Jun 30, 2020
5df7e9e
Approximate check due to muladd
sethaxen Jun 30, 2020
b56e856
Approximate check scalar output too
sethaxen Jun 30, 2020
12b5458
Decrement version number
sethaxen Jun 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.1"
version = "0.7.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
143 changes: 143 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,146 @@ function rrule(::typeof(identity), x)
end
return (x, identity_pullback)
end

#####
##### `evalpoly`
#####

if VERSION ≥ v"1.4"
function frule((_, Δx, Δp), ::typeof(evalpoly), x, p)
N = length(p)
@inbounds y = p[N]
Δy = Δp[N]
@inbounds for i in (N - 1):-1:1
Δy = muladd(Δx, y, muladd(x, Δy, Δp[i]))
y = muladd(x, y, p[i])
end
return y, Δy
end

function rrule(::typeof(evalpoly), x, p)
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NO_FIELDS, ∂x, ∂p
end
return y, evalpoly_pullback
end

# evalpoly but storing intermediates
function _evalpoly_intermediates(x, p::Tuple)
return if @generated
N = length(p.parameters)
exs = []
vars = []
ex = :(p[$N])
for i in 1:(N - 1)
yi = Symbol("y", i)
push!(vars, yi)
push!(exs, :($yi = $ex))
ex = :(muladd(x, $yi, p[$(N - i)]))
end
push!(exs, :(y = $ex))
Expr(:block, exs..., :(y, ($(vars...),)))
else
_evalpoly_intermediates_fallback(x, p)
end
end
function _evalpoly_intermediates_fallback(x, p::Tuple)
N = length(p)
y = p[N]
ys = (y, ntuple(N - 2) do i
return y = muladd(x, y, p[N - i])
end...)
y = muladd(x, y, p[1])
return y, ys
end
function _evalpoly_intermediates(x, p)
N = length(p)
@inbounds yn = one(x) * p[N]
ys = similar(p, typeof(yn), N - 1)
@inbounds ys[1] = yn
@inbounds for i in 2:(N - 1)
ys[i] = muladd(x, ys[i - 1], p[N - i + 1])
end
@inbounds y = muladd(x, ys[N - 1], p[1])
return y, ys
end

# TODO: Handle following cases
# 1) x is a UniformScaling, pᵢ is a matrix
# 2) x is a matrix, pᵢ is a UniformScaling
@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi'
@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x)
@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi))
@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x

@inline _evalpoly_backp(pi, ∂yi) = ∂yi

function _evalpoly_back(x, p::Tuple, ys, Δy)
return if @generated
exs = []
vars = []
N = length(p.parameters)
for i in 2:(N - 1)
∂pi = Symbol("∂p", i)
push!(vars, ∂pi)
push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi)))
push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi)))
push!(exs, :(∂yi = x′ * ∂yi))
end
push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN
Expr(
:block,
:(x′ = x'),
:(∂yi = Δy),
:(∂p1 = _evalpoly_backp(p[1], ∂yi)),
:(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)),
:(∂yi = x′ * ∂yi),
exs...,
:(∂p = (∂p1, $(vars...))),
:(∂x, Composite{typeof(p),typeof(∂p)}(∂p)),
)
else
_evalpoly_back_fallback(x, p, ys, Δy)
end
end
function _evalpoly_back_fallback(x, p::Tuple, ys, Δy)
x′ = x'
∂yi = Δy
N = length(p)
∂p1 = _evalpoly_backp(p[1], ∂yi)
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p = (
∂p1,
ntuple(N - 2) do i
∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi)
∂pi = _evalpoly_backp(p[i+1], ∂yi)
∂yi = x′ * ∂yi
return ∂pi
end...,
_evalpoly_backp(p[N], ∂yi), # ∂pN
)
return ∂x, Composite{typeof(p),typeof(∂p)}(∂p)
end
function _evalpoly_back(x, p, ys, Δy)
x′ = x'
∂yi = one(x′) * Δy
N = length(p)
@inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi)
∂p = similar(p, typeof(∂p1))
@inbounds begin
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p[1] = ∂p1
for i in 2:(N - 1)
∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi)
∂p[i] = _evalpoly_backp(p[i], ∂yi)
∂yi = x′ * ∂yi
end
∂p[N] = _evalpoly_backp(p[N], ∂yi)
end
return ∂x, ∂p
end
end
35 changes: 35 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,41 @@
)
end

VERSION ≥ v"1.4" && @testset "evalpoly" begin
# test fallbacks for when code generation fails
@testset "fallbacks for $T" for T in (Float64, ComplexF64)
x, p = randn(T), Tuple(randn(T, 10))
y_fb, ys_fb = ChainRules._evalpoly_intermediates_fallback(x, p)
y, ys = ChainRules._evalpoly_intermediates(x, p)
@test y_fb ≈ y
@test collect(ys_fb) ≈ collect(ys)

Δy, ys = randn(T), Tuple(randn(T, 9))
∂x_fb, ∂p_fb = ChainRules._evalpoly_back_fallback(x, p, ys, Δy)
∂x, ∂p = ChainRules._evalpoly_back(x, p, ys, Δy)
@test ∂x_fb ≈ ∂x
@test collect(∂p_fb) ≈ collect(∂p)
end

@testset "x dim: $(nx), pi dim: $(np), type: $T" for T in (Float64, ComplexF64), nx in (tuple(), 3), np in (tuple(), 3)
# skip x::Matrix, pi::Number case, which is not supported by evalpoly
isempty(np) && !isempty(nx) && continue
m = 5
sx = (nx..., nx...)
sp = (np..., np...)
x, ẋ, x̄ = randn(T, sx...), randn(T, sx...), randn(T, sx...)
p = [randn(T, sp...) for _ in 1:m]
ṗ = [randn(T, sp...) for _ in 1:m]
p̄ = [randn(T, sp...) for _ in 1:m]
Ω = evalpoly(x, p)
Ω̄ = randn(T, size(Ω)...)
frule_test(evalpoly, (x, ẋ), (p, ṗ))
frule_test(evalpoly, (x, ẋ), (Tuple(p), Tuple(ṗ)))
rrule_test(evalpoly, Ω̄, (x, x̄), (p, p̄))
rrule_test(evalpoly, Ω̄, (x, x̄), (Tuple(p), Tuple(p̄)))
end
end

@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
test_scalar(one, x)
test_scalar(zero, x)
Expand Down