Skip to content

Commit a8fd5d5

Browse files
authored
Merge 7dc8f71 into 51e57fa
2 parents 51e57fa + 7dc8f71 commit a8fd5d5

File tree

4 files changed

+281
-126
lines changed

4 files changed

+281
-126
lines changed

.github/workflows/CI-CheckBy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
version:
2727
- release
2828
- lts
29-
- nightly
29+
# - nightly
3030
os:
3131
- ubuntu-latest
3232
# - macOS-latest

src/icnf.jl

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function augmented_f(
120120
n_aug = n_augment(icnf, mode)
121121
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
122122
z = u[begin:(end - n_aug - 1)]
123-
ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z)
123+
ż, J = icnf_jacobian(icnf, mode, snn, z)
124124
= -LinearAlgebra.tr(J)
125125
return vcat(ż, l̇)
126126
end
@@ -139,7 +139,7 @@ function augmented_f(
139139
n_aug = n_augment(icnf, mode)
140140
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
141141
z = u[begin:(end - n_aug - 1)]
142-
ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z)
142+
ż, J = icnf_jacobian(icnf, mode, snn, z)
143143
du[begin:(end - n_aug - 1)] .=
144144
du[(end - n_aug)] = -LinearAlgebra.tr(J)
145145
return nothing
@@ -158,8 +158,8 @@ function augmented_f(
158158
n_aug = n_augment(icnf, mode)
159159
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
160160
z = u[begin:(end - n_aug - 1), :]
161-
ż, J = jacobian_batched(icnf, snn, z)
162-
= -transpose(LinearAlgebra.tr.(J))
161+
ż, J = icnf_jacobian(icnf, mode, snn, z)
162+
= -transpose(LinearAlgebra.tr.(eachslice(J; dims = 3)))
163163
return vcat(ż, l̇)
164164
end
165165

@@ -177,9 +177,9 @@ function augmented_f(
177177
n_aug = n_augment(icnf, mode)
178178
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
179179
z = u[begin:(end - n_aug - 1), :]
180-
ż, J = jacobian_batched(icnf, snn, z)
180+
ż, J = icnf_jacobian(icnf, mode, snn, z)
181181
du[begin:(end - n_aug - 1), :] .=
182-
du[(end - n_aug), :] .= -(LinearAlgebra.tr.(J))
182+
du[(end - n_aug), :] .= -(LinearAlgebra.tr.(eachslice(J; dims = 3)))
183183
return nothing
184184
end
185185

@@ -196,9 +196,7 @@ function augmented_f(
196196
n_aug = n_augment(icnf, mode)
197197
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
198198
z = u[begin:(end - n_aug - 1)]
199-
ż, ϵJ =
200-
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
201-
ϵJ = only(ϵJ)
199+
ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ)
202200
= -LinearAlgebra.dot(ϵJ, ϵ)
203201
= if NORM_Z
204202
LinearAlgebra.norm(ż)
@@ -227,9 +225,7 @@ function augmented_f(
227225
n_aug = n_augment(icnf, mode)
228226
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
229227
z = u[begin:(end - n_aug - 1)]
230-
ż, ϵJ =
231-
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
232-
ϵJ = only(ϵJ)
228+
ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ)
233229
du[begin:(end - n_aug - 1)] .=
234230
du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ)
235231
du[(end - n_aug + 1)] = if NORM_Z
@@ -258,13 +254,7 @@ function augmented_f(
258254
n_aug = n_augment(icnf, mode)
259255
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
260256
z = u[begin:(end - n_aug - 1)]
261-
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
262-
snn,
263-
icnf.compute_mode.adback,
264-
z,
265-
(ϵ,),
266-
)
267-
= only(Jϵ)
257+
ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ)
268258
= -LinearAlgebra.dot(ϵ, Jϵ)
269259
= if NORM_Z
270260
LinearAlgebra.norm(ż)
@@ -293,13 +283,7 @@ function augmented_f(
293283
n_aug = n_augment(icnf, mode)
294284
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
295285
z = u[begin:(end - n_aug - 1)]
296-
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
297-
snn,
298-
icnf.compute_mode.adback,
299-
z,
300-
(ϵ,),
301-
)
302-
= only(Jϵ)
286+
ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ)
303287
du[begin:(end - n_aug - 1)] .=
304288
du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ)
305289
du[(end - n_aug + 1)] = if NORM_Z
@@ -328,9 +312,7 @@ function augmented_f(
328312
n_aug = n_augment(icnf, mode)
329313
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
330314
z = u[begin:(end - n_aug - 1), :]
331-
ż, ϵJ =
332-
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
333-
ϵJ = only(ϵJ)
315+
ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ)
334316
= -sum(ϵJ .* ϵ; dims = 1)
335317
= transpose(if NORM_Z
336318
LinearAlgebra.norm.(eachcol(ż))
@@ -363,9 +345,7 @@ function augmented_f(
363345
n_aug = n_augment(icnf, mode)
364346
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
365347
z = u[begin:(end - n_aug - 1), :]
366-
ż, ϵJ =
367-
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
368-
ϵJ = only(ϵJ)
348+
ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ)
369349
du[begin:(end - n_aug - 1), :] .=
370350
du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1))
371351
du[(end - n_aug + 1), :] .= if NORM_Z
@@ -394,13 +374,7 @@ function augmented_f(
394374
n_aug = n_augment(icnf, mode)
395375
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
396376
z = u[begin:(end - n_aug - 1), :]
397-
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
398-
snn,
399-
icnf.compute_mode.adback,
400-
z,
401-
(ϵ,),
402-
)
403-
= only(Jϵ)
377+
ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ)
404378
= -sum.* Jϵ; dims = 1)
405379
= transpose(if NORM_Z
406380
LinearAlgebra.norm.(eachcol(ż))
@@ -433,13 +407,7 @@ function augmented_f(
433407
n_aug = n_augment(icnf, mode)
434408
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
435409
z = u[begin:(end - n_aug - 1), :]
436-
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
437-
snn,
438-
icnf.compute_mode.adback,
439-
z,
440-
(ϵ,),
441-
)
442-
= only(Jϵ)
410+
ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ)
443411
du[begin:(end - n_aug - 1), :] .=
444412
du[(end - n_aug), :] .= -vec(sum.* Jϵ; dims = 1))
445413
du[(end - n_aug + 1), :] .= if NORM_Z
@@ -468,8 +436,7 @@ function augmented_f(
468436
n_aug = n_augment(icnf, mode)
469437
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
470438
z = u[begin:(end - n_aug - 1), :]
471-
= snn(z)
472-
ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ)
439+
ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ)
473440
= -sum(ϵJ .* ϵ; dims = 1)
474441
= transpose(if NORM_Z
475442
LinearAlgebra.norm.(eachcol(ż))
@@ -502,8 +469,7 @@ function augmented_f(
502469
n_aug = n_augment(icnf, mode)
503470
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
504471
z = u[begin:(end - n_aug - 1), :]
505-
= snn(z)
506-
ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ)
472+
ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ)
507473
du[begin:(end - n_aug - 1), :] .=
508474
du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1))
509475
du[(end - n_aug + 1), :] .= if NORM_Z
@@ -532,8 +498,7 @@ function augmented_f(
532498
n_aug = n_augment(icnf, mode)
533499
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
534500
z = u[begin:(end - n_aug - 1), :]
535-
= snn(z)
536-
= Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ)
501+
ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ)
537502
= -sum.* Jϵ; dims = 1)
538503
= transpose(if NORM_Z
539504
LinearAlgebra.norm.(eachcol(ż))
@@ -566,8 +531,7 @@ function augmented_f(
566531
n_aug = n_augment(icnf, mode)
567532
snn = LuxCore.StatefulLuxLayer{true}(nn, p, st)
568533
z = u[begin:(end - n_aug - 1), :]
569-
= snn(z)
570-
= Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ)
534+
ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ)
571535
du[begin:(end - n_aug - 1), :] .=
572536
du[(end - n_aug), :] .= -vec(sum.* Jϵ; dims = 1))
573537
du[(end - n_aug + 1), :] .= if NORM_Z

