Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
d78b814
refactor: change problem constructors to `XProblem(sys, op[, tspan])`
AayushSabharwal May 21, 2025
ff2ed42
fix: fix `DAEProblem` initialization with new problem syntax
AayushSabharwal May 21, 2025
f663b2e
fix: use new problem constructor syntax
AayushSabharwal May 21, 2025
8109604
test: use new problem construction syntax
AayushSabharwal May 21, 2025
c1e4dfd
refactor: make `better_varmap_to_vars` more predictable
AayushSabharwal May 22, 2025
8ef587c
feat: add `u0_eltype` kwarg to problems, only use `tofloat` for param…
AayushSabharwal May 22, 2025
4f67240
fix: use `default_toterm` when building `du0` for implicit DAEs
AayushSabharwal May 22, 2025
20344fa
fix: propagate kwargs properly in `JumpProblem`
AayushSabharwal May 22, 2025
06fda92
test: fix DAEProblem initialization test
AayushSabharwal May 22, 2025
a8d14e1
test: update error checking tests for improved error in `better_varma…
AayushSabharwal May 22, 2025
1550787
test: use `u0_eltype` in jumpsystem tests
AayushSabharwal May 22, 2025
800a509
test: mark appropriate optimization tests as broken
AayushSabharwal May 22, 2025
868a07d
refactor: accept a single `op` in `generate_initializesystem`
AayushSabharwal May 23, 2025
205a762
refactor: accept a single `op` in `InitializationProblem`
AayushSabharwal May 23, 2025
1b0ce61
refactor: update initialization code to account for new `op` syntax
AayushSabharwal May 23, 2025
4ce79ee
test: update tests to account for new `generate_initializesystem`/`In…
AayushSabharwal May 23, 2025
8aff22f
refactor: use new `process_SciMLProblem` in optimal-control constructors
AayushSabharwal May 23, 2025
4d79ef4
refactor: format
AayushSabharwal May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ext/MTKCasADiDynamicOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
steps = nothing,
guesses = Dict(), kwargs...)
MTK.warn_overdetermined(sys, u0map)
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
merge(Dict(u0map), Dict(guesses))
pmap = MTK.to_varmap(pmap, parameters(sys))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
t = tspan !== nothing ? tspan[1] : tspan, output_type = MX, kwargs...)

pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
Expand Down
12 changes: 8 additions & 4 deletions ext/MTKInfiniteOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ function MTK.JuMPDynamicOptProblem(sys::System, u0map, tspan, pmap;
steps = nothing,
guesses = Dict(), kwargs...)
MTK.warn_overdetermined(sys, u0map)
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
merge(Dict(u0map), Dict(guesses))
pmap = MTK.to_varmap(pmap, parameters(sys))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)

pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
Expand All @@ -86,8 +88,10 @@ function MTK.InfiniteOptDynamicOptProblem(sys::System, u0map, tspan, pmap;
steps = nothing,
guesses = Dict(), kwargs...)
MTK.warn_overdetermined(sys, u0map)
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
merge(Dict(u0map), Dict(guesses))
pmap = MTK.to_varmap(pmap, parameters(sys))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)

pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ PrecompileTools.@compile_workload begin
using ModelingToolkit
@variables x(ModelingToolkit.t_nounits)
@named sys = System([ModelingToolkit.D_nounits(x) ~ -x], ModelingToolkit.t_nounits)
prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), [], jac = true)
prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), jac = true)
@mtkmodel __testmod__ begin
@constants begin
c = 1.0
Expand Down
1 change: 0 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ macro mtkbuild(exprs...)
@mtkcompile $(exprs...)
end |> esc
end

