Skip to content

Commit db40349

Browse files
committed
Test weirder array types
1 parent 4207633 commit db40349

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ ChainRulesCore = "0.10.4"
1414
ChainRulesTestUtils = "0.7.9"
1515
Compat = "3.30"
1616
FiniteDifferences = "0.12.8"
17+
StaticArrays = "1.2"
1718
julia = "1"
1819

1920
[extras]
2021
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2122
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2223
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
24+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2325
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2426

2527
[targets]
26-
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "Test"]
28+
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "StaticArrays", "Test"]

src/rulesets/Base/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function rrule(
2727

2828
pullbacks = last.(fx_and_pullbacks)
2929
function sum_pullback(ȳ)
30-
f̄_and_x̄s = [pullback(ȳ) for pullback in pullbacks]
30+
f̄_and_x̄s = map(pullback->pullback(ȳ), pullbacks)
3131
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
3232
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
3333
NoTangent()

test/rulesets/Base/mapreduce.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@
3030

3131
test_rrule(sum, abs, [-4.0 2.0; 2.0 -1.0])
3232
test_rrule(sum, abs, [-4.0 2.0; 2.0 -1.0]')
33+
34+
test_rrule(sum, abs, @SVector[1.0, -3.0])
35+
36+
# Make sure we preserve type for StaticArrays
37+
ADviaRuleConfig = ChainRulesTestUtils.ADviaRuleConfig
38+
_, pb = rrule(ADviaRuleConfig(), sum, abs, @SVector[1.0, -3.0])
39+
@test pb(1.0) isa Tuple{NoTangent, NoTangent, SVector{2, Float64}}
40+
41+
# For structured sparse matrixes we screw it up, getting dense back
42+
# see https://github.com/JuliaDiff/ChainRules.jl/issues/232 etc
43+
_, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
44+
@test_broken pb(1.0)[3] isa Diagonal
3345
end
3446

3547
@testset "prod" begin

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using LinearAlgebra
1010
using LinearAlgebra.BLAS
1111
using LinearAlgebra: dot
1212
using Random
13+
using StaticArrays
1314
using Statistics
1415
using Test
1516

0 commit comments

Comments
 (0)