Skip to content

Commit c768c0b

Browse files
committed
switch to LuxCore version of StatefulLuxLayer and setup
1 parent ba21a15 commit c768c0b

File tree

8 files changed

+29
-23
lines changed

8 files changed

+29
-23
lines changed

benchmark/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
66
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
7+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
78
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
89
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
910
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
@@ -16,6 +17,7 @@ BenchmarkTools = "1"
1617
ComponentArrays = "0.15"
1718
DifferentiationInterface = "0.7"
1819
Lux = "1"
20+
LuxCore = "1"
1921
OrdinaryDiffEqDefault = "1"
2022
PkgBenchmark = "0.2"
2123
SciMLSensitivity = "7"

benchmark/benchmarks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import ADTypes,
33
ComponentArrays,
44
DifferentiationInterface,
55
Lux,
6+
LuxCore,
67
OrdinaryDiffEqDefault,
78
PkgBenchmark,
89
SciMLSensitivity,
@@ -48,7 +49,7 @@ icnf = ContinuousNormalizingFlows.construct(
4849
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
4950
),
5051
)
51-
ps, st = Lux.setup(icnf.rng, icnf)
52+
ps, st = LuxCore.setup(icnf.rng, icnf)
5253
ps = ComponentArrays.ComponentArray(ps)
5354
r = rand(icnf.rng, Float32, nvars, n)
5455

