Skip to content

Commit 62266d9

Browse files
committed
Reuse Klement Code
1 parent a5c6195 commit 62266d9

File tree

4 files changed

+61
-136
lines changed

4 files changed

+61
-136
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2626
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2828
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
29-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
29+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3030
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3131

3232
[weakdeps]
@@ -75,7 +75,6 @@ SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using
7575
SparseArrays = "<0.0.1, 1"
7676
SparseDiffTools = "2.14"
7777
StaticArrays = "1"
78-
StaticArraysCore = "1.4"
7978
Symbolics = "5"
8079
Test = "1"
8180
UnPack = "1.0"
@@ -99,7 +98,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9998
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10099
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
101100
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
102-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
103101
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
104102
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
105103
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/NonlinearSolve.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase,
12-
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
13-
SparseDiffTools
11+
using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf,
12+
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays
1413

1514
import ADTypes: AbstractFiniteDifferencesMode
1615
import ArrayInterface: undefmatrix, restructure, can_setindex,
@@ -26,10 +25,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
2625
AbstractVectorOfArray, recursivecopy!, recursivefill!
2726
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
2827
import SciMLOperators: FunctionOperator
29-
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
28+
import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
3029
import UnPack: @unpack
31-
32-
using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
3330
end
3431

3532
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve

src/klement.jl

Lines changed: 53 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -74,38 +74,44 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
7474
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
7575
linsolve_kwargs = (;), kwargs...) where {uType, iip, F}
7676
@unpack f, u0, p = prob
77-
u = alias_u0 ? u0 : deepcopy(u0)
77+
u = __maybe_unaliased(u0, alias_u0)
7878
fu = evaluate_f(prob, u)
7979
J = __init_identity_jacobian(u, fu)
80-
du = _mutable_zero(u)
80+
@bb du = similar(u)
8181

8282
if u isa Number
83-
linsolve = nothing
83+
linsolve = FakeLinearSolveJLCache(J, fu)
8484
alg = alg_
8585
else
8686
# For General Julia Arrays default to LU Factorization
87-
linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() :
88-
nothing
87+
linsolve_alg = (alg_.linsolve === nothing && (u isa Array || u isa StaticArray)) ?
88+
LUFactorization() : nothing
8989
alg = set_linsolve(alg_, linsolve_alg)
90-
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg)
90+
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg; linsolve_kwargs)
9191
end
9292

9393
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
9494
termination_condition)
9595
trace = init_nonlinearsolve_trace(alg, u, fu, J, du; kwargs...)
9696

97-
return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve,
98-
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
99-
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
100-
NLStats(1, 0, 0, 0, 0),
97+
@bb u_prev = copy(u)
98+
@bb fu2 = similar(fu)
99+
@bb J_cache = similar(J)
100+
@bb J_cache2 = similar(J)
101+
@bb Jᵀ²du = similar(fu)
102+
@bb Jdu = similar(fu)
103+
104+
return GeneralKlementCache{iip}(f, alg, u, u_prev, fu, fu2, du, p, linsolve, J, J_cache,
105+
J_cache2, Jᵀ²du, Jdu, 0, false, maxiters, internalnorm, ReturnCode.Default, abstol,
106+
reltol, prob, NLStats(1, 0, 0, 0, 0),
101107
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
102108
end
103109

104-
function perform_step!(cache::GeneralKlementCache{true})
105-
@unpack u, u_prev, fu, f, p, alg, J, linsolve, du = cache
106-
T = eltype(J)
107-
108-
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
110+
function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
111+
@unpack linsolve, alg = cache
112+
# @unpack fu, f, p, alg, J, linsolve = cache
113+
T = eltype(cache.J)
114+
singular, fact_done = __try_factorize_and_check_singular!(linsolve, cache.J)
109115

110116
if singular
111117
if cache.resets == alg.max_resets
@@ -114,135 +120,61 @@ function perform_step!(cache::GeneralKlementCache{true})
114120
return nothing
115121
end
116122
fact_done = false
117-
fill!(J, zero(T))
118-
J[diagind(J)] .= T(1)
123+
cache.J = __reinit_identity_jacobian!!(cache.J)
119124
cache.resets += 1
120125
end
121126

122127
# u = u - J \ fu
123-
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
124-
b = _vec(fu), linu = _vec(du), p, reltol = cache.abstol)
128+
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
129+
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
125130
cache.linsolve = linres.cache
126131

127-
# Line Search
128-
α = perform_linesearch!(cache.ls_cache, u, du)
129-
_axpy!(-α, du, u)
130-
f(cache.fu2, u, p)
131-
132-
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, J,
133-
cache.du, α)
134-
135-
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
136-
cache.stats.nf += 1
137-
cache.stats.nsolve += 1
138-
cache.stats.nfactors += 1
139-
140-
cache.force_stop && return nothing
141-
142-
# Update the Jacobian
143-
cache.du .*= -1
144-
cache.J_cache .= cache.J' .^ 2
145-
cache.Jdu .= _vec(du) .^ 2
146-
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
147-
mul!(cache.Jdu, J, _vec(du))
148-
cache.fu .= cache.fu2 .- cache.fu
149-
cache.fu .= _restructure(cache.fu,
150-
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(real(T))))
151-
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
152-
cache.J_cache .*= J
153-
mul!(cache.J_cache2, cache.J_cache, J)
154-
J .+= cache.J_cache2
155-
156-
@. u_prev = u
157-
cache.fu .= cache.fu2
158-
159-
return nothing
160-
end
161-
162-
function perform_step!(cache::GeneralKlementCache{false})
163-
@unpack fu, f, p, alg, J, linsolve = cache
132+
!iip && (cache.du = linres.u)
164133

