From db391a256f783392760ecdfeafbceea9d926c586 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 25 Oct 2024 03:05:52 +0100 Subject: [PATCH 1/2] Special-case ReshapeTransform for singleton inputs --- Project.toml | 2 +- src/utils.jl | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index eab8c362c..92fb67ddd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30.1" +version = "0.30.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/utils.jl b/src/utils.jl index bd5d365fc..1a2d1ffd7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -286,8 +286,15 @@ function (f::ReshapeTransform)(x) if size(x) != f.input_size throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))")) end - # The call to `tovec` is only needed in case `x` is a scalar. - return reshape(tovec(x), f.output_size) + if f.output_size == () + # Specially handle the case where x is a singleton array, see + # https://github.com/JuliaDiff/ReverseDiff.jl/issues/265 and + # https://github.com/TuringLang/DynamicPPL.jl/issues/698 + return x[] + else + # The call to `tovec` is only needed in case `x` is a scalar. + return reshape(tovec(x), f.output_size) + end end function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x) From 8c9cc716d1211bf7f45cb8dc0dbc07bf23b7a4c4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 25 Oct 2024 11:31:23 +0100 Subject: [PATCH 2/2] Use fill(x[], ()) instead --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 1a2d1ffd7..a809fda17 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -290,7 +290,7 @@ function (f::ReshapeTransform)(x) # Specially handle the case where x is a singleton array, see # https://github.com/JuliaDiff/ReverseDiff.jl/issues/265 and # https://github.com/TuringLang/DynamicPPL.jl/issues/698 - return x[] + return fill(x[], ()) else # The call to `tovec` is only needed in case `x` is a scalar. return reshape(tovec(x), f.output_size)