src/icnf.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ function augmented_f(
118118
ϵ::AbstractVector{T},
119119
) where {T <: AbstractFloat}
120120
n_aug = n_augment(icnf, mode)
121-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
121+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
122122
z = u[begin:(end - n_aug - 1)]
123123
ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z)
124124
= -LinearAlgebra.tr(J)
@@ -137,7 +137,7 @@ function augmented_f(
137137
ϵ::AbstractVector{T},
138138
) where {T <: AbstractFloat}
139139
n_aug = n_augment(icnf, mode)
140-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
140+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
141141
z = u[begin:(end - n_aug - 1)]
142142
ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z)
143143
du[begin:(end - n_aug - 1)] .=
@@ -156,7 +156,7 @@ function augmented_f(
156156
ϵ::AbstractMatrix{T},
157157
) where {T <: AbstractFloat}
158158
n_aug = n_augment(icnf, mode)
159-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
159+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
160160
z = u[begin:(end - n_aug - 1), :]
161161
ż, J = jacobian_batched(icnf, snn, z)
162162
= -transpose(LinearAlgebra.tr.(J))
@@ -175,7 +175,7 @@ function augmented_f(
175175
ϵ::AbstractMatrix{T},
176176
) where {T <: AbstractFloat}
177177
n_aug = n_augment(icnf, mode)
178-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
178+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
179179
z = u[begin:(end - n_aug - 1), :]
180180
ż, J = jacobian_batched(icnf, snn, z)
181181
du[begin:(end - n_aug - 1), :] .=
@@ -194,7 +194,7 @@ function augmented_f(
194194
ϵ::AbstractVector{T},
195195
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
196196
n_aug = n_augment(icnf, mode)
197-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
197+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
198198
z = u[begin:(end - n_aug - 1)]
199199
ż, ϵJ =
200200
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
@@ -225,7 +225,7 @@ function augmented_f(
225225
ϵ::AbstractVector{T},
226226
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
227227
n_aug = n_augment(icnf, mode)
228-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
228+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
229229
z = u[begin:(end - n_aug - 1)]
230230
ż, ϵJ =
231231
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
@@ -256,7 +256,7 @@ function augmented_f(
256256
ϵ::AbstractVector{T},
257257
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
258258
n_aug = n_augment(icnf, mode)
259-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
259+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
260260
z = u[begin:(end - n_aug - 1)]
261261
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
262262
snn,
@@ -291,7 +291,7 @@ function augmented_f(
291291
ϵ::AbstractVector{T},
292292
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
293293
n_aug = n_augment(icnf, mode)
294-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
294+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
295295
z = u[begin:(end - n_aug - 1)]
296296
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
297297
snn,
@@ -326,7 +326,7 @@ function augmented_f(
326326
ϵ::AbstractMatrix{T},
327327
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
328328
n_aug = n_augment(icnf, mode)
329-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
329+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
330330
z = u[begin:(end - n_aug - 1), :]
331331
ż, ϵJ =
332332
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
@@ -361,7 +361,7 @@ function augmented_f(
361361
ϵ::AbstractMatrix{T},
362362
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
363363
n_aug = n_augment(icnf, mode)
364-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
364+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
365365
z = u[begin:(end - n_aug - 1), :]
366366
ż, ϵJ =
367367
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
@@ -392,7 +392,7 @@ function augmented_f(
392392
ϵ::AbstractMatrix{T},
393393
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
394394
n_aug = n_augment(icnf, mode)
395-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
395+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
396396
z = u[begin:(end - n_aug - 1), :]
397397
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
398398
snn,
@@ -431,7 +431,7 @@ function augmented_f(
431431
ϵ::AbstractMatrix{T},
432432
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
433433
n_aug = n_augment(icnf, mode)
434-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
434+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
435435
z = u[begin:(end - n_aug - 1), :]
436436
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
437437
snn,
@@ -466,7 +466,7 @@ function augmented_f(
466466
ϵ::AbstractMatrix{T},
467467
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
468468
n_aug = n_augment(icnf, mode)
469-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
469+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
470470
z = u[begin:(end - n_aug - 1), :]
471471
= snn(z)
472472
ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ)
@@ -500,7 +500,7 @@ function augmented_f(
500500
ϵ::AbstractMatrix{T},
501501
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
502502
n_aug = n_augment(icnf, mode)
503-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
503+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
504504
z = u[begin:(end - n_aug - 1), :]
505505
= snn(z)
506506
ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ)
@@ -530,7 +530,7 @@ function augmented_f(
530530
ϵ::AbstractMatrix{T},
531531
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
532532
n_aug = n_augment(icnf, mode)
533-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
533+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
534534
z = u[begin:(end - n_aug - 1), :]
535535
= snn(z)
536536
= Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ)
@@ -564,7 +564,7 @@ function augmented_f(
564564
ϵ::AbstractMatrix{T},
565565
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
566566
n_aug = n_augment(icnf, mode)
567-
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
567+
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
568568
z = u[begin:(end - n_aug - 1), :]
569569
= snn(z)
570570
= Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ)

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function jacobian_batched(
22
icnf::AbstractICNF{T, <:DIVecJacMatrixMode},
3-
f::Lux.StatefulLuxLayer,
3+
f::LuxCore.StatefulLuxLayer,
44
xs::AbstractMatrix{<:Real},
55
) where {T}
66
y = f(xs)
@@ -18,7 +18,7 @@ end
1818

1919
function jacobian_batched(
2020
icnf::AbstractICNF{T, <:DIJacVecMatrixMode},
21-
f::Lux.StatefulLuxLayer,
21+
f::LuxCore.StatefulLuxLayer,
2222
xs::AbstractMatrix{<:Real},
2323
) where {T}
2424
y = f(xs)
@@ -37,7 +37,7 @@ end
3737

3838
function jacobian_batched(
3939
icnf::AbstractICNF{T, <:DIMatrixMode},
40-
f::Lux.StatefulLuxLayer,
40+
f::LuxCore.StatefulLuxLayer,
4141
xs::AbstractMatrix{<:Real},
4242
) where {T}
4343
y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs)
@@ -55,7 +55,7 @@ end
5555

5656
function jacobian_batched(
5757
icnf::AbstractICNF{T, <:LuxMatrixMode},
58-
f::Lux.StatefulLuxLayer,
58+
f::LuxCore.StatefulLuxLayer,
5959
xs::AbstractMatrix{<:Real},
6060
) where {T}
6161
y = f(xs)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1212
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1313
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1414
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
15+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1516
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1617
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1718
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
@@ -34,6 +35,7 @@ ExplicitImports = "1"
3435
ForwardDiff = "1"
3536
JET = "0.9, 0.10"
3637
Lux = "1"
38+
LuxCore = "1"
3739
MLDataDevices = "1"
3840
MLJBase = "1"
3941
OrdinaryDiffEqDefault = "1"

test/checkby_JET_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Test.@testset "CheckByJET" begin
2828
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
2929
),
3030
)
31-
ps, st = Lux.setup(icnf.rng, icnf)
31+
ps, st = LuxCore.setup(icnf.rng, icnf)
3232
ps = ComponentArrays.ComponentArray(ps)
3333
r = rand(icnf.rng, Float32, nvars, n)
3434

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import ADTypes,
1111
JET,
1212
Logging,
1313
Lux,
14+
LuxCore,
1415
MLDataDevices,
1516
MLJBase,
1617
OrdinaryDiffEqDefault,

test/smoke_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ Test.@testset "Smoke Tests" begin
126126
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
127127
),
128128
)
129-
ps, st = Lux.setup(icnf.rng, icnf)
129+
ps, st = LuxCore.setup(icnf.rng, icnf)
130130
ps = ComponentArrays.ComponentArray(ps)
131131
r = device(r)
132132
r2 = device(r2)

0 commit comments

Comments
 (0)