165-
T = eltype(J)
166-
167-
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
168-
169-
if singular
170-
if cache.resets == alg.max_resets
171-
cache.force_stop = true
172-
cache.retcode = ReturnCode.ConvergenceFailure
173-
return nothing
174-
end
175-
fact_done = false
176-
cache.J = __init_identity_jacobian(cache.u, fu)
177-
cache.resets += 1
178-
end
134+
# Line Search
135+
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
136+
@bb axpy!(-α, cache.du, cache.u)
179137

180-
# u = u - J \ fu
181-
if linsolve === nothing
182-
cache.du = fu / cache.J
138+
if iip
139+
cache.f(cache.fu2, cache.u, cache.p)
183140
else
184-
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
185-
b = _vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
186-
cache.linsolve = linres.cache
141+
cache.fu2 = cache.f(cache.u, cache.p)
187142
end
188143

189-
# Line Search
190-
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
191-
cache.u = @. cache.u - α * cache.du # `u` might not support mutation
192-
cache.fu2 = f(cache.u, p)
193-
194-
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, J,
144+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, cache.J,
195145
cache.du, α)
196146

197147
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
198-
cache.u_prev = cache.u
148+
@bb copyto!(cache.u_prev, cache.u)
149+
199150
cache.stats.nf += 1
200151
cache.stats.nsolve += 1
201152
cache.stats.nfactors += 1
202153

203154
cache.force_stop && return nothing
204155

205156
# Update the Jacobian
206-
cache.du = -cache.du
207-
cache.J_cache = cache.J' .^ 2
208-
cache.Jdu = _vec(cache.du) .^ 2
209-
cache.Jᵀ²du = cache.J_cache * cache.Jdu
210-
cache.Jdu = J * _vec(cache.du)
211-
cache.fu = cache.fu2 .- cache.fu
212-
cache.fu = _restructure(cache.fu,
213-
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(real(T))))
214-
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
215-
cache.J = J .+ cache.J_cache
216-
217-
cache.fu = cache.fu2
157+
@bb cache.du .*= -1
158+
@bb cache.J_cache .= cache.J' .^ 2
159+
@bb @. cache.Jdu = cache.du ^ 2
160+
@bb cache.Jᵀ²du = cache.J_cache × vec(cache.Jdu)
161+
@bb cache.Jdu = cache.J × vec(cache.du)
162+
@bb @. cache.fu = cache.fu2 - cache.fu
218163

219-
return nothing
220-
end
164+
@bb @. cache.fu = (cache.fu - cache.Jdu) / max(cache.Jᵀ²du, eps(real(T)))
221165

222-
function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,
223-
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
224-
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
225-
cache.p = p
226-
if iip
227-
recursivecopy!(cache.u, u0)
228-
cache.f(cache.fu, cache.u, p)
229-
else
230-
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
231-
cache.u = u0
232-
cache.fu = cache.f(cache.u, p)
233-
end
166+
@bb cache.J_cache = vec(cache.fu) × transpose(_vec(cache.du))
167+
@bb @. cache.J_cache *= cache.J
168+
@bb cache.J_cache2 = cache.J_cache × cache.J
169+
@bb cache.J .+= cache.J_cache2
234170

235-
reset!(cache.trace)
236-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
237-
termination_condition)
171+
@bb copyto!(cache.fu, cache.fu2)
238172

239-
cache.abstol = abstol
240-
cache.reltol = reltol
241-
cache.tc_cache = tc_cache
242-
cache.maxiters = maxiters
243-
cache.stats.nf = 1
244-
cache.stats.nsteps = 1
245-
cache.force_stop = false
246-
cache.retcode = ReturnCode.Default
247-
return cache
173+
return nothing
174+
end
175+
176+
function __reinit_internal!(cache::GeneralKlementCache)
177+
cache.J = __reinit_identity_jacobian!!(cache.J)
178+
cache.resets = 0
179+
return nothing
248180
end

src/utils.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ end
356356

357357
# If factorization is LU then perform that and update the linsolve cache
358358
# else check if the matrix is singular
359-
function _try_factorize_and_check_singular!(linsolve, X)
360-
if linsolve.cacheval isa LU
359+
function __try_factorize_and_check_singular!(linsolve, X)
360+
if linsolve.cacheval isa LU || linsolve.cacheval isa StaticArrays.LU
361361
# LU Factorization was used
362362
linsolve.A = X
363363
linsolve.cacheval = LinearSolve.do_factorization(linsolve.alg, X, linsolve.b,
@@ -368,11 +368,9 @@ function _try_factorize_and_check_singular!(linsolve, X)
368368
end
369369
return _issingular(X), false
370370
end
371-
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false
372-
373-
@inline _reshape(x, args...) = reshape(x, args...)
374-
@inline _reshape(x::Number, args...) = x
371+
__try_factorize_and_check_singular!(::FakeLinearSolveJLCache, x) = _issingular(x), false
375372

373+
# TODO: Remove. handled in MaybeInplace.jl
376374
@generated function _axpy!(α, x, y)
377375
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
378376
return :(@. y += α * x)

0 commit comments

Comments
 (0)