Skip to content

Commit 9ebdcab

Browse files
Setup NLsolve as an extension too
1 parent 330eb1a commit 9ebdcab

File tree

9 files changed

+372
-8
lines changed

9 files changed

+372
-8
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3030
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3131
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3232
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
33+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
3334

3435
[extensions]
3536
NonlinearSolveBandedMatricesExt = "BandedMatrices"
3637
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
37-
NonlinearSolveMINPACKExt = "MINPACK"
3838
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
39+
NonlinearSolveMINPACKExt = "MINPACK"
40+
NonlinearSolveNLsolveExt = "NLsolve"
3941

4042
[compat]
4143
ADTypes = "0.2"
@@ -78,6 +80,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7880
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
7981
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
8082
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
83+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
8184
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
8285
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8386
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -89,4 +92,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8992
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9093

9194
[targets]
92-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "MINPACK"]
95+
test = ["NLsolve", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "MINPACK"]

docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ IncompleteLU = "0.2"
2929
LinearSolve = "2"
3030
ModelingToolkit = "8"
3131
NonlinearSolve = "1, 2"
32-
NonlinearSolveMINPACK = "0.1"
3332
SciMLBase = "2.4"
3433
SciMLNLSolve = "0.1"
3534
SimpleNonlinearSolve = "0.1.5"

docs/src/api/nlsolve.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ the package before using these solvers:
66

77
```julia
88
using Pkg
9-
Pkg.add("SciMLNLSolve")
10-
using SciMLNLSolve
9+
Pkg.add("NLsolve")
10+
using NLsolve
1111
```
1212

1313
## Solver API

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ computationally expensive than direct methods.
113113
- `DynamicSS()` : Uses an ODE solver to find the steady state. Automatically
114114
terminates when close to the steady state.
115115

116-
### SciMLNLSolve.jl
116+
### NLsolve.jl
117117

118118
This is a wrapper package for importing solvers from NLsolve.jl into the SciML interface.
119119

@@ -127,8 +127,11 @@ Submethod choices for this algorithm include:
127127

128128
### MINPACK.jl
129129

130-
MINPACK.jl methods are good for medium-sized nonlinear solves. It does not scale due to
131-
the lack of sparse Jacobian support, though the methods are very robust and stable.
130+
MINPACK.jl methods are fine for medium-sized nonlinear solves. They are the FORTRAN
131+
standard methods which are used in many places, such as SciPy. However, our benchmarks
132+
demonstrate that these methods are not robust or stable. In addition, they are slower
133+
than the standard methods and do not scale due to lack of sparse Jacobian support.
134+
Thus they are only recommended for benchmarking and testing code conversions.
132135

