diff --git a/Project.toml b/Project.toml index 5b4f97e815..b41ff59b22 100644 --- a/Project.toml +++ b/Project.toml @@ -150,7 +150,7 @@ StaticArrays = "0.10, 0.11, 0.12, 1.0" StochasticDelayDiffEq = "1.8.1" StochasticDiffEq = "6.72.1" SymbolicIndexingInterface = "0.3.37" -SymbolicUtils = "3.14" +SymbolicUtils = "3.25.1" Symbolics = "6.36" URIs = "1" UnPack = "0.1, 1.0" diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index fb692b4028..df07985f5a 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -535,7 +535,8 @@ end SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true function SymbolicIndexingInterface.observed( - sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true) + sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, + checkbounds = true, cse = true) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing if sym isa Symbol _sym = get(ic.symbol_to_variable, sym, nothing) @@ -559,7 +560,7 @@ function SymbolicIndexingInterface.observed( end end return build_explicit_observed_function( - sys, sym; eval_expression, eval_module, checkbounds) + sys, sym; eval_expression, eval_module, checkbounds, cse) end function SymbolicIndexingInterface.default_values(sys::AbstractSystem) @@ -1774,13 +1775,14 @@ struct ObservedFunctionCache{S} eval_expression::Bool eval_module::Module checkbounds::Bool + cse::Bool end function ObservedFunctionCache( sys; steady_state = false, eval_expression = false, - eval_module = @__MODULE__, checkbounds = true) + eval_module = @__MODULE__, checkbounds = true, cse = true) return ObservedFunctionCache( - sys, Dict(), steady_state, eval_expression, eval_module, checkbounds) + sys, Dict(), steady_state, eval_expression, eval_module, checkbounds, cse) end # This is hit because ensemble problems do a deepcopy @@ -1791,8 +1793,9 @@ function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict) eval_expression = ofc.eval_expression eval_module = ofc.eval_module checkbounds = ofc.checkbounds + cse = ofc.cse newofc = ObservedFunctionCache( - sys, dict, steady_state, eval_expression, eval_module, checkbounds) + sys, dict, steady_state, eval_expression, eval_module, checkbounds, cse) stackdict[ofc] = newofc return newofc end @@ -1801,7 +1804,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...) obs = get!(ofc.dict, value(obsvar)) do SymbolicIndexingInterface.observed( ofc.sys, obsvar; eval_expression = ofc.eval_expression, - eval_module = ofc.eval_module, checkbounds = ofc.checkbounds) + eval_module = ofc.eval_module, checkbounds = ofc.checkbounds, cse = ofc.cse) end if ofc.steady_state obs = let fn = obs diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 937264d083..07809bf611 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -699,7 +699,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no add_integrator_header(sys, integ, outvar), outputidxs = update_inds, create_bindings = false, - kwargs...) + kwargs..., cse = false) # applied user-provided function to the generated expression if postprocess_affect_expr! !== nothing postprocess_affect_expr!(rf_ip, integ) @@ -729,7 +729,7 @@ function generate_single_rootfinding_callback( end rf_oop, rf_ip = generate_custom_function( - sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...) + sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs..., cse = false) affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs) cond = function (u, t, integ) if DiffEqBase.isinplace(integ.sol.prob) @@ -780,7 +780,7 @@ function generate_vector_rootfinding_callback( rhss = map(x -> x.rhs, eqs) _, rf_ip = generate_custom_function( - sys, rhss, dvs, ps; expression = Val{false}, kwargs...) + sys, rhss, dvs, ps; expression = Val{false}, kwargs..., cse = false) affect_functions = @NamedTuple{ affect::Function, diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index 1eeb7e026b..a3fe53b95d 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -132,7 +132,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, wrap_delays = is_dde(sys), wrap_code = identity, add_observed = true, filter_observed = Returns(true), create_bindings = false, output_type = nothing, mkarray = nothing, - wrap_mtkparameters = true, extra_assignments = Assignment[], kwargs...) + wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true, kwargs...) isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic()) # filter observed equations obs = filter(filter_observed, observed(sys)) @@ -234,7 +234,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic() wrap_code = wrap_code[1] end - return build_function(expr, args...; wrap_code, similarto, kwargs...) + return build_function(expr, args...; wrap_code, similarto, cse, kwargs...) end """ diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 306aa59fc1..23f20b00ec 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -312,12 +312,13 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, analytic = nothing, split_idxs = nothing, initialization_data = nothing, + cse = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`") end f_gen = generate_function(sys, dvs, ps; expression = Val{true}, - expression_module = eval_module, checkbounds = checkbounds, + expression_module = eval_module, checkbounds = checkbounds, cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip) @@ -333,7 +334,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, tgrad_gen = generate_tgrad(sys, dvs, ps; simplify = simplify, expression = Val{true}, - expression_module = eval_module, + expression_module = eval_module, cse, checkbounds = checkbounds, kwargs...) tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module) _tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip) @@ -345,7 +346,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, jac_gen = generate_jacobian(sys, dvs, ps; simplify = simplify, sparse = sparse, expression = Val{true}, - expression_module = eval_module, + expression_module = eval_module, cse, checkbounds = checkbounds, kwargs...) jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module) @@ -365,7 +366,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, end observedfun = ObservedFunctionCache( - sys; steady_state, eval_expression, eval_module, checkbounds) + sys; steady_state, eval_expression, eval_module, checkbounds, cse) jac_prototype = if sparse uElType = u0 === nothing ? Float64 : eltype(u0) @@ -420,12 +421,13 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) eval_module = @__MODULE__, checkbounds = false, initialization_data = nothing, + cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`") end f_gen = generate_function(sys, dvs, ps; implicit_dae = true, - expression = Val{true}, + expression = Val{true}, cse, expression_module = eval_module, checkbounds = checkbounds, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) @@ -435,7 +437,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) jac_gen = generate_dae_jacobian(sys, dvs, ps; simplify = simplify, sparse = sparse, expression = Val{true}, - expression_module = eval_module, + expression_module = eval_module, cse, checkbounds = checkbounds, kwargs...) jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module) @@ -445,7 +447,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) end observedfun = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) jac_prototype = if sparse uElType = u0 === nothing ? Float64 : eltype(u0) @@ -479,6 +481,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) eval_module = @__MODULE__, checkbounds = false, initialization_data = nothing, + cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`") @@ -486,7 +489,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) f_gen = generate_function(sys, dvs, ps; isdde = true, expression = Val{true}, expression_module = eval_module, checkbounds = checkbounds, - kwargs...) + cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip) @@ -503,6 +506,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys eval_module = @__MODULE__, checkbounds = false, initialization_data = nothing, + cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`") @@ -510,12 +514,12 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys f_gen = generate_function(sys, dvs, ps; isdde = true, expression = Val{true}, expression_module = eval_module, checkbounds = checkbounds, - kwargs...) + cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip) g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true}, - isdde = true, kwargs...) + isdde = true, cse, kwargs...) g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module) g = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(g_oop, g_iip) @@ -841,6 +845,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [] warn_initialize_determined = true, eval_expression = false, eval_module = @__MODULE__, + cse = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`") @@ -864,12 +869,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [] _u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses)) f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, guesses, - check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...) + check_length, warn_initialize_determined, eval_expression, eval_module, cse, kwargs...) stidxmap = Dict([v => i for (i, v) in enumerate(sts)]) u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k, v) in u0map] - fns = generate_function_bc(sys, u0, u0_idxs, tspan) + fns = generate_function_bc(sys, u0, u0_idxs, tspan; cse) bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module) bc(sol, p, t) = bc_oop(sol, p, t) bc(resid, u, p, t) = bc_iip(resid, u, p, t) @@ -988,15 +993,16 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], eval_expression = false, eval_module = @__MODULE__, u0_constructor = identity, + cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`") end f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, - symbolic_u0 = true, u0_constructor, + symbolic_u0 = true, u0_constructor, cse, check_length, eval_expression, eval_module, kwargs...) - h_gen = generate_history(sys, u0; expression = Val{true}) + h_gen = generate_history(sys, u0; expression = Val{true}, cse) h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module) h = h_oop u0 = float.(h(p, tspan[1])) @@ -1027,6 +1033,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], eval_expression = false, eval_module = @__MODULE__, u0_constructor = identity, + cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`") @@ -1034,8 +1041,8 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, symbolic_u0 = true, eval_expression, eval_module, u0_constructor, - check_length, kwargs...) - h_gen = generate_history(sys, u0; expression = Val{true}) + check_length, cse, kwargs...) + h_gen = generate_history(sys, u0; expression = Val{true}, cse) h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module) h = h_oop u0 = h(p, tspan[1]) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index bd0a58e10a..57d71b18c8 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -454,8 +454,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys - `checkbounds = true` checks bounds if true when destructuring parameters - `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail. - `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist. -- `mkarray`; only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in -the array and `output_type` is the argument of the same name passed to build_explicit_observed_function. +- `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function. +- `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function. ## Returns @@ -493,6 +493,7 @@ function build_explicit_observed_function(sys, ts; param_only = false, op = Operator, throw = true, + cse = true, mkarray = nothing) is_tuple = ts isa Tuple if is_tuple @@ -579,7 +580,7 @@ function build_explicit_observed_function(sys, ts; p_end = length(dvs) + length(inputs) + length(ps) fns = build_function_wrapper( sys, ts, args...; p_start, p_end, filter_observed = obsfilter, - output_type, mkarray, try_namespaced = true, expression = Val{true}) + output_type, mkarray, try_namespaced = true, expression = Val{true}, cse) if fns isa Tuple if expression return return_inplace ? fns : fns[1] diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index f51529a559..66044427eb 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -595,23 +595,23 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( jac = false, Wfact = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = false, initialization_data = nothing, - kwargs...) where {iip, specialize} + cse = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`") end dvs = scalarize.(dvs) - f_gen = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...) + f_gen = generate_function(sys, dvs, ps; expression = Val{true}, cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true}, - kwargs...) + cse, kwargs...) g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip) g = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(g_oop, g_iip) if tgrad - tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true}, + tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true}, cse, kwargs...) tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module) _tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip) @@ -621,7 +621,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( if jac jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{true}, - sparse = sparse, kwargs...) + sparse = sparse, cse, kwargs...) jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module) _jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip) @@ -631,7 +631,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( if Wfact tmp_Wfact, tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true; - expression = Val{true}, kwargs...) + expression = Val{true}, cse, kwargs...) Wfact_oop, Wfact_iip = eval_or_rgf.(tmp_Wfact; eval_expression, eval_module) Wfact_oop_t, Wfact_iip_t = eval_or_rgf.(tmp_Wfact_t; eval_expression, eval_module) @@ -645,7 +645,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( _M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M) observedfun = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) SDEFunction{iip, specialize}(f, g; sys = sys, diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 40f01769ee..5f7c986659 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -360,13 +360,13 @@ function SciMLBase.DiscreteFunction{iip, specialize}( t = nothing, eval_expression = false, eval_module = @__MODULE__, - analytic = nothing, + analytic = nothing, cse = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `DiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") end f_gen = generate_function(sys, dvs, ps; expression = Val{true}, - expression_module = eval_module, kwargs...) + expression_module = eval_module, cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip) @@ -378,7 +378,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}( end observedfun = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) DiscreteFunction{iip, specialize}(f; sys = sys, diff --git a/src/systems/discrete_system/implicit_discrete_system.jl b/src/systems/discrete_system/implicit_discrete_system.jl index ebae78384a..3956c089d4 100644 --- a/src/systems/discrete_system/implicit_discrete_system.jl +++ b/src/systems/discrete_system/implicit_discrete_system.jl @@ -369,13 +369,13 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}( t = nothing, eval_expression = false, eval_module = @__MODULE__, - analytic = nothing, + analytic = nothing, cse = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`") end f_gen = generate_function(sys, dvs, ps; expression = Val{true}, - expression_module = eval_module, kwargs...) + expression_module = eval_module, cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f(u_next, u, p, t) = f_oop(u_next, u, p, t) f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t) @@ -388,7 +388,7 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}( end observedfun = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) ImplicitDiscreteFunction{iip, specialize}(f; sys = sys, diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 9da32a4305..57a3aee7df 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -398,6 +398,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, parammap = DiffEqBase.NullParameters(); eval_expression = false, eval_module = @__MODULE__, + cse = true, kwargs...) if !iscomplete(sys) error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") @@ -408,11 +409,11 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, end _f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; - t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false) + t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse) f = DiffEqBase.DISCRETE_INPLACE_DEFAULT observedfun = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun, initialization_data = get(_f.kwargs, :initialization_data, nothing)) @@ -488,7 +489,7 @@ oprob = ODEProblem(complete(js), u₀map, tspan, parammap) function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing}, parammap = DiffEqBase.NullParameters(); eval_expression = false, - eval_module = @__MODULE__, + eval_module = @__MODULE__, cse = true, kwargs...) if !iscomplete(sys) error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") @@ -507,10 +508,10 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi else _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; t = tspan === nothing ? nothing : tspan[1], tofloat = false, - check_length = false, build_initializeprob = false) + check_length = false, build_initializeprob = false, cse) f = (du, u, p, t) -> (du .= 0; nothing) observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, - checkbounds = get(kwargs, :checkbounds, false)) + checkbounds = get(kwargs, :checkbounds, false), cse) df = ODEFunction(f; sys, observed = observedfun) return ODEProblem(df, u0, tspan, p; kwargs...) end diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 474224eacd..9a77779103 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -489,7 +489,7 @@ end function SciMLBase.HomotopyNonlinearFunction{iip, specialize}( sys::NonlinearSystem, args...; eval_expression = false, eval_module = @__MODULE__, - p = nothing, fraction_cancel_fn = SymbolicUtils.simplify_fractions, + p = nothing, fraction_cancel_fn = SymbolicUtils.simplify_fractions, cse = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationFunction`") @@ -511,9 +511,9 @@ function SciMLBase.HomotopyNonlinearFunction{iip, specialize}( # we want to create f, jac etc. according to `sys2` since that will do the solving # but the `sys` inside for symbolic indexing should be the non-polynomial system - fn = NonlinearFunction{iip}(sys2; eval_expression, eval_module, kwargs...) + fn = NonlinearFunction{iip}(sys2; eval_expression, eval_module, cse, kwargs...) obsfn = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) fn = remake(fn; sys = sys, observed = obsfn) denominator = build_explicit_observed_function(sys2, denoms) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 77215be5eb..856822492b 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -373,19 +373,19 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s eval_expression = false, eval_module = @__MODULE__, sparse = false, simplify = false, - initialization_data = nothing, + initialization_data = nothing, cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`") end - f_gen = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...) + f_gen = generate_function(sys, dvs, ps; expression = Val{true}, cse, kwargs...) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip) if jac jac_gen = generate_jacobian(sys, dvs, ps; simplify = simplify, sparse = sparse, - expression = Val{true}, kwargs...) + expression = Val{true}, cse, kwargs...) jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module) _jac = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(jac_oop, jac_iip) else @@ -393,7 +393,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s end observedfun = ObservedFunctionCache( - sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) + sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) if length(dvs) == length(equations(sys)) resid_prototype = nothing @@ -606,7 +606,7 @@ end function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT}, exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation}; - eval_expression = false, eval_module = @__MODULE__) + eval_expression = false, eval_module = @__MODULE__, cse = true) ps = parameters(sys; initial_parameters = true) rps = reorder_parameters(sys, ps) obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs] @@ -625,7 +625,7 @@ function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT}, sys, nothing, :out, DestructuredArgs(DestructuredArgs.(solsyms), generated_argument_name(1)), rps...; p_start = 3, p_end = length(rps) + 2, - expression = Val{true}, add_observed = false, + expression = Val{true}, add_observed = false, cse, extra_assignments = [array_assignments; obs_assigns; body]) fn = eval_or_rgf(fn; eval_expression, eval_module) fn = GeneratedFunctionWrapper{(3, 3, is_split(sys))}(fn, nothing) @@ -636,7 +636,7 @@ struct SCCNonlinearFunction{iip} end function SCCNonlinearFunction{iip}( sys::NonlinearSystem, _eqs, _dvs, _obs, cachesyms; eval_expression = false, - eval_module = @__MODULE__, kwargs...) where {iip} + eval_module = @__MODULE__, cse = true, kwargs...) where {iip} ps = parameters(sys; initial_parameters = true) rps = reorder_parameters(sys, ps) @@ -646,7 +646,7 @@ function SCCNonlinearFunction{iip}( f_gen = build_function_wrapper(sys, rhss, _dvs, rps..., cachesyms...; p_start = 2, p_end = length(rps) + length(cachesyms) + 1, add_observed = false, - extra_assignments = obs_assignments, expression = Val{true}) + extra_assignments = obs_assignments, expression = Val{true}, cse) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip) @@ -666,7 +666,8 @@ function SciMLBase.SCCNonlinearProblem(sys::NonlinearSystem, args...; kwargs...) end function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, - parammap = SciMLBase.NullParameters(); eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip} + parammap = SciMLBase.NullParameters(); eval_expression = false, eval_module = @__MODULE__, + cse = true, kwargs...) where {iip} if !iscomplete(sys) || get_tearing_state(sys) === nothing error("A simplified `NonlinearSystem` is required. Call `structural_simplify` on the system before creating an `SCCNonlinearProblem`.") end @@ -801,14 +802,14 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1))) push!(explicitfuns, CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs]; - eval_expression, eval_module)) + eval_expression, eval_module, cse)) end cachebufsyms = Tuple(map(cachetypes) do T get(cachevars, T, []) end) f = SCCNonlinearFunction{iip}( - sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, kwargs...) + sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs...) push!(nlfuns, f) end diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index e55f4b7871..be4567aee5 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -300,7 +300,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, cons_sparse = false, checkbounds = false, linenumbers = true, parallel = SerialForm(), eval_expression = false, eval_module = @__MODULE__, - checks = true, + checks = true, cse = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed `OptimizationSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `OptimizationProblem`") @@ -354,8 +354,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, f = let _f = eval_or_rgf( generate_function( - sys, checkbounds = checkbounds, linenumbers = linenumbers, - expression = Val{true}, wrap_mtkparameters = false); + sys; checkbounds = checkbounds, linenumbers = linenumbers, + expression = Val{true}, wrap_mtkparameters = false, cse); eval_expression, eval_module) __f(u, p) = _f(u, p) @@ -367,10 +367,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, if grad _grad = let (grad_oop, grad_iip) = eval_or_rgf.( generate_gradient( - sys, checkbounds = checkbounds, + sys; checkbounds = checkbounds, linenumbers = linenumbers, parallel = parallel, expression = Val{true}, - wrap_mtkparameters = false); + wrap_mtkparameters = false, cse); eval_expression, eval_module) _grad(u, p) = grad_oop(u, p) @@ -386,10 +386,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, if hess _hess = let (hess_oop, hess_iip) = eval_or_rgf.( generate_hessian( - sys, checkbounds = checkbounds, + sys; checkbounds = checkbounds, linenumbers = linenumbers, sparse = sparse, parallel = parallel, - expression = Val{true}, wrap_mtkparameters = false); + expression = Val{true}, wrap_mtkparameters = false, cse); eval_expression, eval_module) _hess(u, p) = hess_oop(u, p) @@ -408,14 +408,14 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, hess_prototype = nothing end - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds, cse) if length(cstr) > 0 @named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks) cons_sys = complete(cons_sys) - cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds, + cons, lcons_, ucons_ = generate_function(cons_sys; checkbounds = checkbounds, linenumbers = linenumbers, - expression = Val{true}; wrap_mtkparameters = false) + expression = Val{true}, wrap_mtkparameters = false, cse) cons = let (cons_oop, cons_iip) = eval_or_rgf.(cons; eval_expression, eval_module) _cons(u, p) = cons_oop(u, p) _cons(resid, u, p) = cons_iip(resid, u, p) @@ -428,7 +428,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, checkbounds = checkbounds, linenumbers = linenumbers, parallel = parallel, expression = Val{true}, - sparse = cons_sparse, wrap_mtkparameters = false); + sparse = cons_sparse, wrap_mtkparameters = false, cse); eval_expression, eval_module) _cons_j(u, p) = cons_jac_oop(u, p) @@ -443,10 +443,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, if cons_h _cons_h = let (cons_hess_oop, cons_hess_iip) = eval_or_rgf.( generate_hessian( - cons_sys, checkbounds = checkbounds, + cons_sys; checkbounds = checkbounds, linenumbers = linenumbers, sparse = cons_sparse, parallel = parallel, - expression = Val{true}, wrap_mtkparameters = false); + expression = Val{true}, wrap_mtkparameters = false, cse); eval_expression, eval_module) _cons_h(u, p) = cons_hess_oop(u, p) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 84539f0b37..7f4a3247ad 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -173,7 +173,7 @@ end p = [rand()] x = [rand()] u = [rand()] - @test f[1](x, u, p, 1) == -x + u + @test f[1](x, u, p, 1) ≈ -x + u # With disturbance inputs @variables x(t)=0 u(t)=0 [input = true] d(t)=0 @@ -191,7 +191,7 @@ end p = [rand()] x = [rand()] u = [rand()] - @test f[1](x, u, p, 1) == -x + u + @test f[1](x, u, p, 1) ≈ -x + u ## With added d argument @variables x(t)=0 u(t)=0 [input = true] d(t)=0 @@ -210,7 +210,7 @@ end x = [rand()] u = [rand()] d = [rand()] - @test f[1](x, u, p, t, d) == -x + u + [d[]^2] + @test f[1](x, u, p, t, d) ≈ -x + u + [d[]^2] end end @@ -434,7 +434,7 @@ matrices, ssys = linearize(augmented_sys, (; io_sys,) = ModelingToolkit.generate_control_function(sys, simplify = true) obsfn = ModelingToolkit.build_explicit_observed_function( io_sys, [x + u * t]; inputs = [u]) - @test obsfn([1.0], [2.0], MTKParameters(io_sys, []), 3.0) == [7.0] + @test obsfn([1.0], [2.0], MTKParameters(io_sys, []), 3.0) ≈ [7.0] end # https://github.com/SciML/ModelingToolkit.jl/issues/2896 @@ -445,7 +445,7 @@ end @named sys = ODESystem(eqs, t, [x], []) f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true) - @test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) == [1.0] + @test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) ≈ [1.0] end @testset "With callable symbolic" begin