diff --git a/src/problems/daeproblem.jl b/src/problems/daeproblem.jl index 6134923d4d..d98e275f17 100644 --- a/src/problems/daeproblem.jl +++ b/src/problems/daeproblem.jl @@ -72,7 +72,7 @@ end eval_module, check_compatibility, implicit_dae = true, expression, kwargs...) kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module, - kwargs...) + op, kwargs...) diffvars = collect_differential_variables(sys) sts = unknowns(sys) diff --git a/src/problems/ddeproblem.jl b/src/problems/ddeproblem.jl index 45af3edddb..4fb9b3f37b 100644 --- a/src/problems/ddeproblem.jl +++ b/src/problems/ddeproblem.jl @@ -66,7 +66,7 @@ end end kwargs = process_kwargs( - sys; expression, callback, eval_expression, eval_module, kwargs...) + sys; expression, callback, eval_expression, eval_module, op, kwargs...) args = (; f, u0, h, tspan, p) return maybe_codegen_scimlproblem(expression, DDEProblem{iip}, args; kwargs...) diff --git a/src/problems/jumpproblem.jl b/src/problems/jumpproblem.jl index 28de46c35f..dd3fd9ba1c 100644 --- a/src/problems/jumpproblem.jl +++ b/src/problems/jumpproblem.jl @@ -80,7 +80,8 @@ end # handle events, making sure to reset aggregators in the generated affect functions - cbs = process_events(sys; callback, eval_expression, eval_module, reset_jumps = true) + cbs = process_events( + sys; callback, eval_expression, eval_module, op, reset_jumps = true) if rng !== nothing kwargs = (; kwargs..., rng) diff --git a/src/problems/odeproblem.jl b/src/problems/odeproblem.jl index 6726322907..bc8b9cf701 100644 --- a/src/problems/odeproblem.jl +++ b/src/problems/odeproblem.jl @@ -75,7 +75,7 @@ end eval_module, expression, check_compatibility, kwargs...) kwargs = process_kwargs( - sys; expression, callback, eval_expression, eval_module, kwargs...) + sys; expression, callback, eval_expression, eval_module, op, kwargs...) ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem()) args = (; f, u0, tspan, p, ptype) diff --git a/src/problems/sddeproblem.jl b/src/problems/sddeproblem.jl index e1cc00b2a7..0e3201c1d1 100644 --- a/src/problems/sddeproblem.jl +++ b/src/problems/sddeproblem.jl @@ -68,7 +68,7 @@ end end noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise) - kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...) + kwargs = process_kwargs(sys; callback, eval_expression, eval_module, op, kwargs...) if expression == Val{true} g = :(f.g) diff --git a/src/problems/sdeproblem.jl b/src/problems/sdeproblem.jl index 1bc47118ff..83a050f7e4 100644 --- a/src/problems/sdeproblem.jl +++ b/src/problems/sdeproblem.jl @@ -78,7 +78,7 @@ end noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise) kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module, - kwargs...) + op, kwargs...) args = (; f, u0, tspan, p) kwargs = (; noise, noise_rate_prototype, kwargs...) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 02a9763226..67eb5b0c73 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -596,7 +596,8 @@ function _distribute_shift(expr, shift) (op isa Union{Pre, Initial, Sample, Hold}) && return expr args = arguments(expr) - if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex + if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex && + !ModelingToolkit.iscalledparameter(expr) (length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) : (return expr) elseif op isa Shift diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index cf9c53e610..a4f39243d9 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -63,7 +63,8 @@ Base.show(io::IO, x::Pre) = print(io, "Pre") input_timedomain(::Pre, _ = nothing) = ContinuousClock() output_timedomain(::Pre, _ = nothing) = ContinuousClock() unPre(x::Num) = unPre(unwrap(x)) -unPre(x::BasicSymbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x +unPre(x::Symbolics.Arr) = unPre(unwrap(x)) +unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x function (p::Pre)(x) iw = Symbolics.iswrapped(x) @@ -797,16 +798,34 @@ function add_integrator_header( expr.body) end +function default_operating_point(affsys::AffectSystem) + sys = system(affsys) + + op = Dict(unknowns(sys) .=> 0.0) + for p in parameters(sys) + T = symtype(p) + if T <: Number + op[p] = false + elseif T <: Array{<:Real} && is_sized_array_symbolic(p) + op[p] = zeros(size(p)) + end + end + return op +end + """ Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates. """ function compile_equational_affect( aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false, - eval_expression = false, eval_module = @__MODULE__, kwargs...) + eval_expression = false, eval_module = @__MODULE__, op = nothing, kwargs...) if aff isa AbstractVector aff = make_affect( aff; iv = get_iv(sys), warn_no_algebraic = false) end + if op === nothing + op = default_operating_point(aff) + end affsys = system(aff) ps_to_update = discretes(aff) dvs_to_update = setdiff(unknowns(aff), getfield.(observed(sys), :lhs)) @@ -871,10 +890,10 @@ function compile_equational_affect( p_getter = getsym(affsys, ps_to_update) affprob = ImplicitDiscreteProblem( - affsys, Pair[unknowns(affsys) .=> 0; parameters(affsys) .=> 0], + affsys, op, (0, 0); build_initializeprob = false, check_length = false, eval_expression, - eval_module, check_compatibility = false) + eval_module, check_compatibility = false, kwargs...) function implicit_affect!(integ) new_u0 = affu_getter(integ) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 856a97484f..d82bf8a5c1 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1348,3 +1348,30 @@ end @test SciMLBase.successful_retcode(sol) @test sol[inner.p][end] ≈ 1.0 end + +mutable struct ParamTest + y::Any +end + +@testset "callable parameter and symbolic affect" begin + (pt::ParamTest)(x) = pt.y - x + + p1 = ParamTest(1) + tp1 = typeof(p1) + @parameters (p_1::tp1)(..) = p1 + @parameters p2(t) = 1.0 + @variables x(t) = 0.0 + @variables x2(t) + event = [0.5] => [p2 ~ Pre(t)] + + eq = [ + D(x) ~ p2, + x2 ~ p_1(x) + ] + @mtkcompile sys = ODESystem(eq, t, [x, x2], [p_1, p2], discrete_events = [event]) + + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob) + @test SciMLBase.successful_retcode(sol) + @test sol[x, end]≈1.0 atol=1e-6 +end