133136
- `CMINPACK()`: A wrapper for using the classic MINPACK method through [MINPACK.jl](https://github.com/sglyon/MINPACK.jl)
134137

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
module NonlinearSolveNLsolveExt
2+
3+
using NonlinearSolve
4+
using NLsolve
5+
using LineSearches
6+
using DiffEqBase
7+
using SciMLBase
8+
9+
function SciMLBase.__solve(prob::Union{SciMLBase.AbstractSteadyStateProblem,
10+
SciMLBase.AbstractNonlinearProblem},
11+
alg::algType,
12+
args...;
13+
abstol = 1e-6,
14+
maxiters = 1000,
15+
kwargs...) where {algType <: NonlinearSolve.SciMLNLSolveAlgorithm}
16+
if typeof(prob.u0) <: Number
17+
u0 = [prob.u0]
18+
else
19+
u0 = deepcopy(prob.u0)
20+
end
21+
22+
iip = isinplace(prob)
23+
24+
sizeu = size(prob.u0)
25+
p = prob.p
26+
27+
# unwrapping alg params
28+
method = alg.method
29+
autodiff = alg.autodiff
30+
store_trace = alg.store_trace
31+
extended_trace = alg.extended_trace
32+
linesearch = alg.linesearch
33+
linsolve = alg.linsolve
34+
factor = alg.factor
35+
autoscale = alg.autoscale
36+
m = alg.m
37+
beta = alg.beta
38+
show_trace = alg.show_trace
39+
40+
### Fix the more general function to Sundials allowed style
41+
if typeof(prob.f) <: ODEFunction
42+
t = Inf
43+
if !iip && typeof(prob.u0) <: Number
44+
f! = (du, u) -> (du .= prob.f(first(u), p, t); Cint(0))
45+
elseif !iip && typeof(prob.u0) <: Vector{Float64}
46+
f! = (du, u) -> (du .= prob.f(u, p, t); Cint(0))
47+
elseif !iip && typeof(prob.u0) <: AbstractArray
48+
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0))
49+
elseif typeof(prob.u0) <: Vector{Float64}
50+
f! = (du, u) -> prob.f(du, u, p, t)
51+
else # Then it's an in-place function on an abstract array
52+
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t);
53+
du = vec(du);
54+
0)
55+
end
56+
elseif typeof(prob.f) <: NonlinearFunction
57+
if !iip && typeof(prob.u0) <: Number
58+
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
59+
elseif !iip && typeof(prob.u0) <: Vector{Float64}
60+
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
61+
elseif !iip && typeof(prob.u0) <: AbstractArray
62+
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
63+
elseif typeof(prob.u0) <: Vector{Float64}
64+
f! = (du, u) -> prob.f(du, u, p)
65+
else # Then it's an in-place function on an abstract array
66+
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p);
67+
du = vec(du);
68+
0)
69+
end
70+
end
71+
72+
resid = similar(u0)
73+
f!(resid, u0)
74+
75+
if SciMLBase.has_jac(prob.f)
76+
if !iip && typeof(prob.u0) <: Number
77+
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
78+
elseif !iip && typeof(prob.u0) <: Vector{T} where {T <: Number}
79+
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
80+
elseif !iip && typeof(prob.u0) <: AbstractArray
81+
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
82+
elseif typeof(prob.u0) <: Vector{T} where {T <: Number}
83+
g! = (du, u) -> prob.f.jac(du, u, p)
84+
else # Then it's an in-place function on an abstract array
85+
g! = (du, u) -> (prob.f.jac(reshape(du, sizeu), reshape(u, sizeu), p);
86+
du = vec(du);
87+
0)
88+
end
89+
if prob.f.jac_prototype !== nothing
90+
J = zero(prob.f.jac_prototype)
91+
df = OnceDifferentiable(f!, g!, u0, resid, J)
92+
else
93+
df = OnceDifferentiable(f!, g!, u0, resid)
94+
end
95+
else
96+
df = OnceDifferentiable(f!, u0, resid, autodiff = autodiff)
97+
end
98+
99+
original = nlsolve(df, u0,
100+
ftol = abstol,
101+
iterations = maxiters,
102+
method = method,
103+
store_trace = store_trace,
104+
extended_trace = extended_trace,
105+
linesearch = linesearch,
106+
linsolve = linsolve,
107+
factor = factor,
108+
autoscale = autoscale,
109+
m = m,
110+
beta = beta,
111+
show_trace = show_trace)
112+
113+
u = reshape(original.zero, size(u0))
114+
f!(resid, u)
115+
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
116+
ReturnCode.Failure
117+
stats = SciMLBase.NLStats(original.f_calls,
118+
original.g_calls,
119+
original.g_calls,
120+
original.g_calls,
121+
original.iterations)
122+
SciMLBase.build_solution(prob, alg, u, resid; retcode = retcode,
123+
original = original, stats = stats)
124+
end
125+
126+
end

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, Pseu
148148
GeneralBroyden, GeneralKlement, LimitedMemoryBroyden
149149
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
150150
export RobustMultiNewton, FastShortcutNonlinearPolyalg
151+
export CMINPACK, NLSolveJL
151152

152153
export LineSearch, LiFukushimaLineSearch
153154

