Skip to content

Commit 33427ca

Browse files
authored
Merge a709393 into c768c0b
2 parents c768c0b + a709393 commit 33427ca

File tree

3 files changed

+50
-18
lines changed

3 files changed

+50
-18
lines changed

src/utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ function jacobian_batched(
66
y = f(xs)
77
z = similar(xs)
88
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
9-
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
9+
res = Zygote.Buffer(
10+
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
11+
size(xs, 1),
12+
size(xs, 1),
13+
size(xs, 2),
14+
)
1015
for i in axes(xs, 1)
1116
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
1217
res[i, :, :] =
@@ -24,7 +29,12 @@ function jacobian_batched(
2429
y = f(xs)
2530
z = similar(xs)
2631
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
27-
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
32+
res = Zygote.Buffer(
33+
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
34+
size(xs, 1),
35+
size(xs, 1),
36+
size(xs, 2),
37+
)
2838
for i in axes(xs, 1)
2939
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
3040
res[:, i, :] = only(

test/smoke_tests.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Test.@testset "Smoke Tests" begin
3636
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
3737
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
3838
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
39+
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
40+
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
41+
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
3942
ContinuousNormalizingFlows.DIVecJacVectorMode(
4043
ADTypes.AutoEnzyme(;
4144
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
@@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin
6063
function_annotation = Enzyme.Const,
6164
),
6265
),
63-
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
64-
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
65-
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
6666
]
6767

6868
Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in
@@ -194,8 +194,12 @@ Test.@testset "Smoke Tests" begin
194194
Test.@test !isnothing(rand(d, ndata))
195195

196196
Test.@testset "$adtype on loss" for adtype in adtypes
197-
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
198-
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))
197+
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken =
198+
GROUP != "All" &&
199+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
200+
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken =
201+
GROUP != "All" &&
202+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
199203

200204
Test.@testset "$n_epochs for fit" for n_epochs in n_epochs_
201205
if cond
@@ -208,14 +212,22 @@ Test.@testset "Smoke Tests" begin
208212
)
209213
mach = MLJBase.machine(model, (df, df2))
210214

211-
Test.@test !isnothing(MLJBase.fit!(mach))
212-
Test.@test !isnothing(MLJBase.transform(mach, (df, df2)))
213-
Test.@test !isnothing(MLJBase.fitted_params(mach))
215+
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
216+
GROUP != "All" &&
217+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
218+
Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken =
219+
GROUP != "All" &&
220+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
221+
Test.@test !isnothing(MLJBase.fitted_params(mach)) broken =
222+
GROUP != "All" &&
223+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
214224
Test.@test !isnothing(MLJBase.serializable(mach))
215225

216226
Test.@test !isnothing(
217227
ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2),
218-
)
228+
) broken =
229+
GROUP != "All" &&
230+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
219231
else
220232
model = ContinuousNormalizingFlows.ICNFModel(
221233
icnf;
@@ -226,12 +238,20 @@ Test.@testset "Smoke Tests" begin
226238
)
227239
mach = MLJBase.machine(model, df)
228240

229-
Test.@test !isnothing(MLJBase.fit!(mach))
230-
Test.@test !isnothing(MLJBase.transform(mach, df))
231-
Test.@test !isnothing(MLJBase.fitted_params(mach))
241+
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
242+
GROUP != "All" &&
243+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
244+
Test.@test !isnothing(MLJBase.transform(mach, df)) broken =
245+
GROUP != "All" &&
246+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
247+
Test.@test !isnothing(MLJBase.fitted_params(mach)) broken =
248+
GROUP != "All" &&
249+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
232250
Test.@test !isnothing(MLJBase.serializable(mach))
233251

234-
Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode))
252+
Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken =
253+
GROUP != "All" &&
254+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
235255
end
236256
end
237257
end

test/speed_tests.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ Test.@testset "Speed Tests" begin
22
compute_modes = ContinuousNormalizingFlows.ComputeMode[
33
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
44
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
5+
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
6+
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
57
ContinuousNormalizingFlows.DIVecJacMatrixMode(
68
ADTypes.AutoEnzyme(;
79
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
@@ -14,8 +16,6 @@ Test.@testset "Speed Tests" begin
1416
function_annotation = Enzyme.Const,
1517
),
1618
),
17-
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
18-
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
1919
]
2020

2121
Test.@testset "$compute_mode" for compute_mode in compute_modes
@@ -57,7 +57,9 @@ Test.@testset "Speed Tests" begin
5757
model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5)
5858

5959
mach = MLJBase.machine(model, df)
60-
Test.@test !isnothing(MLJBase.fit!(mach))
60+
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
61+
GROUP != "All" &&
62+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
6163

6264
@show only(MLJBase.report(mach).stats).time
6365

0 commit comments

Comments
 (0)