Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end

function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
if dvs !== nothing
Expand All @@ -235,16 +235,42 @@ function wrap_array_vars(
push!(inds, j)
end
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
array_vars[k] = inds′
end
end

uind = 1
else
uind = 0
end
# tunables are scalarized and concatenated, so we need to have assignments
# for the non-scalarized versions
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
# Other parameters may be scalarized arrays but used in the vector form
# values are (indexes, index of buffer, size of parameter)
array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
# If for some reason different elements of an array parameter are in different buffers
other_array_parameters = Dict{Any, Any}()

hasinputs = inputs !== nothing
input_vars = Dict{Any, AbstractArray{Int}}()
if hasinputs
for (j, x) in enumerate(inputs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], input_vars, arg)
push!(inds, j)
end
end
for (k, inds) in input_vars
if inds == (inds′ = inds[1]:inds[end])
input_vars[k] = inds′
end
end
end
if has_index_cache(sys)
ic = get_index_cache(sys)
else
ic = nothing
end
if ps isa Tuple && eltype(ps) <: AbstractArray
ps = Iterators.flatten(ps)
end
Expand All @@ -257,25 +283,33 @@ function wrap_array_vars(
scal = collect(p)
# all scalarized variables are in `ps`
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue
(haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue

idx = parameter_index(sys, p)
idx isa Int && continue
if idx isa ParameterIndex
if idx.portion != SciMLStructures.Tunable()
continue
end
idxs = vec(idx.idx)
sz = size(idx.idx)
array_parameters[p] = (vec(idx.idx), 1, size(idx.idx))
else
# idx === nothing
idxs = map(Base.Fix1(parameter_index, sys), scal)
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
idxs = map(x -> x.idx, idxs)
end
if !all(x -> x isa Int, idxs)
other_array_parameters[p] = scal
continue
if first(idxs) isa ParameterIndex
buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs)
if allequal(buffer_idxs)
buffer_idx = first(buffer_idxs)
if first(idxs).portion == SciMLStructures.Tunable()
idxs = map(x -> x.idx, idxs)
else
idxs = map(x -> x.idx[end], idxs)
end
else
other_array_parameters[p] = scal
continue
end
else
buffer_idx = 1
end

sz = size(idxs)
Expand All @@ -285,12 +319,7 @@ function wrap_array_vars(
idxs = idxs[begin]:-1:idxs[end]
end
idxs = vec(idxs)
end
array_tunables[p] = (idxs, sz)
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
array_vars[k] = inds′
array_parameters[p] = (idxs, buffer_idx, sz)
end
end
if isscalar
Expand All @@ -301,8 +330,12 @@ function wrap_array_vars(
Let(
vcat(
[k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k ← :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k ← :(view($(expr.args[uind + hasinputs].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
Expand All @@ -319,8 +352,12 @@ function wrap_array_vars(
Let(
vcat(
[k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k ← :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k ← :(view($(expr.args[uind + hasinputs].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
Expand All @@ -337,8 +374,13 @@ function wrap_array_vars(
vcat(
[k ← :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k ← :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k ← :(view($(expr.args[uind + hasinputs + 1].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
$idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,9 @@ function build_explicit_observed_function(sys, ts;
pre = get_postprocess_fbody(sys)

array_wrapper = if param_only
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing, inputs)
else
wrap_array_vars(sys, ts; ps = _ps)
wrap_array_vars(sys, ts; ps = _ps, inputs)
end
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
Expand Down
30 changes: 30 additions & 0 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,33 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
end
return result
end

# Given a parameter index, find the index of the buffer it is in when
# `MTKParameters` is iterated
function iterated_buffer_index(ic::IndexCache, ind::ParameterIndex)
idx = 0
if ind.portion isa SciMLStructures.Tunable
return idx + 1
elseif ic.tunable_buffer_size.length > 0
idx += 1
end
if ind.portion isa SciMLStructures.Discrete
return idx + length(first(ic.discrete_buffer_sizes)) * (ind.idx[1] - 1) + ind.idx[2]
elseif !isempty(ic.discrete_buffer_sizes)
idx += length(first(ic.discrete_buffer_sizes)) * length(ic.discrete_buffer_sizes)
end
if ind.portion isa SciMLStructures.Constants
return return idx + ind.idx[1]
elseif !isempty(ic.constant_buffer_sizes)
idx += length(ic.constant_buffer_sizes)
end
if ind.portion == DEPENDENT_PORTION
return idx + ind.idx[1]
elseif !isempty(ic.dependent_buffer_sizes)
idx += length(ic.dependent_buffer_sizes)
end
if ind.portion == NONNUMERIC_PORTION
return idx + ind.idx[1]
end
error("Unhandled portion $(ind.portion)")
end