src/utils.jl

Lines changed: 125 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,165 @@
1-
function jacobian_batched(
1+
function icnf_jacobian(
2+
icnf::AbstractICNF{<:AbstractFloat, <:DIVectorMode},
3+
::TestMode,
4+
f::LuxCore.StatefulLuxLayer,
5+
xs::AbstractVector{<:Real},
6+
)
7+
y = f(xs)
8+
return y,
9+
oftype(hcat(y), DifferentiationInterface.jacobian(f, icnf.compute_mode.adback, xs))
10+
end
11+
12+
function icnf_jacobian(
13+
icnf::AbstractICNF{<:AbstractFloat, <:DIMatrixMode},
14+
::TestMode,
15+
f::LuxCore.StatefulLuxLayer,
16+
xs::AbstractMatrix{<:Real},
17+
)
18+
y = f(xs)
19+
J = DifferentiationInterface.jacobian(f, icnf.compute_mode.adback, xs)
20+
return y,
21+
oftype(
22+
cat(y; dims = Val(3)),
23+
cat(
24+
(
25+
J[i:j, i:j] for (i, j) in zip(
26+
firstindex(J, 1):size(y, 1):lastindex(J, 1),
27+
(firstindex(J, 1) + size(y, 1) - 1):size(y, 1):lastindex(J, 1),
28+
)
29+
)...;
30+
dims = Val(3),
31+
),
32+
)
33+
end
34+
35+
function icnf_jacobian(
236
icnf::AbstractICNF{T, <:DIVecJacMatrixMode},
37+
::TestMode,
338
f::LuxCore.StatefulLuxLayer,
439
xs::AbstractMatrix{<:Real},
5-
) where {T}
40+
) where {T <: AbstractFloat}
641
y = f(xs)
742
z = similar(xs)
843
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
9-
res = Zygote.Buffer(
10-
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
11-
size(xs, 1),
12-
size(xs, 1),
13-
size(xs, 2),
14-
)
44+
res = Zygote.Buffer(y, size(xs, 1), size(xs, 1), size(xs, 2))
1545
for i in axes(xs, 1)
1646
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
1747
res[i, :, :] =
1848
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,)))
1949
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
2050
end
21-
return y, eachslice(copy(res); dims = 3)
51+
return y, oftype(cat(y; dims = Val(3)), copy(res))
2252
end
2353