2 changes: 1 addition & 1 deletion src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function linearization_function(sys::AbstractSystem, inputs,
end

prob = ODEProblem{true, SciMLBase.FullSpecialize}(
sys, op, (nothing, nothing), p; allow_incomplete = true,
sys, merge(op, anydict(p)), (nothing, nothing); allow_incomplete = true,
algebraic_only = true, guesses)
u0 = state_values(prob)

Expand Down
13 changes: 7 additions & 6 deletions src/problems/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""
@fallback_iip_specialize function SciMLBase.BVProblem{iip, spec}(
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
sys::System, op, tspan;
check_compatibility = true, cse = true,
checkbounds = false, eval_expression = false, eval_module = @__MODULE__,
expression = Val{false}, guesses = Dict(), callback = nothing,
Expand All @@ -55,22 +55,23 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting

# Systems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
_op = has_alg_eqs(sys) ? op : merge(Dict(op), Dict(guesses))

fode, u0, p = process_SciMLProblem(
ODEFunction{iip, spec}, sys, _u0map, parammap; guesses,
ODEFunction{iip, spec}, sys, _op; guesses,
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility = false, cse,
checkbounds, time_dependent_init = false, expression, kwargs...)

dvs = unknowns(sys)
stidxmap = Dict([v => i for (i, v) in enumerate(dvs)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) : [stidxmap[k] for (k, v) in u0map]
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) :
[stidxmap[k] for (k, v) in op if haskey(stidxmap, k)]
fbc = generate_boundary_conditions(
sys, u0, u0_idxs, tspan[1]; expression = Val{false},
wrap_gfw = Val{true}, cse, checkbounds)

if (length(constraints(sys)) + length(u0map) > length(dvs))
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
if (length(constraints(sys)) + length(op) > length(dvs))
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by op) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
end

kwargs = process_kwargs(sys; expression, kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions src/problems/daeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@
end

@fallback_iip_specialize function SciMLBase.DAEProblem{iip, spec}(
sys::System, du0map, u0map, tspan, parammap = SciMLBase.NullParameters();
sys::System, op, tspan;
callback = nothing, check_length = true, eval_expression = false,
eval_module = @__MODULE__, check_compatibility = true,
expression = Val{false}, kwargs...) where {iip, spec}
check_complete(sys, DAEProblem)
check_compatibility && check_compatible_system(DAEProblem, sys)

f, du0, u0, p = process_SciMLProblem(DAEFunction{iip, spec}, sys, u0map, parammap;
du0map, t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
f, du0, u0, p = process_SciMLProblem(DAEFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
eval_module, check_compatibility, implicit_dae = true, expression, kwargs...)

kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
Expand Down
4 changes: 2 additions & 2 deletions src/problems/ddeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@
end

@fallback_iip_specialize function SciMLBase.DDEProblem{iip, spec}(
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
sys::System, op, tspan;
callback = nothing, check_length = true, cse = true, checkbounds = false,
eval_expression = false, eval_module = @__MODULE__, check_compatibility = true,
u0_constructor = identity, expression = Val{false}, kwargs...) where {iip, spec}
check_complete(sys, DDEProblem)
check_compatibility && check_compatible_system(DDEProblem, sys)

f, u0, p = process_SciMLProblem(DDEFunction{iip, spec}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(DDEFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_length, cse, checkbounds,
eval_expression, eval_module, check_compatibility, symbolic_u0 = true,
expression, u0_constructor, kwargs...)
Expand Down
4 changes: 2 additions & 2 deletions src/problems/discreteproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
end

@fallback_iip_specialize function SciMLBase.DiscreteProblem{iip, spec}(
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
sys::System, op, tspan;
check_compatibility = true, expression = Val{false}, kwargs...) where {iip, spec}
check_complete(sys, DiscreteProblem)
check_compatibility && check_compatible_system(DiscreteProblem, sys)

dvs = unknowns(sys)
u0map = to_varmap(u0map, dvs)
add_toterms!(u0map; replace = true)
f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility, expression,
kwargs...)

Expand Down
8 changes: 4 additions & 4 deletions src/problems/implicitdiscreteproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@
end

@fallback_iip_specialize function SciMLBase.ImplicitDiscreteProblem{iip, spec}(
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
sys::System, op, tspan;
check_compatibility = true, expression = Val{false}, kwargs...) where {iip, spec}
check_complete(sys, ImplicitDiscreteProblem)
check_compatibility && check_compatible_system(ImplicitDiscreteProblem, sys)

dvs = unknowns(sys)
u0map = to_varmap(u0map, dvs)
add_toterms!(u0map; replace = true)
op = to_varmap(op, dvs)
add_toterms!(op; replace = true)
f, u0, p = process_SciMLProblem(
ImplicitDiscreteFunction{iip, spec}, sys, u0map, parammap;
ImplicitDiscreteFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility,
expression, kwargs...)

Expand Down
38 changes: 19 additions & 19 deletions src/problems/initializationproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ struct InitializationProblem{iip, specialization} end

"""
```julia
InitializationProblem{iip}(sys::AbstractSystem, t, u0map,
parammap = DiffEqBase.NullParameters();
InitializationProblem{iip}(sys::AbstractSystem, t, op;
version = nothing, tgrad = false,
jac = false,
checkbounds = false, sparse = false,
Expand All @@ -20,8 +19,7 @@ initial conditions for the given DAE.
"""
@fallback_iip_specialize function InitializationProblem{iip, specialize}(
sys::AbstractSystem,
t, u0map = [],
parammap = DiffEqBase.NullParameters();
t, op = Dict();
guesses = [],
check_length = true,
warn_initialize_determined = true,
Expand All @@ -37,18 +35,24 @@ initial conditions for the given DAE.
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `mtkcompile` on the system before creating an `ODEProblem`")
end
if isempty(u0map) && get_initializesystem(sys) !== nothing
has_u0_ics = false
op = copy(anydict(op))
for k in keys(op)
has_u0_ics |= is_variable(sys, k) || isdifferential(k) ||
symbolic_type(k) == ArraySymbolic() &&
is_sized_array_symbolic(k) && is_variable(sys, first(collect(k)))
end
if !has_u0_ics && get_initializesystem(sys) !== nothing
isys = get_initializesystem(sys; initialization_eqs, check_units)
simplify_system = false
elseif isempty(u0map) && get_initializesystem(sys) === nothing
elseif !has_u0_ics && get_initializesystem(sys) === nothing
isys = generate_initializesystem(
sys; initialization_eqs, check_units, pmap = parammap,
guesses, algebraic_only)
sys; initialization_eqs, check_units, op, guesses, algebraic_only)
simplify_system = true
else
isys = generate_initializesystem(
sys; u0map, initialization_eqs, check_units, time_dependent_init,
pmap = parammap, guesses, algebraic_only)
sys; op, initialization_eqs, check_units, time_dependent_init,
guesses, algebraic_only)
simplify_system = true
end

Expand Down Expand Up @@ -106,20 +110,17 @@ initial conditions for the given DAE.
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true"
end

parammap = recursive_unwrap(anydict(parammap))
if t !== nothing
parammap[get_iv(sys)] = t
op[get_iv(sys)] = t
end
filter!(kvp -> kvp[2] !== missing, parammap)
filter!(kvp -> kvp[2] !== missing, op)

u0map = to_varmap(u0map, unknowns(sys))
if isempty(guesses)
guesses = Dict()
end

filter_missing_values!(u0map)
filter_missing_values!(parammap)
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), u0map)
filter_missing_values!(op)
op = merge(ModelingToolkit.guesses(sys), todict(guesses), op)

