@@ -74,38 +74,43 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
7474 termination_condition = nothing , internalnorm:: F = DEFAULT_NORM,
7575 linsolve_kwargs = (;), kwargs... ) where {uType, iip, F}
7676 @unpack f, u0, p = prob
77- u = alias_u0 ? u0 : deepcopy (u0)
77+ u = __maybe_unaliased (u0, alias_u0 )
7878 fu = evaluate_f (prob, u)
7979 J = __init_identity_jacobian (u, fu)
80- du = _mutable_zero (u)
80+ @bb du = similar (u)
8181
8282 if u isa Number
83- linsolve = nothing
83+ linsolve = FakeLinearSolveJLCache (J, fu)
8484 alg = alg_
8585 else
8686 # For General Julia Arrays default to LU Factorization
87- linsolve_alg = alg_. linsolve === nothing && u isa Array ? LUFactorization () :
88- nothing
87+ linsolve_alg = ( alg_. linsolve === nothing && ( u isa Array || u isa StaticArray)) ?
88+ LUFactorization () : nothing
8989 alg = set_linsolve (alg_, linsolve_alg)
90- linsolve = linsolve_caches (J, _vec (fu), _vec (du), p, alg)
90+ linsolve = linsolve_caches (J, _vec (fu), _vec (du), p, alg; linsolve_kwargs )
9191 end
9292
9393 abstol, reltol, tc_cache = init_termination_cache (abstol, reltol, fu, u,
9494 termination_condition)
9595 trace = init_nonlinearsolve_trace (alg, u, fu, J, du; kwargs... )
9696
97- return GeneralKlementCache {iip} (f, alg, u, zero (u), fu, zero (fu), du, p, linsolve,
98- J, zero (J), zero (J), _vec (zero (fu)), _vec (zero (fu)), 0 , false ,
99- maxiters, internalnorm, ReturnCode. Default, abstol, reltol, prob,
100- NLStats (1 , 0 , 0 , 0 , 0 ),
97+ @bb u_prev = copy (u)
98+ @bb fu2 = similar (fu)
99+ @bb J_cache = similar (J)
100+ @bb J_cache2 = similar (J)
101+ @bb Jᵀ²du = similar (fu)
102+ @bb Jdu = similar (fu)
103+
104+ return GeneralKlementCache {iip} (f, alg, u, u_prev, fu, fu2, du, p, linsolve, J, J_cache,
105+ J_cache2, Jᵀ²du, Jdu, 0 , false , maxiters, internalnorm, ReturnCode. Default, abstol,
106+ reltol, prob, NLStats (1 , 0 , 0 , 0 , 0 ),
101107 init_linesearch_cache (alg. linesearch, f, u, p, fu, Val (iip)), tc_cache, trace)
102108end
103109
104- function perform_step! (cache:: GeneralKlementCache{true} )
105- @unpack u, u_prev, fu, f, p, alg, J, linsolve, du = cache
106- T = eltype (J)
107-
108- singular, fact_done = _try_factorize_and_check_singular! (linsolve, J)
110+ function perform_step! (cache:: GeneralKlementCache{iip} ) where {iip}
111+ @unpack linsolve, alg = cache
112+ T = eltype (cache. J)
113+ singular, fact_done = __try_factorize_and_check_singular! (linsolve, cache. J)
109114
110115 if singular
111116 if cache. resets == alg. max_resets
@@ -114,135 +119,61 @@ function perform_step!(cache::GeneralKlementCache{true})
114119 return nothing
115120 end
116121 fact_done = false
117- fill! (J, zero (T))
118- J[diagind (J)] .= T (1 )
122+ cache. J = __reinit_identity_jacobian!! (cache. J)
119123 cache. resets += 1
120124 end
121125
122126 # u = u - J \ fu
123- linres = dolinsolve (alg. precs, linsolve; A = ifelse (fact_done, nothing , J ),
124- b = _vec (fu), linu = _vec (du), p, reltol = cache. abstol)
127+ linres = dolinsolve (alg. precs, cache . linsolve; A = cache . J, b = _vec (cache . fu ),
128+ linu = _vec (cache . du), cache . p, reltol = cache. abstol)
125129 cache. linsolve = linres. cache
126130
127- # Line Search
128- α = perform_linesearch! (cache. ls_cache, u, du)
129- _axpy! (- α, du, u)
130- f (cache. fu2, u, p)
131-
132- update_trace! (cache. trace, cache. stats. nsteps + 1 , get_u (cache), cache. fu2, J,
133- cache. du, α)
134-
135- check_and_update! (cache, cache. fu2, cache. u, cache. u_prev)
136- cache. stats. nf += 1
137- cache. stats. nsolve += 1
138- cache. stats. nfactors += 1
139-
140- cache. force_stop && return nothing
141-
142- # Update the Jacobian
143- cache. du .*= - 1
144- cache. J_cache .= cache. J' .^ 2
145- cache. Jdu .= _vec (du) .^ 2
146- mul! (cache. Jᵀ²du, cache. J_cache, cache. Jdu)
147- mul! (cache. Jdu, J, _vec (du))
148- cache. fu .= cache. fu2 .- cache. fu
149- cache. fu .= _restructure (cache. fu,
150- (_vec (cache. fu) .- cache. Jdu) ./ max .(cache. Jᵀ²du, eps (real (T))))
151- mul! (cache. J_cache, _vec (cache. fu), _vec (du)' )
152- cache. J_cache .*= J
153- mul! (cache. J_cache2, cache. J_cache, J)
154- J .+ = cache. J_cache2
155-
156- @. u_prev = u
157- cache. fu .= cache. fu2
158-
159- return nothing
160- end
161-
162- function perform_step! (cache:: GeneralKlementCache{false} )
163- @unpack fu, f, p, alg, J, linsolve = cache
131+ ! iip && (cache. du = linres. u)
164132
165- T = eltype (J)
166-
167- singular, fact_done = _try_factorize_and_check_singular! (linsolve, J)
168-
169- if singular
170- if cache. resets == alg. max_resets
171- cache. force_stop = true
172- cache. retcode = ReturnCode. ConvergenceFailure
173- return nothing
174- end
175- fact_done = false
176- cache. J = __init_identity_jacobian (cache. u, fu)
177- cache. resets += 1
178- end
133+ # Line Search
134+ α = perform_linesearch! (cache. ls_cache, cache. u, cache. du)
135+ @bb axpy! (- α, cache. du, cache. u)
179136
180- # u = u - J \ fu
181- if linsolve === nothing
182- cache. du = fu / cache. J
137+ if iip
138+ cache. f (cache. fu2, cache. u, cache. p)
183139 else
184- linres = dolinsolve (alg. precs, linsolve; A = ifelse (fact_done, nothing , J),
185- b = _vec (fu), linu = _vec (cache. du), p, reltol = cache. abstol)
186- cache. linsolve = linres. cache
140+ cache. fu2 = cache. f (cache. u, cache. p)
187141 end
188142
189- # Line Search
190- α = perform_linesearch! (cache. ls_cache, cache. u, cache. du)
191- cache. u = @. cache. u - α * cache. du # `u` might not support mutation
192- cache. fu2 = f (cache. u, p)
193-
194- update_trace! (cache. trace, cache. stats. nsteps + 1 , get_u (cache), cache. fu2, J,
143+ update_trace! (cache. trace, cache. stats. nsteps + 1 , get_u (cache), cache. fu2, cache. J,
195144 cache. du, α)
196145
197146 check_and_update! (cache, cache. fu2, cache. u, cache. u_prev)
198- cache. u_prev = cache. u
147+ @bb copyto! (cache. u_prev, cache. u)
148+
199149 cache. stats. nf += 1
200150 cache. stats. nsolve += 1
201151 cache. stats. nfactors += 1
202152
203153 cache. force_stop && return nothing
204154
205155 # Update the Jacobian
206- cache. du = - cache. du
207- cache. J_cache = cache. J' .^ 2
208- cache. Jdu = _vec (cache. du) .^ 2
209- cache. Jᵀ²du = cache. J_cache * cache. Jdu
210- cache. Jdu = J * _vec (cache. du)
211- cache. fu = cache. fu2 .- cache. fu
212- cache. fu = _restructure (cache. fu,
213- (_vec (cache. fu) .- cache. Jdu) ./ max .(cache. Jᵀ²du, eps (real (T))))
214- cache. J_cache = ((_vec (cache. fu) * _vec (cache. du)' ) .* J) * J
215- cache. J = J .+ cache. J_cache
216-
217- cache. fu = cache. fu2
156+ @bb cache. du .*= - 1
157+ @bb cache. J_cache .= cache. J' .^ 2
158+ @bb @. cache. Jdu = cache. du ^ 2
159+ @bb cache. Jᵀ²du = cache. J_cache × vec (cache. Jdu)
160+ @bb cache. Jdu = cache. J × vec (cache. du)
161+ @bb @. cache. fu = cache. fu2 - cache. fu
218162
219- return nothing
220- end
163+ @bb @. cache. fu = (cache. fu - cache. Jdu) / max (cache. Jᵀ²du, eps (real (T)))
221164
222- function SciMLBase. reinit! (cache:: GeneralKlementCache{iip} , u0 = cache. u; p = cache. p,
223- abstol = cache. abstol, reltol = cache. reltol, maxiters = cache. maxiters,
224- termination_condition = get_termination_mode (cache. tc_cache)) where {iip}
225- cache. p = p
226- if iip
227- recursivecopy! (cache. u, u0)
228- cache. f (cache. fu, cache. u, p)
229- else
230- # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
231- cache. u = u0
232- cache. fu = cache. f (cache. u, p)
233- end
165+ @bb cache. J_cache = vec (cache. fu) × transpose (_vec (cache. du))
166+ @bb @. cache. J_cache *= cache. J
167+ @bb cache. J_cache2 = cache. J_cache × cache. J
168+ @bb cache. J .+ = cache. J_cache2
234169
235- reset! (cache. trace)
236- abstol, reltol, tc_cache = init_termination_cache (abstol, reltol, cache. fu, cache. u,
237- termination_condition)
170+ @bb copyto! (cache. fu, cache. fu2)
238171
239- cache. abstol = abstol
240- cache. reltol = reltol
241- cache. tc_cache = tc_cache
242- cache. maxiters = maxiters
243- cache. stats. nf = 1
244- cache. stats. nsteps = 1
245- cache. force_stop = false
246- cache. retcode = ReturnCode. Default
247- return cache
172+ return nothing
173+ end
174+
175+ function __reinit_internal! (cache:: GeneralKlementCache )
176+ cache. J = __reinit_identity_jacobian!! (cache. J)
177+ cache. resets = 0
178+ return nothing
248179end
0 commit comments