From e337c98d436a580e2c49de45a4ac005e7046279f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 May 2025 23:07:50 +0530 Subject: [PATCH 1/4] refactor: remove code related to parsing and substitution of constants --- src/inputoutput.jl | 3 +- src/problems/optimizationproblem.jl | 4 +- .../StructuralTransformations.jl | 2 +- src/structural_transformation/codegen.jl | 265 +----------------- src/structural_transformation/utils.jl | 4 +- src/systems/abstractsystem.jl | 12 - src/systems/callbacks.jl | 5 - src/systems/codegen_utils.jl | 12 +- src/systems/nonlinear/initializesystem.jl | 3 +- src/systems/parameter_buffer.jl | 3 +- src/systems/problem_utils.jl | 14 +- src/systems/systemstructure.jl | 6 +- src/utils.jl | 117 +------- 13 files changed, 19 insertions(+), 431 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 1beb229664..19603b76cd 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -222,7 +222,7 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs( disturbance_inputs = unwrap.(disturbance_inputs) eqs = [eq for eq in full_equations(sys)] - eqs = map(subs_constants, eqs) + if disturbance_inputs !== nothing && !disturbance_argument # Set all disturbance *inputs* to zero (we just want to keep the disturbance state) subs = Dict(disturbance_inputs .=> 0) @@ -237,7 +237,6 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs( p = reorder_parameters(sys, ps) t = get_iv(sys) - # pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys) if disturbance_argument args = (dvs, inputs, p..., t, disturbance_inputs) else diff --git a/src/problems/optimizationproblem.jl b/src/problems/optimizationproblem.jl index 243d453ada..e0de2f78ff 100644 --- a/src/problems/optimizationproblem.jl +++ b/src/problems/optimizationproblem.jl @@ -56,10 +56,10 @@ function SciMLBase.OptimizationFunction{iip}(sys::System; else _cons_h = cons_hess_prototype = nothing end - cons_expr = subs_constants(cstr) + cons_expr = cstr end - obj_expr = subs_constants(cost(sys)) + obj_expr = cost(sys) observedfun = ObservedFunctionCache( sys; expression, eval_expression, eval_module, checkbounds, cse) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 06d8e440cc..2ba469e26a 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -21,7 +21,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe has_tearing_state, defaults, InvalidSystemException, ExtraEquationsSystemException, ExtraVariablesSystemException, - get_postprocess_fbody, vars!, + vars!, IncrementalCycleTracker, add_edge_checked!, topological_sort, invalidate_cache!, Substitutions, get_or_construct_tearing_state, filter_kwargs, lower_varname_with_unit, diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index 144e19aa31..9afe7ec5e7 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -1,6 +1,6 @@ using LinearAlgebra -using ModelingToolkit: process_events, get_preprocess_constants +using ModelingToolkit: process_events const MAX_INLINE_NLSOLVE_SIZE = 8 @@ -96,136 +96,6 @@ function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_ sparse(I, J, true, length(eqs_idxs), length(states_idxs)) end -function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict, - assignments, (deps, invdeps), var2assignment; checkbounds = true) - isempty(vars) && throw(ArgumentError("vars may not be empty")) - length(eqs) == length(vars) || - throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of")) - rhss = map(x -> x.rhs, eqs) - # We use `vars` instead of `graph` to capture parameters, too. - paramset = ModelingToolkit.vars(r for r in rhss) - - # Compute necessary assignments for the nlsolve expr - init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)] - if isempty(init_assignments) - needed_assignments_idxs = Int[] - needed_assignments = similar(assignments, 0) - else - tmp = [init_assignments] - # `deps[init_assignments]` gives the dependency of `init_assignments` - while true - next_assignments = unique(reduce(vcat, deps[init_assignments])) - isempty(next_assignments) && break - init_assignments = next_assignments - push!(tmp, init_assignments) - end - needed_assignments_idxs = unique(reduce(vcat, reverse(tmp))) - needed_assignments = assignments[needed_assignments_idxs] - end - - # Compute `params`. They are like enclosed variables - rhsvars = [ModelingToolkit.vars(r.rhs) for r in needed_assignments] - vars_set = Set(vars) - outer_set = BitSet() - inner_set = BitSet() - for (i, vs) in enumerate(rhsvars) - j = needed_assignments_idxs[i] - if isdisjoint(vars_set, vs) - push!(outer_set, j) - else - push!(inner_set, j) - end - end - init_refine = BitSet() - for i in inner_set - union!(init_refine, invdeps[i]) - end - intersect!(init_refine, outer_set) - setdiff!(outer_set, init_refine) - union!(inner_set, init_refine) - - next_refine = BitSet() - while true - for i in init_refine - id = invdeps[i] - isempty(id) && break - union!(next_refine, id) - end - intersect!(next_refine, outer_set) - isempty(next_refine) && break - setdiff!(outer_set, next_refine) - union!(inner_set, next_refine) - - init_refine, next_refine = next_refine, init_refine - empty!(next_refine) - end - global2local = Dict(j => i for (i, j) in enumerate(needed_assignments_idxs)) - inner_idxs = [global2local[i] for i in collect(inner_set)] - outer_idxs = [global2local[i] for i in collect(outer_set)] - extravars = reduce(union!, rhsvars[inner_idxs], init = Set()) - union!(paramset, extravars) - setdiff!(paramset, vars) - setdiff!(paramset, [needed_assignments[i].lhs for i in inner_idxs]) - union!(paramset, [needed_assignments[i].lhs for i in outer_idxs]) - params = collect(paramset) - - # splatting to tighten the type - u0 = [] - for v in vars - v in keys(u0map) || (push!(u0, 1e-3); continue) - u = substitute(v, u0map) - for i in 1:length(u0map) - u = substitute(u, u0map) - u isa Number && (push!(u0, u); break) - end - u isa Number || error("$v doesn't have a default.") - end - u0 = [u0...] - # specialize on the scalar case - isscalar = length(u0) == 1 - u0 = isscalar ? u0[1] : SVector(u0...) - - fname = gensym("fun") - # f is the function to find roots on - if isscalar - funex = rhss[1] - pre = get_preprocess_constants(funex) - else - funex = MakeArray(rhss, SVector) - pre = get_preprocess_constants(rhss) - end - f = Func( - [DestructuredArgs(vars, inbounds = !checkbounds) - DestructuredArgs(params, inbounds = !checkbounds)], - [], - pre(Let(needed_assignments[inner_idxs], - funex, - false))) |> SymbolicUtils.Code.toexpr - - # solver call contains code to call the root-finding solver on the function f - solver_call = LiteralExpr(quote - $numerical_nlsolve($fname, - # initial guess - $u0, - # "captured variables" - ($(params...),)) - end) - - preassignments = [] - for i in outer_idxs - ii = needed_assignments_idxs[i] - is_not_prepended_assignment[ii] || continue - is_not_prepended_assignment[ii] = false - push!(preassignments, assignments[ii]) - end - - nlsolve_expr = Assignment[preassignments - fname ← drop_expr(@RuntimeGeneratedFunction(f)) - DestructuredArgs(vars, inbounds = !checkbounds) ← solver_call] - - nlsolve_expr -end - """ find_solve_sequence(sccs, vars) @@ -242,136 +112,3 @@ function find_solve_sequence(sccs, vars) return find_solve_sequence(sccs, vars′) end end - -function build_observed_function(state, ts, var_eq_matching, var_sccs, - is_solver_unknown_idxs, - assignments, - deps, - sol_states, - var2assignment; - expression = false, - output_type = Array, - checkbounds = true) - is_not_prepended_assignment = trues(length(assignments)) - if (isscalar = !(ts isa AbstractVector)) - ts = [ts] - end - ts = unwrap.(Symbolics.scalarize(ts)) - - vars = Set() - sys = state.sys - foreach(Base.Fix1(vars!, vars), ts) - ivs = independent_variables(sys) - dep_vars = collect(setdiff(vars, ivs)) - - fullvars = state.fullvars - s = state.structure - unknown_vars = fullvars[is_solver_unknown_idxs] - algvars = fullvars[.!is_solver_unknown_idxs] - - required_algvars = Set(intersect(algvars, vars)) - obs = observed(sys) - observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs)) - namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs) - namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys)) - sts = Set(unknowns(sys)) - - # FIXME: This is a rather rough estimate of dependencies. We assume - # the expression depends on everything before the `maxidx`. - subs = Dict() - maxidx = 0 - for (i, s) in enumerate(dep_vars) - idx = get(observed_idx, s, nothing) - if idx !== nothing - idx > maxidx && (maxidx = idx) - else - s′ = get(namespaced_to_obs, s, nothing) - if s′ !== nothing - subs[s] = s′ - s = s′ - idx = get(observed_idx, s, nothing) - end - if idx !== nothing - idx > maxidx && (maxidx = idx) - elseif !(s in sts) - s′ = get(namespaced_to_sts, s, nothing) - if s′ !== nothing - subs[s] = s′ - continue - end - throw(ArgumentError("$s is either an observed nor an unknown variable.")) - end - continue - end - end - ts = map(t -> substitute(t, subs), ts) - vs = Set() - for idx in 1:maxidx - vars!(vs, obs[idx].rhs) - union!(required_algvars, intersect(algvars, vs)) - empty!(vs) - end - for eq in assignments - vars!(vs, eq.rhs) - union!(required_algvars, intersect(algvars, vs)) - empty!(vs) - end - - varidxs = findall(x -> x in required_algvars, fullvars) - subset = find_solve_sequence(var_sccs, varidxs) - if !isempty(subset) - eqs = equations(sys) - - nested_torn_vars_idxs = [] - for iscc in subset - torn_vars_idxs = Int[var - for var in var_sccs[iscc] - if var_eq_matching[var] !== unassigned] - isempty(torn_vars_idxs) || push!(nested_torn_vars_idxs, torn_vars_idxs) - end - torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs] - for idxs in nested_torn_vars_idxs] - torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs] - u0map = defaults(sys) - assignments = copy(assignments) - solves = map(zip(torn_eqs, torn_vars)) do (eqs, vars) - gen_nlsolve!(is_not_prepended_assignment, eqs, vars, - u0map, assignments, deps, var2assignment; - checkbounds = checkbounds) - end - else - solves = [] - end - - subs = [] - for sym in vars - eqidx = get(observed_idx, sym, nothing) - eqidx === nothing && continue - push!(subs, sym ← obs[eqidx].rhs) - end - pre = get_postprocess_fbody(sys) - cpre = get_preprocess_constants([obs[1:maxidx]; - isscalar ? ts[1] : MakeArray(ts, output_type)]) - pre2 = x -> pre(cpre(x)) - ex = Code.toexpr( - Func( - [DestructuredArgs(unknown_vars, inbounds = !checkbounds) - DestructuredArgs(parameters(sys), inbounds = !checkbounds) - independent_variables(sys)], - [], - pre2(Let( - [collect(Iterators.flatten(solves)) - assignments[is_not_prepended_assignment] - map(eq -> eq.lhs ← eq.rhs, obs[1:maxidx]) - subs], - isscalar ? ts[1] : MakeArray(ts, output_type), - false))), - sol_states) - - expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex)) -end - -struct ODAEProblem{iip} end - -@deprecate ODAEProblem(args...; kw...) ODEProblem(args...; kw...) -@deprecate ODAEProblem{iip}(args...; kw...) where {iip} ODEProblem{iip}(args...; kw...) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 191c25ab68..2bf316cfa6 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -224,14 +224,12 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no a, b, islinear = linear_expansion(term, var) a, b = unwrap(a), unwrap(b) islinear || (all_int_vars = false; continue) - a = ModelingToolkit.fold_constants(a) - b = ModelingToolkit.fold_constants(b) if a isa Symbolic all_int_vars = false if !allow_symbolic if allow_parameter all( - x -> ModelingToolkit.isparameter(x) || ModelingToolkit.isconstant(x), + x -> ModelingToolkit.isparameter(x), vars(a)) || continue else continue diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 4021f1fc44..1a90faad92 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2540,18 +2540,6 @@ function debug_system( return sys end -function eliminate_constants(sys::AbstractSystem) - if has_eqs(sys) - eqs = get_eqs(sys) - eq_cs = collect_constants(eqs) - if !isempty(eq_cs) - new_eqs = eliminate_constants(eqs, eq_cs) - @set! sys.eqs = new_eqs - end - end - return sys -end - @latexrecipe function f(sys::AbstractSystem) return latexify(equations(sys)) end diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 0bb2318e4b..59ff0e0d98 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -635,11 +635,6 @@ function compile_condition( p = map.(value, reorder_parameters(sys, ps)) t = get_iv(sys) condit = conditions(cbs) - cs = collect_constants(condit) - if !isempty(cs) - cmap = map(x -> x => getdefault(x), cs) - condit = substitute(condit, Dict(cmap)) - end if !is_discrete(cbs) condit = reduce(vcat, flatten_equations(Vector{Equation}(condit))) diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index 8c2c322e31..c3b652740e 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -247,15 +247,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, p_end += 1 end pdeps = parameter_dependencies(sys) - # get the constants to add to the code - cmap, _ = get_cmap(sys) - extra_constants = collect_constants(expr) - filter!(extra_constants) do c - !any(x -> isequal(c, x.lhs), cmap) - end - for c in extra_constants - push!(cmap, c ~ getdefault(c)) - end + # only get the necessary observed equations, avoiding extra computation if add_observed && !isempty(obs) obsidxs = observed_equations_used_by(sys, expr; obs) @@ -270,7 +262,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2, # assignments for reconstructing scalarized array symbolics assignments = array_variable_assignments(args...) - for eq in Iterators.flatten((cmap, pdeps[pdepidxs], obs[obsidxs])) + for eq in Iterators.flatten((pdeps[pdepidxs], obs[obsidxs])) push!(assignments, eq.lhs ← eq.rhs) end append!(assignments, extra_assignments) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 57913c5031..0263749957 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -534,7 +534,6 @@ function SciMLBase.remake_initialization_data( symbols_to_symbolics!(sys, pmap) guesses = Dict() defs = defaults(sys) - cmap, cs = get_cmap(sys) use_scc = true initialization_eqs = Equation[] @@ -589,7 +588,7 @@ function SciMLBase.remake_initialization_data( filter_missing_values!(pmap) op, missing_unknowns, missing_pars = build_operating_point!(sys, - u0map, pmap, defs, cmap, dvs, ps) + u0map, pmap, defs, dvs, ps) floatT = float_type_from_varmap(op) kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; time_dependent_init, diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 6142c95776..c3656d737c 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -44,13 +44,12 @@ function MTKParameters( p = to_varmap(p, ps) symbols_to_symbolics!(sys, p) defs = add_toterms(recursive_unwrap(defaults(sys))) - cmap, cs = get_cmap(sys) is_time_dependent(sys) && add_observed!(sys, u0) add_parameter_dependencies!(sys, p) op, missing_unknowns, missing_pars = build_operating_point!(sys, - u0, p, defs, cmap, dvs, ps) + u0, p, defs, dvs, ps) if t0 !== nothing op[get_iv(sys)] = t0 diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index b2019107cf..346d17ecf5 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -558,15 +558,15 @@ end $(TYPEDSIGNATURES) Construct the operating point of the system from the user-provided `u0map` and `pmap`, system -defaults `defs`, constant equations `cmap` (from `get_cmap(sys)`), unknowns `dvs` and -parameters `ps`. Return the operating point as a dictionary, the list of unknowns for which -no values can be determined, and the list of parameters for which no values can be determined. +defaults `defs`, unknowns `dvs` and parameters `ps`. Return the operating point as a dictionary, +the list of unknowns for which no values can be determined, and the list of parameters for which +no values can be determined. Also updates `u0map` and `pmap` in-place to contain all the initial conditions in `op`, split by unknowns and parameters respectively. """ function build_operating_point!(sys::AbstractSystem, - u0map::AbstractDict, pmap::AbstractDict, defs::AbstractDict, cmap, dvs, ps) + u0map::AbstractDict, pmap::AbstractDict, defs::AbstractDict, dvs, ps) op = add_toterms(u0map) missing_unknowns = add_fallbacks!(op, dvs, defs) for (k, v) in defs @@ -578,9 +578,6 @@ function build_operating_point!(sys::AbstractSystem, merge!(op, pmap) missing_pars = add_fallbacks!(op, ps, defs) filter_missing_values!(op; missing_values = missing_pars) - for eq in cmap - op[eq.lhs] = eq.rhs - end filter!(kvp -> kvp[2] === nothing, u0map) filter!(kvp -> kvp[2] === nothing, pmap) @@ -1084,7 +1081,6 @@ function process_SciMLProblem( check_inputmap_keys(sys, u0map, pmap) defs = add_toterms(recursive_unwrap(defaults(sys))) - cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) if eltype(eqs) <: Equation @@ -1094,7 +1090,7 @@ function process_SciMLProblem( end op, missing_unknowns, missing_pars = build_operating_point!(sys, - u0map, pmap, defs, cmap, dvs, ps) + u0map, pmap, defs, dvs, ps) floatT = Bool if u0Type <: AbstractArray && eltype(u0Type) <: Real diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 144aad148e..c1b2c337ac 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -5,7 +5,7 @@ using SymbolicUtils: quick_cancel, maketerm using ..ModelingToolkit import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, value, InvalidSystemException, isdifferential, _iszero, - isparameter, isconstant, + isparameter, independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, InferredTimeDomain, @@ -314,7 +314,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) _var, _ = var_from_nested_derivative(v) any(isequal(_var), ivs) && continue if isparameter(_var) || - (iscall(_var) && isparameter(operation(_var)) || isconstant(_var)) + (iscall(_var) && isparameter(operation(_var))) if is_time_dependent_parameter(_var, iv) && !haskey(param_derivative_map, Differential(iv)(_var)) # Parameter derivatives default to zero - they stay constant @@ -339,7 +339,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) _var, _ = var_from_nested_derivative(var) any(isequal(_var), ivs) && continue if isparameter(_var) || - (iscall(_var) && isparameter(operation(_var)) || isconstant(_var)) + (iscall(_var) && isparameter(operation(_var))) continue end varidx = addvar!(var) diff --git a/src/utils.jl b/src/utils.jl index efa6196af8..e1e8c50af4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -708,7 +708,7 @@ function collect_var!(unknowns, parameters, var, iv; depth = 0) collect_vars!(unknowns, parameters, arguments(var), iv) elseif isparameter(var) || (iscall(var) && isparameter(operation(var))) push!(parameters, var) - elseif !isconstant(var) + else push!(unknowns, var) end # Add also any parameters that appear only as defaults in the var @@ -734,90 +734,6 @@ function check_scope_depth(scope, depth) end end -""" -Find all the symbolic constants of some equations or terms and return them as a vector. -""" -function collect_constants(x) - constants = BasicSymbolic[] - collect_constants!(constants, x) - return constants -end - -collect_constants!(::Any, ::Symbol) = nothing - -function collect_constants!(constants, arr::AbstractArray) - for el in arr - collect_constants!(constants, el) - end -end - -function collect_constants!(constants, eq::Equation) - collect_constants!(constants, eq.lhs) - collect_constants!(constants, eq.rhs) -end - -function collect_constants!(constants, eq::Inequality) - collect_constants!(constants, eq.lhs) - collect_constants!(constants, eq.rhs) -end - -collect_constants!(constants, x::Num) = collect_constants!(constants, unwrap(x)) -collect_constants!(constants, x::Real) = nothing -collect_constants(n::Nothing) = BasicSymbolic[] - -function collect_constants!(constants, expr::Symbolic) - if issym(expr) && isconstant(expr) - push!(constants, expr) - else - evars = vars(expr) - if length(evars) == 1 && isequal(only(evars), expr) - return nothing #avoid infinite recursion for vars(x(t)) == [x(t)] - else - for var in evars - collect_constants!(constants, var) - end - end - end -end - -function collect_constants!(constants, expr::Union{ConstantRateJump, VariableRateJump}) - collect_constants!(constants, expr.rate) - collect_constants!(constants, expr.affect!) -end - -function collect_constants!(constants, ::MassActionJump) - return constants -end - -""" -Replace symbolic constants with their literal values -""" -function eliminate_constants(eqs, cs) - cmap = Dict(x => getdefault(x) for x in cs) - return substitute(eqs, cmap) -end - -""" -Create a function preface containing assignments of default values to constants. -""" -function get_preprocess_constants(eqs) - cs = collect_constants(eqs) - pre = ex -> Let(Assignment[Assignment(x, getdefault(x)) for x in cs], - ex, false) - return pre -end - -function get_postprocess_fbody(sys) - if has_preface(sys) && (pre = preface(sys); pre !== nothing) - pre_ = let pre = pre - ex -> Let(pre, ex, false) - end - else - pre_ = ex -> ex - end - return pre_ -end - """ $(SIGNATURES) @@ -838,22 +754,6 @@ end isarray(x) = x isa AbstractArray || x isa Symbolics.Arr -function get_cmap(sys, exprs = nothing) - #Inject substitutions for constants => values - buffer = [] - has_eqs(sys) && append!(buffer, collect(get_eqs(sys))) - has_observed(sys) && append!(buffer, collect(get_observed(sys))) - has_op(sys) && push!(buffer, get_op(sys)) - has_constraints(sys) && append!(buffer, get_constraints(sys)) - cs = collect_constants(buffer) #ctrls? what else? - if exprs !== nothing - cs = [cs; collect_constants(exprs)] - end - # Swap constants for their values - cmap = map(x -> x ~ getdefault(x), cs) - return cmap, cs -end - function empty_substitutions(sys) isempty(observed(sys)) end @@ -1043,21 +943,6 @@ function Base.iterate(it::StatefulBFS, queue = (eltype(it)[(0, it.t)])) return (lv, t), queue end -function fold_constants(ex) - if iscall(ex) - maketerm(typeof(ex), operation(ex), map(fold_constants, arguments(ex)), - metadata(ex)) - elseif issym(ex) && isconstant(ex) - if (unit = getmetadata(ex, VariableUnit, nothing); unit !== nothing) - ex # we cannot fold constant with units - else - getdefault(ex) - end - else - ex - end -end - normalize_to_differential(s) = s function restrict_array_to_union(arr) From 6cc55d42d07e9593e65eda2cc016c8a07f6eaf12 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 May 2025 23:08:06 +0530 Subject: [PATCH 2/4] refactor: make `@constants` create non-tunable parameters --- src/constants.jl | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/src/constants.jl b/src/constants.jl index a0a38fd057..4113287ad4 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -1,13 +1,9 @@ -import SymbolicUtils: symtype, term, hasmetadata, issym -struct MTKConstantCtx end - -isconstant(x::Num) = isconstant(unwrap(x)) """ Test whether `x` is a constant-type Sym. """ function isconstant(x) x = unwrap(x) - x isa Symbolic && getmetadata(x, MTKConstantCtx, false) + x isa Symbolic && !getmetadata(x, VariableTunable, true) end """ @@ -16,12 +12,11 @@ end Maps the parameter to a constant. The parameter must have a default. """ function toconstant(s) - hasmetadata(s, Symbolics.VariableDefaultValue) || - throw(ArgumentError("Constant `$(s)` must be assigned a default value.")) - setmetadata(s, MTKConstantCtx, true) + s = toparam(s) + setmetadata(s, VariableTunable, false) end -toconstant(s::Num) = wrap(toconstant(value(s))) +toconstant(s::Union{Num, Symbolics.Arr}) = wrap(toconstant(value(s))) """ $(SIGNATURES) @@ -36,15 +31,3 @@ macro constants(xs...) xs, toconstant) |> esc end - -""" -Substitute all `@constants` in the given expression -""" -function subs_constants(eqs) - consts = collect_constants(eqs) - if !isempty(consts) - csubs = Dict(c => getdefault(c) for c in consts) - eqs = substitute(eqs, csubs) - end - return eqs -end From 80c4a383e24ca88f2183c2b0ca084d33240c269b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 May 2025 23:08:16 +0530 Subject: [PATCH 3/4] refactor: update `@constants` parsing in `@mtkmodel` --- src/systems/model_parsing.jl | 76 +++++++++++++----------------------- 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 8fe07f7f99..9d293fd40d 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -53,7 +53,6 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) end exprs = Expr(:block) dict = Dict{Symbol, Any}( - :constants => Dict{Symbol, Dict}(), :defaults => Dict{Symbol, Any}(), :kwargs => Dict{Symbol, Dict}(), :structural_parameters => Dict{Symbol, Dict}() @@ -125,7 +124,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) description = get(dict, :description, "") @inline pop_structure_dict!.( - Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters]) + Ref(dict), [:defaults, :kwargs, :structural_parameters]) sys = :($type($(flatten_equations)(equations), $iv, variables, parameters; name, description = $description, systems, gui_metadata = $gui_metadata, @@ -320,6 +319,10 @@ Base.@nospecializeinfer function parse_variable_def!( Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $first(@parameters ($a[$(indices...)]::$type = $varval), $meta_val)) + elseif varclass == :constants + Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) + var = :($varname = $first(@constants ($a[$(indices...)]::$type = $varval), + $meta_val)) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") @@ -351,6 +354,12 @@ Base.@nospecializeinfer function parse_variable_def!( var = :($varname = $varname === $NO_VALUE ? $val : $varname; $varname = $first(@parameters ($a[$(indices...)]::$type = $varval), $(def_n_meta...))) + elseif varclass == :constants + Meta.isexpr(a, :call) && + assert_unique_independent_var(dict, a.args[end]) + var = :($varname = $varname === $NO_VALUE ? $val : $varname; + $varname = $first(@constants ($a[$(indices...)]::$type = $varval), + $(def_n_meta...))) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") @@ -366,6 +375,11 @@ Base.@nospecializeinfer function parse_variable_def!( assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; $varname = $first(@parameters $a[$(indices...)]::$type = $varname)) + elseif varclass == :constants + Meta.isexpr(a, :call) && + assert_unique_independent_var(dict, a.args[end]) + var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; + $varname = $first(@constants $a[$(indices...)]::$type = $varname)) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") @@ -393,6 +407,9 @@ Base.@nospecializeinfer function parse_variable_def!( if varclass == :parameters Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $first(@parameters $a[$(indices...)]::$type = $varname)) + elseif varclass == :constants + Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) + var = :($varname = $first(@constants $a[$(indices...)]::$type = $varname)) elseif varclass == :variables Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") @@ -453,6 +470,8 @@ function generate_var(a, varclass; type = Real) var = Symbolics.variable(a; T = type) if varclass == :parameters var = toparam(var) + elseif varclass == :constants + var = toconstant(var) elseif varclass == :independent_variables var = toiv(var) end @@ -513,6 +532,8 @@ function generate_var!(dict, a, b, varclass, mod; end if varclass == :parameters var = toparam(var) + elseif varclass == :constants + var = toconstant(var) end var end @@ -622,7 +643,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, elseif mname == Symbol("@equations") parse_equations!(exprs, eqs, dict, body) elseif mname == Symbol("@constants") - parse_constants!(exprs, dict, body, mod) + parse_variables!(exprs, ps, dict, mod, body, :constants, kwargs, where_types) elseif mname == Symbol("@continuous_events") parse_continuous_events!(c_evts, dict, body) elseif mname == Symbol("@discrete_events") @@ -643,49 +664,6 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, end end -function parse_constants!(exprs, dict, body, mod) - Base.remove_linenums!(body) - for arg in body.args - MLStyle.@match arg begin - Expr(:(=), Expr(:(::), a, type), Expr(:tuple, b, metadata)) || Expr(:(=), Expr(:(::), a, type), b) => begin - type = getfield(mod, type) - b = _type_check!(get_var(mod, b), a, type, :constants) - push!(exprs, - :($(Symbolics._parse_vars( - :constants, type, [:($a = $b), metadata], toconstant)))) - dict[:constants][a] = Dict(:value => b, :type => type) - if @isdefined metadata - for data in metadata.args - dict[:constants][a][data.args[1]] = data.args[2] - end - end - end - Expr(:(=), a, Expr(:tuple, b, metadata)) => begin - push!(exprs, - :($(Symbolics._parse_vars( - :constants, Real, [:($a = $b), metadata], toconstant)))) - dict[:constants][a] = Dict{Symbol, Any}(:value => get_var(mod, b)) - for data in metadata.args - dict[:constants][a][data.args[1]] = data.args[2] - end - end - Expr(:(=), a, b) => begin - push!(exprs, - :($(Symbolics._parse_vars( - :constants, Real, [:($a = $b)], toconstant)))) - dict[:constants][a] = Dict(:value => get_var(mod, b)) - end - _ => error("""Malformed constant definition `$arg`. Please use the following syntax: - ``` - @constants begin - var = value, [description = "This is an example constant."] - end - ``` - """) - end - end -end - push_additional_defaults!(dict, a, b::Number) = dict[:defaults][a] = b push_additional_defaults!(dict, a, b::QuoteNode) = dict[:defaults][a] = b.value function push_additional_defaults!(dict, a, b::Expr) @@ -950,6 +928,7 @@ function handle_conditional_vars!( arg, conditional_branch, mod, varclass, kwargs, where_types) conditional_dict = Dict(:kwargs => Dict(), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}()], + :constants => Any[Dict{Symbol, Dict{Symbol, Any}}()], :variables => Any[Dict{Symbol, Dict{Symbol, Any}}()]) for _arg in arg.args name, ex = parse_variable_arg( @@ -964,7 +943,7 @@ function prune_conditional_dict!(conditional_tuple::Tuple) prune_conditional_dict!.(collect(conditional_tuple)) end function prune_conditional_dict!(conditional_dict::Dict) - for k in [:parameters, :variables] + for k in [:parameters, :variables, :constants] length(conditional_dict[k]) == 1 && isempty(first(conditional_dict[k])) && delete!(conditional_dict, k) end @@ -981,7 +960,7 @@ end function get_conditional_dict!(conditional_dict::Dict, conditional_y_tuple::Dict) merge!(conditional_dict[:kwargs], conditional_y_tuple[:kwargs]) - for key in [:parameters, :variables] + for key in [:parameters, :variables, :constants] merge!(conditional_dict[key][1], conditional_y_tuple[key][1]) end conditional_dict @@ -1000,6 +979,7 @@ function push_conditional_dict!(dict, condition, conditional_dict, end conditional_y_dict = Dict(:kwargs => Dict(), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}()], + :constants => Any[Dict{Symbol, Dict{Symbol, Any}}()], :variables => Any[Dict{Symbol, Dict{Symbol, Any}}()]) get_conditional_dict!(conditional_y_dict, conditional_y_tuple) From 83a7cb034ce35d18b7336e87666036bf9c160b23 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 19 May 2025 23:08:36 +0530 Subject: [PATCH 4/4] test: account for new `@constants` behavior in tests --- test/components.jl | 2 +- test/constants.jl | 13 +++++-------- test/discrete_system.jl | 2 +- test/dq_units.jl | 3 +-- test/funcaffect.jl | 2 +- test/input_output_handling.jl | 2 +- test/jumpsystem.jl | 13 ++++++++----- test/model_parsing.jl | 2 +- test/nonlinearsystem.jl | 18 +++++++++--------- test/odesystem.jl | 13 +++++++------ test/parameter_dependencies.jl | 8 ++++---- test/structural_transformation/tearing.jl | 12 +++++------- 12 files changed, 44 insertions(+), 46 deletions(-) diff --git a/test/components.jl b/test/components.jl index 7680afc50c..6102762d01 100644 --- a/test/components.jl +++ b/test/components.jl @@ -230,7 +230,7 @@ end eqs = [ v ~ i * R ] - extend(System(eqs, t, [], []; name = name), oneport) + extend(System(eqs, t, [], [R]; name = name), oneport) end capacitor = Capacitor(; name = :c1, C = 1.0) resistor = FixedResistor(; name = :r1) diff --git a/test/constants.jl b/test/constants.jl index 5e97d52d7f..bd28517ae6 100644 --- a/test/constants.jl +++ b/test/constants.jl @@ -4,7 +4,8 @@ MT = ModelingToolkit UMT = ModelingToolkit.UnitfulUnitCheck @constants a = 1 -@test_throws MT.ArgumentError @constants b +@test isconstant(a) +@test !istunable(a) @independent_variables t @variables x(t) w(t) @@ -14,9 +15,6 @@ eqs = [D(x) ~ a] prob = ODEProblem(complete(sys), [0], [0.0, 1.0], []) sol = solve(prob, Tsit5()) -newsys = MT.eliminate_constants(sys) -@test isequal(equations(newsys), [D(x) ~ 1]) - # Test structural_simplify substitutions & observed values eqs = [D(x) ~ 1, w ~ a] @@ -29,6 +27,7 @@ simp = structural_simplify(sys) @constants β=1 [unit = u"m/s"] UMT.get_unit(β) @test MT.isconstant(β) +@test !MT.istunable(β) @independent_variables t [unit = u"s"] @variables x(t) [unit = u"m"] D = Differential(t) @@ -36,17 +35,15 @@ eqs = [D(x) ~ β] @named sys = System(eqs, t) simp = structural_simplify(sys) -@test isempty(MT.collect_constants(nothing)) - @testset "Issue#3044" begin - @constants h = 1 + @constants h @parameters τ = 0.5 * h @variables x(MT.t_nounits) = h eqs = [MT.D_nounits(x) ~ (h - x) / τ] @mtkbuild fol_model = System(eqs, MT.t_nounits) - prob = ODEProblem(fol_model, [], (0.0, 10.0)) + prob = ODEProblem(fol_model, [], (0.0, 10.0), [h => 1]) @test prob[x] ≈ 1 @test prob.ps[τ] ≈ 0.5 end diff --git a/test/discrete_system.jl b/test/discrete_system.jl index 1ed6438d76..f3c5bff496 100644 --- a/test/discrete_system.jl +++ b/test/discrete_system.jl @@ -30,7 +30,7 @@ eqs = [S ~ S(k - 1) - infection * h, R ~ R(k - 1) + recovery] # System -@named sys = System(eqs, t, [S, I, R], [c, nsteps, δt, β, γ]) +@named sys = System(eqs, t, [S, I, R], [c, nsteps, δt, β, γ, h]) syss = structural_simplify(sys) @test syss == syss diff --git a/test/dq_units.jl b/test/dq_units.jl index 4d8c245e06..f0dc2dbe23 100644 --- a/test/dq_units.jl +++ b/test/dq_units.jl @@ -233,8 +233,7 @@ end L(t), [unit = u"m"] L_out(t), [unit = u"1"] end -@test to_m in ModelingToolkit.vars(ModelingToolkit.fold_constants(Symbolics.unwrap(L_out * - -to_m))) +@test to_m in ModelingToolkit.vars(Symbolics.unwrap(L_out * -to_m)) # test units for registered functions let diff --git a/test/funcaffect.jl b/test/funcaffect.jl index 8e73280eb3..b0745c8a9d 100644 --- a/test/funcaffect.jl +++ b/test/funcaffect.jl @@ -282,7 +282,7 @@ function bb_affect!(integ, u, p, ctx) integ.u[u.v] = -integ.u[u.v] end -@named bb_model = System(bb_eqs, t, sts, par, +@named bb_model = System(bb_eqs, t, sts, [par; zr], continuous_events = [ [y ~ zr] => (bb_affect!, [v], [], [], nothing) ]) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 1fd7732c50..1f14bc3814 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -447,7 +447,7 @@ end @constants c = 2.0 @variables x(t) eqs = [D(x) ~ c * x] - @mtkbuild sys = System(eqs, t, [x], []) + @mtkbuild sys = System(eqs, t, [x], [c]) f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys) @test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) ≈ [1.0] diff --git a/test/jumpsystem.jl b/test/jumpsystem.jl index 50d59b3313..a42ba5ecb9 100644 --- a/test/jumpsystem.jl +++ b/test/jumpsystem.jl @@ -1,4 +1,5 @@ using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra +using SymbolicIndexingInterface using Random, StableRNGs, NonlinearSolve using OrdinaryDiffEq using ModelingToolkit: t_nounits as t, D_nounits as D @@ -17,7 +18,7 @@ rate₂ = γ * I + t affect₂ = [I ~ Pre(I) - 1, R ~ Pre(R) + 1] j₁ = ConstantRateJump(rate₁, affect₁) j₂ = VariableRateJump(rate₂, affect₂) -@named js = JumpSystem([j₁, j₂], t, [S, I, R], [β, γ]) +@named js = JumpSystem([j₁, j₂], t, [S, I, R], [β, γ, h]) unknowntoid = Dict(MT.value(unknown) => i for (i, unknown) in enumerate(unknowns(js))) mtjump1 = MT.assemble_crj(js, j₁, unknowntoid) mtjump2 = MT.assemble_vrj(js, j₂, unknowntoid) @@ -38,7 +39,7 @@ jump2 = VariableRateJump(rate2, affect2!) # test crjs u = [100, 9, 5] -p = (0.1 / 1000, 0.01) +p = (0.1 / 1000, 0.01, 1) tf = 1.0 mutable struct TestInt{U, V, T} u::U @@ -62,15 +63,15 @@ jump2.affect!(integrator) rate₃ = γ * I * h affect₃ = [I ~ Pre(I) * h - 1, R ~ Pre(R) + 1] j₃ = ConstantRateJump(rate₃, affect₃) -@named js2 = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ]) +@named js2 = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ, h]) js2 = complete(js2) u₀ = [999, 1, 0]; -p = (0.1 / 1000, 0.01); tspan = (0.0, 250.0); u₀map = [S => 999, I => 1, R => 0] parammap = [β => 0.1 / 1000, γ => 0.01] jprob = JumpProblem(js2, u₀map, tspan, parammap; aggregator = Direct(), save_positions = (false, false), rng) +p = parameter_values(jprob) @test jprob.prob isa DiscreteProblem Nsims = 30000 function getmean(jprob, Nsims; use_stepper = true) @@ -90,7 +91,7 @@ mb = getmean(jprobb, Nsims; use_stepper = false) @variables S2(t) obs = [S2 ~ 2 * S] -@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs) +@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ, h], observed = obs) js2b = complete(js2b) jprob = JumpProblem(js2b, u₀map, tspan, parammap; aggregator = Direct(), save_positions = (false, false), rng) @@ -110,6 +111,8 @@ jump2 = ConstantRateJump(rate2, affect2!) mtjumps = jprob.discrete_jump_aggregation @test abs(mtjumps.rates[1](u, p, tf) - jump1.rate(u, p, tf)) < 10 * eps() @test abs(mtjumps.rates[2](u, p, tf) - jump2.rate(u, p, tf)) < 10 * eps() + +ModelingToolkit.@set! mtintegrator.p = (mtintegrator.p, (1,)) mtjumps.affects![1](mtintegrator) jump1.affect!(integrator) @test all(integrator.u .== mtintegrator.u) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 705ec79816..88fa7faac7 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -511,7 +511,7 @@ using ModelingToolkit: getdefault, scalarize @test eval(ModelWithComponentArray.structure[:parameters][:r][:unit]) == eval(u"Ω") - @test lastindex(parameters(model_with_component_array)) == 3 + @test lastindex(parameters(model_with_component_array)) == 4 # Test the constant `k`. Manually k's value should be kept in sync here # and the ModelParsingPrecompile. diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index 38778fd17b..a4f72c00de 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -26,13 +26,13 @@ end eqs = [0 ~ σ * (y - x) * h, 0 ~ x * (ρ - z) - y, 0 ~ x * y - β * z] -@named ns = System(eqs, [x, y, z], [σ, ρ, β], defaults = Dict(x => 2)) +@named ns = System(eqs, [x, y, z], [σ, ρ, β, h], defaults = Dict(x => 2)) @test eval(toexpr(ns)) == ns -test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β)) +test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β, h)) @test begin - f = generate_rhs(ns, [x, y, z], [σ, ρ, β], expression = Val{false})[2] + f = generate_rhs(ns, [x, y, z], [σ, ρ, β, h], expression = Val{false})[2] du = [0.0, 0.0, 0.0] - f(du, [1, 2, 3], [1, 2, 3]) + f(du, [1, 2, 3], [1, 2, 3, 1]) du ≈ [1, -3, -7] end @@ -64,9 +64,9 @@ a = y - x eqs = [0 ~ σ * a * h, 0 ~ x * (ρ - z) - y, 0 ~ x * y - β * z] -@named ns = System(eqs, [x, y, z], [σ, ρ, β]) +@named ns = System(eqs, [x, y, z], [σ, ρ, β, h]) ns = complete(ns) -nlsys_func = generate_rhs(ns, [x, y, z], [σ, ρ, β]) +nlsys_func = generate_rhs(ns, [x, y, z], [σ, ρ, β, h]) nf = NonlinearFunction(ns) jac = calculate_jacobian(ns) @@ -99,7 +99,7 @@ eqs1 = [ 0 ~ x + y - z - u ] -lorenz = name -> System(eqs1, [x, y, z, u, F], [σ, ρ, β], name = name) +lorenz = name -> System(eqs1, [x, y, z, u, F], [σ, ρ, β, h], name = name) lorenz1 = lorenz(:lorenz1) @test_throws ArgumentError NonlinearProblem(complete(lorenz1), zeros(5), zeros(3)) lorenz2 = lorenz(:lorenz2) @@ -132,7 +132,7 @@ sol = solve(prob, FBDF(), reltol = 1e-7, abstol = 1e-7) eqs = [0 ~ σ * (y - x), 0 ~ x * (ρ - z) - y, 0 ~ x * y - β * z * h] -@named ns = System(eqs, [x, y, z], [σ, ρ, β]) +@named ns = System(eqs, [x, y, z], [σ, ρ, β, h]) np = NonlinearProblem( complete(ns), [0, 0, 0], [σ => 1, ρ => 2, β => 3], jac = true, sparse = true) @test calculate_jacobian(ns, sparse = true) isa SparseMatrixCSC @@ -214,7 +214,7 @@ testdict = Dict([:test => 1]) eqs = [0 ~ a * (y - x) * h, 0 ~ x * (b - z) - y, 0 ~ x * y - c * z] - @named sys = System(eqs, [x, y, z], [a, b, c], defaults = Dict(x => 2.0)) + @named sys = System(eqs, [x, y, z], [a, b, c, h], defaults = Dict(x => 2.0)) sys = complete(sys) prob = NonlinearProblem(sys, ones(length(unknowns(sys)))) diff --git a/test/odesystem.jl b/test/odesystem.jl index 2bfca7dcdd..bfad3af0df 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -27,7 +27,7 @@ ModelingToolkit.toexpr.(eqs)[1] @named de = System(eqs, t; defaults = Dict(x => 1)) subed = substitute(de, [σ => k]) ssort(eqs) = sort(eqs, by = string) -@test isequal(ssort(parameters(subed)), [k, β, ρ]) +@test isequal(ssort(parameters(subed)), [k, β, κ, ρ]) @test isequal(equations(subed), [D(x) ~ k * (y - x) D(y) ~ (ρ - z) * x - y @@ -47,7 +47,7 @@ function test_diffeq_inference(name, sys, iv, dvs, ps) end end -test_diffeq_inference("standard", de, t, [x, y, z], [ρ, σ, β]) +test_diffeq_inference("standard", de, t, [x, y, z], [ρ, σ, β, κ]) jac_expr = generate_jacobian(de) jac = calculate_jacobian(de) jacfun = eval(jac_expr[2]) @@ -138,11 +138,11 @@ tgrad_iip(du, u, p, t) eqs = [D(x) ~ σ(t - 1) * (y - x), D(y) ~ x * (ρ - z) - y, D(z) ~ x * y - β * z * κ] -@named de = System(eqs, t, [x, y, z], [σ, ρ, β]) -test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ, ρ, β)) +@named de = System(eqs, t, [x, y, z], [σ, ρ, β, κ]) +test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ, ρ, β, κ)) f = generate_rhs(de, expression = Val{false}, wrap_gfw = Val{true}) du = [0.0, 0.0, 0.0] -f(du, [1.0, 2.0, 3.0], [x -> x + 7, 2, 3], 5.0) +f(du, [1.0, 2.0, 3.0], [x -> x + 7, 2, 3, 1], 5.0) @test du ≈ [11, -3, -7] eqs = [D(x) ~ x + 10σ(t - 1) + 100σ(t - 2) + 1000σ(t^2)] @@ -1202,7 +1202,8 @@ end prob = ODEProblem(sys, u0, (0.0, 1.0), p) # evaluate - u0_v, p_v, _ = ModelingToolkit.get_u0_p(sys, u0, p) + u0_v = prob.u0 + p_v = prob.p @test prob.f(u0_v, p_v, 0.0) == [c_b, c_a] end diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 624dbe8b6b..5498ecf4d8 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -296,9 +296,9 @@ end j₁ = ConstantRateJump(rate₁, affect₁) j₃ = ConstantRateJump(rate₃, affect₃) @named js2 = JumpSystem( - [j₃], t, [S, I, R], [γ]; parameter_dependencies = [β => 0.01γ]) - @test isequal(only(parameters(js2)), γ) - @test Set(full_parameters(js2)) == Set([γ, β]) + [j₃], t, [S, I, R], [γ, h]; parameter_dependencies = [β => 0.01γ]) + @test issetequal(parameters(js2), [γ, h]) + @test Set(full_parameters(js2)) == Set([γ, β, h]) js2 = complete(js2) tspan = (0.0, 250.0) u₀map = [S => 999, I => 1, R => 0] @@ -310,7 +310,7 @@ end @test_nowarn solve(jprob, SSAStepper()) @named js2 = JumpSystem( - [j₁, j₃], t, [S, I, R], [γ]; parameter_dependencies = [β => 0.01γ], + [j₁, j₃], t, [S, I, R], [γ, h]; parameter_dependencies = [β => 0.01γ], discrete_events = [SymbolicDiscreteCallback( [10.0] => [γ ~ 0.02], discrete_parameters = [γ])]) js2 = complete(js2) diff --git a/test/structural_transformation/tearing.jl b/test/structural_transformation/tearing.jl index e91f3fa988..f6c4fe5c44 100644 --- a/test/structural_transformation/tearing.jl +++ b/test/structural_transformation/tearing.jl @@ -18,7 +18,7 @@ eqs = [ 0 ~ u4 - hypot(u2, u3), 0 ~ u5 - hypot(u4, u1) ] -@named sys = System(eqs, [u1, u2, u3, u4, u5], []) +@named sys = System(eqs, [u1, u2, u3, u4, u5], [h]) state = TearingState(sys) StructuralTransformations.find_solvables!(state) @@ -149,19 +149,17 @@ eqs = [D(x) ~ z * h 0 ~ sin(z) + y - p * t] @named daesys = System(eqs, t) newdaesys = structural_simplify(daesys) -@test equations(newdaesys) == [D(x) ~ z; 0 ~ y + sin(z) - p * t] -@test equations(tearing_substitution(newdaesys)) == [D(x) ~ z; 0 ~ x + sin(z) - p * t] +@test equations(newdaesys) == [D(x) ~ h * z; 0 ~ y + sin(z) - p * t] +@test equations(tearing_substitution(newdaesys)) == [D(x) ~ h * z; 0 ~ x + sin(z) - p * t] @test isequal(unknowns(newdaesys), [x, z]) -@test isequal(unknowns(newdaesys), [x, z]) -@test_deprecated ODAEProblem(newdaesys, [x => 1.0, z => -0.5π], (0, 1.0), [p => 0.2]) prob = ODEProblem(newdaesys, [x => 1.0, z => -0.5π], (0, 1.0), [p => 0.2]) du = [0.0, 0.0]; u = [1.0, -0.5π]; -pr = 0.2; +pr = prob.p; tt = 0.1; @test (@ballocated $(prob.f)($du, $u, $pr, $tt)) == 0 prob.f(du, u, pr, tt) -@test du≈[u[2], u[1] + sin(u[2]) - pr * tt] atol=1e-5 +@test du≈[u[2], u[1] + sin(u[2]) - prob.ps[p] * tt] atol=1e-5 # test the initial guess is respected @named sys = System(eqs, t, defaults = Dict(z => NaN))