Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/exts/dist_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real})
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
first(inference(d.m, d.mode, x, d.ys, d.ps, d.st))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
@warn "to compute by matrices, data should be a matrix."
first(Distributions._logpdf(d, hcat(x)))
else
error("Not Implemented")
end
end
function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real})
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
Distributions._logpdf.(d, eachcol(A))
@warn "to compute by vectors, data should be a vector."
Distributions._logpdf.(d, collect(collect.(eachcol(A))))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
else
Expand All @@ -41,6 +43,7 @@ function Distributions._rand!(
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
x .= generate(d.m, d.mode, d.ys, d.ps, d.st)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
@warn "to compute by matrices, data should be a matrix."
x .= Distributions._rand!(rng, d, hcat(x))
else
error("Not Implemented")
Expand All @@ -52,7 +55,8 @@ function Distributions._rand!(
A::AbstractMatrix{<:Real},
)
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...)
@warn "to compute by vectors, data should be a vector."
A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
else
Expand Down
8 changes: 6 additions & 2 deletions src/exts/dist_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real})
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
first(inference(d.m, d.mode, x, d.ps, d.st))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
@warn "to compute by matrices, data should be a matrix."
first(Distributions._logpdf(d, hcat(x)))
else
error("Not Implemented")
Expand All @@ -22,7 +23,8 @@ end

function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real})
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
Distributions._logpdf.(d, eachcol(A))
@warn "to compute by vectors, data should be a vector."
Distributions._logpdf.(d, collect(collect.(eachcol(A))))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ps, d.st))
else
Expand All @@ -38,6 +40,7 @@ function Distributions._rand!(
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
x .= generate(d.m, d.mode, d.ps, d.st)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
@warn "to compute by matrices, data should be a matrix."
x .= Distributions._rand!(rng, d, hcat(x))
else
error("Not Implemented")
Expand All @@ -49,7 +52,8 @@ function Distributions._rand!(
A::AbstractMatrix{<:Real},
)
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...)
@warn "to compute by vectors, data should be a vector."
A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2))
else
Expand Down
5 changes: 3 additions & 2 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)
(ps, st) = fitresult

logp̂x = if model.m.compute_mode isa VectorMode
@warn "to compute by vectors, data should be a vector."
broadcast(
function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
return first(inference(model.m, TestMode(), x, y, ps, st))
end,
eachcol(xnew),
eachcol(ynew),
collect(collect.(eachcol(xnew))),
collect(collect.(eachcol(ynew))),
)
elseif model.m.compute_mode isa MatrixMode
first(inference(model.m, TestMode(), xnew, ynew, ps, st))
Expand Down
8 changes: 6 additions & 2 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
(ps, st) = fitresult

logp̂x = if model.m.compute_mode isa VectorMode
broadcast(function (x::AbstractVector{<:Real})
@warn "to compute by vectors, data should be a vector."
broadcast(
function (x::AbstractVector{<:Real})
return first(inference(model.m, TestMode(), x, ps, st))
end, eachcol(xnew))
end,
collect(collect.(eachcol(xnew))),
)
elseif model.m.compute_mode isa MatrixMode
first(inference(model.m, TestMode(), xnew, ps, st))
else
Expand Down
14 changes: 12 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ function jacobian_batched(
y = f(xs)
z = similar(xs)
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
res = Zygote.Buffer(
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
size(xs, 1),
size(xs, 1),
size(xs, 2),
)
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[i, :, :] =
Expand All @@ -24,7 +29,12 @@ function jacobian_batched(
y = f(xs)
z = similar(xs)
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
res = Zygote.Buffer(
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
size(xs, 1),
size(xs, 1),
size(xs, 2),
)
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[:, i, :] = only(
Expand Down
21 changes: 13 additions & 8 deletions test/smoke_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Test.@testset "Smoke Tests" begin
data_types = Type{<:AbstractFloat}[Float32]
devices = MLDataDevices.AbstractDevice[MLDataDevices.cpu_device()]
adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(),
# ADTypes.AutoForwardDiff(),
# ADTypes.AutoEnzyme(;
# mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
# function_annotation = Enzyme.Const,
Expand All @@ -30,27 +31,29 @@ Test.@testset "Smoke Tests" begin
# mode = Enzyme.set_runtime_activity(Enzyme.Forward),
# function_annotation = Enzyme.Const,
# ),
# ADTypes.AutoForwardDiff(),
]
compute_modes = ContinuousNormalizingFlows.ComputeMode[
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand All @@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
]

Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in
Expand Down Expand Up @@ -193,6 +193,11 @@ Test.@testset "Smoke Tests" begin
Test.@test !isnothing(rand(d))
Test.@test !isnothing(rand(d, ndata))

if GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
continue
end

Test.@testset "$adtype on loss" for adtype in adtypes
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))
Expand Down
12 changes: 9 additions & 3 deletions test/speed_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ Test.@testset "Speed Tests" begin
compute_modes = ContinuousNormalizingFlows.ComputeMode[
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
Expand All @@ -14,8 +16,6 @@ Test.@testset "Speed Tests" begin
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
]

Test.@testset "$compute_mode" for compute_mode in compute_modes
Expand Down Expand Up @@ -54,10 +54,16 @@ Test.@testset "Speed Tests" begin
)

df = DataFrames.DataFrame(transpose(r), :auto)

if GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
continue
end

model = ContinuousNormalizingFlows.ICNFModel(icnf; batch_size = 0, n_epochs = 5)

mach = MLJBase.machine(model, df)
Test.@test !isnothing(MLJBase.fit!(mach))
MLJBase.fit!(mach)

@show only(MLJBase.report(mach).stats).time

Expand Down
Loading