@@ -3,11 +3,12 @@ module NonlinearSolveFastLevenbergMarquardtExt
33using ArrayInterface, NonlinearSolve, SciMLBase
44import ConcreteStructs: @concrete
55import FastLevenbergMarquardt as FastLM
6+ import FiniteDiff, ForwardDiff
67
78function _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , x) where {linsolve}
8- if linsolve == :cholesky
9+ if linsolve === :cholesky
910 return FastLM. CholeskySolver (ArrayInterface. undefmatrix (x))
10- elseif linsolve == :qr
11+ elseif linsolve === :qr
1112 return FastLM. QRSolver (eltype (x), length (x))
1213 else
1314 throw (ArgumentError (" Unknown FastLevenbergMarquardt Linear Solver: $linsolve " ))
3334
3435function SciMLBase. __init (prob:: NonlinearLeastSquaresProblem ,
3536 alg:: FastLevenbergMarquardtJL , args... ; alias_u0 = false , abstol = 1e-8 ,
36- reltol = 1e-8 , verbose = false , maxiters = 1000 , kwargs... )
37+ reltol = 1e-8 , maxiters = 1000 , kwargs... )
3738 iip = SciMLBase. isinplace (prob)
38- u0 = alias_u0 ? prob. u0 : deepcopy (prob. u0)
39-
40- @assert prob. f. jac!= = nothing " FastLevenbergMarquardt requires a Jacobian!"
39+ u = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
40+ fu = NonlinearSolve. evaluate_f (prob, u)
4141
4242 f! = InplaceFunction {iip} (prob. f)
43- J! = InplaceFunction {iip} (prob. f. jac)
4443
45- resid_prototype = prob. f. resid_prototype === nothing ?
46- (! iip ? prob. f (u0, prob. p) : zeros (u0)) :
47- prob. f. resid_prototype
44+ if prob. f. jac === nothing
45+ use_forward_diff = if alg. autodiff === nothing
46+ ForwardDiff. can_dual (eltype (u))
47+ else
48+ alg. autodiff isa AutoForwardDiff
49+ end
50+ uf = SciMLBase. JacobianWrapper {iip} (prob. f, prob. p)
51+ if use_forward_diff
52+ cache = iip ? ForwardDiff. JacobianConfig (uf, fu, u) :
53+ ForwardDiff. JacobianConfig (uf, u)
54+ else
55+ cache = FiniteDiff. JacobianCache (u, fu)
56+ end
57+ J! = if iip
58+ if use_forward_diff
59+ fu_cache = similar (fu)
60+ function (J, x, p)
61+ uf. p = p
62+ ForwardDiff. jacobian! (J, uf, fu_cache, x, cache)
63+ return J
64+ end
65+ else
66+ function (J, x, p)
67+ uf. p = p
68+ FiniteDiff. finite_difference_jacobian! (J, uf, x, cache)
69+ return J
70+ end
71+ end
72+ else
73+ if use_forward_diff
74+ function (J, x, p)
75+ uf. p = p
76+ ForwardDiff. jacobian! (J, uf, x, cache)
77+ return J
78+ end
79+ else
80+ function (J, x, p)
81+ uf. p = p
82+ J_ = FiniteDiff. finite_difference_jacobian (uf, x, cache)
83+ copyto! (J, J_)
84+ return J
85+ end
86+ end
87+ end
88+ else
89+ J! = InplaceFunction {iip} (prob. f. jac)
90+ end
4891
49- J = similar (u0 , length (resid_prototype ), length (u0 ))
92+ J = similar (u , length (fu ), length (u ))
5093
51- solver = _fast_lm_solver (alg, u0 )
52- LM = FastLM. LMWorkspace (u0, resid_prototype , J)
94+ solver = _fast_lm_solver (alg, u )
95+ LM = FastLM. LMWorkspace (u, fu , J)
5396
5497 return FastLevenbergMarquardtJLCache (f!, J!, prob, alg, LM, solver,
5598 (; xtol = abstol, ftol = reltol, maxit = maxiters, alg. factor, alg. factoraccept,
0 commit comments