From fa1138d0032d5602f38d6e48c6a6b3913edfad9e Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 21 Feb 2022 11:58:25 +0000 Subject: [PATCH] to_vec for inplaceablethunk --- Project.toml | 2 +- src/to_vec.jl | 11 +++++++++++ test/to_vec.jl | 4 +++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0c3f88b..68561cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.23" +version = "0.12.24" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/to_vec.jl b/src/to_vec.jl index 37287f9..6dd1fbc 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -287,3 +287,14 @@ function FiniteDifferences.to_vec(t::Thunk) Thunk_from_vec = v -> @thunk(back(v)) return v, Thunk_from_vec end + +function FiniteDifferences.to_vec(t::InplaceableThunk) + v, back = to_vec(unthunk(t)) + function InplaceableThunk_from_vec(v) + return InplaceableThunk( + Δ -> Δ += back(b), + @thunk(back(v)) + ) + end + return v, InplaceableThunk_from_vec +end diff --git a/test/to_vec.jl b/test/to_vec.jl index 54e18ab..4f01ff1 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -249,7 +249,9 @@ end end @testset "Thunks" begin - test_to_vec(@thunk(3.2+4.3)) + t = @thunk(3.2+4.3) + test_to_vec(t) + test_to_vec(InplaceableThunk(Δ -> Δ += t, t)) end end