diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 4e18078282..4b71ad8817 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -224,7 +224,7 @@ function wrap_assignments(isscalar, assignments; let_block = false) end function wrap_array_vars( - sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys)) + sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing) isscalar = !(exprs isa AbstractArray) array_vars = Dict{Any, AbstractArray{Int}}() if dvs !== nothing @@ -235,16 +235,42 @@ function wrap_array_vars( push!(inds, j) end end + for (k, inds) in array_vars + if inds == (inds′ = inds[1]:inds[end]) + array_vars[k] = inds′ + end + end + uind = 1 else uind = 0 end - # tunables are scalarized and concatenated, so we need to have assignments - # for the non-scalarized versions - array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}() - # Other parameters may be scalarized arrays but used in the vector form + # values are (indexes, index of buffer, size of parameter) + array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}() + # If for some reason different elements of an array parameter are in different buffers other_array_parameters = Dict{Any, Any}() + hasinputs = inputs !== nothing + input_vars = Dict{Any, AbstractArray{Int}}() + if hasinputs + for (j, x) in enumerate(inputs) + if iscall(x) && operation(x) == getindex + arg = arguments(x)[1] + inds = get!(() -> Int[], input_vars, arg) + push!(inds, j) + end + end + for (k, inds) in input_vars + if inds == (inds′ = inds[1]:inds[end]) + input_vars[k] = inds′ + end + end + end + if has_index_cache(sys) + ic = get_index_cache(sys) + else + ic = nothing + end if ps isa Tuple && eltype(ps) <: AbstractArray ps = Iterators.flatten(ps) end @@ -257,7 +283,7 @@ function wrap_array_vars( scal = collect(p) # all scalarized variables are in `ps` any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue - (haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue + (haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue idx = parameter_index(sys, p) idx isa Int && continue @@ -265,17 +291,25 @@ function wrap_array_vars( if idx.portion != SciMLStructures.Tunable() continue end - idxs = vec(idx.idx) - sz = size(idx.idx) + array_parameters[p] = (vec(idx.idx), 1, size(idx.idx)) else # idx === nothing idxs = map(Base.Fix1(parameter_index, sys), scal) - if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs) - idxs = map(x -> x.idx, idxs) - end - if !all(x -> x isa Int, idxs) - other_array_parameters[p] = scal - continue + if first(idxs) isa ParameterIndex + buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs) + if allequal(buffer_idxs) + buffer_idx = first(buffer_idxs) + if first(idxs).portion == SciMLStructures.Tunable() + idxs = map(x -> x.idx, idxs) + else + idxs = map(x -> x.idx[end], idxs) + end + else + other_array_parameters[p] = scal + continue + end + else + buffer_idx = 1 end sz = size(idxs) @@ -285,12 +319,7 @@ function wrap_array_vars( idxs = idxs[begin]:-1:idxs[end] end idxs = vec(idxs) - end - array_tunables[p] = (idxs, sz) - end - for (k, inds) in array_vars - if inds == (inds′ = inds[1]:inds[end]) - array_vars[k] = inds′ + array_parameters[p] = (idxs, buffer_idx, sz) end end if isscalar @@ -301,8 +330,12 @@ function wrap_array_vars( Let( vcat( [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], - [k ← :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz)) - for (k, (idxs, sz)) in array_tunables], + [k ← :(view($(expr.args[uind + hasinputs].name), $v)) + for (k, v) in input_vars], + [k ← :(reshape( + view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs), + $sz)) + for (k, (idxs, buffer_idx, sz)) in array_parameters], [k ← Code.MakeArray(v, symtype(k)) for (k, v) in other_array_parameters] ), @@ -319,8 +352,12 @@ function wrap_array_vars( Let( vcat( [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], - [k ← :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz)) - for (k, (idxs, sz)) in array_tunables], + [k ← :(view($(expr.args[uind + hasinputs].name), $v)) + for (k, v) in input_vars], + [k ← :(reshape( + view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs), + $sz)) + for (k, (idxs, buffer_idx, sz)) in array_parameters], [k ← Code.MakeArray(v, symtype(k)) for (k, v) in other_array_parameters] ), @@ -337,8 +374,13 @@ function wrap_array_vars( vcat( [k ← :(view($(expr.args[uind + 1].name), $v)) for (k, v) in array_vars], - [k ← :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz)) - for (k, (idxs, sz)) in array_tunables], + [k ← :(view($(expr.args[uind + hasinputs + 1].name), $v)) + for (k, v) in input_vars], + [k ← :(reshape( + view($(expr.args[uind + hasinputs + buffer_idx + 1].name), + $idxs), + $sz)) + for (k, (idxs, buffer_idx, sz)) in array_parameters], [k ← Code.MakeArray(v, symtype(k)) for (k, v) in other_array_parameters] ), diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index fedc3a4c33..4ef6111aad 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -498,9 +498,9 @@ function build_explicit_observed_function(sys, ts; pre = get_postprocess_fbody(sys) array_wrapper = if param_only - wrap_array_vars(sys, ts; ps = _ps, dvs = nothing) + wrap_array_vars(sys, ts; ps = _ps, dvs = nothing, inputs) else - wrap_array_vars(sys, ts; ps = _ps) + wrap_array_vars(sys, ts; ps = _ps, inputs) end # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 112a0d196d..fa8987382c 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -501,3 +501,33 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false) end return result end + +# Given a parameter index, find the index of the buffer it is in when +# `MTKParameters` is iterated +function iterated_buffer_index(ic::IndexCache, ind::ParameterIndex) + idx = 0 + if ind.portion isa SciMLStructures.Tunable + return idx + 1 + elseif ic.tunable_buffer_size.length > 0 + idx += 1 + end + if ind.portion isa SciMLStructures.Discrete + return idx + length(first(ic.discrete_buffer_sizes)) * (ind.idx[1] - 1) + ind.idx[2] + elseif !isempty(ic.discrete_buffer_sizes) + idx += length(first(ic.discrete_buffer_sizes)) * length(ic.discrete_buffer_sizes) + end + if ind.portion isa SciMLStructures.Constants + return return idx + ind.idx[1] + elseif !isempty(ic.constant_buffer_sizes) + idx += length(ic.constant_buffer_sizes) + end + if ind.portion == DEPENDENT_PORTION + return idx + ind.idx[1] + elseif !isempty(ic.dependent_buffer_sizes) + idx += length(ic.dependent_buffer_sizes) + end + if ind.portion == NONNUMERIC_PORTION + return idx + ind.idx[1] + end + error("Unhandled portion $(ind.portion)") +end