Skip to content

Commit c634ec9

Browse files
authored
Merge b3bb9af into c768c0b
2 parents c768c0b + b3bb9af commit c634ec9

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ function jacobian_batched(
66
y = f(xs)
77
z = similar(xs)
88
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
9-
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
9+
res = Zygote.Buffer(
10+
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
11+
size(xs, 1),
12+
size(xs, 1),
13+
size(xs, 2),
14+
)
1015
for i in axes(xs, 1)
1116
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
1217
res[i, :, :] =
@@ -24,7 +29,12 @@ function jacobian_batched(
2429
y = f(xs)
2530
z = similar(xs)
2631
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
27-
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
32+
res = Zygote.Buffer(
33+
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
34+
size(xs, 1),
35+
size(xs, 1),
36+
size(xs, 2),
37+
)
2838
for i in axes(xs, 1)
2939
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
3040
res[:, i, :] = only(

0 commit comments

Comments
 (0)