diff --git a/Project.toml b/Project.toml index 55e922b..0212034 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.3.0" +version = "1.3.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "1.2" +ChainRulesCore = "1.11.2" Compat = "3" FiniteDifferences = "0.12.12" julia = "1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index d4916cc..ab985dc 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -16,21 +16,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "bdc0937269321858ab2a4f288486cb258b9a0af7" +git-tree-sha1 = "f885e7e7c124f8c92650d61b9477b9ac2ee607dd" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.3.0" +version = "1.11.1" [[ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] path = ".." uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.2.1" +version = "1.2.4" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "727e463cfebd0c7b999bbf3e9e7e16f254b94193" +git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.34.0" +version = "3.40.0" [[Dates]] deps = ["Printf"] @@ -46,15 +46,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] deps = ["LibGit2"] -git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.5" +version = "0.8.6" [[Documenter]] deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "350dced36c11f794c6c4da5dc6493ec894e50c16" +git-tree-sha1 = "f425293f7e0acaf9144de6d731772de156676233" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.5" +version = "0.27.10" [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] @@ -62,9 +62,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"] -git-tree-sha1 = "9a586f04a21e6945f4cbee0d0fb6aebd7b86aa8f" +git-tree-sha1 = "c56a261e1a5472f20cbd7aa218840fd203243319" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.18" +version = "0.12.19" [[IOCapture]] deps = ["Logging", "Random"] @@ -127,9 +127,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" [[Parsers]] deps = ["Dates"] -git-tree-sha1 = "438d35d2d95ae2c5e8780b330592b6de8494e779" +git-tree-sha1 = "ae4bbcadb2906ccc085cf52ac286dc1377dceccc" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.0.3" +version = "2.1.2" [[Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -172,9 +172,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "3240808c6d463ac46f1c1cd7638375cd22abbccb" +git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.2.12" +version = "1.2.13" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] diff --git a/src/finite_difference_calls.jl b/src/finite_difference_calls.jl index 72cb7ae..8335319 100644 --- a/src/finite_difference_calls.jl +++ b/src/finite_difference_calls.jl @@ -21,7 +21,7 @@ function _make_jvp_call(fdm, f, y, xs, ẋs, ignores) ignores = collect(ignores) all(ignores) && return ntuple(_ -> NoTangent(), length(xs)) sigargs = zip(xs[.!ignores], ẋs[.!ignores]) - return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...)) + return ProjectTo(y)(jvp(fdm, f2, sigargs...)) end """ @@ -52,7 +52,7 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores) @assert length(fd) == length(arginds) for (dx, ind) in zip(fd, arginds) - args[ind] = _maybe_fix_to_composite(xs[ind], dx) + args[ind] = ProjectTo(xs[ind])(dx) end return (args...,) end @@ -87,10 +87,3 @@ function _wrap_function(f, xs, ignores) end return fnew end - -# TODO: remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97 -# For functions which return a tuple, FD returns a tuple to represent the differential. Tuple -# is not a natural differential, because it doesn't overload +, so make it a Tangent. -_maybe_fix_to_composite(::P, x::Tuple) where {P} = Tangent{P}(x...) -_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Tangent{P}(; x...) -_maybe_fix_to_composite(::Any, x) = x diff --git a/test/testers.jl b/test/testers.jl index ec7e6c6..d695aa0 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -57,6 +57,14 @@ end abstract type MySpecialTrait end struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end +# Type-stable derivative for test below +struct FVecOfTuplesPullback{T} end +function (f::FVecOfTuplesPullback{T})(Δ) where {T} + ΔΩ_first, ΔΩ_last = unthunk(Δ) + Δx = map(z -> Tangent{T}(z, ΔΩ_last), ΔΩ_first) + return NoTangent(), Δx +end + @testset "testers.jl" begin @testset "test_scalar" begin @testset "Ensure correct rules succeed" begin @@ -608,7 +616,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end test_rrule( does_not_accept_thunk_id, [1.0, 2.0]; check_thunked_output_tangent=false ) - @test errors(r"MethodError.*Thunk") do + @test errors(r"MethodError.*Thunk") do test_rrule(does_not_accept_thunk_id, [1.0, 2.0]) end end @@ -736,4 +744,33 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end test_rrule(my_id, 2.0; check_inferred=false) test_rrule(my_id, 2.0; check_thunked_output_tangent=false) end + + # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/224 + @testset "vectors of tuples" begin + function f_vec_of_tuples(x::AbstractVector{<:Tuple{<:Any,<:Any}}) + return map(first, x), sum(last, x) + end + function ChainRulesCore.frule( + (_, Δx), + ::typeof(f_vec_of_tuples), + x::AbstractVector{<:Tuple{<:Any,<:Any}}, + ) + Ω = f_vec_of_tuples(x) + Ω̄ = Tangent{typeof(Ω)}(f_vec_of_tuples(map(ChainRulesCore.backing, Δx))...) + return Ω, Ω̄ + end + function ChainRulesCore.rrule( + ::typeof(f_vec_of_tuples), + x::AbstractVector{<:Tuple{<:Any,<:Any}}, + ) + Ω = f_vec_of_tuples(x) + # We use a functor here to fix type inference + f_vec_of_tuples_pullback = FVecOfTuplesPullback{eltype(x)}() + return Ω, f_vec_of_tuples_pullback + end + + x_tuples = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] + test_frule(f_vec_of_tuples, x_tuples) + test_rrule(f_vec_of_tuples, x_tuples) + end end