From 6febb777de2ef83b508886900fd2d70c6c99b297 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 6 Aug 2025 14:17:04 -0400 Subject: [PATCH 01/61] add SciMLStructures --- lib/NonlinearSolveBase/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index da80f2ec4..424a8dfa2 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -22,6 +22,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" @@ -71,6 +72,7 @@ RecursiveArrayTools = "3" SciMLBase = "2.116" SciMLJacobianOperators = "0.1.1" SciMLOperators = "1.7" +SciMLStructures = "1.5" SparseArrays = "1.10" SparseMatrixColorings = "0.4.5" StaticArraysCore = "1.4" From a83b26f9601a42f3354452b1d3adc781269c502d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 6 Aug 2025 18:48:46 -0400 Subject: [PATCH 02/61] add extension functions --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index e3efc9ef5..bddf9c674 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -15,15 +15,17 @@ using StaticArraysCore: StaticArray, SMatrix, SArray, MArray using CommonSolve: CommonSolve, init using EnzymeCore: EnzymeCore using MaybeInplace: @bb -using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition +using RecursiveArrayTools: RecursiveArrayTools, AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, AbstractNonlinearAlgorithm, NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction, NLStats, LinearProblem, - LinearAliasSpecifier, ImmutableNonlinearProblem + LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier +import SciMLBase: solve, init, solve!, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface +import SciMLStructures using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul! using Markdown: @doc_str @@ -32,6 +34,11 @@ using Printf: @printf const DI = DifferentiationInterface const SII = SymbolicIndexingInterface +# Extension Functions +eltypedual(x) = false +promote_u0(::Nothing, p, t0) = nothing +isdualtype(::Type{T}) where {T} = false + include("public.jl") include("utils.jl") From 5d65a3793562d50f1af8440f00d0556319ab7428 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 6 Aug 2025 18:49:15 -0400 Subject: [PATCH 03/61] make solve us solve_up, init_up, utilities --- lib/NonlinearSolveBase/src/solve.jl | 666 ++++++++++++++++++++++++++++ 1 file changed, 666 insertions(+) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 91b7a6aa6..d129442e7 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -1,3 +1,449 @@ +const allowedkeywords = (:dense, + :saveat, + :save_idxs, + :tstops, + :tspan, + :d_discontinuities, + :save_everystep, + :save_on, + :save_start, + :save_end, + :initialize_save, + :adaptive, + :abstol, + :reltol, + :dt, + :dtmax, + :dtmin, + :force_dtmin, + :internalnorm, + :controller, + :gamma, + :beta1, + :beta2, + :qmax, + :qmin, + :qsteady_min, + :qsteady_max, + :qoldinit, + :failfactor, + :calck, + :alias_u0, + :maxiters, + :maxtime, + :callback, + :isoutofdomain, + :unstable_check, + :verbose, + :merge_callbacks, + :progress, + :progress_steps, + :progress_name, + :progress_message, + :progress_id, + :timeseries_errors, + :dense_errors, + :weak_timeseries_errors, + :weak_dense_errors, + :wrap, + :calculate_error, + :initializealg, + :alg, + :save_noise, + :delta, + :seed, + :alg_hints, + :kwargshandle, + :trajectories, + :batch_size, + :sensealg, + :advance_to_tstop, + :stop_at_next_tstop, + :u0, + :p, + # These two are from the default algorithm handling + :default_set, + :second_time, + # This is for DiffEqDevTools + :prob_choice, + # Jump problems + :alias_jump, + # This is for copying/deepcopying noise in StochasticDiffEq + :alias_noise, + # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves + :batch, + # Shooting method in BVP needs to differentiate between these two categories + :nlsolve_kwargs, + :odesolve_kwargs, + # If Solvers which internally use linsolve + :linsolve_kwargs, + # Solvers internally using EnsembleProblem + :ensemblealg, + # Fine Grained Control of Tracing (Storing and Logging) during Solve + :show_trace, + :trace_level, + :store_trace, + # Termination condition for solvers + :termination_condition, + # For AbstractAliasSpecifier + :alias, + # Parameter estimation with BVP + :fit_parameters) + +const KWARGWARN_MESSAGE = """ +Unrecognized keyword arguments found. +The only allowed keyword arguments to `solve` are: +$allowedkeywords + +See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. + +Set kwargshandle=KeywordArgError for an error message. +Set kwargshandle=KeywordArgSilent to ignore this message. +""" + +const KWARGERROR_MESSAGE = """ + Unrecognized keyword arguments found. + The only allowed keyword arguments to `solve` are: + $allowedkeywords + + See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. + """ + +struct CommonKwargError <: Exception + kwargs::Any +end + +function Base.showerror(io::IO, e::CommonKwargError) + println(io, KWARGERROR_MESSAGE) + notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) + unrecognized = collect(keys(e.kwargs))[notin] + print(io, "Unrecognized keyword arguments: ") + printstyled(io, unrecognized; bold = true, color = :red) + print(io, "\n\n") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +@enum KeywordArgError KeywordArgWarn KeywordArgSilent + +const INCOMPATIBLE_U0_MESSAGE = """ + Initial condition incompatible with functional form. + Detected an in-place function with an initial condition of type Number or SArray. + This is incompatible because Numbers cannot be mutated, i.e. + `x = 2.0; y = 2.0; x .= y` will error. + + If using a immutable initial condition type, please use the out-of-place form. + I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. + + If your differential equation function was defined with multiple dispatches and one is + in-place, then the automatic detection will choose in-place. In this case, override the + choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. + + For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: + https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation + """ + +struct IncompatibleInitialConditionError <: Exception end + +function Base.showerror(io::IO, e::IncompatibleInitialConditionError) + print(io, INCOMPATIBLE_U0_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NO_DEFAULT_ALGORITHM_MESSAGE = """ + Default algorithm choices require NonlinearSolve.jl. + Please specify an algorithm (e.g., `solve(prob, NewtonRaphson())` or + init(prob, NewtonRaphson()) or + import NonlinearSolve.jl directly. + + You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ + and its associated pages. + """ + +struct NoDefaultAlgorithmError <: Exception end + +function Base.showerror(io::IO, e::NoDefaultAlgorithmError) + print(io, NO_DEFAULT_ALGORITHM_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NON_SOLVER_MESSAGE = """ + The arguments to solve are incorrect. + The second argument must be a solver choice, `solve(prob,alg)` + where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. + + Please double check the arguments being sent to the solver. + + You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ + and its associated pages. + """ + +struct NonSolverError <: Exception end + +function Base.showerror(io::IO, e::NonSolverError) + print(io, NON_SOLVER_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ + Incompatible solver + automatic differentiation pairing. + The chosen automatic differentiation algorithm requires the ability + for compiler transforms on the code which is only possible on pure-Julia + solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods + which require this ability include: + + - Direct use of ForwardDiff.jl on the solver + - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` + sensealg choices for adjoint differentiation. + + Either switch the choice of solver to a pure Julia method, or change the automatic + differentiation method to one that does not require such transformations. + + For more details on automatic differentiation, adjoint, and sensitivity analysis + of differential equations, see the documentation page: + + https://diffeq.sciml.ai/stable/analysis/sensitivity/ + """ + +struct DirectAutodiffError <: Exception end + +function Base.showerror(io::IO, e::DirectAutodiffError) + println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +struct EvalFunc{F} <: Function + f::F +end +(f::EvalFunc)(args...) = f.f(args...) + +""" +```julia +solve(prob::NonlinearProblem, alg::Union{AbstractNonlinearAlgorithm,Nothing}; kwargs...) +``` + +## Arguments + +The only positional argument is `alg` which is optional. By default, `alg = nothing`. +If `alg = nothing`, then `solve` dispatches to the NonlinearSolve.jl automated +algorithm selection (if `using NonlinearSolve` was done, otherwise it will +error with a `MethodError`). + +## Keyword Arguments + +The NonlinearSolve.jl universe has a large set of common arguments available +for the `solve` function. These arguments apply to `solve` on any problem type and +are only limited by limitations of the specific implementations. + +Many of the defaults depend on the algorithm or the package the algorithm derives +from. Not all of the interface is provided by every algorithm. +For more detailed information on the defaults and the available options +for specific algorithms / packages, see the manual pages for the solvers of specific +problems. + +#### Error Control + +* `abstol`: Absolute tolerance. +* `reltol`: Relative tolerance. + +### Miscellaneous + +* `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5. +* `verbose`: Toggles whether warnings are thrown when the solver exits early. + Defaults to true. + +### Sensitivity Algorithms (`sensealg`) + +`sensealg` is used for choosing the way the automatic differentiation is performed. + For more information, see the documentation for SciMLSensitivity: + https://docs.sciml.ai/SciMLSensitivity/stable/ +""" +function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing, + u0 = nothing, p = nothing, wrap = Val(true), kwargs...) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + + if haskey(prob.kwargs, :alias_u0) + @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." + alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0]) + elseif haskey(kwargs, :alias_u0) + @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." + alias_spec = NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0]) + end + + if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool + alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias]) + elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool + alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias]) + end + + if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier + alias_spec = prob.kwargs[:alias] + elseif haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier + alias_spec = kwargs[:alias] + else + alias_spec = NonlinearAliasSpecifier(alias_u0 = false) + end + + alias_u0 = alias_spec.alias_u0 + + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + + if wrap isa Val{true} + wrap_sol(solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = SciMLBase.ChainRulesOriginator(), + kwargs...)) + else + solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + end +end + +function solve_up(prob::AbstractNonlinearProblem, sensealg, u0, p, + args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + if isnothing(alg) || !(alg isa AbstractNonlinearSolveAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, + p = p, kwargs...) + solve_call(_prob, args...; kwargs...) + else + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + #check_prob_alg_pairing(_prob, alg) # use alg for improved inference + if length(args) > 1 + solve_call(_prob, alg, Base.tail(args)...; kwargs...) + else + solve_call(_prob, alg; kwargs...) + end + end +end + +function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, + kwargs...) + kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? + _prob.kwargs[:kwargshandle] : kwargshandle + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + checkkwargs(kwargshandle; kwargs...) + if isdefined(_prob, :u0) + if _prob.u0 isa Array + if !isconcretetype(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0)) + throw(NonConcreteEltypeError(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0))) + end + + if !(eltype(_prob.u0) <: Number) && !(eltype(_prob.u0) <: Enum) && + !(_prob.u0 isa AbstractVector{<:AbstractArray} && _prob isa BVProblem) + # Allow Enums for FunctionMaps, make into a trait in the future + # BVPs use Vector of Arrays for initial guesses + throw(NonNumberEltypeError(eltype(_prob.u0))) + end + end + + if _prob.u0 === nothing + return build_null_solution(_prob, args...; kwargs...) + end + end + + if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && + _prob.f.f isa EvalFunc + Base.invokelatest(__solve, _prob, args...; kwargs...)#::T + else + __solve(_prob, args...; kwargs...)#::T + end +end + +function init( + prob::AbstractNonlinearProblem, args...; sensealg = nothing, + u0 = nothing, p = nothing, kwargs...) + if sensealg === nothing && has_kwargs(prob) && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + + init_up(prob, sensealg, u0, p, args...; kwargs...) +end + +function init_up(prob::AbstractNonlinearProblem, sensealg, u0, p, args...; kwargs...) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + p = p, kwargs...) + init_call(_prob, args...; kwargs...) + else + tstops = get(kwargs, :tstops, nothing) + if tstops === nothing && has_kwargs(prob) + tstops = get(prob.kwargs, :tstops, nothing) + end + if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && + !SciMLBase.allows_late_binding_tstops(alg) + throw(LateBindingTstopsNotSupportedError()) + end + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + #check_prob_alg_pairing(_prob, alg) # alg for improved inference + if length(args) > 1 + init_call(_prob, alg, Base.tail(args)...; kwargs...) + else + init_call(_prob, alg; kwargs...) + end + end +end + +function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, + kwargs...) + kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? + _prob.kwargs[:kwargshandle] : kwargshandle + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + checkkwargs(kwargshandle; kwargs...) + + if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && + _prob.f.f isa EvalFunc + Base.invokelatest(__init, _prob, args...; kwargs...)#::T + else + __init(_prob, args...; kwargs...)#::T + end +end + function SciMLBase.__solve( prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs... @@ -127,6 +573,18 @@ function SciMLBase.__solve( __generated_polysolve(prob, alg, args...; kwargs...) end +function SciMLBase.__solve( + prob::AbstractNonlinearProblem, args...; default_set = false, second_time = false, + kwargs...) + if second_time + throw(NoDefaultAlgorithmError()) + elseif length(args) > 0 && !(first(args) isa AbstractNonlinearAlgorithm) + throw(NonSolverError()) + else + __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) + end +end + @generated function __generated_polysolve( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, @@ -297,6 +755,10 @@ SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) +has_kwargs(_prob::AbstractNonlinearProblem) = has_kwargs(typeof(_prob)) +Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) +has_kwargs(::Type{T}) where {T} = __has_kwargs(T) + function SciMLBase.reinit!( cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs... ) @@ -328,3 +790,207 @@ function CommonSolve.solve!(cache::NonlinearSolveNoInitCache) end return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) end + +function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +""" +Given the index provider `indp` used to construct the problem `prob` being solved, return +an updated `prob` to be used for solving. All implementations should accept arbitrary +keyword arguments. + +Should be called before the problem is solved, after performing type-promotion on the +problem. If the returned problem is not `===` the provided `prob`, it is assumed to +contain the `u0` and `p` passed as keyword arguments. + +# Keyword Arguments + +- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which + should be used instead of the ones in `prob`. +""" +function get_updated_symbolic_problem(indp, prob; kw...) + return prob +end + +function build_null_solution( + prob::NonlinearProblem, + args...; + saveat = (), + save_everystep = true, + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = true, + kwargs...) + prob, success = hack_null_solution_init(prob) + retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure + SciMLBase.build_solution(prob, nothing, Float64[], nothing; retcode) +end + +function build_null_solution( + prob::NonlinearLeastSquaresProblem, + args...; abstol = 1e-6, kwargs...) + prob, success = hack_null_solution_init(prob) + retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure + + if isinplace(prob) + resid = isnothing(prob.f.resid_prototype) ? Float64[] : copy(prob.f.resid_prototype) + prob.f(resid, prob.u0, prob.p) + else + resid = prob.f(prob.f.resid_prototype, prob.p) + end + + if success + retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure + end + + SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) +end + +@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) + if isempty(solve_args) || isnothing(first(solve_args)) + if haskey(solve_kwargs, :alg) + solve_kwargs[:alg] + elseif haskey(prob_kwargs, :alg) + prob_kwargs[:alg] + else + nothing + end + elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && + !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) + first(solve_args) + else + nothing + end +end + +function get_concrete_u0(prob, isadapt, t0, kwargs) + if eval_u0(prob.u0) + u0 = prob.u0(prob.p, t0) + elseif haskey(kwargs, :u0) + u0 = kwargs[:u0] + else + u0 = prob.u0 + end + + isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) + + _u0 = handle_distribution_u0(u0) + + if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) + throw(IncompatibleInitialConditionError()) + end + + if _u0 isa Tuple + throw(TupleStateError()) + end + + _u0 +end + +function get_concrete_p(prob, kwargs) + if haskey(kwargs, :p) + p = kwargs[:p] + else + p = prob.p + end +end + +eval_u0(u0::Function) = true +eval_u0(u0) = false + +handle_distribution_u0(_u0) = _u0 + +anyeltypedual(x) = anyeltypedual(x, Val{0}) +anyeltypedual(x, counter) = Any + +function promote_u0(u0, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if !isequal(_p, p) + return promote_u0(u0, _p, t0) + end + end + Tu = eltype(u0) + if isdualtype(Tu) + return u0 + end + Tp = anyeltypedual(p, Val{0}) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0, Val{0}) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(Tu, Tp, Tt) + return if isdualtype(Tcommon) + Tcommon.(u0) + else + u0 + end +end + +function promote_u0(u0::AbstractArray{<:Complex}, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if !isequal(_p, p) + return promote_u0(u0, _p, t0) + end + end + Tu = real(eltype(u0)) + if isdualtype(Tu) + return u0 + end + Tp = anyeltypedual(p, Val{0}) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0, Val{0}) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(eltype(u0), Tp, Tt) + return if isdualtype(real(Tcommon)) + Tcommon.(u0) + else + u0 + end +end + +function checkkwargs(kwargshandle; kwargs...) + if any(x -> x ∉ allowedkeywords, keys(kwargs)) + if kwargshandle == KeywordArgError + throw(CommonKwargError(kwargs)) + elseif kwargshandle == KeywordArgWarn + @warn KWARGWARN_MESSAGE + unrecognized = setdiff(keys(kwargs), allowedkeywords) + print("Unrecognized keyword arguments: ") + printstyled(unrecognized; bold = true, color = :red) + print("\n\n") + else + @assert kwargshandle == KeywordArgSilent + end + end +end \ No newline at end of file From cc8565200fab69da181f3874ac9393d72c8c662c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 6 Aug 2025 18:49:45 -0400 Subject: [PATCH 04/61] get rid of DiffEqBase --- lib/NonlinearSolveBase/src/utils.jl | 2 ++ lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl | 2 +- lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl | 2 +- .../src/NonlinearSolveSpectralMethods.jl | 2 +- .../ext/SimpleNonlinearSolveDiffEqBaseExt.jl | 2 +- src/NonlinearSolve.jl | 2 +- 6 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/NonlinearSolveBase/src/utils.jl b/lib/NonlinearSolveBase/src/utils.jl index 05bc71158..18d78c451 100644 --- a/lib/NonlinearSolveBase/src/utils.jl +++ b/lib/NonlinearSolveBase/src/utils.jl @@ -320,4 +320,6 @@ function clean_sprint_struct(x, indent::Int) return "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))" end +set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x + end diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 57a1f0105..79bc2faac 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -12,7 +12,7 @@ using LineSearch: BackTracking using StaticArraysCore: SArray using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl index 167f1fa85..fd55ca034 100644 --- a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl +++ b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl @@ -8,7 +8,7 @@ using ArrayInterface: ArrayInterface using StaticArraysCore: StaticArray, Size, MArray using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra, Diagonal, dot, diag using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb diff --git a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl index 93a620761..c0a6bf2e9 100644 --- a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl +++ b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl @@ -5,7 +5,7 @@ using Reexport: @reexport using PrecompileTools: @compile_workload, @setup_workload using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LineSearch: RobustNonMonotoneLineSearch using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl index 4954ffb26..5326b0a88 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl @@ -1,6 +1,6 @@ module SimpleNonlinearSolveDiffEqBaseExt -using DiffEqBase: DiffEqBase +#using DiffEqBase: DiffEqBase using SimpleNonlinearSolve: SimpleNonlinearSolve diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 8eddf0712..a4eb06042 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,7 +8,7 @@ using FastClosures: @closure using ADTypes: ADTypes using ArrayInterface: ArrayInterface using CommonSolve: CommonSolve, init, solve, solve! -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches +#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, From a79e12a8e8950f8cbbe1e5b5a5e05251264a90a7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 7 Aug 2025 13:41:27 -0400 Subject: [PATCH 05/61] add the ChainRulesCore extension --- lib/NonlinearSolveBase/Project.toml | 3 ++ .../NonlinearSolveBaseChainRulesCoreExt.jl | 33 +++++++++++++++++++ .../ext/NonlinearSolveBaseForwardDiffExt.jl | 9 +++++ 3 files changed, 45 insertions(+) create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 424a8dfa2..663ff2e62 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -29,6 +29,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" @@ -44,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" +NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore" [compat] ADTypes = "1.9" @@ -51,6 +53,7 @@ Adapt = "4.1.0" Aqua = "0.8.7" ArrayInterface = "7.9" BandedMatrices = "1.5" +ChainRulesCore = "1" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl new file mode 100644 index 000000000..b15f139be --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl @@ -0,0 +1,33 @@ +module NonlinearSolveBaseChainRulesCoreExt + +using NonlinearSolveBase +using NonlinearSolveBase: AbstractNonlinearProblem +using SciMLBase +using SciMLBase: AbstractSensitivityAlgorithm + +import ChainRulesCore +import ChainRulesCore: NoTangent + +ChainRulesCore.@non_differentiable NonlinearSolveBase.checkkwargs(kwargshandle) + +function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + NonlinearSolveBase._solve_forward( + prob, sensealg, u0, p, + originator, args...; + kwargs...) +end + +function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + NonlinearSolveBase._solve_adjoint( + prob, sensealg, u0, p, + originator, args...; + kwargs...) +end + +end \ No newline at end of file diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 7cfc792e8..73d0d9075 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -195,4 +195,13 @@ NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.valu @inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) @inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) +eltypedual(x) = eltype(x) <: ForwardDiff.Dual +isdualtype(::Type{<:ForwardDiff.Dual}) = true + +function anyeltypedual( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual((prob.u0, prob.p)) +end + end From 26fb469c441409e49f5ee00ccc96d7807d1f1564 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 8 Aug 2025 10:57:11 -0400 Subject: [PATCH 06/61] add solve_adjoint and solve_forward --- lib/NonlinearSolveBase/src/solve.jl | 63 +++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index d129442e7..c2c5b4774 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -791,6 +791,69 @@ function CommonSolve.solve!(cache::NonlinearSolveNoInitCache) return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) end +function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, + kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + end + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + if length(args) > 1 + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator, + Base.tail(args)...; kwargs...) + else + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...) + end +end + +function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, + kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + end + + if has_kwargs(_prob) + if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + kwargs_temp = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], + values(kwargs).callback),)) + kwargs = merge(kwargs_temp, callbacks) + end + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + if length(args) > 1 + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator, + Base.tail(args)...; kwargs...) + else + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) + end +end + + function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) oldprob = prob prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) From 7b527cd0fa3ecff896e568e9e3e50a8150bbe771 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 22 Aug 2025 14:57:58 -0400 Subject: [PATCH 07/61] add more imports --- .../src/NonlinearSolveBase.jl | 4 +- lib/NonlinearSolveBase/src/solve.jl | 306 +++++++++--------- 2 files changed, 156 insertions(+), 154 deletions(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index bddf9c674..e5abc0814 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -20,7 +20,9 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear AbstractNonlinearAlgorithm, NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction, NLStats, LinearProblem, - LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier + LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, + promote_u0, anyeltypedual, eval_u0, get_concrete_u0, get_concrete_p, + has_kwargs, extract_alg, get_concrete_problem import SciMLBase: solve, init, solve!, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index c2c5b4774..96dc1eb0f 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -393,7 +393,7 @@ end function init_up(prob::AbstractNonlinearProblem, sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) init_call(_prob, args...; kwargs...) else @@ -755,9 +755,9 @@ SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) -has_kwargs(_prob::AbstractNonlinearProblem) = has_kwargs(typeof(_prob)) -Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) -has_kwargs(::Type{T}) where {T} = __has_kwargs(T) +# has_kwargs(_prob::AbstractNonlinearProblem) = has_kwargs(typeof(_prob)) +# Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) +# has_kwargs(::Type{T}) where {T} = __has_kwargs(T) function SciMLBase.reinit!( cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs... @@ -854,29 +854,29 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end -function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) - oldprob = prob - prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) - if prob !== oldprob - kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) - end - p = get_concrete_p(prob, kwargs) - u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) -end - -function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) - oldprob = prob - prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) - if prob !== oldprob - kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) - end - p = get_concrete_p(prob, kwargs) - u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = promote_u0(u0, p, nothing) - remake(prob; u0 = u0, p = p) -end +# function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) +# oldprob = prob +# prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) +# if prob !== oldprob +# kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) +# end +# p = get_concrete_p(prob, kwargs) +# u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) +# u0 = promote_u0(u0, p, nothing) +# remake(prob; u0 = u0, p = p) +# end + +# function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) +# oldprob = prob +# prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) +# if prob !== oldprob +# kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) +# end +# p = get_concrete_p(prob, kwargs) +# u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) +# u0 = promote_u0(u0, p, nothing) +# remake(prob; u0 = u0, p = p) +# end """ Given the index provider `indp` used to construct the problem `prob` being solved, return @@ -931,129 +931,129 @@ function build_null_solution( SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) end -@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) - if isempty(solve_args) || isnothing(first(solve_args)) - if haskey(solve_kwargs, :alg) - solve_kwargs[:alg] - elseif haskey(prob_kwargs, :alg) - prob_kwargs[:alg] - else - nothing - end - elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && - !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) - first(solve_args) - else - nothing - end -end - -function get_concrete_u0(prob, isadapt, t0, kwargs) - if eval_u0(prob.u0) - u0 = prob.u0(prob.p, t0) - elseif haskey(kwargs, :u0) - u0 = kwargs[:u0] - else - u0 = prob.u0 - end - - isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - - _u0 = handle_distribution_u0(u0) - - if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) - throw(IncompatibleInitialConditionError()) - end - - if _u0 isa Tuple - throw(TupleStateError()) - end - - _u0 -end - -function get_concrete_p(prob, kwargs) - if haskey(kwargs, :p) - p = kwargs[:p] - else - p = prob.p - end -end - -eval_u0(u0::Function) = true -eval_u0(u0) = false - -handle_distribution_u0(_u0) = _u0 - -anyeltypedual(x) = anyeltypedual(x, Val{0}) -anyeltypedual(x, counter) = Any - -function promote_u0(u0, p, t0) - if SciMLStructures.isscimlstructure(p) - _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] - if !isequal(_p, p) - return promote_u0(u0, _p, t0) - end - end - Tu = eltype(u0) - if isdualtype(Tu) - return u0 - end - Tp = anyeltypedual(p, Val{0}) - if Tp == Any - Tp = Tu - end - Tt = anyeltypedual(t0, Val{0}) - if Tt == Any - Tt = Tu - end - Tcommon = promote_type(Tu, Tp, Tt) - return if isdualtype(Tcommon) - Tcommon.(u0) - else - u0 - end -end - -function promote_u0(u0::AbstractArray{<:Complex}, p, t0) - if SciMLStructures.isscimlstructure(p) - _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] - if !isequal(_p, p) - return promote_u0(u0, _p, t0) - end - end - Tu = real(eltype(u0)) - if isdualtype(Tu) - return u0 - end - Tp = anyeltypedual(p, Val{0}) - if Tp == Any - Tp = Tu - end - Tt = anyeltypedual(t0, Val{0}) - if Tt == Any - Tt = Tu - end - Tcommon = promote_type(eltype(u0), Tp, Tt) - return if isdualtype(real(Tcommon)) - Tcommon.(u0) - else - u0 - end -end - -function checkkwargs(kwargshandle; kwargs...) - if any(x -> x ∉ allowedkeywords, keys(kwargs)) - if kwargshandle == KeywordArgError - throw(CommonKwargError(kwargs)) - elseif kwargshandle == KeywordArgWarn - @warn KWARGWARN_MESSAGE - unrecognized = setdiff(keys(kwargs), allowedkeywords) - print("Unrecognized keyword arguments: ") - printstyled(unrecognized; bold = true, color = :red) - print("\n\n") - else - @assert kwargshandle == KeywordArgSilent - end - end -end \ No newline at end of file +# @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) +# if isempty(solve_args) || isnothing(first(solve_args)) +# if haskey(solve_kwargs, :alg) +# solve_kwargs[:alg] +# elseif haskey(prob_kwargs, :alg) +# prob_kwargs[:alg] +# else +# nothing +# end +# elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && +# !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) +# first(solve_args) +# else +# nothing +# end +# end + +# function get_concrete_u0(prob, isadapt, t0, kwargs) +# if eval_u0(prob.u0) +# u0 = prob.u0(prob.p, t0) +# elseif haskey(kwargs, :u0) +# u0 = kwargs[:u0] +# else +# u0 = prob.u0 +# end + +# isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) + +# _u0 = handle_distribution_u0(u0) + +# if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) +# throw(IncompatibleInitialConditionError()) +# end + +# if _u0 isa Tuple +# throw(TupleStateError()) +# end + +# _u0 +# end + +# function get_concrete_p(prob, kwargs) +# if haskey(kwargs, :p) +# p = kwargs[:p] +# else +# p = prob.p +# end +# end + +# eval_u0(u0::Function) = true +# eval_u0(u0) = false + +# handle_distribution_u0(_u0) = _u0 + +# anyeltypedual(x) = anyeltypedual(x, Val{0}) +# anyeltypedual(x, counter) = Any + +# function promote_u0(u0, p, t0) +# if SciMLStructures.isscimlstructure(p) +# _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] +# if !isequal(_p, p) +# return promote_u0(u0, _p, t0) +# end +# end +# Tu = eltype(u0) +# if isdualtype(Tu) +# return u0 +# end +# Tp = anyeltypedual(p, Val{0}) +# if Tp == Any +# Tp = Tu +# end +# Tt = anyeltypedual(t0, Val{0}) +# if Tt == Any +# Tt = Tu +# end +# Tcommon = promote_type(Tu, Tp, Tt) +# return if isdualtype(Tcommon) +# Tcommon.(u0) +# else +# u0 +# end +# end + +# function promote_u0(u0::AbstractArray{<:Complex}, p, t0) +# if SciMLStructures.isscimlstructure(p) +# _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] +# if !isequal(_p, p) +# return promote_u0(u0, _p, t0) +# end +# end +# Tu = real(eltype(u0)) +# if isdualtype(Tu) +# return u0 +# end +# Tp = anyeltypedual(p, Val{0}) +# if Tp == Any +# Tp = Tu +# end +# Tt = anyeltypedual(t0, Val{0}) +# if Tt == Any +# Tt = Tu +# end +# Tcommon = promote_type(eltype(u0), Tp, Tt) +# return if isdualtype(real(Tcommon)) +# Tcommon.(u0) +# else +# u0 +# end +# end + +# function checkkwargs(kwargshandle; kwargs...) +# if any(x -> x ∉ allowedkeywords, keys(kwargs)) +# if kwargshandle == KeywordArgError +# throw(CommonKwargError(kwargs)) +# elseif kwargshandle == KeywordArgWarn +# @warn KWARGWARN_MESSAGE +# unrecognized = setdiff(keys(kwargs), allowedkeywords) +# print("Unrecognized keyword arguments: ") +# printstyled(unrecognized; bold = true, color = :red) +# print("\n\n") +# else +# @assert kwargshandle == KeywordArgSilent +# end +# end +# end \ No newline at end of file From 1e412bb98624a1b12b857e88d9d02fcbe2846440 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 22 Aug 2025 15:38:13 -0400 Subject: [PATCH 08/61] empty diffeqbase ext --- .../ext/NonlinearSolveBaseDiffEqBaseExt.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl index a1d6c44ce..c5dbb9aec 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl @@ -1,16 +1,3 @@ module NonlinearSolveBaseDiffEqBaseExt -using DiffEqBase: DiffEqBase -using SciMLBase: SciMLBase, remake - -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem - -function DiffEqBase.get_concrete_problem( - prob::ImmutableNonlinearProblem, isadapt; kwargs...) - u0 = SciMLBase.get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = SciMLBase.promote_u0(u0, prob.p, nothing) - p = SciMLBase.get_concrete_p(prob, kwargs) - return remake(prob; u0 = u0, p = p) -end - end From 56ff49780231bb189e1f42c340520250cbef9154 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 22 Aug 2025 15:38:40 -0400 Subject: [PATCH 09/61] add get_concrete_problem --- lib/NonlinearSolveBase/src/solve.jl | 52 +++++++++++++++++------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 96dc1eb0f..bdd41ae47 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -854,29 +854,37 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end -# function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) -# oldprob = prob -# prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) -# if prob !== oldprob -# kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) -# end -# p = get_concrete_p(prob, kwargs) -# u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) -# u0 = promote_u0(u0, p, nothing) -# remake(prob; u0 = u0, p = p) -# end +function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end -# function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) -# oldprob = prob -# prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) -# if prob !== oldprob -# kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) -# end -# p = get_concrete_p(prob, kwargs) -# u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) -# u0 = promote_u0(u0, p, nothing) -# remake(prob; u0 = u0, p = p) -# end +function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +function get_concrete_problem( + prob::ImmutableNonlinearProblem, isadapt; kwargs...) + u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = DiffEqBase.promote_u0(u0, prob.p, nothing) + p = DiffEqBase.get_concrete_p(prob, kwargs) + return remake(prob; u0 = u0, p = p) +end """ Given the index provider `indp` used to construct the problem `prob` being solved, return From 57e3072767cc9534b441dcb6eaf5a8af8f95846f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 22 Aug 2025 15:38:55 -0400 Subject: [PATCH 10/61] get rid of unused imports --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index e5abc0814..ddd50532d 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -21,8 +21,8 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction, NLStats, LinearProblem, LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, - promote_u0, anyeltypedual, eval_u0, get_concrete_u0, get_concrete_p, - has_kwargs, extract_alg, get_concrete_problem + promote_u0, get_concrete_u0, get_concrete_p, + has_kwargs, extract_alg, promote_u0 import SciMLBase: solve, init, solve!, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator @@ -36,11 +36,6 @@ using Printf: @printf const DI = DifferentiationInterface const SII = SymbolicIndexingInterface -# Extension Functions -eltypedual(x) = false -promote_u0(::Nothing, p, t0) = nothing -isdualtype(::Type{T}) where {T} = false - include("public.jl") include("utils.jl") From 5c2c3b6389daab0cb341af6d31fa045538affbfc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 22 Aug 2025 15:39:18 -0400 Subject: [PATCH 11/61] get rid of unused things in ForwardDiffExt --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 73d0d9075..7cfc792e8 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -195,13 +195,4 @@ NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.valu @inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) @inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) -eltypedual(x) = eltype(x) <: ForwardDiff.Dual -isdualtype(::Type{<:ForwardDiff.Dual}) = true - -function anyeltypedual( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, - ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual((prob.u0, prob.p)) -end - end From 3a57aebe03f4fca51dc1b7e7a565aee4e1bdb122 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Sun, 24 Aug 2025 23:27:41 -0400 Subject: [PATCH 12/61] add solve_call for SteadyStateProblems --- lib/NonlinearSolveBase/src/solve.jl | 49 +++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index bdd41ae47..0e2086085 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -377,6 +377,14 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi end end +function solve_call(prob::SteadyStateProblem, + alg::AbstractNonlinearAlgorithm, args...; + kwargs...) + solve_call(NonlinearProblem(prob), + alg, args...; + kwargs...) +end + function init( prob::AbstractNonlinearProblem, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) @@ -420,7 +428,6 @@ function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle - if has_kwargs(_prob) if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) kwargs_temp = NamedTuple{ @@ -435,7 +442,6 @@ function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, end checkkwargs(kwargshandle; kwargs...) - if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && _prob.f.f isa EvalFunc Base.invokelatest(__init, _prob, args...; kwargs...)#::T @@ -585,6 +591,18 @@ function SciMLBase.__solve( end end +function __init(prob::AbstractNonlinearProblem, args...; default_set = false, second_time = false, + kwargs...) + if second_time + throw(NoDefaultAlgorithmError()) + elseif length(args) > 0 && !(first(args) isa + Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) + throw(NonSolverError()) + else + __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) + end +end + @generated function __generated_polysolve( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, @@ -886,6 +904,19 @@ function get_concrete_problem( return remake(prob; u0 = u0, p = p) end +function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, Inf, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + + """ Given the index provider `indp` used to construct the problem `prob` being solved, return an updated `prob` to be used for solving. All implementations should accept arbitrary @@ -939,6 +970,20 @@ function build_null_solution( SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) end + +function solve(prob::EnsembleProblem, args...; kwargs...) + alg = extract_alg(args, kwargs, kwargs) + if length(args) > 1 + __solve(prob, alg, Base.tail(args)...; kwargs...) + else + __solve(prob, alg; kwargs...) + end +end + +function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) + SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) +end + # @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) # if isempty(solve_args) || isnothing(first(solve_args)) # if haskey(solve_kwargs, :alg) From 144fe20c0a683fd0b566730dca56e8006ea0fc16 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Sun, 24 Aug 2025 23:28:02 -0400 Subject: [PATCH 13/61] no need for rule for checkkwargs now --- .../ext/NonlinearSolveBaseChainRulesCoreExt.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl index b15f139be..d60be6211 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl @@ -8,8 +8,6 @@ using SciMLBase: AbstractSensitivityAlgorithm import ChainRulesCore import ChainRulesCore: NoTangent -ChainRulesCore.@non_differentiable NonlinearSolveBase.checkkwargs(kwargshandle) - function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob, sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), From 44dea4d809f15b233e67c3dd2b5ad9db1967e973 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Sun, 24 Aug 2025 23:28:24 -0400 Subject: [PATCH 14/61] more imports from SciMLBase --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index ddd50532d..c347bdcb7 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -22,7 +22,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear NonlinearFunction, NLStats, LinearProblem, LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, - has_kwargs, extract_alg, promote_u0 + has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, EnsembleProblem import SciMLBase: solve, init, solve!, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator From 92f3607d5eca30ad9352986408d1067b6728087a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Sun, 24 Aug 2025 23:30:45 -0400 Subject: [PATCH 15/61] remove stale import --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index c347bdcb7..cc0baf4d4 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -23,7 +23,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, EnsembleProblem -import SciMLBase: solve, init, solve!, __init, __solve, wrap_sol, get_root_indp, isinplace, remake +import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface From 8ea632359762aa68ef828ac3baf2a8f2691477f6 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:37:49 -0400 Subject: [PATCH 16/61] no ensembleproblem --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index cc0baf4d4..7f4d21044 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -22,7 +22,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear NonlinearFunction, NLStats, LinearProblem, LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, - has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, EnsembleProblem + has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator From 431a4728f647089d02e11f4187683980f42e59c5 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:39:11 -0400 Subject: [PATCH 17/61] don't need to merge callback kwargs --- lib/NonlinearSolveBase/src/solve.jl | 104 ++++++++++++++-------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 0e2086085..8d420e7d7 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -337,15 +337,15 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi _prob.kwargs[:kwargshandle] : kwargshandle if has_kwargs(_prob) - if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - kwargs_temp = NamedTuple{ - Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - _prob.kwargs[:callback], - values(kwargs).callback),)) - kwargs = merge(kwargs_temp, callbacks) - end + # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + # kwargs_temp = NamedTuple{ + # Base.diff_names(Base._nt_names(values(kwargs)), + # (:callback,))}(values(kwargs)) + # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + # _prob.kwargs[:callback], + # values(kwargs).callback),)) + # kwargs = merge(kwargs_temp, callbacks) + # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -429,15 +429,15 @@ function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle if has_kwargs(_prob) - if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - kwargs_temp = NamedTuple{ - Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - _prob.kwargs[:callback], - values(kwargs).callback),)) - kwargs = merge(kwargs_temp, callbacks) - end + # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + # kwargs_temp = NamedTuple{ + # Base.diff_names(Base._nt_names(values(kwargs)), + # (:callback,))}(values(kwargs)) + # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + # _prob.kwargs[:callback], + # values(kwargs).callback),)) + # kwargs = merge(kwargs_temp, callbacks) + # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -820,15 +820,15 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba end if has_kwargs(_prob) - if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - kwargs_temp = NamedTuple{ - Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - _prob.kwargs[:callback], - values(kwargs).callback),)) - kwargs = merge(kwargs_temp, callbacks) - end + # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + # kwargs_temp = NamedTuple{ + # Base.diff_names(Base._nt_names(values(kwargs)), + # (:callback,))}(values(kwargs)) + # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + # _prob.kwargs[:callback], + # values(kwargs).callback),)) + # kwargs = merge(kwargs_temp, callbacks) + # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -851,15 +851,15 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end if has_kwargs(_prob) - if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - kwargs_temp = NamedTuple{ - Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - _prob.kwargs[:callback], - values(kwargs).callback),)) - kwargs = merge(kwargs_temp, callbacks) - end + # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) + # kwargs_temp = NamedTuple{ + # Base.diff_names(Base._nt_names(values(kwargs)), + # (:callback,))}(values(kwargs)) + # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + # _prob.kwargs[:callback], + # values(kwargs).callback),)) + # kwargs = merge(kwargs_temp, callbacks) + # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -898,9 +898,9 @@ end function get_concrete_problem( prob::ImmutableNonlinearProblem, isadapt; kwargs...) - u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = DiffEqBase.promote_u0(u0, prob.p, nothing) - p = DiffEqBase.get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, prob.p, nothing) + p = get_concrete_p(prob, kwargs) return remake(prob; u0 = u0, p = p) end @@ -936,7 +936,7 @@ function get_updated_symbolic_problem(indp, prob; kw...) end function build_null_solution( - prob::NonlinearProblem, + prob::Union{NonlinearProblem, SteadyStateProblem}, args...; saveat = (), save_everystep = true, @@ -970,19 +970,19 @@ function build_null_solution( SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) end +#TODO: THIS SHOULD GO IN SCIMLBASE. THIS IS TEMPORARY FOR TESTING PURPOSES. REMOVE +# function solve(prob::EnsembleProblem, args...; kwargs...) +# alg = extract_alg(args, kwargs, kwargs) +# if length(args) > 1 +# __solve(prob, alg, Base.tail(args)...; kwargs...) +# else +# __solve(prob, alg; kwargs...) +# end +# end -function solve(prob::EnsembleProblem, args...; kwargs...) - alg = extract_alg(args, kwargs, kwargs) - if length(args) > 1 - __solve(prob, alg, Base.tail(args)...; kwargs...) - else - __solve(prob, alg; kwargs...) - end -end - -function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) - SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) -end +# function solve(prob::WeightedEnsembleProblem, args...; kwargs...) +# SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) +# end # @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) # if isempty(solve_args) || isnothing(first(solve_args)) From 398b93df5e078d65ddcabc8a819edcbfefb0eaf7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:40:36 -0400 Subject: [PATCH 18/61] treat NonlinearProblems as own for piracy in aqua tests --- lib/NonlinearSolveBase/test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index 95ae283cc..d9f702347 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -7,12 +7,14 @@ using InteractiveUtils, Test @testset "NonlinearSolveBase.jl" begin @testset "Aqua" begin using Aqua, NonlinearSolveBase + using NonlinearSolveBase: AbstractNonlinearProblem, NonlinearProblem Aqua.test_all( NonlinearSolveBase; piracies = false, ambiguities = false, stale_deps = false ) Aqua.test_stale_deps(NonlinearSolveBase; ignore = [:TimerOutputs]) - Aqua.test_piracies(NonlinearSolveBase) + #ENSEMBLE PROBLEM SHOULD BE REMOVED, THIS IS TEMPORARY FOR TESTS + Aqua.test_piracies(NonlinearSolveBase, treat_as_own = [AbstractNonlinearProblem, NonlinearProblem]) Aqua.test_ambiguities(NonlinearSolveBase; recursive = false) end From 2492ab32aa7340db55e1bc1e8d25589e2b4f4bbf Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:41:51 -0400 Subject: [PATCH 19/61] rm DiffEqBase, update SparseConnectivityTracer compat --- lib/NonlinearSolveFirstOrder/Project.toml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index 37e5bd1ba..9f134628e 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -1,5 +1,6 @@ name = "NonlinearSolveFirstOrder" uuid = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d" +version = "1.7.0" authors = ["Avik Pal and contributors"] version = "1.8.1" @@ -8,12 +9,11 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -35,8 +35,6 @@ BandedMatrices = "1.7.5" BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DifferentiationInterface = "0.7.3" -DiffEqBase = "6.188" Enzyme = "0.13.12" ExplicitImports = "1.5" FiniteDiff = "2.24" @@ -59,7 +57,7 @@ SciMLBase = "2.116" SciMLJacobianOperators = "0.1.0" Setfield = "1.1.1" SparseArrays = "1.10" -SparseConnectivityTracer = "1" +SparseConnectivityTracer = "1, 1" SparseMatrixColorings = "0.4.5" StableRNGs = "1" StaticArrays = "1.9.8" @@ -73,9 +71,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" From 7272bb205552a972c069c6b56227be0c650001e2 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:42:46 -0400 Subject: [PATCH 20/61] remove DiffEqBase from Project.toml QuasiNewton --- lib/NonlinearSolveQuasiNewton/Project.toml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 7a6194560..1562f1d96 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -1,5 +1,6 @@ name = "NonlinearSolveQuasiNewton" uuid = "9a2c21bd-3a47-402d-9113-8faf9a0ee114" +version = "1.8.0" authors = ["Avik Pal and contributors"] version = "1.8.1" @@ -7,7 +8,6 @@ version = "1.8.1" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" @@ -21,8 +21,8 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -[sources.NonlinearSolveBase] -path = "../NonlinearSolveBase" +[sources] +NonlinearSolveBase = {path = "../NonlinearSolveBase"} [extensions] NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff" @@ -34,7 +34,6 @@ ArrayInterface = "7.16.0" BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" Enzyme = "0.13.12" ExplicitImports = "1.5" FiniteDiff = "2.24" From 5a38f54f72d2026e27aa3230ffd5bbe7e43c4a21 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:48:31 -0400 Subject: [PATCH 21/61] remove DiffEqBase dependency in NonlinearSolveBase --- lib/NonlinearSolveBase/Project.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 663ff2e62..69fe84f6c 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,5 +1,6 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" +version = "1.14.0" authors = ["Avik Pal and contributors"] version = "1.14.1" @@ -30,7 +31,6 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -39,13 +39,12 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [extensions] NonlinearSolveBaseBandedMatricesExt = "BandedMatrices" -NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" +NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore" NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" -NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore" [compat] ADTypes = "1.9" From db1bd2fc07f233199c90945d7f5eb7f39f8ec11e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 25 Aug 2025 15:48:42 -0400 Subject: [PATCH 22/61] get rid of callback comments --- lib/NonlinearSolveBase/src/solve.jl | 50 ----------------------------- 1 file changed, 50 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 8d420e7d7..e57895465 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -337,15 +337,6 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi _prob.kwargs[:kwargshandle] : kwargshandle if has_kwargs(_prob) - # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - # kwargs_temp = NamedTuple{ - # Base.diff_names(Base._nt_names(values(kwargs)), - # (:callback,))}(values(kwargs)) - # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - # _prob.kwargs[:callback], - # values(kwargs).callback),)) - # kwargs = merge(kwargs_temp, callbacks) - # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -429,15 +420,6 @@ function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle if has_kwargs(_prob) - # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - # kwargs_temp = NamedTuple{ - # Base.diff_names(Base._nt_names(values(kwargs)), - # (:callback,))}(values(kwargs)) - # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - # _prob.kwargs[:callback], - # values(kwargs).callback),)) - # kwargs = merge(kwargs_temp, callbacks) - # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -820,15 +802,6 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba end if has_kwargs(_prob) - # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - # kwargs_temp = NamedTuple{ - # Base.diff_names(Base._nt_names(values(kwargs)), - # (:callback,))}(values(kwargs)) - # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - # _prob.kwargs[:callback], - # values(kwargs).callback),)) - # kwargs = merge(kwargs_temp, callbacks) - # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -851,15 +824,6 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end if has_kwargs(_prob) - # if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) - # kwargs_temp = NamedTuple{ - # Base.diff_names(Base._nt_names(values(kwargs)), - # (:callback,))}(values(kwargs)) - # callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( - # _prob.kwargs[:callback], - # values(kwargs).callback),)) - # kwargs = merge(kwargs_temp, callbacks) - # end kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) end @@ -970,20 +934,6 @@ function build_null_solution( SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) end -#TODO: THIS SHOULD GO IN SCIMLBASE. THIS IS TEMPORARY FOR TESTING PURPOSES. REMOVE -# function solve(prob::EnsembleProblem, args...; kwargs...) -# alg = extract_alg(args, kwargs, kwargs) -# if length(args) > 1 -# __solve(prob, alg, Base.tail(args)...; kwargs...) -# else -# __solve(prob, alg; kwargs...) -# end -# end - -# function solve(prob::WeightedEnsembleProblem, args...; kwargs...) -# SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) -# end - # @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) # if isempty(solve_args) || isnothing(first(solve_args)) # if haskey(solve_kwargs, :alg) From 34b05d0249acd60e09f38f70df5771d9738e098b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 27 Aug 2025 10:11:33 -0400 Subject: [PATCH 23/61] add extensions for autodiff --- .../ext/NonlinearSolveBaseEnzymeExt.jl | 61 +++++++++++ .../ext/NonlinearSolveBaseMooncakeExt.jl | 31 ++++++ .../ext/NonlinearSolveBaseReverseDiffExt.jl | 103 ++++++++++++++++++ .../ext/NonlinearSolveBaseTrackerExt.jl | 49 +++++++++ 4 files changed, 244 insertions(+) create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl new file mode 100644 index 000000000..95122874c --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -0,0 +1,61 @@ +module NonlinearSolveBaseEnzymeExt + +@static if isempty(VERSION.prerelease) + using NonlinearSolveBase + import SciMLBase: value + using Enzyme + import Enzyme: Const + using ChainRulesCore + + function Enzyme.EnzymeRules.augmented_primal( + config::Enzyme.EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{Duplicated{RT}}, prob, + sensealg::Union{ + Const{Nothing}, Const{<:SciMLBase.AbstractSensitivityAlgorithm}}, + u0, p, args...; kwargs...) where {RT} + @inline function copy_or_reuse(val, idx) + if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) + return deepcopy(val) + else + return val + end + end + + @inline function arg_copy(i) + copy_or_reuse(args[i].val, i + 5) + end + + res = DiffEqBase._solve_adjoint( + copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), + copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), + SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; + kwargs...) + + dres = Enzyme.make_zero(res[1])::RT + tup = (dres, res[2]) + return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) + end + + function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, + sensealg::Union{ + Const{Nothing}, Const{<:SciMLBase.AbstractSensitivityAlgorithm}}, + u0, p, args...; kwargs...) where {RT} + dres, clos = tape + dres = dres::RT + dargs = clos(dres) + for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) + if ptr isa Enzyme.Const + continue + end + if darg == ChainRulesCore.NoTangent() + continue + end + ptr.dval .+= darg + end + Enzyme.make_zero!(dres.u) + return ntuple(_ -> nothing, Val(length(args) + 4)) + end +end + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl new file mode 100644 index 000000000..3737a16cc --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl @@ -0,0 +1,31 @@ +module NonlinearSolveBaseMooncakeExt + +using NonlinearSolveBase, Mooncake +using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator +import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, + @from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx, + NoPullback + +@from_rrule(MinimalCtx, + Tuple{ + typeof(NonlinearSolveBase.solve_up), + SciMLBase.AbstractDEProblem, + Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm}, + Any, + Any, + Any + }, + true,) + +# Dispatch for auto-alg +@from_rrule(MinimalCtx, + Tuple{ + typeof(NonlinearSolveBase.solve_up), + SciMLBase.AbstractDEProblem, + Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm}, + Any, + Any + }, + true,) + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl new file mode 100644 index 000000000..63cfc5098 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl @@ -0,0 +1,103 @@ +module NonlinearSolveBaseReverseDiffExt + +using NonlinearSolveBase +import SciMLBase: value +import ReverseDiff +import ArrayInterface + +# `ReverseDiff.TrackedArray` +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::ReverseDiff.TrackedArray, args...; kwargs...) + ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::ReverseDiff.TrackedArray, + args...; kwargs...) + ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, p, + args...; kwargs...) + ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +# `AbstractArray{<:ReverseDiff.TrackedReal}` +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; + kwargs...) + SciMLBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), + ArrayInterface.aos_to_soa(p), args...; + kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) + SciMLBase.solve_up( + prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) + SciMLBase.solve_up( + prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, + args...; kwargs...) + SciMLBase.solve_up( + prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, + args...; kwargs...) + SciMLBase.solve_up( + prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) +end + +# Required becase ReverseDiff.@grad function SciMLBase.solve_up is not supported! +import SciMLBase: solve_up +ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) + out = SciMLBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), + ReverseDiff.value(p), + SciMLBase.ReverseDiffOriginator(), args...; kwargs...) + function actual_adjoint(_args...) + original_adjoint = out[2](_args...) + if isempty(args) # alg is missing + tuple(original_adjoint[1:4]..., original_adjoint[6:end]...) + else + original_adjoint + end + end + Array(out[1]), actual_adjoint +end + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl new file mode 100644 index 000000000..e62a2cc65 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl @@ -0,0 +1,49 @@ +module NonlinearSolveBaseTrackerExt + +using NonlinearSolveBase +import SciMLBase: value +import Tracker + +function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, + p::Tracker.TrackedArray, args...; kwargs...) + Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, p, args...; + kwargs...) + Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::Tracker.TrackedArray, args...; + kwargs...) + Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +Tracker.@grad function SciMLBase.solve_up(prob, + sensealg::Union{Nothing, + SciMLBase.AbstractOverloadingSensitivityAlgorithm + }, + u0, p, args...; + kwargs...) + sol, + pb_f = SciMLBase._solve_adjoint( + prob, sensealg, Tracker.data(u0), Tracker.data(p), + SciMLBase.TrackerOriginator(), args...; kwargs...) + + if sol isa AbstractArray + !hasfield(typeof(sol), :u) && return sol, pb_f # being safe here + return sol.u, pb_f # AbstractNoTimeSolution isa AbstractArray + end + return convert(AbstractArray, sol), pb_f +end + +end From 3f94489160f518c78ef907f6f52d9e9a4f13d35c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 21:41:08 -0400 Subject: [PATCH 24/61] fix adjoints --- lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl | 2 +- .../ext/NonlinearSolveBaseTrackerExt.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index 95122874c..f8128854c 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -25,7 +25,7 @@ module NonlinearSolveBaseEnzymeExt copy_or_reuse(args[i].val, i + 5) end - res = DiffEqBase._solve_adjoint( + res = NonlinearSolveBase._solve_adjoint( copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl index e62a2cc65..8fe2a10be 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl @@ -4,7 +4,7 @@ using NonlinearSolveBase import SciMLBase: value import Tracker -function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::Tracker.TrackedArray, @@ -12,7 +12,7 @@ function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::Tracker.TrackedArray, p, args...; @@ -20,7 +20,7 @@ function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function SciMLBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, p::Tracker.TrackedArray, args...; @@ -35,7 +35,7 @@ Tracker.@grad function SciMLBase.solve_up(prob, u0, p, args...; kwargs...) sol, - pb_f = SciMLBase._solve_adjoint( + pb_f = NonlinearSolveBase._solve_adjoint( prob, sensealg, Tracker.data(u0), Tracker.data(p), SciMLBase.TrackerOriginator(), args...; kwargs...) From bc32e6d3fcedd685b7c5821f458c8616bf457792 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 21:42:12 -0400 Subject: [PATCH 25/61] imports for concrete solves --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 7f4d21044..5b8549dd8 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -17,7 +17,7 @@ using EnzymeCore: EnzymeCore using MaybeInplace: @bb using RecursiveArrayTools: RecursiveArrayTools, AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, - AbstractNonlinearAlgorithm, + AbstractNonlinearAlgorithm, _concrete_solve_adjoint, _concrete_solve_forward, NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction, NLStats, LinearProblem, LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, From c2ea203be899c111de6d84bc5b960ee6e0d31c68 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 21:51:56 -0400 Subject: [PATCH 26/61] fix project.toml --- lib/NonlinearSolveFirstOrder/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index 9f134628e..d066c3794 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -1,6 +1,5 @@ name = "NonlinearSolveFirstOrder" uuid = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d" -version = "1.7.0" authors = ["Avik Pal and contributors"] version = "1.8.1" From 7d75320729865aad7e6b1021a757b59ecafab5c7 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 21:58:28 -0400 Subject: [PATCH 27/61] get rid of stale DiffEqBase deps --- Project.toml | 1 - lib/NonlinearSolveSpectralMethods/Project.toml | 2 -- 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index 90c87d393..c045fd2e9 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index f4e17f1bc..378bf65f0 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -6,7 +6,6 @@ version = "1.3.1" [deps] CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" @@ -28,7 +27,6 @@ Aqua = "0.8" BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" ExplicitImports = "1.5" ForwardDiff = "0.10.36, 1" Hwloc = "3" From 60654d7026a66a20fc1ce66080f4094f54cfc430 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 29 Aug 2025 09:17:20 -0400 Subject: [PATCH 28/61] add missing function --- lib/NonlinearSolveBase/src/solve.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index e57895465..23d1729c4 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -934,6 +934,21 @@ function build_null_solution( SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) end +function hack_null_solution_init(prob::Union{NonlinearProblem, NonlinearLeastSquareProblem, SteadyStateProblem}) + if SciMLBase.has_initialization_data(prob.f) + initializeprob = prob.f.initialization_data.initializeprob + nlsol = solve(initializeprob) + success = SciMLBase.successful_retcode(nlsol) + if prob.f.initialization_data.initializeprobpmap !== nothing + @set! prob.p = prob.f.initializeprobpmap(prob, nlsol) + end + else + success = true + end + return prob, success +end + + # @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) # if isempty(solve_args) || isnothing(first(solve_args)) # if haskey(solve_kwargs, :alg) From 7f7cf66c45063bcc9b67ec5c4449f1b385caf4a1 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 29 Aug 2025 09:38:10 -0400 Subject: [PATCH 29/61] use Setfield --- lib/NonlinearSolveBase/Project.toml | 2 ++ lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 1 + lib/NonlinearSolveBase/src/solve.jl | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 69fe84f6c..5eb7fde58 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -24,6 +24,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" @@ -75,6 +76,7 @@ SciMLBase = "2.116" SciMLJacobianOperators = "0.1.1" SciMLOperators = "1.7" SciMLStructures = "1.5" +Setfield = "1.1.2" SparseArrays = "1.10" SparseMatrixColorings = "0.4.5" StaticArraysCore = "1.4" diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 5b8549dd8..4c7764562 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -28,6 +28,7 @@ using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface import SciMLStructures +using Setfield: @set! using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul! using Markdown: @doc_str diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 23d1729c4..4d7d152a4 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -934,7 +934,7 @@ function build_null_solution( SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) end -function hack_null_solution_init(prob::Union{NonlinearProblem, NonlinearLeastSquareProblem, SteadyStateProblem}) +function hack_null_solution_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, SteadyStateProblem}) if SciMLBase.has_initialization_data(prob.f) initializeprob = prob.f.initialization_data.initializeprob nlsol = solve(initializeprob) From f8099300033384f31398af01566fcdf8221f18b0 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 2 Sep 2025 11:55:04 -0400 Subject: [PATCH 30/61] remove commented code --- lib/NonlinearSolveBase/src/solve.jl | 128 ---------------------------- 1 file changed, 128 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 4d7d152a4..a4492f589 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -947,131 +947,3 @@ function hack_null_solution_init(prob::Union{NonlinearProblem, NonlinearLeastSqu end return prob, success end - - -# @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) -# if isempty(solve_args) || isnothing(first(solve_args)) -# if haskey(solve_kwargs, :alg) -# solve_kwargs[:alg] -# elseif haskey(prob_kwargs, :alg) -# prob_kwargs[:alg] -# else -# nothing -# end -# elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && -# !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) -# first(solve_args) -# else -# nothing -# end -# end - -# function get_concrete_u0(prob, isadapt, t0, kwargs) -# if eval_u0(prob.u0) -# u0 = prob.u0(prob.p, t0) -# elseif haskey(kwargs, :u0) -# u0 = kwargs[:u0] -# else -# u0 = prob.u0 -# end - -# isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - -# _u0 = handle_distribution_u0(u0) - -# if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) -# throw(IncompatibleInitialConditionError()) -# end - -# if _u0 isa Tuple -# throw(TupleStateError()) -# end - -# _u0 -# end - -# function get_concrete_p(prob, kwargs) -# if haskey(kwargs, :p) -# p = kwargs[:p] -# else -# p = prob.p -# end -# end - -# eval_u0(u0::Function) = true -# eval_u0(u0) = false - -# handle_distribution_u0(_u0) = _u0 - -# anyeltypedual(x) = anyeltypedual(x, Val{0}) -# anyeltypedual(x, counter) = Any - -# function promote_u0(u0, p, t0) -# if SciMLStructures.isscimlstructure(p) -# _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] -# if !isequal(_p, p) -# return promote_u0(u0, _p, t0) -# end -# end -# Tu = eltype(u0) -# if isdualtype(Tu) -# return u0 -# end -# Tp = anyeltypedual(p, Val{0}) -# if Tp == Any -# Tp = Tu -# end -# Tt = anyeltypedual(t0, Val{0}) -# if Tt == Any -# Tt = Tu -# end -# Tcommon = promote_type(Tu, Tp, Tt) -# return if isdualtype(Tcommon) -# Tcommon.(u0) -# else -# u0 -# end -# end - -# function promote_u0(u0::AbstractArray{<:Complex}, p, t0) -# if SciMLStructures.isscimlstructure(p) -# _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] -# if !isequal(_p, p) -# return promote_u0(u0, _p, t0) -# end -# end -# Tu = real(eltype(u0)) -# if isdualtype(Tu) -# return u0 -# end -# Tp = anyeltypedual(p, Val{0}) -# if Tp == Any -# Tp = Tu -# end -# Tt = anyeltypedual(t0, Val{0}) -# if Tt == Any -# Tt = Tu -# end -# Tcommon = promote_type(eltype(u0), Tp, Tt) -# return if isdualtype(real(Tcommon)) -# Tcommon.(u0) -# else -# u0 -# end -# end - -# function checkkwargs(kwargshandle; kwargs...) -# if any(x -> x ∉ allowedkeywords, keys(kwargs)) -# if kwargshandle == KeywordArgError -# throw(CommonKwargError(kwargs)) -# elseif kwargshandle == KeywordArgWarn -# @warn KWARGWARN_MESSAGE -# unrecognized = setdiff(keys(kwargs), allowedkeywords) -# print("Unrecognized keyword arguments: ") -# printstyled(unrecognized; bold = true, color = :red) -# print("\n\n") -# else -# @assert kwargshandle == KeywordArgSilent -# end -# end -# end \ No newline at end of file From 23bb2d703201e9c451c357fdb139700ae01696bb Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 2 Sep 2025 14:41:09 -0400 Subject: [PATCH 31/61] no errors --- lib/NonlinearSolveBase/src/solve.jl | 416 ++++++++++++++-------------- 1 file changed, 208 insertions(+), 208 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index a4492f589..b4d8c13fb 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -1,215 +1,215 @@ -const allowedkeywords = (:dense, - :saveat, - :save_idxs, - :tstops, - :tspan, - :d_discontinuities, - :save_everystep, - :save_on, - :save_start, - :save_end, - :initialize_save, - :adaptive, - :abstol, - :reltol, - :dt, - :dtmax, - :dtmin, - :force_dtmin, - :internalnorm, - :controller, - :gamma, - :beta1, - :beta2, - :qmax, - :qmin, - :qsteady_min, - :qsteady_max, - :qoldinit, - :failfactor, - :calck, - :alias_u0, - :maxiters, - :maxtime, - :callback, - :isoutofdomain, - :unstable_check, - :verbose, - :merge_callbacks, - :progress, - :progress_steps, - :progress_name, - :progress_message, - :progress_id, - :timeseries_errors, - :dense_errors, - :weak_timeseries_errors, - :weak_dense_errors, - :wrap, - :calculate_error, - :initializealg, - :alg, - :save_noise, - :delta, - :seed, - :alg_hints, - :kwargshandle, - :trajectories, - :batch_size, - :sensealg, - :advance_to_tstop, - :stop_at_next_tstop, - :u0, - :p, - # These two are from the default algorithm handling - :default_set, - :second_time, - # This is for DiffEqDevTools - :prob_choice, - # Jump problems - :alias_jump, - # This is for copying/deepcopying noise in StochasticDiffEq - :alias_noise, - # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves - :batch, - # Shooting method in BVP needs to differentiate between these two categories - :nlsolve_kwargs, - :odesolve_kwargs, - # If Solvers which internally use linsolve - :linsolve_kwargs, - # Solvers internally using EnsembleProblem - :ensemblealg, - # Fine Grained Control of Tracing (Storing and Logging) during Solve - :show_trace, - :trace_level, - :store_trace, - # Termination condition for solvers - :termination_condition, - # For AbstractAliasSpecifier - :alias, - # Parameter estimation with BVP - :fit_parameters) - -const KWARGWARN_MESSAGE = """ -Unrecognized keyword arguments found. -The only allowed keyword arguments to `solve` are: -$allowedkeywords - -See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. - -Set kwargshandle=KeywordArgError for an error message. -Set kwargshandle=KeywordArgSilent to ignore this message. -""" - -const KWARGERROR_MESSAGE = """ - Unrecognized keyword arguments found. - The only allowed keyword arguments to `solve` are: - $allowedkeywords - - See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. - """ - -struct CommonKwargError <: Exception - kwargs::Any -end - -function Base.showerror(io::IO, e::CommonKwargError) - println(io, KWARGERROR_MESSAGE) - notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) - unrecognized = collect(keys(e.kwargs))[notin] - print(io, "Unrecognized keyword arguments: ") - printstyled(io, unrecognized; bold = true, color = :red) - print(io, "\n\n") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -@enum KeywordArgError KeywordArgWarn KeywordArgSilent - -const INCOMPATIBLE_U0_MESSAGE = """ - Initial condition incompatible with functional form. - Detected an in-place function with an initial condition of type Number or SArray. - This is incompatible because Numbers cannot be mutated, i.e. - `x = 2.0; y = 2.0; x .= y` will error. - - If using a immutable initial condition type, please use the out-of-place form. - I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. - - If your differential equation function was defined with multiple dispatches and one is - in-place, then the automatic detection will choose in-place. In this case, override the - choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. - - For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: - https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation - """ - -struct IncompatibleInitialConditionError <: Exception end - -function Base.showerror(io::IO, e::IncompatibleInitialConditionError) - print(io, INCOMPATIBLE_U0_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NO_DEFAULT_ALGORITHM_MESSAGE = """ - Default algorithm choices require NonlinearSolve.jl. - Please specify an algorithm (e.g., `solve(prob, NewtonRaphson())` or - init(prob, NewtonRaphson()) or - import NonlinearSolve.jl directly. - - You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ - and its associated pages. - """ - -struct NoDefaultAlgorithmError <: Exception end +# const allowedkeywords = (:dense, +# :saveat, +# :save_idxs, +# :tstops, +# :tspan, +# :d_discontinuities, +# :save_everystep, +# :save_on, +# :save_start, +# :save_end, +# :initialize_save, +# :adaptive, +# :abstol, +# :reltol, +# :dt, +# :dtmax, +# :dtmin, +# :force_dtmin, +# :internalnorm, +# :controller, +# :gamma, +# :beta1, +# :beta2, +# :qmax, +# :qmin, +# :qsteady_min, +# :qsteady_max, +# :qoldinit, +# :failfactor, +# :calck, +# :alias_u0, +# :maxiters, +# :maxtime, +# :callback, +# :isoutofdomain, +# :unstable_check, +# :verbose, +# :merge_callbacks, +# :progress, +# :progress_steps, +# :progress_name, +# :progress_message, +# :progress_id, +# :timeseries_errors, +# :dense_errors, +# :weak_timeseries_errors, +# :weak_dense_errors, +# :wrap, +# :calculate_error, +# :initializealg, +# :alg, +# :save_noise, +# :delta, +# :seed, +# :alg_hints, +# :kwargshandle, +# :trajectories, +# :batch_size, +# :sensealg, +# :advance_to_tstop, +# :stop_at_next_tstop, +# :u0, +# :p, +# # These two are from the default algorithm handling +# :default_set, +# :second_time, +# # This is for DiffEqDevTools +# :prob_choice, +# # Jump problems +# :alias_jump, +# # This is for copying/deepcopying noise in StochasticDiffEq +# :alias_noise, +# # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves +# :batch, +# # Shooting method in BVP needs to differentiate between these two categories +# :nlsolve_kwargs, +# :odesolve_kwargs, +# # If Solvers which internally use linsolve +# :linsolve_kwargs, +# # Solvers internally using EnsembleProblem +# :ensemblealg, +# # Fine Grained Control of Tracing (Storing and Logging) during Solve +# :show_trace, +# :trace_level, +# :store_trace, +# # Termination condition for solvers +# :termination_condition, +# # For AbstractAliasSpecifier +# :alias, +# # Parameter estimation with BVP +# :fit_parameters) + +# const KWARGWARN_MESSAGE = """ +# Unrecognized keyword arguments found. +# The only allowed keyword arguments to `solve` are: +# $allowedkeywords + +# See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. + +# Set kwargshandle=KeywordArgError for an error message. +# Set kwargshandle=KeywordArgSilent to ignore this message. +# """ + +# const KWARGERROR_MESSAGE = """ +# Unrecognized keyword arguments found. +# The only allowed keyword arguments to `solve` are: +# $allowedkeywords + +# See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. +# """ + +# struct CommonKwargError <: Exception +# kwargs::Any +# end + +# function Base.showerror(io::IO, e::CommonKwargError) +# println(io, KWARGERROR_MESSAGE) +# notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) +# unrecognized = collect(keys(e.kwargs))[notin] +# print(io, "Unrecognized keyword arguments: ") +# printstyled(io, unrecognized; bold = true, color = :red) +# print(io, "\n\n") +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# @enum KeywordArgError KeywordArgWarn KeywordArgSilent + +# const INCOMPATIBLE_U0_MESSAGE = """ +# Initial condition incompatible with functional form. +# Detected an in-place function with an initial condition of type Number or SArray. +# This is incompatible because Numbers cannot be mutated, i.e. +# `x = 2.0; y = 2.0; x .= y` will error. + +# If using a immutable initial condition type, please use the out-of-place form. +# I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. + +# If your differential equation function was defined with multiple dispatches and one is +# in-place, then the automatic detection will choose in-place. In this case, override the +# choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. + +# For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: +# https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation +# """ + +# struct IncompatibleInitialConditionError <: Exception end + +# function Base.showerror(io::IO, e::IncompatibleInitialConditionError) +# print(io, INCOMPATIBLE_U0_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NO_DEFAULT_ALGORITHM_MESSAGE = """ +# Default algorithm choices require NonlinearSolve.jl. +# Please specify an algorithm (e.g., `solve(prob, NewtonRaphson())` or +# init(prob, NewtonRaphson()) or +# import NonlinearSolve.jl directly. + +# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ +# and its associated pages. +# """ + +# struct NoDefaultAlgorithmError <: Exception end + +# function Base.showerror(io::IO, e::NoDefaultAlgorithmError) +# print(io, NO_DEFAULT_ALGORITHM_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NON_SOLVER_MESSAGE = """ +# The arguments to solve are incorrect. +# The second argument must be a solver choice, `solve(prob,alg)` +# where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. + +# Please double check the arguments being sent to the solver. + +# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ +# and its associated pages. +# """ + +# struct NonSolverError <: Exception end + +# function Base.showerror(io::IO, e::NonSolverError) +# print(io, NON_SOLVER_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ +# Incompatible solver + automatic differentiation pairing. +# The chosen automatic differentiation algorithm requires the ability +# for compiler transforms on the code which is only possible on pure-Julia +# solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods +# which require this ability include: + +# - Direct use of ForwardDiff.jl on the solver +# - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` +# sensealg choices for adjoint differentiation. -function Base.showerror(io::IO, e::NoDefaultAlgorithmError) - print(io, NO_DEFAULT_ALGORITHM_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end +# Either switch the choice of solver to a pure Julia method, or change the automatic +# differentiation method to one that does not require such transformations. + +# For more details on automatic differentiation, adjoint, and sensitivity analysis +# of differential equations, see the documentation page: -const NON_SOLVER_MESSAGE = """ - The arguments to solve are incorrect. - The second argument must be a solver choice, `solve(prob,alg)` - where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. +# https://diffeq.sciml.ai/stable/analysis/sensitivity/ +# """ - Please double check the arguments being sent to the solver. - - You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ - and its associated pages. - """ - -struct NonSolverError <: Exception end - -function Base.showerror(io::IO, e::NonSolverError) - print(io, NON_SOLVER_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end +# struct DirectAutodiffError <: Exception end -const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ - Incompatible solver + automatic differentiation pairing. - The chosen automatic differentiation algorithm requires the ability - for compiler transforms on the code which is only possible on pure-Julia - solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods - which require this ability include: - - - Direct use of ForwardDiff.jl on the solver - - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` - sensealg choices for adjoint differentiation. - - Either switch the choice of solver to a pure Julia method, or change the automatic - differentiation method to one that does not require such transformations. - - For more details on automatic differentiation, adjoint, and sensitivity analysis - of differential equations, see the documentation page: - - https://diffeq.sciml.ai/stable/analysis/sensitivity/ - """ - -struct DirectAutodiffError <: Exception end - -function Base.showerror(io::IO, e::DirectAutodiffError) - println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end +# function Base.showerror(io::IO, e::DirectAutodiffError) +# println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end struct EvalFunc{F} <: Function f::F From 77cb606039565a299414b9d982f74d75b5edefea Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 2 Sep 2025 14:49:52 -0400 Subject: [PATCH 32/61] import SciMLBase error messages --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 4c7764562..43ef1e2e3 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -22,7 +22,11 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear NonlinearFunction, NLStats, LinearProblem, LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, - has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem + has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, + allowedkeywords, KWARGWARN_MESSAGE, KWARGERROR_MESSAGE, CommonKwargError, + KeywordArgError, KeywordArgWarn, KeywordArgSilent, INCOMPATIBLE_U0_MESSAGE, + IncompatibleInitialConditionError, NO_DEFAULT_ALGORITHM_MESSAGE, NoDefaultAlgorithmError, + NON_SOLVER_MESSAGE, NonSolverError, DIRECT_AUTODIFF_INCOMPATIBILITY_MESSAGE, DirectAutodiffError import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator From 9d9648a03c63887236707b44dbe17e5845239ba4 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 2 Sep 2025 16:34:28 -0400 Subject: [PATCH 33/61] fix error imports --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 43ef1e2e3..751271f34 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -28,6 +28,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear IncompatibleInitialConditionError, NO_DEFAULT_ALGORITHM_MESSAGE, NoDefaultAlgorithmError, NON_SOLVER_MESSAGE, NonSolverError, DIRECT_AUTODIFF_INCOMPATIBILITY_MESSAGE, DirectAutodiffError import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake + using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface From fc1336fb93ef1f73d0486f0f0390d970f288752b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 2 Sep 2025 16:43:20 -0400 Subject: [PATCH 34/61] remove stale error imports --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 751271f34..8f660379f 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -23,10 +23,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, - allowedkeywords, KWARGWARN_MESSAGE, KWARGERROR_MESSAGE, CommonKwargError, - KeywordArgError, KeywordArgWarn, KeywordArgSilent, INCOMPATIBLE_U0_MESSAGE, - IncompatibleInitialConditionError, NO_DEFAULT_ALGORITHM_MESSAGE, NoDefaultAlgorithmError, - NON_SOLVER_MESSAGE, NonSolverError, DIRECT_AUTODIFF_INCOMPATIBILITY_MESSAGE, DirectAutodiffError + NoDefaultAlgorithmError, NonSolverError import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator From 6671263549bba51dcd7fd8c3a54552302e8627b2 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 2 Sep 2025 18:24:54 -0400 Subject: [PATCH 35/61] put KeywordArgError back --- lib/NonlinearSolveBase/src/NonlinearSolveBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 8f660379f..1e215e275 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -23,7 +23,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, - NoDefaultAlgorithmError, NonSolverError + NoDefaultAlgorithmError, NonSolverError, KeywordArgError import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator From cd09c2c488719bb2905839589b3634d0d1c5e425 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 4 Sep 2025 13:39:16 -0400 Subject: [PATCH 36/61] fix adjoints --- .../src/NonlinearSolveBase.jl | 2 +- lib/NonlinearSolveBase/src/solve.jl | 6 +++--- .../src/termination_conditions.jl | 2 +- .../SimpleNonlinearSolveChainRulesCoreExt.jl | 7 +++---- .../ext/SimpleNonlinearSolveDiffEqBaseExt.jl | 10 +++++----- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 6 +++--- .../ext/SimpleNonlinearSolveTrackerExt.jl | 6 +++--- .../src/SimpleNonlinearSolve.jl | 10 +++++----- .../test/core/adjoint_tests.jl | 2 +- test/adjoint_tests.jl | 19 +++++++++++++++++++ 10 files changed, 44 insertions(+), 26 deletions(-) create mode 100644 test/adjoint_tests.jl diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 1e215e275..eac8dc1e9 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -23,7 +23,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, promote_u0, get_concrete_u0, get_concrete_p, has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, - NoDefaultAlgorithmError, NonSolverError, KeywordArgError + NoDefaultAlgorithmError, NonSolverError, KeywordArgError, AbstractDEAlgorithm import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index b4d8c13fb..020a04da6 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -794,11 +794,11 @@ end function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) - if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) else - _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) end if has_kwargs(_prob) diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index 771c2517f..8b050b3b0 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -82,7 +82,7 @@ function CommonSolve.init( length(saved_value_prototype) == 0 && (saved_value_prototype = nothing) leastsq = typeof(prob) <: NonlinearLeastSquaresProblem - + Main.@infiltrate return NonlinearTerminationModeCache( u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0, saved_value_prototype, diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl index a9d86ea84..efb4cace7 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -2,11 +2,10 @@ module SimpleNonlinearSolveChainRulesCoreExt using ChainRulesCore: ChainRulesCore, NoTangent -using NonlinearSolveBase: ImmutableNonlinearProblem +using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem -using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up, - solve_adjoint +using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up function ChainRulesCore.rrule( ::typeof(simplenonlinearsolve_solve_up), @@ -14,7 +13,7 @@ function ChainRulesCore.rrule( sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs... ) out, - ∇internal = solve_adjoint( + ∇internal = _solve_adjoint( prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs... ) function ∇simplenonlinearsolve_solve_up(Δ) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl index 5326b0a88..70bd1b9fd 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl @@ -2,12 +2,12 @@ module SimpleNonlinearSolveDiffEqBaseExt #using DiffEqBase: DiffEqBase -using SimpleNonlinearSolve: SimpleNonlinearSolve +# using SimpleNonlinearSolve: SimpleNonlinearSolve -SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true +# SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true -function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...) - return DiffEqBase._solve_adjoint(args...; kwargs...) -end +# function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...) +# return DiffEqBase._solve_adjoint(args...; kwargs...) +# end end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index dca55621a..27e1cc1ac 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -1,12 +1,12 @@ module SimpleNonlinearSolveReverseDiffExt -using NonlinearSolveBase: ImmutableNonlinearProblem +using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake using ArrayInterface: ArrayInterface using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal -using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint +using SimpleNonlinearSolve: SimpleNonlinearSolve import SimpleNonlinearSolve: simplenonlinearsolve_solve_up for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) @@ -27,7 +27,7 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp) prob = remake(tprob; u0, p) out, - ∇internal = solve_adjoint( + ∇internal = _solve_adjoint( prob, sensealg, u0, p, ReverseDiffOriginator(), alg, args...; kwargs...) function ∇simplenonlinearsolve_solve_up(Δ...) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index 551d7080a..9f71c4f55 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -1,12 +1,12 @@ module SimpleNonlinearSolveTrackerExt -using NonlinearSolveBase: ImmutableNonlinearProblem +using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake using ArrayInterface: ArrayInterface using Tracker: Tracker, TrackedArray, TrackedReal -using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint +using SimpleNonlinearSolve: SimpleNonlinearSolve for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any) @@ -26,7 +26,7 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) u0, p = Tracker.data(tu0), Tracker.data(tp) prob = remake(tprob; u0, p) out, - ∇internal = solve_adjoint( + ∇internal = _solve_adjoint( prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...) function ∇simplenonlinearsolve_solve_up(Δ) diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 7a8c5a308..9f0e1c46f 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -126,12 +126,12 @@ end # NOTE: This is defined like this so that we don't have to keep have 2 args for the # extensions -function solve_adjoint(args...; kws...) - is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...) - error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") -end +# function solve_adjoint(args...; kws...) +# is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...) +# error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") +# end -function solve_adjoint_internal end +# function solve_adjoint_internal end @setup_workload begin for T in (Float64,) diff --git a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl index 1580ade60..c56850eb5 100644 --- a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl @@ -1,5 +1,5 @@ @testitem "Simple Adjoint Test" tags=[:adjoint] begin - using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, DiffEqBase + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote ff(u, p) = u .^ 2 .- p diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl new file mode 100644 index 000000000..ddfabed45 --- /dev/null +++ b/test/adjoint_tests.jl @@ -0,0 +1,19 @@ +using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test + +ff(u, p) = u .^ 2 .- p + +function solve_nlprob(p) + prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) + sol = solve(prob, NewtonRaphson()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) +end + +p = [3.0, 2.0] + +∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) +∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) +∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) +∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) +@test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff +@test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff \ No newline at end of file From ac5c542c7c2f16c935f4a33c4b5e6a3aac6bb90f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 4 Sep 2025 13:51:20 -0400 Subject: [PATCH 37/61] rm infiltrate --- lib/NonlinearSolveBase/src/termination_conditions.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index 8b050b3b0..5442f22dd 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -82,7 +82,6 @@ function CommonSolve.init( length(saved_value_prototype) == 0 && (saved_value_prototype = nothing) leastsq = typeof(prob) <: NonlinearLeastSquaresProblem - Main.@infiltrate return NonlinearTerminationModeCache( u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0, saved_value_prototype, From aa4bdc1b8f39449fa9d8dbe1bfde2ee00230bfef Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 4 Sep 2025 15:05:10 -0400 Subject: [PATCH 38/61] add weakdeps, fix AD extensions --- lib/NonlinearSolveBase/Project.toml | 8 +++ .../ext/NonlinearSolveBaseEnzymeExt.jl | 2 +- .../ext/NonlinearSolveBaseReverseDiffExt.jl | 68 +++++++++---------- .../ext/NonlinearSolveBaseTrackerExt.jl | 16 ++--- 4 files changed, 51 insertions(+), 43 deletions(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 5eb7fde58..3c7965fce 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -32,20 +32,27 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] NonlinearSolveBaseBandedMatricesExt = "BandedMatrices" NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore" +NonlinearSolveBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" +NonlinearSolveBaseReverseDiffExt = "ReverseDiff" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" +NonlinearSolveBaseTrackerExt = "Tracker" + [compat] ADTypes = "1.9" @@ -82,6 +89,7 @@ SparseMatrixColorings = "0.4.5" StaticArraysCore = "1.4" SymbolicIndexingInterface = "0.3.43" Test = "1.10" +Tracker = "0.2.35" TimerOutputs = "0.5.23" julia = "1.10" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl index f8128854c..3bdc7f42e 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -2,7 +2,7 @@ module NonlinearSolveBaseEnzymeExt @static if isempty(VERSION.prerelease) using NonlinearSolveBase - import SciMLBase: value + import SciMLBase: SciMLBase, value using Enzyme import Enzyme: Const using ChainRulesCore diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl index 63cfc5098..fdfe774db 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl @@ -1,92 +1,92 @@ module NonlinearSolveBaseReverseDiffExt -using NonlinearSolveBase -import SciMLBase: value +using NonlinearSolveBase +import SciMLBase: SciMLBase, value import ReverseDiff import ArrayInterface # `ReverseDiff.TrackedArray` -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, args...; kwargs...) - ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, p::ReverseDiff.TrackedArray, args...; kwargs...) - ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::ReverseDiff.TrackedArray, p, args...; kwargs...) - ReverseDiff.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end # `AbstractArray{<:ReverseDiff.TrackedReal}` -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; kwargs...) - SciMLBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), + NonlinearSolveBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), ArrayInterface.aos_to_soa(p), args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; kwargs...) - SciMLBase.solve_up( + NonlinearSolveBase.solve_up( prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::ReverseDiff.TrackedArray, p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; kwargs...) - SciMLBase.solve_up( + NonlinearSolveBase.solve_up( prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, - u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, - args...; kwargs...) - SciMLBase.solve_up( - prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) -end +# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, +# sensealg::Union{ +# SciMLBase.AbstractOverloadingSensitivityAlgorithm, +# Nothing}, +# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, +# args...; kwargs...) +# NonlinearSolveBase.solve_up( +# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) +# end -function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, - u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, - args...; kwargs...) - SciMLBase.solve_up( - prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) -end +# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, +# sensealg::Union{ +# SciMLBase.AbstractOverloadingSensitivityAlgorithm, +# Nothing}, +# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, +# args...; kwargs...) +# NonlinearSolveBase.solve_up( +# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) +# end # Required becase ReverseDiff.@grad function SciMLBase.solve_up is not supported! -import SciMLBase: solve_up +import NonlinearSolveBase: solve_up ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) - out = SciMLBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), + out = NonlinearSolveBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p), SciMLBase.ReverseDiffOriginator(), args...; kwargs...) function actual_adjoint(_args...) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl index 8fe2a10be..dd73531c9 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl @@ -1,34 +1,34 @@ module NonlinearSolveBaseTrackerExt using NonlinearSolveBase -import SciMLBase: value +import SciMLBase: SciMLBase, value import Tracker -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::Tracker.TrackedArray, p::Tracker.TrackedArray, args...; kwargs...) - Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) + Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0::Tracker.TrackedArray, p, args...; kwargs...) - Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) + Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -function NonlinearSolveBase.solve_up(prob::SciMLBase.AbstractDEProblem, +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, sensealg::Union{ SciMLBase.AbstractOverloadingSensitivityAlgorithm, Nothing}, u0, p::Tracker.TrackedArray, args...; kwargs...) - Tracker.track(SciMLBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) + Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end -Tracker.@grad function SciMLBase.solve_up(prob, +Tracker.@grad function NonlinearSolveBase.solve_up(prob, sensealg::Union{Nothing, SciMLBase.AbstractOverloadingSensitivityAlgorithm }, From 7ec5a826528243f8e886db593d4099a522dd5255 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 4 Sep 2025 15:05:29 -0400 Subject: [PATCH 39/61] add adjoints test item --- test/adjoint_tests.jl | 50 +++++++++++++++++++++++++++++++------------ test/runtests.jl | 3 +++ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl index ddfabed45..6ea4ace08 100644 --- a/test/adjoint_tests.jl +++ b/test/adjoint_tests.jl @@ -1,19 +1,41 @@ -using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test +@testitem "Adjoint Tests" tags = [:adjoint] begin + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test -ff(u, p) = u .^ 2 .- p + ff(u, p) = u .^ 2 .- p -function solve_nlprob(p) - prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) - sol = solve(prob, NewtonRaphson()) - res = sol isa AbstractArray ? sol : sol.u - return sum(abs2, res) + function solve_nlprob(p) + prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) + sol = solve(prob, NewtonRaphson()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) + end + + p = [3.0, 2.0] + + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) + ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) + ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) + ∂p_enzyme = Enzyme.gradient(Enzyme.Reverse, solve_nlprob, p)[1] + @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme + @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme end -p = [3.0, 2.0] +@testitem "Simple Adjoint Test" tags=[:adjoint] begin + using ForwardDiff, Zygote, BracketingNonlinearSolve + + ff(u, p) = u^2 .- p[1] -∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) -∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) -∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) -∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) -@test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff -@test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff \ No newline at end of file + function solve_nlprob(p) + prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p) + sol = solve(prob, Bisection()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) + end + + p = [2.0, 2.0] + + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) + @test ∂p_zygote ≈ ∂p_forwarddiff +end diff --git a/test/runtests.jl b/test/runtests.jl index c946e93ef..90e3b7273 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,9 @@ if GROUP == "all" || GROUP == "cuda" push!(EXTRA_PKGS, Pkg.PackageSpec("CUDA")) end end + +(GROUP == "all" || GROUP == "adjoint") && Pkg.add(["SciMLSensitivity"]) + length(EXTRA_PKGS) ≥ 1 && Pkg.add(EXTRA_PKGS) # Use sequential execution for wrapper tests to avoid parallel initialization issues From 34937a4ce236bbe5a821f72cab3fab8dc1f33748 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 4 Sep 2025 15:38:56 -0400 Subject: [PATCH 40/61] fix up adjoint tests --- Project.toml | 6 +++++- test/adjoint_tests.jl | 21 +-------------------- test/runtests.jl | 2 -- 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index c045fd2e9..133e1305c 100644 --- a/Project.toml +++ b/Project.toml @@ -137,6 +137,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" @@ -154,6 +155,8 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -162,7 +165,8 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme"] diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl index 6ea4ace08..5e9f84f31 100644 --- a/test/adjoint_tests.jl +++ b/test/adjoint_tests.jl @@ -1,5 +1,5 @@ @testitem "Adjoint Tests" tags = [:adjoint] begin - using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme ff(u, p) = u .^ 2 .- p @@ -20,22 +20,3 @@ @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme end - -@testitem "Simple Adjoint Test" tags=[:adjoint] begin - using ForwardDiff, Zygote, BracketingNonlinearSolve - - ff(u, p) = u^2 .- p[1] - - function solve_nlprob(p) - prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p) - sol = solve(prob, Bisection()) - res = sol isa AbstractArray ? sol : sol.u - return sum(abs2, res) - end - - p = [2.0, 2.0] - - ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) - ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) - @test ∂p_zygote ≈ ∂p_forwarddiff -end diff --git a/test/runtests.jl b/test/runtests.jl index 90e3b7273..75a9e8ffc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,8 +42,6 @@ if GROUP == "all" || GROUP == "cuda" end end -(GROUP == "all" || GROUP == "adjoint") && Pkg.add(["SciMLSensitivity"]) - length(EXTRA_PKGS) ≥ 1 && Pkg.add(EXTRA_PKGS) # Use sequential execution for wrapper tests to avoid parallel initialization issues From 50a395564d2ff037886a2e289fde544db0ed478d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 4 Sep 2025 15:45:58 -0400 Subject: [PATCH 41/61] fix compat bounds --- Project.toml | 1 + lib/NonlinearSolveBase/Project.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 133e1305c..6db2a2d82 100644 --- a/Project.toml +++ b/Project.toml @@ -85,6 +85,7 @@ ConcreteStructs = "0.2.3" DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" ExplicitImports = "1.5" +Enzyme = "0.13.12" FastClosures = "0.3.2" FastLevenbergMarquardt = "0.1" FiniteDiff = "2.24" diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 3c7965fce..262ca7beb 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -67,6 +67,7 @@ ConcreteStructs = "0.2.3" DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" EnzymeCore = "0.8" +Enzyme = "0.13.12" ExplicitImports = "1.10.1" FastClosures = "0.3" ForwardDiff = "0.10.36, 1" From e8b6aea89a1be070bea6522df044b4a69ee4a5fa Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 5 Sep 2025 12:05:46 -0400 Subject: [PATCH 42/61] add MooncakeExt --- lib/NonlinearSolveBase/Project.toml | 3 +++ lib/NonlinearSolveBase/src/solve.jl | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 262ca7beb..7e2d4a14b 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -36,6 +36,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -48,6 +49,7 @@ NonlinearSolveBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" +NonlinearSolveBaseMooncakeExt = "Mooncake" NonlinearSolveBaseReverseDiffExt = "ReverseDiff" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" @@ -77,6 +79,7 @@ LinearAlgebra = "1.10" LinearSolve = "3.15" Markdown = "1.10" MaybeInplace = "0.1.4" +Mooncake = "0.4" Preferences = "1.4" Printf = "1.10" RecursiveArrayTools = "3" diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 020a04da6..b4d8c13fb 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -794,11 +794,11 @@ end function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) - if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, true; u0 = u0, + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, p = p, kwargs...) else - _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) end if has_kwargs(_prob) From 0671fe303ff29e310096d02a01d5d8137bb41b85 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 5 Sep 2025 13:37:30 -0400 Subject: [PATCH 43/61] add Mooncake to adjoint tests --- Project.toml | 4 +++- .../ext/NonlinearSolveBaseMooncakeExt.jl | 2 +- lib/NonlinearSolveBase/src/solve.jl | 4 ++-- test/adjoint_tests.jl | 11 ++++++++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 6db2a2d82..a720cb41d 100644 --- a/Project.toml +++ b/Project.toml @@ -99,6 +99,7 @@ LineSearches = "7.3" LinearAlgebra = "1.10" LinearSolve = "2.36.1, 3" MINPACK = "1.2" +Mooncake = "0.4" MPI = "0.20.22" NLSolvers = "0.5" NLsolve = "4.5" @@ -146,6 +147,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" @@ -170,4 +172,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme", "Mooncake"] diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl index 3737a16cc..91b901099 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl @@ -1,7 +1,7 @@ module NonlinearSolveBaseMooncakeExt using NonlinearSolveBase, Mooncake -using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator +using SciMLBase: SciMLBase import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, @from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx, NoPullback diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index b4d8c13fb..179d7d7d8 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -795,7 +795,7 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) else _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) @@ -817,7 +817,7 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling - _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) else _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl index 5e9f84f31..cbde35ebe 100644 --- a/test/adjoint_tests.jl +++ b/test/adjoint_tests.jl @@ -1,5 +1,5 @@ -@testitem "Adjoint Tests" tags = [:adjoint] begin - using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme +@testitem "Adjoint Tests" tags = [:nopre] begin + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake ff(u, p) = u .^ 2 .- p @@ -17,6 +17,11 @@ ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) ∂p_enzyme = Enzyme.gradient(Enzyme.Reverse, solve_nlprob, p)[1] - @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme + + cache = Mooncake.prepare_gradient_cache(solve_nlprob, p) + ∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2] + + @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme + @test_broken ∂p_forwarddiff ≈ ∂p_mooncake end From bba01be84cef18f9092161ba8f3759251d7afe3b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 5 Sep 2025 16:47:47 -0400 Subject: [PATCH 44/61] set runtime activity for enzyme --- test/adjoint_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl index cbde35ebe..8882c1916 100644 --- a/test/adjoint_tests.jl +++ b/test/adjoint_tests.jl @@ -16,7 +16,7 @@ ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) - ∂p_enzyme = Enzyme.gradient(Enzyme.Reverse, solve_nlprob, p)[1] + ∂p_enzyme = Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), solve_nlprob, p)[1] cache = Mooncake.prepare_gradient_cache(solve_nlprob, p) ∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2] From 08087ce1f44c55cb4b1c650acb29735affc47de5 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 5 Sep 2025 16:48:41 -0400 Subject: [PATCH 45/61] move from Union to AbstractNonlinearProblem --- lib/NonlinearSolveBase/src/solve.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 179d7d7d8..0f9e410e3 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -389,7 +389,8 @@ function init( init_up(prob, sensealg, u0, p, args...; kwargs...) end -function init_up(prob::AbstractNonlinearProblem, sensealg, u0, p, args...; kwargs...) +function init_up(prob::AbstractNonlinearProblem, + sensealg, u0, p, args...; kwargs...) alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, true; u0 = u0, From 78ebecd51d94ca6f43bed8e3c05c643425ff731f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 5 Sep 2025 17:37:45 -0400 Subject: [PATCH 46/61] make sure Mooncake and SciMLSensitivity are only on nopre --- Project.toml | 44 ++++++++++++++++++++------------------------ test/runtests.jl | 2 ++ 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index a720cb41d..66051c1e5 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -43,24 +44,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" -[sources.BracketingNonlinearSolve] -path = "lib/BracketingNonlinearSolve" - -[sources.NonlinearSolveBase] -path = "lib/NonlinearSolveBase" - -[sources.NonlinearSolveFirstOrder] -path = "lib/NonlinearSolveFirstOrder" - -[sources.NonlinearSolveQuasiNewton] -path = "lib/NonlinearSolveQuasiNewton" - -[sources.NonlinearSolveSpectralMethods] -path = "lib/NonlinearSolveSpectralMethods" - -[sources.SimpleNonlinearSolve] -path = "lib/SimpleNonlinearSolve" - [extensions] NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration" @@ -85,7 +68,6 @@ ConcreteStructs = "0.2.3" DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" ExplicitImports = "1.5" -Enzyme = "0.13.12" FastClosures = "0.3.2" FastLevenbergMarquardt = "0.1" FiniteDiff = "2.24" @@ -99,7 +81,6 @@ LineSearches = "7.3" LinearAlgebra = "1.10" LinearSolve = "2.36.1, 3" MINPACK = "1.2" -Mooncake = "0.4" MPI = "0.20.22" NLSolvers = "0.5" NLsolve = "4.5" @@ -139,7 +120,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" @@ -147,7 +127,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" @@ -159,7 +138,6 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -171,5 +149,23 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources.BracketingNonlinearSolve] +path = "lib/BracketingNonlinearSolve" + +[sources.NonlinearSolveBase] +path = "lib/NonlinearSolveBase" + +[sources.NonlinearSolveFirstOrder] +path = "lib/NonlinearSolveFirstOrder" + +[sources.NonlinearSolveQuasiNewton] +path = "lib/NonlinearSolveQuasiNewton" + +[sources.NonlinearSolveSpectralMethods] +path = "lib/NonlinearSolveSpectralMethods" + +[sources.SimpleNonlinearSolve] +path = "lib/SimpleNonlinearSolve" + [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme", "Mooncake"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker"] diff --git a/test/runtests.jl b/test/runtests.jl index 75a9e8ffc..3ed6985dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,8 @@ if GROUP == "all" || GROUP == "nopre" # Only add Enzyme for nopre group if not on prerelease Julia if isempty(VERSION.prerelease) push!(EXTRA_PKGS, Pkg.PackageSpec("Enzyme")) + push!(EXTRA_PKGS, Pkg.PackageSpec("Mooncake")) + push!(EXTRA_PKGS, Pkg.PackageSpec("SciMLSensitivity")) end end if GROUP == "all" || GROUP == "cuda" From 74b766f0abfcb8876a96c024e6c442fc9788d052 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 09:43:22 -0400 Subject: [PATCH 47/61] remove diffeqbase as dep --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 66051c1e5..b526bc0b9 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" From 692123617b15a9c4e05337edf7bb4adf3e988e9d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 09:46:30 -0400 Subject: [PATCH 48/61] fixes to deps, remove SimpleNonlinearSolveDiffEqBaseExt --- Project.toml | 1 - lib/SimpleNonlinearSolve/Project.toml | 2 -- .../ext/SimpleNonlinearSolveDiffEqBaseExt.jl | 13 ------------- 3 files changed, 16 deletions(-) delete mode 100644 lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl diff --git a/Project.toml b/Project.toml index b526bc0b9..c78fb2c6b 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,6 @@ BenchmarkTools = "1.4" BracketingNonlinearSolve = "1" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" ExplicitImports = "1.5" FastClosures = "0.3.2" diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index e73fbfc42..a2dcd7b93 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -25,7 +25,6 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -37,7 +36,6 @@ path = "../NonlinearSolveBase" [extensions] SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" -SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase" SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" SimpleNonlinearSolveTrackerExt = "Tracker" diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl deleted file mode 100644 index 70bd1b9fd..000000000 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module SimpleNonlinearSolveDiffEqBaseExt - -#using DiffEqBase: DiffEqBase - -# using SimpleNonlinearSolve: SimpleNonlinearSolve - -# SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true - -# function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...) -# return DiffEqBase._solve_adjoint(args...; kwargs...) -# end - -end From aa02f756f53407385e8524e1dd97a61ca4264d00 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 10:06:13 -0400 Subject: [PATCH 49/61] fix Project.toml --- lib/NonlinearSolveBase/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 7e2d4a14b..a3a8584b3 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,6 +1,5 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" -version = "1.14.0" authors = ["Avik Pal and contributors"] version = "1.14.1" From 42376cc5a944be0f256ee6fc8bda75c9d4d3aa9b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 10:19:50 -0400 Subject: [PATCH 50/61] another Project fix --- lib/NonlinearSolveQuasiNewton/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 1562f1d96..02e7829d6 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -1,6 +1,5 @@ name = "NonlinearSolveQuasiNewton" uuid = "9a2c21bd-3a47-402d-9113-8faf9a0ee114" -version = "1.8.0" authors = ["Avik Pal and contributors"] version = "1.8.1" From 0065f975796b9d771d5157be6a5647349b15d9f2 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 10:54:55 -0400 Subject: [PATCH 51/61] add compat bounds for ReverseDiff --- Project.toml | 1 + lib/NonlinearSolveBase/Project.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index c78fb2c6b..ab42049ff 100644 --- a/Project.toml +++ b/Project.toml @@ -97,6 +97,7 @@ Preferences = "1.4.3" Random = "1.10" ReTestItems = "1.24" Reexport = "1.2.2" +ReverseDiff = "1.15" SIAMFANLEquations = "1.0.1" SciMLBase = "2.116" SimpleNonlinearSolve = "2.1" diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index a3a8584b3..6bc86e3e8 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -82,6 +82,7 @@ Mooncake = "0.4" Preferences = "1.4" Printf = "1.10" RecursiveArrayTools = "3" +ReverseDiff = "1.15" SciMLBase = "2.116" SciMLJacobianOperators = "0.1.1" SciMLOperators = "1.7" From 2c6d1b939b3c3eda1e8ea8e2df01791cb09d520d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 11:16:31 -0400 Subject: [PATCH 52/61] use sources to test DiffEqBase without solve for NonlinearProblem --- Project.toml | 21 ++++----------------- lib/NonlinearSolveBase/Project.toml | 3 +++ lib/SimpleNonlinearSolve/Project.toml | 3 +++ 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index ab42049ff..06a6ab486 100644 --- a/Project.toml +++ b/Project.toml @@ -65,6 +65,7 @@ BracketingNonlinearSolve = "1" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DifferentiationInterface = "0.7.3" +DiffEqBase = "6.188" ExplicitImports = "1.5" FastClosures = "0.3.2" FastLevenbergMarquardt = "0.1" @@ -118,6 +119,7 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" @@ -148,23 +150,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -[sources.BracketingNonlinearSolve] -path = "lib/BracketingNonlinearSolve" - -[sources.NonlinearSolveBase] -path = "lib/NonlinearSolveBase" - -[sources.NonlinearSolveFirstOrder] -path = "lib/NonlinearSolveFirstOrder" - -[sources.NonlinearSolveQuasiNewton] -path = "lib/NonlinearSolveQuasiNewton" - -[sources.NonlinearSolveSpectralMethods] -path = "lib/NonlinearSolveSpectralMethods" - -[sources.SimpleNonlinearSolve] -path = "lib/SimpleNonlinearSolve" +[sources] +DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} [targets] test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker"] diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 6bc86e3e8..cd91eb4ca 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -97,6 +97,9 @@ Tracker = "0.2.35" TimerOutputs = "0.5.23" julia = "1.10" +[sources] +DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index a2dcd7b93..975ac7135 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -76,6 +76,9 @@ Tracker = "0.2.35" Zygote = "0.6.70, 0.7" julia = "1.10" +[sources] +DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" From dfe1c75c22389388dd937e80131c4b16c3c03f10 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 11:33:13 -0400 Subject: [PATCH 53/61] add DiffEqBase to test targets --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 06a6ab486..7b1149efb 100644 --- a/Project.toml +++ b/Project.toml @@ -154,4 +154,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "DiffEqBase"] From bc97510ec67ebd9c54d74f3856e2c557e51bdfe9 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 12:06:12 -0400 Subject: [PATCH 54/61] add other sources back --- Project.toml | 18 ++++++++++++++++++ lib/NonlinearSolveBase/src/solve.jl | 1 - 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7b1149efb..5959ed99a 100644 --- a/Project.toml +++ b/Project.toml @@ -150,6 +150,24 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources.BracketingNonlinearSolve] +path = "lib/BracketingNonlinearSolve" + +[sources.NonlinearSolveBase] +path = "lib/NonlinearSolveBase" + +[sources.NonlinearSolveFirstOrder] +path = "lib/NonlinearSolveFirstOrder" + +[sources.NonlinearSolveQuasiNewton] +path = "lib/NonlinearSolveQuasiNewton" + +[sources.NonlinearSolveSpectralMethods] +path = "lib/NonlinearSolveSpectralMethods" + +[sources.SimpleNonlinearSolve] +path = "lib/SimpleNonlinearSolve" + [sources] DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 0f9e410e3..b67454933 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -836,7 +836,6 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end end - function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) oldprob = prob prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) From 568e9a5a021dd70f30b8c9a0be59e3a92ea841f8 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 13:12:13 -0400 Subject: [PATCH 55/61] fix compat --- lib/NonlinearSolveFirstOrder/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index d066c3794..6f5a2f0f4 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -34,6 +34,7 @@ BandedMatrices = "1.7.5" BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" +DifferentiationInterface = "0.7.3" Enzyme = "0.13.12" ExplicitImports = "1.5" FiniteDiff = "2.24" From 14689ced98e1fcf65ca13b60ff126aee24f21b4e Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 8 Sep 2025 17:42:01 -0400 Subject: [PATCH 56/61] clean up --- .../ext/NonlinearSolveBaseDiffEqBaseExt.jl | 3 --- .../src/NonlinearSolveFirstOrder.jl | 1 - .../src/NonlinearSolveQuasiNewton.jl | 1 - .../src/NonlinearSolveSpectralMethods.jl | 1 - lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl | 9 --------- src/NonlinearSolve.jl | 1 - 6 files changed, 16 deletions(-) delete mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl deleted file mode 100644 index c5dbb9aec..000000000 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl +++ /dev/null @@ -1,3 +0,0 @@ -module NonlinearSolveBaseDiffEqBaseExt - -end diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 79bc2faac..0a6c009c6 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -12,7 +12,6 @@ using LineSearch: BackTracking using StaticArraysCore: SArray using CommonSolve: CommonSolve -#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl index fd55ca034..a7f3d93c3 100644 --- a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl +++ b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl @@ -8,7 +8,6 @@ using ArrayInterface: ArrayInterface using StaticArraysCore: StaticArray, Size, MArray using CommonSolve: CommonSolve -#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra, Diagonal, dot, diag using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb diff --git a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl index c0a6bf2e9..2706d5670 100644 --- a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl +++ b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl @@ -5,7 +5,6 @@ using Reexport: @reexport using PrecompileTools: @compile_workload, @setup_workload using CommonSolve: CommonSolve -#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LineSearch: RobustNonMonotoneLineSearch using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 9f0e1c46f..782de6468 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -124,15 +124,6 @@ function simplenonlinearsolve_solve_up( return SciMLBase.__solve(prob, alg, args...; kwargs...) end -# NOTE: This is defined like this so that we don't have to keep have 2 args for the -# extensions -# function solve_adjoint(args...; kws...) -# is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...) -# error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") -# end - -# function solve_adjoint_internal end - @setup_workload begin for T in (Float64,) prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index a4eb06042..d11f99749 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,7 +8,6 @@ using FastClosures: @closure using ADTypes: ADTypes using ArrayInterface: ArrayInterface using CommonSolve: CommonSolve, init, solve, solve! -#using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, From 069605de437b65934319f9c46a129dffed46221b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 9 Sep 2025 12:09:55 -0400 Subject: [PATCH 57/61] remove large comment --- lib/NonlinearSolveBase/src/solve.jl | 213 ---------------------------- 1 file changed, 213 deletions(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index b67454933..0fe813dbd 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -1,216 +1,3 @@ -# const allowedkeywords = (:dense, -# :saveat, -# :save_idxs, -# :tstops, -# :tspan, -# :d_discontinuities, -# :save_everystep, -# :save_on, -# :save_start, -# :save_end, -# :initialize_save, -# :adaptive, -# :abstol, -# :reltol, -# :dt, -# :dtmax, -# :dtmin, -# :force_dtmin, -# :internalnorm, -# :controller, -# :gamma, -# :beta1, -# :beta2, -# :qmax, -# :qmin, -# :qsteady_min, -# :qsteady_max, -# :qoldinit, -# :failfactor, -# :calck, -# :alias_u0, -# :maxiters, -# :maxtime, -# :callback, -# :isoutofdomain, -# :unstable_check, -# :verbose, -# :merge_callbacks, -# :progress, -# :progress_steps, -# :progress_name, -# :progress_message, -# :progress_id, -# :timeseries_errors, -# :dense_errors, -# :weak_timeseries_errors, -# :weak_dense_errors, -# :wrap, -# :calculate_error, -# :initializealg, -# :alg, -# :save_noise, -# :delta, -# :seed, -# :alg_hints, -# :kwargshandle, -# :trajectories, -# :batch_size, -# :sensealg, -# :advance_to_tstop, -# :stop_at_next_tstop, -# :u0, -# :p, -# # These two are from the default algorithm handling -# :default_set, -# :second_time, -# # This is for DiffEqDevTools -# :prob_choice, -# # Jump problems -# :alias_jump, -# # This is for copying/deepcopying noise in StochasticDiffEq -# :alias_noise, -# # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves -# :batch, -# # Shooting method in BVP needs to differentiate between these two categories -# :nlsolve_kwargs, -# :odesolve_kwargs, -# # If Solvers which internally use linsolve -# :linsolve_kwargs, -# # Solvers internally using EnsembleProblem -# :ensemblealg, -# # Fine Grained Control of Tracing (Storing and Logging) during Solve -# :show_trace, -# :trace_level, -# :store_trace, -# # Termination condition for solvers -# :termination_condition, -# # For AbstractAliasSpecifier -# :alias, -# # Parameter estimation with BVP -# :fit_parameters) - -# const KWARGWARN_MESSAGE = """ -# Unrecognized keyword arguments found. -# The only allowed keyword arguments to `solve` are: -# $allowedkeywords - -# See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. - -# Set kwargshandle=KeywordArgError for an error message. -# Set kwargshandle=KeywordArgSilent to ignore this message. -# """ - -# const KWARGERROR_MESSAGE = """ -# Unrecognized keyword arguments found. -# The only allowed keyword arguments to `solve` are: -# $allowedkeywords - -# See https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/ for more details. -# """ - -# struct CommonKwargError <: Exception -# kwargs::Any -# end - -# function Base.showerror(io::IO, e::CommonKwargError) -# println(io, KWARGERROR_MESSAGE) -# notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) -# unrecognized = collect(keys(e.kwargs))[notin] -# print(io, "Unrecognized keyword arguments: ") -# printstyled(io, unrecognized; bold = true, color = :red) -# print(io, "\n\n") -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# @enum KeywordArgError KeywordArgWarn KeywordArgSilent - -# const INCOMPATIBLE_U0_MESSAGE = """ -# Initial condition incompatible with functional form. -# Detected an in-place function with an initial condition of type Number or SArray. -# This is incompatible because Numbers cannot be mutated, i.e. -# `x = 2.0; y = 2.0; x .= y` will error. - -# If using a immutable initial condition type, please use the out-of-place form. -# I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. - -# If your differential equation function was defined with multiple dispatches and one is -# in-place, then the automatic detection will choose in-place. In this case, override the -# choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. - -# For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: -# https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation -# """ - -# struct IncompatibleInitialConditionError <: Exception end - -# function Base.showerror(io::IO, e::IncompatibleInitialConditionError) -# print(io, INCOMPATIBLE_U0_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NO_DEFAULT_ALGORITHM_MESSAGE = """ -# Default algorithm choices require NonlinearSolve.jl. -# Please specify an algorithm (e.g., `solve(prob, NewtonRaphson())` or -# init(prob, NewtonRaphson()) or -# import NonlinearSolve.jl directly. - -# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ -# and its associated pages. -# """ - -# struct NoDefaultAlgorithmError <: Exception end - -# function Base.showerror(io::IO, e::NoDefaultAlgorithmError) -# print(io, NO_DEFAULT_ALGORITHM_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NON_SOLVER_MESSAGE = """ -# The arguments to solve are incorrect. -# The second argument must be a solver choice, `solve(prob,alg)` -# where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. - -# Please double check the arguments being sent to the solver. - -# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ -# and its associated pages. -# """ - -# struct NonSolverError <: Exception end - -# function Base.showerror(io::IO, e::NonSolverError) -# print(io, NON_SOLVER_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ -# Incompatible solver + automatic differentiation pairing. -# The chosen automatic differentiation algorithm requires the ability -# for compiler transforms on the code which is only possible on pure-Julia -# solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods -# which require this ability include: - -# - Direct use of ForwardDiff.jl on the solver -# - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` -# sensealg choices for adjoint differentiation. - -# Either switch the choice of solver to a pure Julia method, or change the automatic -# differentiation method to one that does not require such transformations. - -# For more details on automatic differentiation, adjoint, and sensitivity analysis -# of differential equations, see the documentation page: - -# https://diffeq.sciml.ai/stable/analysis/sensitivity/ -# """ - -# struct DirectAutodiffError <: Exception end - -# function Base.showerror(io::IO, e::DirectAutodiffError) -# println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - struct EvalFunc{F} <: Function f::F end From 2e7eec0d48976b52d3fb1be8beb2c6f549ed22bc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 9 Sep 2025 14:11:52 -0400 Subject: [PATCH 58/61] bump NonlinearSolveBase, and compat --- Project.toml | 2 +- lib/NonlinearSolveBase/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5959ed99a..1a5a4700c 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ NLSolvers = "0.5" NLsolve = "4.5" NaNMath = "1" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" NonlinearSolveFirstOrder = "1.2" NonlinearSolveQuasiNewton = "1.8" NonlinearSolveSpectralMethods = "1.1" diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index cd91eb4ca..312df9ec4 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" authors = ["Avik Pal and contributors"] -version = "1.14.1" +version = "1.15.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 968a976186e984cf4c50380b92164689d7c53ae0 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Tue, 9 Sep 2025 15:05:50 -0400 Subject: [PATCH 59/61] bump compats for NonlinearSolveBase --- docs/Project.toml | 2 +- lib/BracketingNonlinearSolve/Project.toml | 2 +- lib/NonlinearSolveFirstOrder/Project.toml | 2 +- lib/NonlinearSolveHomotopyContinuation/Project.toml | 2 +- lib/NonlinearSolveQuasiNewton/Project.toml | 2 +- lib/NonlinearSolveSciPy/Project.toml | 2 +- lib/NonlinearSolveSpectralMethods/Project.toml | 2 +- lib/SCCNonlinearSolve/Project.toml | 2 +- lib/SimpleNonlinearSolve/Project.toml | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 2ecbbf3cc..f1c875de5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -79,7 +79,7 @@ InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1" LinearSolve = "2, 3" NonlinearSolve = "4" -NonlinearSolveBase = "1" +NonlinearSolveBase = "1.15" NonlinearSolveFirstOrder = "1" NonlinearSolveHomotopyContinuation = "0.1" NonlinearSolveQuasiNewton = "1" diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index a4976d168..bac48ac19 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -30,7 +30,7 @@ ConcreteStructs = "0.2.3" ExplicitImports = "1.10.1" ForwardDiff = "0.10.36, 1" InteractiveUtils = "<0.0.1, 1" -NonlinearSolveBase = "1.1" +NonlinearSolveBase = "1.15" PrecompileTools = "1.2" Reexport = "1.2.2" SciMLBase = "2.116" diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index 6f5a2f0f4..72059a11b 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -47,7 +47,7 @@ LinearAlgebra = "1.10" LinearSolve = "2.36.1, 3" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PrecompileTools = "1.2" Random = "1.10" diff --git a/lib/NonlinearSolveHomotopyContinuation/Project.toml b/lib/NonlinearSolveHomotopyContinuation/Project.toml index f011bb485..21a839d47 100644 --- a/lib/NonlinearSolveHomotopyContinuation/Project.toml +++ b/lib/NonlinearSolveHomotopyContinuation/Project.toml @@ -31,7 +31,7 @@ HomotopyContinuation = "2.12.0" LinearAlgebra = "1.10" NaNMath = "1.1" NonlinearSolve = "4.10" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" SciMLBase = "2.116" SymbolicIndexingInterface = "0.3.43" TaylorDiff = "0.3.1" diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 02e7829d6..cb2eaa8f8 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -45,7 +45,7 @@ LinearAlgebra = "1.10" LinearSolve = "2.36.1, 3" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PrecompileTools = "1.2" ReTestItems = "1.24" diff --git a/lib/NonlinearSolveSciPy/Project.toml b/lib/NonlinearSolveSciPy/Project.toml index 2745b8006..4726d30d7 100644 --- a/lib/NonlinearSolveSciPy/Project.toml +++ b/lib/NonlinearSolveSciPy/Project.toml @@ -18,7 +18,7 @@ path = "../NonlinearSolveBase" ConcreteStructs = "0.2.3" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" PrecompileTools = "1.2" PythonCall = "0.9" ReTestItems = "1.24" diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index 378bf65f0..bc58d7030 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -34,7 +34,7 @@ InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1.4" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PrecompileTools = "1.2" ReTestItems = "1.24" diff --git a/lib/SCCNonlinearSolve/Project.toml b/lib/SCCNonlinearSolve/Project.toml index 1d39e5fed..fa3e615f1 100644 --- a/lib/SCCNonlinearSolve/Project.toml +++ b/lib/SCCNonlinearSolve/Project.toml @@ -19,7 +19,7 @@ Hwloc = "3" InteractiveUtils = "<0.0.1, 1" NonlinearProblemLibrary = "0.1.2" NonlinearSolve = "4.8" -NonlinearSolveBase = "1.5.1" +NonlinearSolveBase = "1.15" NonlinearSolveFirstOrder = "1" Pkg = "1.10" PrecompileTools = "1.2" diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 975ac7135..e6f1d38f7 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -59,7 +59,7 @@ LineSearch = "0.1.3" LinearAlgebra = "1.10" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PolyesterForwardDiff = "0.1.3" PrecompileTools = "1.2" From 446262a938e3f864b6f1da84f9f961d0be7cbf76 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 10 Sep 2025 11:40:01 -0400 Subject: [PATCH 60/61] remove more of DiffEqBase --- Project.toml | 7 +------ lib/NonlinearSolveBase/Project.toml | 7 +------ lib/NonlinearSolveBase/test/runtests.jl | 3 +-- lib/SimpleNonlinearSolve/Project.toml | 7 +------ 4 files changed, 4 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 1a5a4700c..7772ea7d5 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,6 @@ BracketingNonlinearSolve = "1" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DifferentiationInterface = "0.7.3" -DiffEqBase = "6.188" ExplicitImports = "1.5" FastClosures = "0.3.2" FastLevenbergMarquardt = "0.1" @@ -119,7 +118,6 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" @@ -168,8 +166,5 @@ path = "lib/NonlinearSolveSpectralMethods" [sources.SimpleNonlinearSolve] path = "lib/SimpleNonlinearSolve" -[sources] -DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} - [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "DiffEqBase"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker"] diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 312df9ec4..fe0f84721 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -65,7 +65,6 @@ ChainRulesCore = "1" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" EnzymeCore = "0.8" Enzyme = "0.13.12" @@ -97,13 +96,9 @@ Tracker = "0.2.35" TimerOutputs = "0.5.23" julia = "1.10" -[sources] -DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} - [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -112,4 +107,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BandedMatrices", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"] +test = ["Aqua", "BandedMatrices", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"] diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index d9f702347..86eb95730 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -13,13 +13,12 @@ using InteractiveUtils, Test NonlinearSolveBase; piracies = false, ambiguities = false, stale_deps = false ) Aqua.test_stale_deps(NonlinearSolveBase; ignore = [:TimerOutputs]) - #ENSEMBLE PROBLEM SHOULD BE REMOVED, THIS IS TEMPORARY FOR TESTS Aqua.test_piracies(NonlinearSolveBase, treat_as_own = [AbstractNonlinearProblem, NonlinearProblem]) Aqua.test_ambiguities(NonlinearSolveBase; recursive = false) end @testset "Explicit Imports" begin - import ForwardDiff, SparseArrays, DiffEqBase + import ForwardDiff, SparseArrays using ExplicitImports, NonlinearSolveBase @test check_no_implicit_imports(NonlinearSolveBase; skip = (Base, Core)) === nothing diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index e6f1d38f7..f0cbbc99c 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -47,7 +47,6 @@ BracketingNonlinearSolve = "1.1" ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" Enzyme = "0.13.11" ExplicitImports = "1.9" @@ -76,12 +75,8 @@ Tracker = "0.2.35" Zygote = "0.6.70, 0.7" julia = "1.10" -[sources] -DiffEqBase = {url = "https://github.com/jClugstor/DiffEqBase.jl", rev = "remove_nonlinear"} - [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -97,4 +92,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] +test = ["Aqua", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] From 8acbbf56e78bbffd6fe4b70d7bfdfa42b6b2bdd1 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 10 Sep 2025 14:32:57 -0400 Subject: [PATCH 61/61] remove DiffEqBase from makedocs --- docs/Project.toml | 2 -- docs/make.jl | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index f1c875de5..bd67a099f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,7 +4,6 @@ AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" @@ -69,7 +68,6 @@ AlgebraicMultigrid = "0.5, 0.6, 1" ArrayInterface = "6, 7" BenchmarkTools = "1" BracketingNonlinearSolve = "1" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" Documenter = "1" DocumenterCitations = "1" diff --git a/docs/make.jl b/docs/make.jl index cce131634..42cbb4b56 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,8 +1,7 @@ using Documenter, DocumenterCitations, DocumenterInterLinks -import DiffEqBase using Sundials -using NonlinearSolveBase, SciMLBase, DiffEqBase +using NonlinearSolveBase, SciMLBase using SimpleNonlinearSolve, BracketingNonlinearSolve using NonlinearSolveFirstOrder, NonlinearSolveQuasiNewton, NonlinearSolveSpectralMethods using NonlinearSolveHomotopyContinuation, NonlinearSolveSciPy @@ -33,7 +32,7 @@ makedocs(; sitename = "NonlinearSolve.jl", authors = "SciML", modules = [ - NonlinearSolveBase, SciMLBase, DiffEqBase, + NonlinearSolveBase, SciMLBase, SimpleNonlinearSolve, BracketingNonlinearSolve, NonlinearSolveFirstOrder, NonlinearSolveQuasiNewton, NonlinearSolveSpectralMethods, NonlinearSolveHomotopyContinuation,