24-
function jacobian_batched(
54+
function icnf_jacobian(
2555
icnf::AbstractICNF{T, <:DIJacVecMatrixMode},
56+
::TestMode,
2657
f::LuxCore.StatefulLuxLayer,
2758
xs::AbstractMatrix{<:Real},
28-
) where {T}
59+
) where {T <: AbstractFloat}
2960
y = f(xs)
3061
z = similar(xs)
3162
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
32-
res = Zygote.Buffer(
33-
convert.(promote_type(eltype(xs), eltype(f.ps)), xs),
34-
size(xs, 1),
35-
size(xs, 1),
36-
size(xs, 2),
37-
)
63+
res = Zygote.Buffer(y, size(xs, 1), size(xs, 1), size(xs, 2))
3864
for i in axes(xs, 1)
3965
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
4066
res[:, i, :] = only(
4167
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,)),
4268
)
4369
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
4470
end
45-
return y, eachslice(copy(res); dims = 3)
71+
return y, oftype(cat(y; dims = Val(3)), copy(res))
4672
end
4773

48-
function jacobian_batched(
49-
icnf::AbstractICNF{T, <:DIMatrixMode},
74+
function icnf_jacobian(
75+
icnf::AbstractICNF{<:AbstractFloat, <:LuxMatrixMode},
76+
::TestMode,
5077
f::LuxCore.StatefulLuxLayer,
5178
xs::AbstractMatrix{<:Real},
52-
) where {T}
53-
y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs)
54-
return y, split_jac(J, size(xs, 1))
79+
)
80+
y = f(xs)
81+
return y,
82+
oftype(cat(y; dims = Val(3)), Lux.batched_jacobian(f, icnf.compute_mode.adback, xs))
5583
end
5684

