diff --git a/Project.toml b/Project.toml index 1651aab2a..419a0ee09 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.15.1" +version = "3.15.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl index 190c80645..bf0abbc77 100644 --- a/src/internal/forward_diff.jl +++ b/src/internal/forward_diff.jl @@ -2,16 +2,33 @@ import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_∂p, __nlsolve_∂f_∂u -function SciMLBase.solve( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +for (uType, pType) in [ + (Union{<:Number, <:AbstractArray}, Union{<:Dual, <:AbstractArray{<:Dual}}), + (Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Dual, <:AbstractArray{<:Dual}}), + (Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Number, <:AbstractArray}) +] + @eval begin + function SciMLBase.solve( + prob::NonlinearProblem{<:$(uType), iip, <:$(pType)}, + alg::Union{Nothing, AbstractNonlinearAlgorithm}, + args...; kwargs...) where {iip} + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) + end + + function SciMLBase.init( + prob::NonlinearProblem{<:$(uType), iip, <:$(pType)}, + alg::Union{Nothing, AbstractNonlinearAlgorithm}, + args...; kwargs...) where {iip} + p = __value(prob.p) + newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) + end + end end @concrete mutable struct NonlinearSolveForwardDiffCache @@ -35,19 +52,6 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache; return cache end -function SciMLBase.init( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - p = __value(prob.p) - newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) - cache = init(newprob, alg, args...; kwargs...) - return NonlinearSolveForwardDiffCache( - cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) -end - function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache) sol = solve!(cache.cache) prob = cache.prob