Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,14 @@ One property to note is that if a system is complete, the system will no longer
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
"""
function complete(sys::AbstractSystem; split = true, flatten = true)
if !(sys isa JumpSystem)
newunknowns = OrderedSet()
newparams = OrderedSet()
iv = has_iv(sys) ? get_iv(sys) : nothing
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
# don't update unknowns to not disturb `structural_simplify` order
# `GlobalScope`d unknowns will be picked up and added there
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
end
newunknowns = OrderedSet()
newparams = OrderedSet()
iv = has_iv(sys) ? get_iv(sys) : nothing
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
# don't update unknowns to not disturb `structural_simplify` order
# `GlobalScope`d unknowns will be picked up and added there
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))

if flatten
eqs = equations(sys)
if eqs isa AbstractArray && eltype(eqs) <: Equation
Expand Down
81 changes: 59 additions & 22 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,40 @@ function JumpSystem(eqs, iv, unknowns, ps;
metadata = nothing,
gui_metadata = nothing,
kwargs...)

# variable processing, similar to ODESystem
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
eqs = scalarize.(eqs)
iv′ = value(iv)
us′ = value.(unknowns)
ps′ = value.(ps)
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
if !(isempty(default_u0) && isempty(default_p))
Base.depwarn(
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:JumpSystem, force = true)
end
defaults = Dict{Any, Any}(todict(defaults))
var_to_name = Dict()
process_variables!(var_to_name, defaults, us′)
process_variables!(var_to_name, defaults, ps′)
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
#! format: off
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults) if value(v) !== nothing)
#! format: on
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))

sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
end

# equation processing
# this and the treatment of continuous events are the only part
# unique to JumpSystems
eqs = scalarize.(eqs)
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
for eq in eqs
if eq isa MassActionJump
Expand All @@ -179,30 +206,42 @@ function JumpSystem(eqs, iv, unknowns, ps;
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
end
end
if !(isempty(default_u0) && isempty(default_p))
Base.depwarn(
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:JumpSystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

unknowns, ps = value.(unknowns), value.(ps)
var_to_name = Dict()
process_variables!(var_to_name, defaults, unknowns)
process_variables!(var_to_name, defaults, ps)
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
(continuous_events === nothing) ||
error("JumpSystems currently only support discrete events.")
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps)

JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
ap, value(iv), unknowns, ps, var_to_name, observed, name, description, systems,
ap, iv′, us′, ps, var_to_name, observed, name, description, systems,
defaults, connector_type, disc_callbacks, parameter_dependencies,
metadata, gui_metadata, checks = checks)
end

##### MTK dispatches for JumpSystems #####
eqtype_supports_collect_vars(j::MassActionJump) = true
function collect_vars!(unknowns, parameters, j::MassActionJump, iv; depth = 0,
op = Differential)
collect_vars!(unknowns, parameters, j.scaled_rates, iv; depth, op)
for field in (j.reactant_stoch, j.net_stoch)
for el in field
collect_vars!(unknowns, parameters, el, iv; depth, op)
end
end
return nothing
end

eqtype_supports_collect_vars(j::Union{ConstantRateJump, VariableRateJump}) = true
function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump, VariableRateJump},
iv; depth = 0, op = Differential)
collect_vars!(unknowns, parameters, j.rate, iv; depth, op)
for eq in j.affect!
(eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
return nothing
end

##########################################

has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
Expand Down Expand Up @@ -240,9 +279,8 @@ function assemble_vrj(

outputvars = (value(affect.lhs) for affect in vrj.affect!)
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(
generate_affect_function(js, vrj.affect!,
outputidxs); eval_expression, eval_module)
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
eval_expression, eval_module)
VariableRateJump(rate, affect)
end

Expand All @@ -269,9 +307,8 @@ function assemble_crj(

outputvars = (value(affect.lhs) for affect in crj.affect!)
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(
generate_affect_function(js, crj.affect!,
outputidxs); eval_expression, eval_module)
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);
eval_expression, eval_module)
ConstantRateJump(rate, affect)
end

Expand Down
82 changes: 82 additions & 0 deletions test/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,85 @@ let

@test all(abs.(cmean .- cmean2) .<= 0.05 .* cmean)
end

# collect_vars! tests for jumps
let
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
@parameters p1 p2 p3 p4 p5
j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
us = Set()
ps = Set()
iv = t

MT.collect_vars!(us, ps, j1, iv)
@test issetequal(us, [x1])
@test issetequal(ps, [p1])

empty!(us)
empty!(ps)
MT.collect_vars!(us, ps, j2, iv)
@test issetequal(us, [x2, x3])
@test issetequal(ps, [p2])

empty!(us)
empty!(ps)
MT.collect_vars!(us, ps, j3, iv)
@test issetequal(us, [x3, x4])
@test issetequal(ps, [p3])

empty!(us)
empty!(ps)
MT.collect_vars!(us, ps, j4, iv)
@test issetequal(us, [x1, x5, x2])
@test issetequal(ps, [p4, p5])
end

# scoping tests
let
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
x2 = ParentScope(x2)
x3 = ParentScope(ParentScope(x3))
x4 = DelayParentScope(x4, 2)
x5 = GlobalScope(x5)
@parameters p1 p2 p3 p4 p5
p2 = ParentScope(p2)
p3 = ParentScope(ParentScope(p3))
p4 = DelayParentScope(p4, 2)
p5 = GlobalScope(p5)

j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
@named js = JumpSystem([j1, j2, j3, j4], t, [x1, x2, x3, x4, x5], [p1, p2, p3, p4, p5])

us = Set()
ps = Set()
iv = t
MT.collect_scoped_vars!(us, ps, js, iv)
@test issetequal(us, [x2])
@test issetequal(ps, [p2])

empty!.((us, ps))
MT.collect_scoped_vars!(us, ps, js, iv; depth = 0)
@test issetequal(us, [x1])
@test issetequal(ps, [p1])

empty!.((us, ps))
MT.collect_scoped_vars!(us, ps, js, iv; depth = 1)
@test issetequal(us, [x2])
@test issetequal(ps, [p2])

empty!.((us, ps))
MT.collect_scoped_vars!(us, ps, js, iv; depth = 2)
@test issetequal(us, [x3, x4])
@test issetequal(ps, [p3, p4])

empty!.((us, ps))
MT.collect_scoped_vars!(us, ps, js, iv; depth = -1)
@test issetequal(us, [x5])
@test issetequal(ps, [p5])
end
Loading