src/extension_algs.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,97 @@ end
118118
function CMINPACK(; show_trace::Bool = false, tracing::Bool = false, method::Symbol = :hybr,
119119
io::IO = stdout)
120120
CMINPACK(show_trace, tracing, method, io)
121+
end
122+
123+
abstract type SciMLNLSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
124+
125+
"""
126+
```julia
127+
NLSolveJL(;
128+
method=:trust_region,
129+
autodiff=:central,
130+
store_trace=false,
131+
extended_trace=false,
132+
linesearch=LineSearches.Static(),
133+
linsolve=(x, A, b) -> copyto!(x, A\\b),
134+
factor = one(Float64),
135+
autoscale=true,
136+
m=10,
137+
beta=one(Float64),
138+
show_trace=false,
139+
)
140+
```
141+
142+
### Keyword Arguments
143+
144+
- `method`: the choice of method for solving the nonlinear system.
145+
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or
146+
central differencing via FiniteDiff.jl. The other choices are `:forward`
147+
- `show_trace`: should a trace of the optimization algorithm's state be shown on STDOUT?
148+
Default: false.
149+
- `extended_trace`: should additional algorithm internals be added to the state trace?
150+
Default: false.
151+
- `linesearch`: the line search method to be used within the solver method. The choices
152+
are line search types from
153+
[LineSearches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl). Defaults to
154+
`LineSearches.Static()`.
155+
- `linsolve`: a function `linsolve(x, A, b)` that solves `Ax = b`. Defaults to using Julia's
156+
`\\`.
157+
- `factor``: determines the size of the initial trust region. This size is set to the
158+
product of factor and the euclidean norm of `u0` if nonzero, or else to factor itself.
159+
Default: 1.0.
160+
- `autoscale`: if true, then the variables will be automatically rescaled. The scaling
161+
factors are the norms of the Jacobian columns. Default: true.
162+
- `m`: the amount of history in the Anderson method. Naive "Picard"-style iteration can be
163+
achieved by setting m=0, but that isn't advisable for contractions whose Lipschitz
164+
constants are close to 1. If convergence fails, though, you may consider lowering it.
165+
- `beta`: It is also known as DIIS or Pulay mixing, this method is based on the acceleration
166+
of the fixed-point iteration xₙ₊₁ = xₙ + beta*f(xₙ), where by default beta=1.
167+
- `store_trace``: should a trace of the optimization algorithm's state be stored? Default:
168+
false.
169+
170+
### Submethod Choice
171+
172+
Choices for methods in `NLSolveJL`:
173+
174+
- `:anderson`: Anderson-accelerated fixed-point iteration
175+
- `:broyden`: Broyden's quasi-Newton method
176+
- `:newton`: Classical Newton method with an optional line search
177+
- `:trust_region`: Trust region Newton method (the default choice)
178+
179+
For more information on these arguments, consult the
180+
[NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl).
181+
"""
182+
struct NLSolveJL{LSH, LS} <: SciMLNLSolveAlgorithm
183+
# Refer for tuning parameter choices: https://github.com/JuliaNLSolvers/NLsolve.jl#automatic-differentiation
184+
method::Symbol
185+
autodiff::Symbol
186+
store_trace::Bool
187+
extended_trace::Bool
188+
linesearch::LSH
189+
linsolve::LS
190+
factor::Real
191+
autoscale::Bool
192+
m::Int
193+
beta::Real
194+
show_trace::Bool
195+
# aa_start::Int
196+
# droptol::Real
197+
end
198+
199+
function NLSolveJL(;
200+
method = :trust_region,
201+
autodiff = :central,
202+
store_trace = false,
203+
extended_trace = false,
204+
linesearch = LineSearches.Static(),
205+
linsolve = (x, A, b) -> copyto!(x, A \ b),
206+
factor = one(Float64),
207+
autoscale = true,
208+
m = 10,
209+
beta = one(Float64),
210+
show_trace = false)
211+
NLSolveJL{typeof(linesearch), typeof(linsolve)}(method, autodiff, store_trace,
212+
extended_trace, linesearch, linsolve,
213+
factor, autoscale, m, beta, show_trace)
121214
end

0 commit comments

Comments
 (0)