57-
function split_jac(x::AbstractMatrix{<:Real}, sz::Integer)
58-
return (
59-
x[i:j, i:j] for (i, j) in zip(
60-
firstindex(x, 1):sz:lastindex(x, 1),
61-
(firstindex(x, 1) + sz - 1):sz:lastindex(x, 1),
62-
)
85+
function icnf_jacobian(
86+
icnf::AbstractICNF{T, <:DIVecJacVectorMode},
87+
::TrainMode,
88+
f::LuxCore.StatefulLuxLayer,
89+
xs::AbstractVector{<:Real},
90+
ϵ::AbstractVector{T},
91+
) where {T <: AbstractFloat}
92+
y = f(xs)
93+
return y,
94+
oftype(
95+
y,
96+
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (ϵ,))),
97+
)
98+
end
99+
100+
function icnf_jacobian(
101+
icnf::AbstractICNF{T, <:DIJacVecVectorMode},
102+
::TrainMode,
103+
f::LuxCore.StatefulLuxLayer,
104+
xs::AbstractVector{<:Real},
105+
ϵ::AbstractVector{T},
106+
) where {T <: AbstractFloat}
107+
y = f(xs)
108+
return y,
109+
oftype(
110+
y,
111+
only(DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (ϵ,))),
112+
)
113+
end
114+
115+
function icnf_jacobian(
116+
icnf::AbstractICNF{T, <:DIVecJacMatrixMode},
117+
::TrainMode,
118+
f::LuxCore.StatefulLuxLayer,
119+
xs::AbstractMatrix{<:Real},
120+
ϵ::AbstractMatrix{T},
121+
) where {T <: AbstractFloat}
122+
y = f(xs)
123+
return y,
124+
oftype(
125+
y,
126+
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (ϵ,))),
127+
)
128+
end
129+
130+
function icnf_jacobian(
131+
icnf::AbstractICNF{T, <:DIJacVecMatrixMode},
132+
::TrainMode,
133+
f::LuxCore.StatefulLuxLayer,
134+
xs::AbstractMatrix{<:Real},
135+
ϵ::AbstractMatrix{T},
136+
) where {T <: AbstractFloat}
137+
y = f(xs)
138+
return y,
139+
oftype(
140+
y,
141+
only(DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (ϵ,))),
63142
)
64143
end
65144

66-
function jacobian_batched(
67-
icnf::AbstractICNF{T, <:LuxMatrixMode},
145+
function icnf_jacobian(
146+
icnf::AbstractICNF{T, <:LuxVecJacMatrixMode},
147+
::TrainMode,
148+
f::LuxCore.StatefulLuxLayer,
149+
xs::AbstractMatrix{<:Real},
150+
ϵ::AbstractMatrix{T},
151+
) where {T <: AbstractFloat}
152+
y = f(xs)
153+
return y, oftype(y, Lux.vector_jacobian_product(f, icnf.compute_mode.adback, xs, ϵ))
154+
end
155+
156+
function icnf_jacobian(
157+
icnf::AbstractICNF{T, <:LuxJacVecMatrixMode},
158+
::TrainMode,
68159
f::LuxCore.StatefulLuxLayer,
69160
xs::AbstractMatrix{<:Real},
70-
) where {T}
161+
ϵ::AbstractMatrix{T},
162+
) where {T <: AbstractFloat}
71163
y = f(xs)
72-
J = Lux.batched_jacobian(f, icnf.compute_mode.adback, xs)
73-
return y, eachslice(J; dims = 3)
164+
return y, oftype(y, Lux.jacobian_vector_product(f, icnf.compute_mode.adback, xs, ϵ))
74165
end

0 commit comments

Comments
 (0)