Skip to content

Commit 9efac94

Browse files
feat: implement DAEProblem and DAEFunction for System
1 parent 350f661 commit 9efac94

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

src/problems/daeproblem.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
@fallback_iip_specialize function SciMLBase.DAEFunction{iip, spec}(
2+
sys::System, _d = nothing, u0 = nothing, p = nothing; tgrad = false, jac = false,
3+
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
4+
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
5+
simplify = false, cse = true, initialization_data = nothing,
6+
check_compatibility = true, kwargs...) where {iip, spec}
7+
check_complete(sys, DAEFunction)
8+
check_compatibility && check_compatible_system(DAEFunction, sys)
9+
10+
dvs = unknowns(sys)
11+
ps = parameters(sys)
12+
f = generate_rhs(sys, dvs, ps; expression = Val{false}, implicit_dae = true,
13+
eval_expression, eval_module, checkbounds = checkbounds, cse,
14+
kwargs...)
15+
16+
if spec === SciMLBase.FunctionWrapperSpecialize && iip
17+
if u0 === nothing || p === nothing || t === nothing
18+
error("u0, p, and t must be specified for FunctionWrapperSpecialize on ODEFunction.")
19+
end
20+
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
21+
end
22+
23+
if jac
24+
_jac = generate_dae_jacobian(sys, dvs, ps; expression = Val{false},
25+
simplify, sparse, cse, eval_expression, eval_module, checkbounds, kwargs...)
26+
else
27+
_jac = nothing
28+
end
29+
30+
observedfun = ObservedFunctionCache(
31+
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
32+
33+
jac_prototype = if sparse
34+
uElType = u0 === nothing ? Float64 : eltype(u0)
35+
if jac
36+
J1 = calculate_jacobian(sys, sparse = sparse)
37+
derivatives = Differential(get_iv(sys)).(unknowns(sys))
38+
J2 = calculate_jacobian(sys; sparse = sparse, dvs = derivatives)
39+
similar(J1 + J2, uElType)
40+
else
41+
similar(jacobian_dae_sparsity(sys), uElType)
42+
end
43+
else
44+
nothing
45+
end
46+
47+
DAEFunction{iip, spec}(f;
48+
sys = sys,
49+
jac = _jac,
50+
jac_prototype = jac_prototype,
51+
observed = observedfun,
52+
analytic = analytic,
53+
initialization_data)
54+
end
55+
56+
@fallback_iip_specialize function SciMLBase.DAEProblem{iip, spec}(
57+
sys::System, du0map, u0map, tspan, parammap = SciMLBase.NullParameters();
58+
callback = nothing, check_length = true, eval_expression = false,
59+
eval_module = @__MODULE__, check_compatibility = true, kwargs...) where {iip, spec}
60+
check_complete(sys, DAEProblem)
61+
check_compatibility && check_compatible_system(DAEProblem, sys)
62+
63+
f, du0, u0, p = process_SciMLProblem(DAEFunction{iip, spec}, sys, u0map, parammap;
64+
du0map, t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
65+
eval_module, check_compatibility, implicit_dae = true, kwargs...)
66+
67+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
68+
69+
diffvars = collect_differential_variables(sys)
70+
sts = unknowns(sys)
71+
differential_vars = map(Base.Fix2(in, diffvars), sts)
72+
73+
# Call `remake` so it runs initialization if it is trivial
74+
return remake(DAEProblem{iip}(
75+
f, du0, u0, tspan, p; kwargs...))
76+
end

src/problems/odeproblem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ end
7272
f, u0, tspan, p, StandardODEProblem(); kwargs...))
7373
end
7474

75-
function check_compatible_system(T::Union{Type{ODEFunction}, Type{ODEProblem}}, sys::System)
75+
function check_compatible_system(
76+
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction}, Type{DAEProblem}},
77+
sys::System)
7678
check_time_dependent(sys, T)
77-
check_not_dde(sys, T)
79+
check_not_dde(sys)
7880
check_no_cost(sys, T)
7981
check_no_constraints(sys, T)
8082
check_no_jumps(sys, T)

0 commit comments

Comments
 (0)