Skip to content

Commit 7dc8f71

Browse files
committed
just gradiant
1 parent 779ff28 commit 7dc8f71

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

src/utils.jl

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ function icnf_jacobian(
44
f::LuxCore.StatefulLuxLayer,
55
xs::AbstractVector{<:Real},
66
)
7-
y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs)
8-
return y, oftype(hcat(y), J)
7+
y = f(xs)
8+
return y,
9+
oftype(hcat(y), DifferentiationInterface.jacobian(f, icnf.compute_mode.adback, xs))
910
end
1011

1112
function icnf_jacobian(
@@ -14,7 +15,8 @@ function icnf_jacobian(
1415
f::LuxCore.StatefulLuxLayer,
1516
xs::AbstractMatrix{<:Real},
1617
)
17-
y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs)
18+
y = f(xs)
19+
J = DifferentiationInterface.jacobian(f, icnf.compute_mode.adback, xs)
1820
return y,
1921
oftype(
2022
cat(y; dims = Val(3)),
@@ -87,9 +89,12 @@ function icnf_jacobian(
8789
xs::AbstractVector{<:Real},
8890
ϵ::AbstractVector{T},
8991
) where {T <: AbstractFloat}
90-
y, ϵJ =
91-
DifferentiationInterface.value_and_pullback(f, icnf.compute_mode.adback, xs, (ϵ,))
92-
return y, oftype(y, only(ϵJ))
92+
y = f(xs)
93+
return y,
94+
oftype(
95+
y,
96+
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (ϵ,))),
97+
)
9398
end
9499

95100
function icnf_jacobian(
@@ -99,13 +104,12 @@ function icnf_jacobian(
99104
xs::AbstractVector{<:Real},
100105
ϵ::AbstractVector{T},
101106
) where {T <: AbstractFloat}
102-
y, Jϵ = DifferentiationInterface.value_and_pushforward(
103-
f,
104-
icnf.compute_mode.adback,
105-
xs,
106-
(ϵ,),
107+
y = f(xs)
108+
return y,
109+
oftype(
110+
y,
111+
only(DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (ϵ,))),
107112
)
108-
return y, oftype(y, only(Jϵ))
109113
end
110114

111115
function icnf_jacobian(
@@ -115,9 +119,12 @@ function icnf_jacobian(
115119
xs::AbstractMatrix{<:Real},
116120
ϵ::AbstractMatrix{T},
117121
) where {T <: AbstractFloat}
118-
y, ϵJ =
119-
DifferentiationInterface.value_and_pullback(f, icnf.compute_mode.adback, xs, (ϵ,))
120-
return y, oftype(y, only(ϵJ))
122+
y = f(xs)
123+
return y,
124+
oftype(
125+
y,
126+
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (ϵ,))),
127+
)
121128
end
122129

123130
function icnf_jacobian(
@@ -127,13 +134,12 @@ function icnf_jacobian(
127134
xs::AbstractMatrix{<:Real},
128135
ϵ::AbstractMatrix{T},
129136
) where {T <: AbstractFloat}
130-
y, Jϵ = DifferentiationInterface.value_and_pushforward(
131-
f,
132-
icnf.compute_mode.adback,
133-
xs,
134-
(ϵ,),
137+
y = f(xs)
138+
return y,
139+
oftype(
140+
y,
141+
only(DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (ϵ,))),
135142
)
136-
return y, oftype(y, only(Jϵ))
137143
end
138144

139145
function icnf_jacobian(

0 commit comments

Comments
 (0)