Skip to content

Commit 345d05a

Browse files
authored
Merge 5dbc35b into c768c0b
2 parents c768c0b + 5dbc35b commit 345d05a

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

src/utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ function jacobian_batched(
66
y = f(xs)
77
z = similar(xs)
88
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
9-
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
9+
res = Zygote.Buffer(
10+
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
11+
size(xs, 1),
12+
size(xs, 1),
13+
size(xs, 2),
14+
)
1015
for i in axes(xs, 1)
1116
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
1217
res[i, :, :] =
@@ -24,7 +29,12 @@ function jacobian_batched(
2429
y = f(xs)
2530
z = similar(xs)
2631
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
27-
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
32+
res = Zygote.Buffer(
33+
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
34+
size(xs, 1),
35+
size(xs, 1),
36+
size(xs, 2),
37+
)
2838
for i in axes(xs, 1)
2939
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
3040
res[:, i, :] = only(

test/smoke_tests.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Test.@testset "Smoke Tests" begin
3636
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
3737
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
3838
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
39+
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
40+
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
41+
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
3942
ContinuousNormalizingFlows.DIVecJacVectorMode(
4043
ADTypes.AutoEnzyme(;
4144
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
@@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin
6063
function_annotation = Enzyme.Const,
6164
),
6265
),
63-
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
64-
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
65-
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
6666
]
6767

6868
Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in
@@ -194,8 +194,12 @@ Test.@testset "Smoke Tests" begin
194194
Test.@test !isnothing(rand(d, ndata))
195195

196196
Test.@testset "$adtype on loss" for adtype in adtypes
197-
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
198-
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))
197+
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken =
198+
GROUP != "All" &&
199+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
200+
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken =
201+
GROUP != "All" &&
202+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
199203

200204
Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_
201205
if cond

test/speed_tests.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ Test.@testset "Speed Tests" begin
22
compute_modes = ContinuousNormalizingFlows.ComputeMode[
33
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
44
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
5+
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
6+
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
57
ContinuousNormalizingFlows.DIVecJacMatrixMode(
68
ADTypes.AutoEnzyme(;
79
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
@@ -14,8 +16,6 @@ Test.@testset "Speed Tests" begin
1416
function_annotation = Enzyme.Const,
1517
),
1618
),
17-
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
18-
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
1919
]
2020

2121
Test.@testset "$compute_mode" for compute_mode in compute_modes
@@ -57,7 +57,9 @@ Test.@testset "Speed Tests" begin
5757
model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5)
5858

5959
mach = MLJBase.machine(model, df)
60-
Test.@test !isnothing(MLJBase.fit!(mach))
60+
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
61+
GROUP != "All" &&
62+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
6163

6264
@show only(MLJBase.report(mach).stats).time
6365

0 commit comments

Comments
 (0)