1+ module NonlinearSolveNLsolveExt
2+
3+ using NonlinearSolve
4+ using NLsolve
5+ using LineSearches
6+ using DiffEqBase
7+ using SciMLBase
8+
9+ function SciMLBase. __solve (prob:: Union {SciMLBase. AbstractSteadyStateProblem,
10+ SciMLBase. AbstractNonlinearProblem},
11+ alg:: algType ,
12+ args... ;
13+ abstol = 1e-6 ,
14+ maxiters = 1000 ,
15+ kwargs... ) where {algType <: NonlinearSolve.SciMLNLSolveAlgorithm }
16+ if typeof (prob. u0) <: Number
17+ u0 = [prob. u0]
18+ else
19+ u0 = deepcopy (prob. u0)
20+ end
21+
22+ iip = isinplace (prob)
23+
24+ sizeu = size (prob. u0)
25+ p = prob. p
26+
27+ # unwrapping alg params
28+ method = alg. method
29+ autodiff = alg. autodiff
30+ store_trace = alg. store_trace
31+ extended_trace = alg. extended_trace
32+ linesearch = alg. linesearch
33+ linsolve = alg. linsolve
34+ factor = alg. factor
35+ autoscale = alg. autoscale
36+ m = alg. m
37+ beta = alg. beta
38+ show_trace = alg. show_trace
39+
40+ # ## Fix the more general function to Sundials allowed style
41+ if typeof (prob. f) <: ODEFunction
42+ t = Inf
43+ if ! iip && typeof (prob. u0) <: Number
44+ f! = (du, u) -> (du .= prob. f (first (u), p, t); Cint (0 ))
45+ elseif ! iip && typeof (prob. u0) <: Vector{Float64}
46+ f! = (du, u) -> (du .= prob. f (u, p, t); Cint (0 ))
47+ elseif ! iip && typeof (prob. u0) <: AbstractArray
48+ f! = (du, u) -> (du .= vec (prob. f (reshape (u, sizeu), p, t)); Cint (0 ))
49+ elseif typeof (prob. u0) <: Vector{Float64}
50+ f! = (du, u) -> prob. f (du, u, p, t)
51+ else # Then it's an in-place function on an abstract array
52+ f! = (du, u) -> (prob. f (reshape (du, sizeu), reshape (u, sizeu), p, t);
53+ du = vec (du);
54+ 0 )
55+ end
56+ elseif typeof (prob. f) <: NonlinearFunction
57+ if ! iip && typeof (prob. u0) <: Number
58+ f! = (du, u) -> (du .= prob. f (first (u), p); Cint (0 ))
59+ elseif ! iip && typeof (prob. u0) <: Vector{Float64}
60+ f! = (du, u) -> (du .= prob. f (u, p); Cint (0 ))
61+ elseif ! iip && typeof (prob. u0) <: AbstractArray
62+ f! = (du, u) -> (du .= vec (prob. f (reshape (u, sizeu), p)); Cint (0 ))
63+ elseif typeof (prob. u0) <: Vector{Float64}
64+ f! = (du, u) -> prob. f (du, u, p)
65+ else # Then it's an in-place function on an abstract array
66+ f! = (du, u) -> (prob. f (reshape (du, sizeu), reshape (u, sizeu), p);
67+ du = vec (du);
68+ 0 )
69+ end
70+ end
71+
72+ resid = similar (u0)
73+ f! (resid, u0)
74+
75+ if SciMLBase. has_jac (prob. f)
76+ if ! iip && typeof (prob. u0) <: Number
77+ g! = (du, u) -> (du .= prob. f. jac (first (u), p); Cint (0 ))
78+ elseif ! iip && typeof (prob. u0) <: Vector{T} where {T <: Number }
79+ g! = (du, u) -> (du .= prob. f. jac (u, p); Cint (0 ))
80+ elseif ! iip && typeof (prob. u0) <: AbstractArray
81+ g! = (du, u) -> (du .= vec (prob. f. jac (reshape (u, sizeu), p)); Cint (0 ))
82+ elseif typeof (prob. u0) <: Vector{T} where {T <: Number }
83+ g! = (du, u) -> prob. f. jac (du, u, p)
84+ else # Then it's an in-place function on an abstract array
85+ g! = (du, u) -> (prob. f. jac (reshape (du, sizeu), reshape (u, sizeu), p);
86+ du = vec (du);
87+ 0 )
88+ end
89+ if prob. f. jac_prototype != = nothing
90+ J = zero (prob. f. jac_prototype)
91+ df = OnceDifferentiable (f!, g!, u0, resid, J)
92+ else
93+ df = OnceDifferentiable (f!, g!, u0, resid)
94+ end
95+ else
96+ df = OnceDifferentiable (f!, u0, resid, autodiff = autodiff)
97+ end
98+
99+ original = nlsolve (df, u0,
100+ ftol = abstol,
101+ iterations = maxiters,
102+ method = method,
103+ store_trace = store_trace,
104+ extended_trace = extended_trace,
105+ linesearch = linesearch,
106+ linsolve = linsolve,
107+ factor = factor,
108+ autoscale = autoscale,
109+ m = m,
110+ beta = beta,
111+ show_trace = show_trace)
112+
113+ u = reshape (original. zero, size (u0))
114+ f! (resid, u)
115+ retcode = original. x_converged || original. f_converged ? ReturnCode. Success :
116+ ReturnCode. Failure
117+ stats = SciMLBase. NLStats (original. f_calls,
118+ original. g_calls,
119+ original. g_calls,
120+ original. g_calls,
121+ original. iterations)
122+ SciMLBase. build_solution (prob, alg, u, resid; retcode = retcode,
123+ original = original, stats = stats)
124+ end
125+
126+ end
0 commit comments