10
10
11
11
function CacheWriter (sys:: AbstractSystem , buffer_types:: Vector{TypeT} ,
12
12
exprs:: Dict{TypeT, Vector{Any}} , solsyms, obseqs:: Vector{Equation} ;
13
- eval_expression = false , eval_module = @__MODULE__ , cse = true )
13
+ eval_expression = false , eval_module = @__MODULE__ , cse = true , sparse = false )
14
14
ps = parameters (sys; initial_parameters = true )
15
15
rps = reorder_parameters (sys, ps)
16
16
obs_assigns = [eq. lhs ← eq. rhs for eq in obseqs]
39
39
struct SCCNonlinearFunction{iip} end
40
40
41
41
function SCCNonlinearFunction {iip} (
42
- sys:: System , _eqs, _dvs, _obs, cachesyms; eval_expression = false ,
42
+ sys:: System , _eqs, _dvs, _obs, cachesyms, op ; eval_expression = false ,
43
43
eval_module = @__MODULE__ , cse = true , kwargs... ) where {iip}
44
44
ps = parameters (sys; initial_parameters = true )
45
+ subsys = System (
46
+ _eqs, _dvs, ps; observed = _obs, name = nameof (sys), defaults = defaults (sys))
47
+ @set! subsys. parameter_dependencies = parameter_dependencies (sys)
48
+ if get_index_cache (sys) != = nothing
49
+ @set! subsys. index_cache = subset_unknowns_observed (
50
+ get_index_cache (sys), sys, _dvs, getproperty .(_obs, (:lhs ,)))
51
+ @set! subsys. complete = true
52
+ end
53
+ # generate linear problem instead
54
+ if isaffine (subsys)
55
+ return LinearFunction {iip} (
56
+ subsys; eval_expression, eval_module, cse, cachesyms, kwargs... )
57
+ end
45
58
rps = reorder_parameters (sys, ps)
46
59
47
60
obs_assignments = [eq. lhs ← eq. rhs for eq in _obs]
@@ -54,14 +67,6 @@ function SCCNonlinearFunction{iip}(
54
67
f_oop, f_iip = eval_or_rgf .(f_gen; eval_expression, eval_module)
55
68
f = GeneratedFunctionWrapper {(2, 2, is_split(sys))} (f_oop, f_iip)
56
69
57
- subsys = System (_eqs, _dvs, ps; observed = _obs,
58
- parameter_dependencies = parameter_dependencies (sys), name = nameof (sys))
59
- if get_index_cache (sys) != = nothing
60
- @set! subsys. index_cache = subset_unknowns_observed (
61
- get_index_cache (sys), sys, _dvs, getproperty .(_obs, (:lhs ,)))
62
- @set! subsys. complete = true
63
- end
64
-
65
70
return NonlinearFunction {iip} (f; sys = subsys)
66
71
end
67
72
@@ -70,7 +75,7 @@ function SciMLBase.SCCNonlinearProblem(sys::System, args...; kwargs...)
70
75
end
71
76
72
77
function SciMLBase. SCCNonlinearProblem {iip} (sys:: System , op; eval_expression = false ,
73
- eval_module = @__MODULE__ , cse = true , kwargs... ) where {iip}
78
+ eval_module = @__MODULE__ , cse = true , u0_constructor = identity, kwargs... ) where {iip}
74
79
if ! iscomplete (sys) || get_tearing_state (sys) === nothing
75
80
error (" A simplified `System` is required. Call `mtkcompile` on the system before creating an `SCCNonlinearProblem`." )
76
81
end
@@ -113,7 +118,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
113
118
114
119
_, u0,
115
120
p = process_SciMLProblem (
116
- EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, kwargs... )
121
+ EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, symbolic_u0 = true , kwargs... )
117
122
118
123
explicitfuns = []
119
124
nlfuns = []
@@ -224,28 +229,57 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
224
229
get (cachevars, T, [])
225
230
end )
226
231
f = SCCNonlinearFunction {iip} (
227
- sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs... )
232
+ sys, _eqs, _dvs, _obs, cachebufsyms, op;
233
+ eval_expression, eval_module, cse, kwargs... )
228
234
push! (nlfuns, f)
229
235
end
230
236
237
+ u0_eltype = Union{}
238
+ for x in u0
239
+ symbolic_type (x) == NotSymbolic () || continue
240
+ u0_eltype = typeof (x)
241
+ break
242
+ end
243
+ if u0_eltype == Union{}
244
+ u0_eltype = Float64
245
+ end
246
+ u0_eltype = float (u0_eltype)
247
+
231
248
if ! isempty (cachetypes)
232
249
templates = map (cachetypes, cachesizes) do T, n
233
250
# Real refers to `eltype(u0)`
234
251
if T == Real
235
- T = eltype (u0)
252
+ T = u0_eltype
236
253
elseif T <: Array && eltype (T) == Real
237
- T = Array{eltype (u0) , ndims (T)}
254
+ T = Array{u0_eltype , ndims (T)}
238
255
end
239
256
BufferTemplate (T, n)
240
257
end
241
258
p = rebuild_with_caches (p, templates... )
242
259
end
243
260
261
+ # yes, `get_p_constructor` since this is only used for `LinearProblem` and
262
+ # will retain the shape of `A`
263
+ u0_constructor = get_p_constructor (u0_constructor, typeof (u0), u0_eltype)
244
264
subprobs = []
245
- for (f, vscc) in zip (nlfuns, var_sccs)
265
+ for (i, ( f, vscc)) in enumerate ( zip (nlfuns, var_sccs) )
246
266
_u0 = SymbolicUtils. Code. create_array (
247
267
typeof (u0), eltype (u0), Val (1 ), Val (length (vscc)), u0[vscc]. .. )
248
- prob = NonlinearProblem (f, _u0, p)
268
+ symbolic_idxs = findall (x -> symbolic_type (x) != NotSymbolic (), _u0)
269
+ explicitfuns[i](p, subprobs)
270
+ if f isa LinearFunction
271
+ _u0 = isempty (symbolic_idxs) ? _u0 : zeros (u0_eltype, length (_u0))
272
+ _u0 = u0_eltype .(_u0)
273
+ symbolic_interface = f. interface
274
+ A,
275
+ b = get_A_b_from_LinearFunction (
276
+ sys, f, p; eval_expression, eval_module, u0_constructor, u0_eltype)
277
+ prob = LinearProblem {iip} (A, b, p; f = symbolic_interface, u0 = _u0)
278
+ else
279
+ isempty (symbolic_idxs) || throw (MissingGuessError (dvs[vscc], _u0))
280
+ _u0 = u0_eltype .(_u0)
281
+ prob = NonlinearProblem (f, _u0, p)
282
+ end
249
283
push! (subprobs, prob)
250
284
end
251
285
@@ -255,5 +289,5 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
255
289
@set! sys. eqs = new_eqs
256
290
@set! sys. index_cache = subset_unknowns_observed (
257
291
get_index_cache (sys), sys, new_dvs, getproperty .(obs, (:lhs ,)))
258
- return SCCNonlinearProblem (subprobs, explicitfuns, p, true ; sys)
292
+ return SCCNonlinearProblem (Tuple ( subprobs), Tuple ( explicitfuns) , p, true ; sys)
259
293
end
0 commit comments