@@ -36,6 +36,9 @@ Test.@testset "Smoke Tests" begin
3636 ContinuousNormalizingFlows. LuxVecJacMatrixMode (ADTypes. AutoZygote ()),
3737 ContinuousNormalizingFlows. DIVecJacVectorMode (ADTypes. AutoZygote ()),
3838 ContinuousNormalizingFlows. DIVecJacMatrixMode (ADTypes. AutoZygote ()),
39+ ContinuousNormalizingFlows. LuxJacVecMatrixMode (ADTypes. AutoForwardDiff ()),
40+ ContinuousNormalizingFlows. DIJacVecVectorMode (ADTypes. AutoForwardDiff ()),
41+ ContinuousNormalizingFlows. DIJacVecMatrixMode (ADTypes. AutoForwardDiff ()),
3942 ContinuousNormalizingFlows. DIVecJacVectorMode (
4043 ADTypes. AutoEnzyme (;
4144 mode = Enzyme. set_runtime_activity (Enzyme. Reverse),
@@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin
6063 function_annotation = Enzyme. Const,
6164 ),
6265 ),
63- ContinuousNormalizingFlows. LuxJacVecMatrixMode (ADTypes. AutoForwardDiff ()),
64- ContinuousNormalizingFlows. DIJacVecVectorMode (ADTypes. AutoForwardDiff ()),
65- ContinuousNormalizingFlows. DIJacVecMatrixMode (ADTypes. AutoForwardDiff ()),
6666 ]
6767
6868 Test. @testset " $device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt " for device in
@@ -194,8 +194,12 @@ Test.@testset "Smoke Tests" begin
194194 Test. @test ! isnothing (rand (d, ndata))
195195
196196 Test. @testset " $adtype on loss" for adtype in adtypes
197- Test. @test ! isnothing (DifferentiationInterface. gradient (diff_loss, adtype, ps))
198- Test. @test ! isnothing (DifferentiationInterface. gradient (diff2_loss, adtype, r))
197+ Test. @test ! isnothing (DifferentiationInterface. gradient (diff_loss, adtype, ps)) broken =
198+ GROUP != " All" &&
199+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
200+ Test. @test ! isnothing (DifferentiationInterface. gradient (diff2_loss, adtype, r)) broken =
201+ GROUP != " All" &&
202+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
199203
200204 Test. @testset " $n_epochs for fit" for n_epochs in n_epochs_
201205 if cond
@@ -208,14 +212,22 @@ Test.@testset "Smoke Tests" begin
208212 )
209213 mach = MLJBase. machine (model, (df, df2))
210214
211- Test. @test ! isnothing (MLJBase. fit! (mach))
212- Test. @test ! isnothing (MLJBase. transform (mach, (df, df2)))
213- Test. @test ! isnothing (MLJBase. fitted_params (mach))
215+ Test. @test ! isnothing (MLJBase. fit! (mach)) broken =
216+ GROUP != " All" &&
217+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
218+ Test. @test ! isnothing (MLJBase. transform (mach, (df, df2))) broken =
219+ GROUP != " All" &&
220+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
221+ Test. @test ! isnothing (MLJBase. fitted_params (mach)) broken =
222+ GROUP != " All" &&
223+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
214224 Test. @test ! isnothing (MLJBase. serializable (mach))
215225
216226 Test. @test ! isnothing (
217227 ContinuousNormalizingFlows. CondICNFDist (mach, omode, r2),
218- )
228+ ) broken =
229+ GROUP != " All" &&
230+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
219231 else
220232 model = ContinuousNormalizingFlows. ICNFModel (
221233 icnf;
@@ -226,12 +238,20 @@ Test.@testset "Smoke Tests" begin
226238 )
227239 mach = MLJBase. machine (model, df)
228240
229- Test. @test ! isnothing (MLJBase. fit! (mach))
230- Test. @test ! isnothing (MLJBase. transform (mach, df))
231- Test. @test ! isnothing (MLJBase. fitted_params (mach))
241+ Test. @test ! isnothing (MLJBase. fit! (mach)) broken =
242+ GROUP != " All" &&
243+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
244+ Test. @test ! isnothing (MLJBase. transform (mach, df)) broken =
245+ GROUP != " All" &&
246+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
247+ Test. @test ! isnothing (MLJBase. fitted_params (mach)) broken =
248+ GROUP != " All" &&
249+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
232250 Test. @test ! isnothing (MLJBase. serializable (mach))
233251
234- Test. @test ! isnothing (ContinuousNormalizingFlows. ICNFDist (mach, omode))
252+ Test. @test ! isnothing (ContinuousNormalizingFlows. ICNFDist (mach, omode)) broken =
253+ GROUP != " All" &&
254+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
235255 end
236256 end
237257 end
0 commit comments