Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/problems/daeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ end
eval_module, check_compatibility, implicit_dae = true, expression, kwargs...)

kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
kwargs...)
op, kwargs...)

diffvars = collect_differential_variables(sys)
sts = unknowns(sys)
Expand Down
2 changes: 1 addition & 1 deletion src/problems/ddeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
end

kwargs = process_kwargs(
sys; expression, callback, eval_expression, eval_module, kwargs...)
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
args = (; f, u0, h, tspan, p)

return maybe_codegen_scimlproblem(expression, DDEProblem{iip}, args; kwargs...)
Expand Down
3 changes: 2 additions & 1 deletion src/problems/jumpproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@
end

# handle events, making sure to reset aggregators in the generated affect functions
cbs = process_events(sys; callback, eval_expression, eval_module, reset_jumps = true)
cbs = process_events(
sys; callback, eval_expression, eval_module, op, reset_jumps = true)

if rng !== nothing
kwargs = (; kwargs..., rng)
Expand Down
2 changes: 1 addition & 1 deletion src/problems/odeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
eval_module, expression, check_compatibility, kwargs...)

kwargs = process_kwargs(
sys; expression, callback, eval_expression, eval_module, kwargs...)
sys; expression, callback, eval_expression, eval_module, op, kwargs...)

ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
args = (; f, u0, tspan, p, ptype)
Expand Down
2 changes: 1 addition & 1 deletion src/problems/sddeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end
end

noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, op, kwargs...)

if expression == Val{true}
g = :(f.g)
Expand Down
2 changes: 1 addition & 1 deletion src/problems/sdeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
kwargs...)
op, kwargs...)

args = (; f, u0, tspan, p)
kwargs = (; noise, noise_rate_prototype, kwargs...)
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ function _distribute_shift(expr, shift)
(op isa Union{Pre, Initial, Sample, Hold}) && return expr
args = arguments(expr)

if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex
if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex &&
!ModelingToolkit.iscalledparameter(expr)
(length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) :
(return expr)
elseif op isa Shift
Expand Down
27 changes: 23 additions & 4 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ Base.show(io::IO, x::Pre) = print(io, "Pre")
input_timedomain(::Pre, _ = nothing) = ContinuousClock()
output_timedomain(::Pre, _ = nothing) = ContinuousClock()
unPre(x::Num) = unPre(unwrap(x))
unPre(x::BasicSymbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
unPre(x::Symbolics.Arr) = unPre(unwrap(x))
unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x

function (p::Pre)(x)
iw = Symbolics.iswrapped(x)
Expand Down Expand Up @@ -797,16 +798,34 @@ function add_integrator_header(
expr.body)
end

function default_operating_point(affsys::AffectSystem)
sys = system(affsys)

op = Dict(unknowns(sys) .=> 0.0)
for p in parameters(sys)
T = symtype(p)
if T <: Number
op[p] = false
elseif T <: Array{<:Real} && is_sized_array_symbolic(p)
op[p] = zeros(size(p))
end
end
return op
end

"""
Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
"""
function compile_equational_affect(
aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false,
eval_expression = false, eval_module = @__MODULE__, kwargs...)
eval_expression = false, eval_module = @__MODULE__, op = nothing, kwargs...)
if aff isa AbstractVector
aff = make_affect(
aff; iv = get_iv(sys), warn_no_algebraic = false)
end
if op === nothing
op = default_operating_point(aff)
end
affsys = system(aff)
ps_to_update = discretes(aff)
dvs_to_update = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
Expand Down Expand Up @@ -871,10 +890,10 @@ function compile_equational_affect(
p_getter = getsym(affsys, ps_to_update)

affprob = ImplicitDiscreteProblem(
affsys, Pair[unknowns(affsys) .=> 0; parameters(affsys) .=> 0],
affsys, op,
(0, 0);
build_initializeprob = false, check_length = false, eval_expression,
eval_module, check_compatibility = false)
eval_module, check_compatibility = false, kwargs...)

function implicit_affect!(integ)
new_u0 = affu_getter(integ)
Expand Down
27 changes: 27 additions & 0 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1348,3 +1348,30 @@ end
@test SciMLBase.successful_retcode(sol)
@test sol[inner.p][end] ≈ 1.0
end

mutable struct ParamTest
y::Any
end

@testset "callable parameter and symbolic affect" begin
(pt::ParamTest)(x) = pt.y - x

p1 = ParamTest(1)
tp1 = typeof(p1)
@parameters (p_1::tp1)(..) = p1
@parameters p2(t) = 1.0
@variables x(t) = 0.0
@variables x2(t)
event = [0.5] => [p2 ~ Pre(t)]

eq = [
D(x) ~ p2,
x2 ~ p_1(x)
]
@mtkcompile sys = ODESystem(eq, t, [x, x2], [p_1, p2], discrete_events = [event])

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob)
@test SciMLBase.successful_retcode(sol)
@test sol[x, end]≈1.0 atol=1e-6
end
Loading