@@ -120,7 +120,7 @@ function augmented_f(
120120 n_aug = n_augment (icnf, mode)
121121 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
122122 z = u[begin : (end - n_aug - 1 )]
123- ż, J = DifferentiationInterface . value_and_jacobian (snn, icnf . compute_mode . adback , z)
123+ ż, J = icnf_jacobian (icnf, mode, snn , z)
124124 l̇ = - LinearAlgebra. tr (J)
125125 return vcat (ż, l̇)
126126end
@@ -139,7 +139,7 @@ function augmented_f(
139139 n_aug = n_augment (icnf, mode)
140140 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
141141 z = u[begin : (end - n_aug - 1 )]
142- ż, J = DifferentiationInterface . value_and_jacobian (snn, icnf . compute_mode . adback , z)
142+ ż, J = icnf_jacobian (icnf, mode, snn , z)
143143 du[begin : (end - n_aug - 1 )] .= ż
144144 du[(end - n_aug)] = - LinearAlgebra. tr (J)
145145 return nothing
@@ -158,8 +158,8 @@ function augmented_f(
158158 n_aug = n_augment (icnf, mode)
159159 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
160160 z = u[begin : (end - n_aug - 1 ), :]
161- ż, J = jacobian_batched (icnf, snn, z)
162- l̇ = - transpose (LinearAlgebra. tr .(J ))
161+ ż, J = icnf_jacobian (icnf, mode , snn, z)
162+ l̇ = - transpose (LinearAlgebra. tr .(eachslice (J; dims = 3 ) ))
163163 return vcat (ż, l̇)
164164end
165165
@@ -177,9 +177,9 @@ function augmented_f(
177177 n_aug = n_augment (icnf, mode)
178178 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
179179 z = u[begin : (end - n_aug - 1 ), :]
180- ż, J = jacobian_batched (icnf, snn, z)
180+ ż, J = icnf_jacobian (icnf, mode , snn, z)
181181 du[begin : (end - n_aug - 1 ), :] .= ż
182- du[(end - n_aug), :] .= - (LinearAlgebra. tr .(J ))
182+ du[(end - n_aug), :] .= - (LinearAlgebra. tr .(eachslice (J; dims = 3 ) ))
183183 return nothing
184184end
185185
@@ -196,9 +196,7 @@ function augmented_f(
196196 n_aug = n_augment (icnf, mode)
197197 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
198198 z = u[begin : (end - n_aug - 1 )]
199- ż, ϵJ =
200- DifferentiationInterface. value_and_pullback (snn, icnf. compute_mode. adback, z, (ϵ,))
201- ϵJ = only (ϵJ)
199+ ż, ϵJ = icnf_jacobian (icnf, mode, snn, z, ϵ)
202200 l̇ = - LinearAlgebra. dot (ϵJ, ϵ)
203201 Ė = if NORM_Z
204202 LinearAlgebra. norm (ż)
@@ -227,9 +225,7 @@ function augmented_f(
227225 n_aug = n_augment (icnf, mode)
228226 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
229227 z = u[begin : (end - n_aug - 1 )]
230- ż, ϵJ =
231- DifferentiationInterface. value_and_pullback (snn, icnf. compute_mode. adback, z, (ϵ,))
232- ϵJ = only (ϵJ)
228+ ż, ϵJ = icnf_jacobian (icnf, mode, snn, z, ϵ)
233229 du[begin : (end - n_aug - 1 )] .= ż
234230 du[(end - n_aug)] = - LinearAlgebra. dot (ϵJ, ϵ)
235231 du[(end - n_aug + 1 )] = if NORM_Z
@@ -258,13 +254,7 @@ function augmented_f(
258254 n_aug = n_augment (icnf, mode)
259255 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
260256 z = u[begin : (end - n_aug - 1 )]
261- ż, Jϵ = DifferentiationInterface. value_and_pushforward (
262- snn,
263- icnf. compute_mode. adback,
264- z,
265- (ϵ,),
266- )
267- Jϵ = only (Jϵ)
257+ ż, Jϵ = icnf_jacobian (icnf, mode, snn, z, ϵ)
268258 l̇ = - LinearAlgebra. dot (ϵ, Jϵ)
269259 Ė = if NORM_Z
270260 LinearAlgebra. norm (ż)
@@ -293,13 +283,7 @@ function augmented_f(
293283 n_aug = n_augment (icnf, mode)
294284 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
295285 z = u[begin : (end - n_aug - 1 )]
296- ż, Jϵ = DifferentiationInterface. value_and_pushforward (
297- snn,
298- icnf. compute_mode. adback,
299- z,
300- (ϵ,),
301- )
302- Jϵ = only (Jϵ)
286+ ż, Jϵ = icnf_jacobian (icnf, mode, snn, z, ϵ)
303287 du[begin : (end - n_aug - 1 )] .= ż
304288 du[(end - n_aug)] = - LinearAlgebra. dot (ϵ, Jϵ)
305289 du[(end - n_aug + 1 )] = if NORM_Z
@@ -328,9 +312,7 @@ function augmented_f(
328312 n_aug = n_augment (icnf, mode)
329313 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
330314 z = u[begin : (end - n_aug - 1 ), :]
331- ż, ϵJ =
332- DifferentiationInterface. value_and_pullback (snn, icnf. compute_mode. adback, z, (ϵ,))
333- ϵJ = only (ϵJ)
315+ ż, ϵJ = icnf_jacobian (icnf, mode, snn, z, ϵ)
334316 l̇ = - sum (ϵJ .* ϵ; dims = 1 )
335317 Ė = transpose (if NORM_Z
336318 LinearAlgebra. norm .(eachcol (ż))
@@ -363,9 +345,7 @@ function augmented_f(
363345 n_aug = n_augment (icnf, mode)
364346 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
365347 z = u[begin : (end - n_aug - 1 ), :]
366- ż, ϵJ =
367- DifferentiationInterface. value_and_pullback (snn, icnf. compute_mode. adback, z, (ϵ,))
368- ϵJ = only (ϵJ)
348+ ż, ϵJ = icnf_jacobian (icnf, mode, snn, z, ϵ)
369349 du[begin : (end - n_aug - 1 ), :] .= ż
370350 du[(end - n_aug), :] .= - vec (sum (ϵJ .* ϵ; dims = 1 ))
371351 du[(end - n_aug + 1 ), :] .= if NORM_Z
@@ -394,13 +374,7 @@ function augmented_f(
394374 n_aug = n_augment (icnf, mode)
395375 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
396376 z = u[begin : (end - n_aug - 1 ), :]
397- ż, Jϵ = DifferentiationInterface. value_and_pushforward (
398- snn,
399- icnf. compute_mode. adback,
400- z,
401- (ϵ,),
402- )
403- Jϵ = only (Jϵ)
377+ ż, Jϵ = icnf_jacobian (icnf, mode, snn, z, ϵ)
404378 l̇ = - sum (ϵ .* Jϵ; dims = 1 )
405379 Ė = transpose (if NORM_Z
406380 LinearAlgebra. norm .(eachcol (ż))
@@ -433,13 +407,7 @@ function augmented_f(
433407 n_aug = n_augment (icnf, mode)
434408 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
435409 z = u[begin : (end - n_aug - 1 ), :]
436- ż, Jϵ = DifferentiationInterface. value_and_pushforward (
437- snn,
438- icnf. compute_mode. adback,
439- z,
440- (ϵ,),
441- )
442- Jϵ = only (Jϵ)
410+ ż, Jϵ = icnf_jacobian (icnf, mode, snn, z, ϵ)
443411 du[begin : (end - n_aug - 1 ), :] .= ż
444412 du[(end - n_aug), :] .= - vec (sum (ϵ .* Jϵ; dims = 1 ))
445413 du[(end - n_aug + 1 ), :] .= if NORM_Z
@@ -468,8 +436,7 @@ function augmented_f(
468436 n_aug = n_augment (icnf, mode)
469437 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
470438 z = u[begin : (end - n_aug - 1 ), :]
471- ż = snn (z)
472- ϵJ = Lux. vector_jacobian_product (snn, icnf. compute_mode. adback, z, ϵ)
439+ ż, ϵJ = icnf_jacobian (icnf, mode, snn, z, ϵ)
473440 l̇ = - sum (ϵJ .* ϵ; dims = 1 )
474441 Ė = transpose (if NORM_Z
475442 LinearAlgebra. norm .(eachcol (ż))
@@ -502,8 +469,7 @@ function augmented_f(
502469 n_aug = n_augment (icnf, mode)
503470 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
504471 z = u[begin : (end - n_aug - 1 ), :]
505- ż = snn (z)
506- ϵJ = Lux. vector_jacobian_product (snn, icnf. compute_mode. adback, z, ϵ)
472+ ż, ϵJ = icnf_jacobian (icnf, mode, snn, z, ϵ)
507473 du[begin : (end - n_aug - 1 ), :] .= ż
508474 du[(end - n_aug), :] .= - vec (sum (ϵJ .* ϵ; dims = 1 ))
509475 du[(end - n_aug + 1 ), :] .= if NORM_Z
@@ -532,8 +498,7 @@ function augmented_f(
532498 n_aug = n_augment (icnf, mode)
533499 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
534500 z = u[begin : (end - n_aug - 1 ), :]
535- ż = snn (z)
536- Jϵ = Lux. jacobian_vector_product (snn, icnf. compute_mode. adback, z, ϵ)
501+ ż, Jϵ = icnf_jacobian (icnf, mode, snn, z, ϵ)
537502 l̇ = - sum (ϵ .* Jϵ; dims = 1 )
538503 Ė = transpose (if NORM_Z
539504 LinearAlgebra. norm .(eachcol (ż))
@@ -566,8 +531,7 @@ function augmented_f(
566531 n_aug = n_augment (icnf, mode)
567532 snn = LuxCore. StatefulLuxLayer {true} (nn, p, st)
568533 z = u[begin : (end - n_aug - 1 ), :]
569- ż = snn (z)
570- Jϵ = Lux. jacobian_vector_product (snn, icnf. compute_mode. adback, z, ϵ)
534+ ż, Jϵ = icnf_jacobian (icnf, mode, snn, z, ϵ)
571535 du[begin : (end - n_aug - 1 ), :] .= ż
572536 du[(end - n_aug), :] .= - vec (sum (ϵ .* Jϵ; dims = 1 ))
573537 du[(end - n_aug + 1 ), :] .= if NORM_Z
0 commit comments