TProb = if neqs == nunknown && isempty(unassigned_vars)
if use_scc && neqs > 0
Expand All @@ -135,8 +136,7 @@ initial conditions for the given DAE.
else
NonlinearLeastSquaresProblem
end
TProb{iip}(isys, u0map, parammap; kwargs...,
build_initializeprob = false, is_initializeprob = true)
TProb{iip}(isys, op; kwargs..., build_initializeprob = false, is_initializeprob = true)
end

const INCOMPLETE_INITIALIZATION_MESSAGE = """
Expand Down
4 changes: 3 additions & 1 deletion src/problems/intervalnonlinearproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ function SciMLBase.IntervalNonlinearProblem(
check_compatibility && check_compatible_system(IntervalNonlinearProblem, sys)

u0map = unknowns(sys) .=> uspan[1]
f, u0, p = process_SciMLProblem(IntervalNonlinearFunction, sys, u0map, parammap;
op = anydict([unknowns(sys)[1] => uspan[1]])
merge!(op, to_varmap(parammap, parameters(sys)))
f, u0, p = process_SciMLProblem(IntervalNonlinearFunction, sys, op;
check_compatibility, expression, kwargs...)

kwargs = process_kwargs(sys; kwargs...)
Expand Down
16 changes: 8 additions & 8 deletions src/problems/jumpproblem.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@fallback_iip_specialize function JumpProcesses.JumpProblem{iip, spec}(
sys::System, u0map, tspan::Union{Tuple, Nothing}, pmap = SciMLBase.NullParameters();
sys::System, op, tspan::Union{Tuple, Nothing};
check_compatibility = true, eval_expression = false, eval_module = @__MODULE__,
checkbounds = false, cse = true, aggregator = JumpProcesses.NullAggregator(),
callback = nothing, rng = nothing, kwargs...) where {iip, spec}
Expand All @@ -13,27 +13,27 @@
if (has_vrjs || has_eqs)
if has_eqs && has_noise
prob = SDEProblem{iip, spec}(
sys, u0map, tspan, pmap; check_compatibility = false,
sys, op, tspan; check_compatibility = false,
build_initializeprob = false, checkbounds, cse, check_length = false,
kwargs...)
elseif has_eqs
prob = ODEProblem{iip, spec}(
sys, u0map, tspan, pmap; check_compatibility = false,
sys, op, tspan; check_compatibility = false,
build_initializeprob = false, checkbounds, cse, check_length = false,
kwargs...)
else
_, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, u0map, pmap;
t = tspan === nothing ? nothing : tspan[1], tofloat = false,
check_length = false, build_initializeprob = false)
_, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, op;
t = tspan === nothing ? nothing : tspan[1],
check_length = false, build_initializeprob = false, kwargs...)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module,
checkbounds, cse)
f = (du, u, p, t) -> (du .= 0; nothing)
df = ODEFunction{true, spec}(f; sys, observed = observedfun)
prob = ODEProblem{true}(df, u0, tspan, p; kwargs...)
end
else
_f, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, u0map, pmap;
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse)
_f, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, op;
t = tspan === nothing ? nothing : tspan[1], check_length = false, build_initializeprob = false, cse, kwargs...)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(
Expand Down
8 changes: 4 additions & 4 deletions src/problems/nonlinearproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@
end

@fallback_iip_specialize function SciMLBase.NonlinearProblem{iip, spec}(
sys::System, u0map, parammap = SciMLBase.NullParameters(); expression = Val{false},
sys::System, op; expression = Val{false},
check_length = true, check_compatibility = true, kwargs...) where {iip, spec}
check_complete(sys, NonlinearProblem)
if is_time_dependent(sys)
sys = NonlinearSystem(sys)
end
check_compatibility && check_compatible_system(NonlinearProblem, sys)

f, u0, p = process_SciMLProblem(NonlinearFunction{iip, spec}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(NonlinearFunction{iip, spec}, sys, op;
check_length, check_compatibility, expression, kwargs...)

kwargs = process_kwargs(sys; kwargs...)
Expand All @@ -75,12 +75,12 @@ end
end

@fallback_iip_specialize function SciMLBase.NonlinearLeastSquaresProblem{iip, spec}(
sys::System, u0map, parammap = DiffEqBase.NullParameters(); check_length = false,
sys::System, op; check_length = false,
check_compatibility = true, expression = Val{false}, kwargs...) where {iip, spec}
check_complete(sys, NonlinearLeastSquaresProblem)
check_compatibility && check_compatible_system(NonlinearLeastSquaresProblem, sys)

f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, op;
check_length, expression, kwargs...)

kwargs = process_kwargs(sys; kwargs...)
Expand Down
8 changes: 4 additions & 4 deletions src/problems/odeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@
end

@fallback_iip_specialize function SciMLBase.ODEProblem{iip, spec}(
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
sys::System, op, tspan;
callback = nothing, check_length = true, eval_expression = false,
expression = Val{false}, eval_module = @__MODULE__, check_compatibility = true,
kwargs...) where {iip, spec}
check_complete(sys, ODEProblem)
check_compatibility && check_compatible_system(ODEProblem, sys)

f, u0, p = process_SciMLProblem(ODEFunction{iip, spec}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(ODEFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
eval_module, expression, check_compatibility, kwargs...)

Expand All @@ -98,12 +98,12 @@ Generates an SteadyStateProblem from a `System` of ODEs and allows for automatic
symbolically calculating numerical enhancements.
"""
@fallback_iip_specialize function DiffEqBase.SteadyStateProblem{iip, spec}(
sys::System, u0map, parammap; check_length = true, check_compatibility = true,
sys::System, op; check_length = true, check_compatibility = true,
expression = Val{false}, kwargs...) where {iip, spec}
check_complete(sys, SteadyStateProblem)
check_compatibility && check_compatible_system(SteadyStateProblem, sys)

f, u0, p = process_SciMLProblem(ODEFunction{iip}, sys, u0map, parammap;
f, u0, p = process_SciMLProblem(ODEFunction{iip}, sys, op;
steady_state = true, check_length, check_compatibility, expression,
force_initialization_time_independent = true, kwargs...)

Expand Down
Loading
Loading