Skip to content

Commit 688af01

Browse files
committed
unrelated test fixes in passing
1 parent 96c3185 commit 688af01

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

test/rulesets/Base/base.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,16 @@
7979
@testset "($x) * ($y)" for
8080
x in test_points, y in test_points
8181

82-
# ensure all complex if any complex for FiniteDifferences
83-
x, y = Base.promote(x, y)
82+
# all complex if any complex, was a limitation of FiniteDifferences?
83+
xx, yy = Base.promote(x, y)
84+
test_frule(*, xx, yy)
85+
test_rrule(*, xx, yy)
8486

87+
# explicitly allow mixed types
8588
test_frule(*, x, y)
8689
test_rrule(*, x, y)
90+
rrule(*, x, y)[2](1)[2] isa typeof(x)
91+
rrule(*, x, y)[2](1)[3] isa typeof(y)
8792
end
8893
end
8994

@@ -117,14 +122,17 @@
117122
test_rrule(identity, Tuple(randn(T, 3)))
118123
end
119124

120-
@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
125+
@testset "one(::Number), zero(::Number)" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
121126
test_scalar(one, x)
122127
test_scalar(zero, x)
128+
129+
rrule(one, x)[2](1) === (NoTangent(), zero(x))
130+
rrule(zero, x)[2](1) === (NoTangent(), zero(x))
123131
end
124132

125133
@testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64)
126-
test_frule(muladd, 10randn(), randn(), randn())
127-
test_rrule(muladd, 10randn(), randn(), randn())
134+
test_frule(muladd, 10randn(T), randn(T), randn(T))
135+
test_rrule(muladd, 10randn(T), randn(T), randn(T))
128136
end
129137

130138
@testset "fma" begin
@@ -144,6 +152,13 @@
144152
# to right
145153
test_frule(clamp, 4., 2., 3.)
146154
test_rrule(clamp, 4., 2., 3.)
155+
156+
# nonzero gradient at the boundaries
157+
@test frule((0,1,0,0), clamp, 2, 2, 3) == (2, 1)
158+
@test rrule(clamp, 2.0, 2, 3)[2](1)[2] == 1.0
159+
160+
@test frule((0,1,0,0), clamp, 3, 2, 3) == (3, 1)
161+
@test rrule(clamp, 3, 2, 3)[2](1)[2] == 1.0
147162
end
148163

149164
@testset "rounding" begin

0 commit comments

Comments
 (0)