Skip to content

Add DDE support in System #2207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 151 additions & 12 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
implicit_dae = false,
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
nothing,
isdde = false,
has_difference = false,
kwargs...)
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
if isdde
eqs = delay_to_function(sys)
else
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
end
if !implicit_dae
check_operator_variables(eqs, Differential)
check_lhs(eqs, Differential, Set(dvs))
Expand All @@ -136,15 +141,59 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
p = map(x -> time_varying_as_func(value(x), sys), ps)
t = get_iv(sys)

pre, sol_states = get_substitutions_and_solved_states(sys,
no_postprocess = has_difference)
if isdde
build_function(rhss, u, DDE_HISTORY_FUN, p, t; kwargs...)
else
pre, sol_states = get_substitutions_and_solved_states(sys,
no_postprocess = has_difference)

if implicit_dae
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
if implicit_dae
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
end
end
end

function isdelay(var, iv)
iv === nothing && return false
isvariable(var) || return false
if istree(var) && !ModelingToolkit.isoperator(var, Symbolics.Operator)
args = arguments(var)
length(args) == 1 || return false
isequal(args[1], iv) || return true
end
return false
end
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
function delay_to_function(sys::AbstractODESystem)
delay_to_function(full_equations(sys),
get_iv(sys),
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(states(sys))),
parameters(sys),
DDE_HISTORY_FUN)
end
function delay_to_function(eqs::Vector{<:Equation}, iv, sts, ps, h)
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
end
function delay_to_function(eq::Equation, iv, sts, ps, h)
delay_to_function(eq.lhs, iv, sts, ps, h) ~ delay_to_function(eq.rhs, iv, sts, ps, h)
end
function delay_to_function(expr, iv, sts, ps, h)
if isdelay(expr, iv)
v = operation(expr)
time = arguments(expr)[1]
idx = sts[v]
return term(getindex, h(Sym{Any}(:ˍ₋arg3), time), idx, type = Real) # BIG BIG HACK
elseif istree(expr)
return similarterm(expr,
operation(expr),
map(x -> delay_to_function(x, iv, sts, ps, h), arguments(expr)))
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
return expr
end
end

Expand Down Expand Up @@ -485,6 +534,30 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
observed = observedfun)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
DDEFunction{true}(sys, args...; kwargs...)
end

