@@ -7,12 +7,6 @@ Rather it is an arbitary value, that is generated using the `rng`.
77"""
88rand_tangent (x) = rand_tangent (Random. GLOBAL_RNG, x)
99
10- rand_tangent (rng:: AbstractRNG , x:: Symbol ) = NoTangent ()
11- rand_tangent (rng:: AbstractRNG , x:: AbstractChar ) = NoTangent ()
12- rand_tangent (rng:: AbstractRNG , x:: AbstractString ) = NoTangent ()
13-
14- rand_tangent (rng:: AbstractRNG , x:: Integer ) = NoTangent ()
15-
1610# Try and make nice numbers with short decimal representations for good error messages
1711# while also not biasing the sample space too much
1812function rand_tangent (rng:: AbstractRNG , x:: T ) where {T<: Number }
@@ -24,25 +18,11 @@ function rand_tangent(rng::AbstractRNG, x::ComplexF64)
2418 return ComplexF64 (rand (rng, - 9 : 0.1 : 9 ), rand (rng, - 9 : 0.1 : 9 ))
2519end
2620
27- # BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should
28-
21+ # BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should
2922# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
3023rand_tangent (rng:: AbstractRNG , :: BigFloat ) = round (big (9 * randn (rng)), digits= 5 , base= 2 )
3124
32-
33- rand_tangent (rng:: AbstractRNG , x:: Array{<:Any, 0} ) = _compress_notangent (fill (rand_tangent (rng, x[])))
34- rand_tangent (rng:: AbstractRNG , x:: Array ) = _compress_notangent (rand_tangent .(Ref (rng), x))
35-
36- # All other AbstractArray's can be handled using the ProjectTo mechanics.
37- # and follow the same requirements
38- function rand_tangent (rng:: AbstractRNG , x:: AbstractArray )
39- return _compress_notangent (ProjectTo (x)(rand_tangent (rng, collect (x))))
40- end
41-
42- # TODO : arguably ProjectTo should handle this for us for AbstactArrays
43- # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/410
44- _compress_notangent (:: AbstractArray{NoTangent} ) = NoTangent ()
45- _compress_notangent (x) = x
25+ rand_tangent (rng:: AbstractRNG , x:: AbstractArray ) = ProjectTo (x)(rand_tangent .(Ref (rng), x))
4626
4727function rand_tangent (rng:: AbstractRNG , x:: T ) where {T}
4828 if ! isstructtype (T)
@@ -65,5 +45,9 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T}
6545 end
6646end
6747
48+ rand_tangent (rng:: AbstractRNG , x:: Symbol ) = NoTangent ()
49+ rand_tangent (rng:: AbstractRNG , x:: AbstractChar ) = NoTangent ()
50+ rand_tangent (rng:: AbstractRNG , x:: AbstractString ) = NoTangent ()
51+ rand_tangent (rng:: AbstractRNG , x:: Integer ) = NoTangent ()
6852rand_tangent (rng:: AbstractRNG , :: Type ) = NoTangent ()
6953rand_tangent (rng:: AbstractRNG , :: Module ) = NoTangent ()
0 commit comments