Skip to content

Commit 8bce094

Browse files
committed
Run Formatter
1 parent 8f9a432 commit 8bce094

13 files changed

+163
-120
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.5.4"
4+
version = "3.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/devdocs/internal_interfaces.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ NonlinearSolve.AbstractNonlinearSolveCache
1313
```@docs
1414
NonlinearSolve.AbstractDescentAlgorithm
1515
NonlinearSolve.AbstractDescentCache
16+
NonlinearSolve.DescentResult
1617
```
1718

1819
## Approximate Jacobian

src/NonlinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ include("adtypes.jl")
4545
include("timer_outputs.jl")
4646
include("internal/helpers.jl")
4747

48+
include("descent/common.jl")
4849
include("descent/newton.jl")
4950
include("descent/steepest.jl")
5051
include("descent/dogleg.jl")
5152
include("descent/damped_newton.jl")
5253
include("descent/geodesic_acceleration.jl")
54+
include("descent/multistep.jl")
5355

5456
include("internal/operators.jl")
5557
include("internal/jacobian.jl")

src/abstract_types.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ Abstract Type for all Descent Caches.
6666
### `__internal_solve!` specification
6767
6868
```julia
69-
δu, success, intermediates = __internal_solve!(cache::AbstractDescentCache, J, fu, u,
70-
idx::Val; skip_solve::Bool = false, kwargs...)
69+
descent_result = __internal_solve!(cache::AbstractDescentCache, J, fu, u, idx::Val;
70+
skip_solve::Bool = false, kwargs...)
7171
```
7272
7373
- `J`: Jacobian or Inverse Jacobian (if `pre_inverted = Val(true)`).
@@ -79,14 +79,7 @@ Abstract Type for all Descent Caches.
7979
direction was rejected and we want to try with a modified trust region.
8080
- `kwargs`: keyword arguments to pass to the linear solver if there is one.
8181
82-
#### Returned values
83-
84-
- `δu`: the descent direction.
85-
- `success`: Certain Descent Algorithms can reject a descent direction for example
86-
`GeodesicAcceleration`.
87-
- `intermediates`: A named tuple containing intermediates computed during the solve.
88-
For example, `GeodesicAcceleration` returns `NamedTuple{(:v, :a)}` containing the
89-
"velocity" and "acceleration" terms.
82+
Returns a result of type [`DescentResult`](@ref).
9083
9184
### Interface Functions
9285

src/core/approximate_jacobian.jl

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
163163

164164
linsolve = get_linear_solver(alg.descent)
165165
initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p;
166-
linsolve,
167-
maxiters, internalnorm)
166+
linsolve, maxiters, internalnorm)
168167

169168
abstol, reltol, termination_cache = init_termination_cache(abstol, reltol, fu, u,
170169
termination_condition)
@@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
222221
new_jacobian = true
223222
@static_timeit cache.timer "jacobian init/reinit" begin
224223
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
225-
J_init = __internal_solve!(cache.initialization_cache,
226-
cache.fu,
227-
cache.u,
224+
J_init = __internal_solve!(cache.initialization_cache, cache.fu, cache.u,
228225
Val(false))
229226
if INV
230227
if jacobian_initialized_preinverted(cache.initialization_cache.alg)
@@ -283,54 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
283280
@static_timeit cache.timer "descent" begin
284281
if cache.trustregion_cache !== nothing &&
285282
hasfield(typeof(cache.trustregion_cache), :trust_region)
286-
δu, descent_success, descent_intermediates = __internal_solve!(
287-
cache.descent_cache,
288-
J, cache.fu, cache.u; new_jacobian,
289-
trust_region = cache.trustregion_cache.trust_region)
283+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
284+
new_jacobian, trust_region = cache.trustregion_cache.trust_region)
290285
else
291-
δu, descent_success, descent_intermediates = __internal_solve!(
292-
cache.descent_cache,
293-
J, cache.fu, cache.u; new_jacobian)
286+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
287+
new_jacobian)
294288
end
295289
end
296290