function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
ps = parameters(sys), u0 = nothing;
eval_module = @__MODULE__,
checkbounds = false,
kwargs...) where {iip}
f_gen = generate_function(sys, dvs, ps; isdde = true,
expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
f(u, p, h, t) = f_oop(u, p, h, t)
f(du, u, p, h, t) = f_iip(du, u, p, h, t)

DDEFunction{iip}(f,
sys = sys,
syms = Symbol.(dvs),
indepsym = Symbol(get_iv(sys)),
paramsyms = Symbol.(ps))
end

"""
```julia
ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
Expand Down Expand Up @@ -577,9 +650,14 @@ end
"""
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)

Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
"""
function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
function get_u0_p(sys,
u0map,
parammap;
use_union = false,
tofloat = !use_union,
symbolic_u0 = false)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
Expand All @@ -588,7 +666,11 @@ function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
else
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
end
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
p = p === nothing ? SciMLBase.NullParameters() : p
u0, p, defs
Expand All @@ -604,13 +686,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
eval_expression = true,
use_union = false,
tofloat = !use_union,
symbolic_u0 = false,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
iv = get_iv(sys)

u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)

if implicit_dae && du0map !== nothing
ddvs = map(Differential(iv), dvs)
Expand Down Expand Up @@ -802,6 +885,62 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
end
end

function generate_history(sys::AbstractODESystem, u0; kwargs...)
build_function(u0, parameters(sys), get_iv(sys); expression = Val{false}, kwargs...)
end

function DiffEqBase.DDEProblem(sys::AbstractODESystem, args...; kwargs...)
DDEProblem{true}(sys, args...; kwargs...)
end
function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
callback = nothing,
check_length = true,
kwargs...) where {iip}
has_difference = any(isdifferenceeq, equations(sys))
f, u0, p = process_DEProblem(DDEFunction{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
has_difference = has_difference,
symbolic_u0 = true,
check_length, kwargs...)
h_oop, h_iip = generate_history(sys, u0)
h = h_oop
u0 = h(p, tspan[1])
cbs = process_events(sys; callback, has_difference, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
else
error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
end

"""
```julia
ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
Expand Down
7 changes: 6 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

iv′ = value(scalarize(iv))
dvs′ = value.(scalarize(dvs))
ps′ = value.(scalarize(ps))
ctrl′ = value.(scalarize(controls))
dvs′ = value.(scalarize(dvs))
dvs′ = filter(x -> !isdelay(x, iv), dvs′)

if !(isempty(default_u0) && isempty(default_p))
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
Expand Down Expand Up @@ -258,6 +259,10 @@ function ODESystem(eqs, iv = nothing; kwargs...)
push!(algeeq, eq)
end
end
for v in allstates
isdelay(v, iv) || continue
collect_vars!(allstates, ps, arguments(v)[1], iv)
end
algevars = setdiff(allstates, diffvars)
# the orders here are very important!
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
Expand Down
2 changes: 2 additions & 0 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ end
function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
iv = length(ivs) == 1 ? ivs[1] : nothing
eqs = copy(equations(sys))
neqs = length(eqs)
dervaridxs = OrderedSet{Int}()
Expand Down Expand Up @@ -287,6 +288,7 @@ function TearingState(sys; quick_cancel = false, check = true)
isalgeq = true
statevars = []
for var in vars
ModelingToolkit.isdelay(var, iv) && continue
set_incidence = true
@label ANOTHER_VAR
_var, _ = var_from_nested_derivative(var)
Expand Down
51 changes: 51 additions & 0 deletions test/dde.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using ModelingToolkit, DelayDiffEq, Test
p0 = 0.2;
q0 = 0.3;
v0 = 1;
d0 = 5;
p1 = 0.2;
q1 = 0.3;
v1 = 1;
d1 = 1;
d2 = 1;
beta0 = 1;
beta1 = 1;
tau = 1;
function bc_model(du, u, h, p, t)
du[1] = (v0 / (1 + beta0 * (h(p, t - tau)[3]^2))) * (p0 - q0) * u[1] - d0 * u[1]
du[2] = (v0 / (1 + beta0 * (h(p, t - tau)[3]^2))) * (1 - p0 + q0) * u[1] +
(v1 / (1 + beta1 * (h(p, t - tau)[3]^2))) * (p1 - q1) * u[2] - d1 * u[2]
du[3] = (v1 / (1 + beta1 * (h(p, t - tau)[3]^2))) * (1 - p1 + q1) * u[2] - d2 * u[3]
end
lags = [tau]
h(p, t) = ones(3)
h2(p, t) = ones(3) .- t * q1 * 10
tspan = (0.0, 10.0)
u0 = [1.0, 1.0, 1.0]
prob = DDEProblem(bc_model, u0, h, tspan, constant_lags = lags)
alg = MethodOfSteps(Vern9())
sol = solve(prob, alg, reltol = 1e-7, abstol = 1e-10)
prob2 = DDEProblem(bc_model, u0, h2, tspan, constant_lags = lags)
sol2 = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)

@parameters p0=0.2 p1=0.2 q0=0.3 q1=0.3 v0=1 v1=1 d0=5 d1=1 d2=1 beta0=1 beta1=1
@variables t x₀(t) x₁(t) x₂(..)
tau = 1
D = Differential(t)
eqs = [D(x₀) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (p0 - q0) * x₀ - d0 * x₀
D(x₁) ~ (v0 / (1 + beta0 * (x₂(t - tau)^2))) * (1 - p0 + q0) * x₀ +
(v1 / (1 + beta1 * (x₂(t - tau)^2))) * (p1 - q1) * x₁ - d1 * x₁
D(x₂(t)) ~ (v1 / (1 + beta1 * (x₂(t - tau)^2))) * (1 - p1 + q1) * x₁ - d2 * x₂(t)]
@named sys = System(eqs)
prob = DDEProblem(sys,
[x₀ => 1.0, x₁ => 1.0, x₂(t) => 1.0],
tspan,
constant_lags = [tau])
sol_mtk = solve(prob, alg, reltol = 1e-7, abstol = 1e-10)
@test sol_mtk.u[end] ≈ sol.u[end]
prob2 = DDEProblem(sys,
[x₀ => 1.0 - t * q1 * 10, x₁ => 1.0 - t * q1 * 10, x₂(t) => 1.0 - t * q1 * 10],
tspan,
constant_lags = [tau])
sol2_mtk = solve(prob2, alg, reltol = 1e-7, abstol = 1e-10)
@test sol2_mtk.u[end] ≈ sol2.u[end]