Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ jobs:
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
17 changes: 13 additions & 4 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ end
function has_observed_with_lhs(sys, sym)
has_observed(sys) || return false
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return any(isequal(sym), ic.observed_syms)
return haskey(ic.observed_syms_to_timeseries, sym)
else
return any(isequal(sym), [eq.lhs for eq in observed(sys)])
end
Expand All @@ -740,7 +740,7 @@ end
function has_parameter_dependency_with_lhs(sys, sym)
has_parameter_dependencies(sys) || return false
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return any(isequal(sym), ic.dependent_pars)
return haskey(ic.dependent_pars_to_timeseries, unwrap(sym))
else
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
end
Expand All @@ -762,11 +762,20 @@ for traitT in [
allsyms = vars(sym; op = Symbolics.Operator)
for s in allsyms
s = unwrap(s)
if is_variable(sys, s) || is_independent_variable(sys, s) ||
has_observed_with_lhs(sys, s)
if is_variable(sys, s) || is_independent_variable(sys, s)
push!(ts_idxs, ContinuousTimeseries())
elseif is_timeseries_parameter(sys, s)
push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx)
elseif is_time_dependent(sys) && iscall(s) && issym(operation(s)) &&
is_variable(sys, operation(s)(get_iv(sys)))
# DDEs case, to detect x(t - k)
push!(ts_idxs, ContinuousTimeseries())
elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing
union!(ts_idxs, ts)
elseif (ts = get(ic.dependent_pars_to_timeseries, s, nothing)) !== nothing
union!(ts_idxs, ts)
end
end
end
end
Expand Down
11 changes: 10 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,16 @@ function build_explicit_observed_function(sys, ts;
ivs = independent_variables(sys)
dep_vars = scalarize(setdiff(vars, ivs))

obs = param_only ? Equation[] : observed(sys)
obs = observed(sys)
if param_only
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
obs = filter(obs) do eq
!(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
end
else
obs = Equation[]
end
end

cs = collect_constants(obs)
if !isempty(cs) > 0
Expand Down
85 changes: 56 additions & 29 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const UnknownIndexMap = Dict{
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
const TunableIndexMap = Dict{BasicSymbolic,
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}

struct IndexCache
unknown_idx::UnknownIndexMap
Expand All @@ -48,8 +49,9 @@ struct IndexCache
tunable_idx::TunableIndexMap
constant_idx::ParamIndexMap
nonnumeric_idx::NonnumericMap
observed_syms::Set{BasicSymbolic}
dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}}
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
dependent_pars_to_timeseries::Dict{
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
Expand Down Expand Up @@ -91,20 +93,6 @@ function IndexCache(sys::AbstractSystem)
end
end

observed_syms = Set{BasicSymbolic}()
for eq in observed(sys)
if symbolic_type(eq.lhs) != NotSymbolic()
sym = eq.lhs
ttsym = default_toterm(sym)
rsym = renamespace(sys, sym)
rttsym = renamespace(sys, ttsym)
push!(observed_syms, sym)
push!(observed_syms, ttsym)
push!(observed_syms, rsym)
push!(observed_syms, rttsym)
end
end

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
Expand Down Expand Up @@ -267,38 +255,77 @@ function IndexCache(sys::AbstractSystem)
end
end

for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
keys(const_idxs), keys(nonnumeric_idxs),
observed_syms, independent_variable_symbols(sys)))
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
symbol_to_variable[getname(sym)] = sym
end
end

dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}()
dependent_pars_to_timeseries = Dict{
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()

for eq in parameter_dependencies(sys)
sym = eq.lhs
vs = vars(eq.rhs)
timeseries = TimeseriesSetType()
if is_time_dependent(sys)
for v in vs
if (idx = get(disc_idxs, v, nothing)) !== nothing
push!(timeseries, idx.clock_idx)
end
end
end
ttsym = default_toterm(sym)
rsym = renamespace(sys, sym)
rttsym = renamespace(sys, ttsym)
for s in [sym, ttsym, rsym, rttsym]
push!(dependent_pars, s)
for s in (sym, ttsym, rsym, rttsym)
dependent_pars_to_timeseries[s] = timeseries
if hasname(s) && (!iscall(s) || operation(s) != getindex)
symbol_to_variable[getname(s)] = sym
end
end
end

observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}()
for eq in observed(sys)
if symbolic_type(eq.lhs) != NotSymbolic()
sym = eq.lhs
vs = vars(eq.rhs; op = Nothing)
timeseries = TimeseriesSetType()
if is_time_dependent(sys)
for v in vs
if (idx = get(disc_idxs, v, nothing)) !== nothing
push!(timeseries, idx.clock_idx)
elseif haskey(observed_syms_to_timeseries, v)
union!(timeseries, observed_syms_to_timeseries[v])
elseif haskey(dependent_pars_to_timeseries, v)
union!(timeseries, dependent_pars_to_timeseries[v])
end
end
if isempty(timeseries)
push!(timeseries, ContinuousTimeseries())
end
end
ttsym = default_toterm(sym)
rsym = renamespace(sys, sym)
rttsym = renamespace(sys, ttsym)
for s in (sym, ttsym, rsym, rttsym)
observed_syms_to_timeseries[s] = timeseries
end
end
end

for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
keys(const_idxs), keys(nonnumeric_idxs),
keys(observed_syms_to_timeseries), independent_variable_symbols(sys)))
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
symbol_to_variable[getname(sym)] = sym
end
end

return IndexCache(
unk_idxs,
disc_idxs,
callback_to_clocks,
tunable_idxs,
const_idxs,
nonnumeric_idxs,
observed_syms,
dependent_pars,
observed_syms_to_timeseries,
dependent_pars_to_timeseries,
disc_buffer_templates,
BufferTemplate(Real, tunable_buffer_size),
const_buffer_sizes,
Expand Down
11 changes: 11 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1504,3 +1504,14 @@ end
sys2 = complete(sys; split = false)
@test ModelingToolkit.get_index_cache(sys2) === nothing
end

# https://github.com/SciML/SciMLBase.jl/issues/786
@testset "Observed variables dependent on discrete parameters" begin
@variables x(t) obs(t)
@parameters c(t)
@mtkbuild sys = ODESystem(
[D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
sol = solve(prob, Tsit5())
@test sol[obs] ≈ 1:7
end
Comment on lines +1509 to +1517
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this MWE I provided in the issue is not a sufficient testcase. I'd propose the following

    @variables x1(t)=0 x2(t)=0 obs1(t) obs2(t)
    @parameters c1(t)=1 c2=1
    @mtkbuild sys = ODESystem(
        [D(x1) ~ c1,
        D(x2) ~ c2,
        obs1 ~ x1*c1,
        obs2 ~ x2*c2], t; discrete_events = [[1.0] => [c1 ~ 0, c2 ~ 0]])
    prob = ODEProblem(sys, [x1=>0, x2=>0], (0.0, 2))
    sol = solve(prob, Tsit5())

    # tests that should pass (?)
    @test sol([0,2], idxs=c1) == [1.0, 0.0]
    @test sol([0,0.9,1.1,2], idxs=obs1)  [0, 0.9, 0, 0]
    @test sol[obs1] == sol(sol.t, idxs=obs1) # errors because of mixed timeseries

    # the following tests check an observable which depends on a parameter which is not declared time dependent, which is done in the docs on discrete events
    # i don't know how this should be handled. Personally, as a user i'd expect all parameters to be discrete timeseries implicitly
    # Depending on your API design, those failures might be by design and don't need fixing.
    @test sol([0,2], idxs=c2) == [1.0, 0.0]
    @test sol([0,0.9,1.1,2], idxs=obs2)  [0, 0.9, 0, 0]
    @test sol[obs2] == sol(sol.t, idxs=obs2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work with the existing SII infrastructure, because parameters changed in callbacks are piecewise continuous (green in the diagram in #3106 (comment)) but SII always assumes discrete variables are clocked (red). It explicitly disallows indexing variables like sol[obs1] because it is a mix of discrete and continuous, and SII doesn't know which timeseries to return. As a workaround, the interpolation syntax still works, so sol(sol.t, idxs=obs1) returns the required values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supporting this is part of the vision for [email protected]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay.

Loading