297-
if descent_success
298-
if GB === :LineSearch
299-
@static_timeit cache.timer "linesearch" begin
300-
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
301-
end
302-
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
303-
cache.force_reinit = true
304-
else
305-
@static_timeit cache.timer "step" begin
306-
@bb axpy!(α, δu, cache.u)
307-
evaluate_f!(cache, cache.u, cache.p)
308-
end
309-
end
310-
elseif GB === :TrustRegion
311-
@static_timeit cache.timer "trustregion" begin
312-
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J,
313-
cache.fu, cache.u, δu, descent_intermediates)
314-
if tr_accepted
315-
@bb copyto!(cache.u, u_new)
316-
@bb copyto!(cache.fu, fu_new)
317-
end
318-
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
319-
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
320-
cache.retcode = ReturnCode.ShrinkThresholdExceeded
321-
cache.force_stop = true
322-
end
323-
end
324-
α = true
325-
elseif GB === :None
291+
if descent_result.success
292+
if GB === :None
326293
@static_timeit cache.timer "step" begin
327-
@bb axpy!(1, δu, cache.u)
294+
if descent_result.u !== missing
295+
@bb copyto!(cache.u, descent_result.u)
296+
elseif descent_result.δu !== missing
297+
@bb axpy!(1, descent_result.δu, cache.u)
298+
else
299+
error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \
300+
specified.")
301+
end
328302
evaluate_f!(cache, cache.u, cache.p)
329303
end
330304
α = true
331305
else
332-
error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
333-
:TrustRegion, :None)")
306+
δu = descent_result.δu
307+
@assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`."
308+
309+
if GB === :LineSearch
310+
@static_timeit cache.timer "linesearch" begin
311+
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
312+
end
313+
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
314+
cache.force_reinit = true
315+
else
316+
@static_timeit cache.timer "step" begin
317+
@bb axpy!(α, δu, cache.u)
318+
evaluate_f!(cache, cache.u, cache.p)
319+
end
320+
end
321+
elseif GB === :TrustRegion
322+
@static_timeit cache.timer "trustregion" begin
323+
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache,
324+
J, cache.fu, cache.u, δu, descent_result.extras)
325+
if tr_accepted
326+
@bb copyto!(cache.u, u_new)
327+
@bb copyto!(cache.fu, fu_new)
328+
α = true
329+
else
330+
α = false
331+
end
332+
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
333+
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
334+
cache.retcode = ReturnCode.ShrinkThresholdExceeded
335+
cache.force_stop = true
336+
end
337+
end
338+
else
339+
error("Unknown Globalization Strategy: $(GB). Allowed values are \
340+
(:LineSearch, :TrustRegion, :None)")
341+
end
334342
end
335343
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
336344
else

src/core/generalized_first_order.jl

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -215,59 +215,67 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
215215
@static_timeit cache.timer "descent" begin
216216
if cache.trustregion_cache !== nothing &&
217217
hasfield(typeof(cache.trustregion_cache), :trust_region)
218-
δu, descent_success, descent_intermediates = __internal_solve!(
219-
cache.descent_cache,
220-
J, cache.fu, cache.u; new_jacobian,
221-
trust_region = cache.trustregion_cache.trust_region)
218+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
219+
new_jacobian, trust_region = cache.trustregion_cache.trust_region)
222220
else
223-
δu, descent_success, descent_intermediates = __internal_solve!(
224-
cache.descent_cache,
225-
J, cache.fu, cache.u; new_jacobian)
221+
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
222+
new_jacobian)
226223
end
227224
end
228225

229-
if descent_success
226+
if descent_result.success
230227
cache.make_new_jacobian = true
231-
if GB === :LineSearch
232-
@static_timeit cache.timer "linesearch" begin
233-
linesearch_failed, α = __internal_solve!(cache.linesearch_cache,
234-
cache.u, δu)
235-
end
236-
if linesearch_failed
237-
cache.retcode = ReturnCode.InternalLineSearchFailed
238-
cache.force_stop = true
239-
end
228+
if GB === :None
240229
@static_timeit cache.timer "step" begin
241-
@bb axpy!(α, δu, cache.u)
242-
evaluate_f!(cache, cache.u, cache.p)
243-
end
244-
elseif GB === :TrustRegion
245-
@static_timeit cache.timer "trustregion" begin
246-
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J,
247-
cache.fu, cache.u, δu, descent_intermediates)
248-
if tr_accepted
249-
@bb copyto!(cache.u, u_new)
250-
@bb copyto!(cache.fu, fu_new)
251-
α = true
230+
if descent_result.u !== missing
231+
@bb copyto!(cache.u, descent_result.u)
232+
elseif descent_result.δu !== missing
233+
@bb axpy!(1, descent_result.δu, cache.u)
252234
else
253-
α = false
254-
cache.make_new_jacobian = false
235+
error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \
236+
specified.")
255237
end
256-
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
257-
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
258-
cache.retcode = ReturnCode.ShrinkThresholdExceeded
259-
cache.force_stop = true
260-
end
261-
end
262-
elseif GB === :None
263-
@static_timeit cache.timer "step" begin
264-
@bb axpy!(1, δu, cache.u)
265238
evaluate_f!(cache, cache.u, cache.p)
266239
end
267240
α = true
268241
else
269-
error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
270-
:TrustRegion, :None)")
242+
δu = descent_result.δu
243+
@assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`."
244+
245+
if GB === :LineSearch
246+
@static_timeit cache.timer "linesearch" begin
247+
failed, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
248+
end
249+
if failed
250+
cache.retcode = ReturnCode.InternalLineSearchFailed
251+
cache.force_stop = true
252+
else
253+
@static_timeit cache.timer "step" begin
254+
@bb axpy!(α, δu, cache.u)
255+
evaluate_f!(cache, cache.u, cache.p)
256+
end
257+
end
258+
elseif GB === :TrustRegion
259+
@static_timeit cache.timer "trustregion" begin
260+
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache,
261+
J, cache.fu, cache.u, δu, descent_result.extras)
262+
if tr_accepted
263+
@bb copyto!(cache.u, u_new)
264+
@bb copyto!(cache.fu, fu_new)
265+
α = true
266+
else
267+
α = false
268+
end
269+
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
270+
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
271+
cache.retcode = ReturnCode.ShrinkThresholdExceeded
272+
cache.force_stop = true
273+
end
274+
end
275+
else
276+
error("Unknown Globalization Strategy: $(GB). Allowed values are \
277+
(:LineSearch, :TrustRegion, :None)")
278+
end
271279
end
272280
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
273281
else

src/descent/common.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
3+
4+
Construct a `DescentResult` object.
5+
6+
### Keyword Arguments
7+
8+
* `δu`: The descent direction.
9+
* `u`: The new iterate. This is provided only for multi-step methods currently.
10+
* `success`: Certain Descent Algorithms can reject a descent direction for example
11+
[`GeodesicAcceleration`](@ref).
12+
* `extras`: A named tuple containing intermediates computed during the solve.
13+
For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing
14+
the "velocity" and "acceleration" terms.
15+
"""
16+
@concrete struct DescentResult
17+
δu
18+
u
19+
success::Bool
20+
extras
21+
end
22+
23+
function DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
24+
@assert δu !== missing || u !== missing
25+
return DescentResult(δu, u, success, extras)
26+
end

src/descent/damped_newton.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u,
138138
idx::Val{N} = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true,
139139
kwargs...) where {INV, N, mode}
140140
δu = get_du(cache, idx)
141-
skip_solve && return δu, true, (;)
141+
skip_solve && return DescentResult(; δu)
142142

143143
recompute_A = idx === Val(1)
144144

@@ -203,15 +203,14 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u,
203203
end
204204

205205
@static_timeit cache.timer "linear solve" begin
206-
δu = cache.lincache(; A, b,
207-
reuse_A_if_factorization = !new_jacobian && !recompute_A,
208-
kwargs..., linu = _vec(δu))
206+
δu = cache.lincache(; A, b, linu = _vec(δu),
207+
reuse_A_if_factorization = !new_jacobian && !recompute_A, kwargs...)
209208
δu = _restructure(get_du(cache, idx), δu)
210209
end
211210

212211
@bb @. δu *= -1
213212
set_du!(cache, δu, idx)
214-
return δu, true, (;)
213+
return DescentResult(; δu)
215214
end
216215

217216
# Define special concatenation for certain Array combinations

0 commit comments

Comments
 (0)