Skip to content

Commit 0bd0ed0

Browse files
Merge pull request #111 from SciML/as/array-indexing
fix: fix indexing related to array symbolics
2 parents 10bb71e + 5cabf15 commit 0bd0ed0

File tree

5 files changed

+66
-3
lines changed

5 files changed

+66
-3
lines changed

src/parameter_indexing.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ function _getp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p)
628628
idx = parameter_index(sys, p)
629629
if is_timeseries_parameter(sys, p)
630630
ts_idx = timeseries_parameter_index(sys, p)
631-
return GetParameterTimeseriesIndex(idx, ts_idx)
631+
return GetParameterTimeseriesIndex(
632+
GetParameterIndex(idx), GetParameterIndex(ts_idx))
632633
else
633634
return GetParameterIndex(idx)
634635
end
@@ -750,5 +751,5 @@ function _setp_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
750751
if is_parameter(indp, sym)
751752
return OOPSetter(_root_indp(indp), parameter_index(indp, sym), false)
752753
end
753-
error("$sym is not a valid parameter")
754+
return setp_oop(indp, collect(sym))
754755
end

src/state_indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ function _setsym_oop(indp, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
468468
return setsym_oop(indp, idx)
469469
elseif (idx = parameter_index(indp, sym)) !== nothing
470470
return FullSetter(
471-
nothing, OOPSetter(indp, idx isa AbstractArray ? idx : (idx,), false))
471+
nothing, OOPSetter(indp, idx, false))
472472
end
473473
return setsym_oop(indp, collect(sym))
474474
end

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
66

77
[compat]
88
SymbolicUtils = "3.2"
9+
ModelingToolkit = "9.60"

test/downstream/array_indexing.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using ModelingToolkit, SymbolicIndexingInterface
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
4+
@variables x(t)[1:2]
5+
@parameters p[1:2, 1:2] q(t)[1:2] r[1:2]
6+
7+
ev = [x[1] ~ 2.0] => [q ~ -ones(2)]
8+
@mtkbuild sys = ODESystem(
9+
[D(x) ~ p * x + q + r], t, [x], [p, q, r...]; continuous_events = [ev])
10+
@test is_timeseries_parameter(sys, q)
11+
@test !is_timeseries_parameter(sys, p)
12+
@test !is_parameter(sys, r)
13+
@test is_parameter(sys, r[1])
14+
@test is_parameter(sys, r[2])
15+
16+
prob = ODEProblem(
17+
sys, [x => ones(2)], (0.0, 10.0), [p => ones(2, 2), q => ones(2), r => 2ones(2)])
18+
@test prob.ps[q] ones(2)
19+
@test prob.ps[p] ones(2, 2)
20+
@test prob.ps[r] 2ones(2)
21+
@test prob.ps[p * q] 2ones(2)
22+
23+
@test getu(sys, p)(prob) ones(2, 2)
24+
@test getu(sys, r)(prob) 2ones(2)
25+
26+
prob.ps[p] = 2ones(2, 2)
27+
@test prob.ps[p] 2ones(2, 2)
28+
prob.ps[q] = 2ones(2)
29+
@test prob.ps[q] 2ones(2)
30+
prob.ps[r] = ones(2)
31+
@test prob.ps[r] ones(2)
32+
33+
setter = setp_oop(sys, p)
34+
newp = setter(prob, 3ones(2, 2))
35+
@test getp(sys, p)(newp) 3ones(2, 2)
36+
setter = setp_oop(sys, r)
37+
newp = setter(prob, 3ones(2))
38+
@test getp(sys, r)(newp) 3ones(2)
39+
40+
setter = setsym_oop(sys, p)
41+
_, newp = setter(prob, 3ones(2, 2))
42+
@test getp(sys, p)(newp) 3ones(2, 2)
43+
setter = setsym_oop(sys, r)
44+
_, newp = setter(prob, 3ones(2))
45+
@test getp(sys, r)(newp) 3ones(2)
46+
47+
@test prob[x] ones(2)
48+
prob[x] = 2ones(2)
49+
@test prob[x] 2ones(2)
50+
51+
setu(sys, p)(prob, 4ones(2, 2))
52+
@test prob.ps[p] 4ones(2, 2)
53+
setu(sys, r)(prob, 4ones(2))
54+
@test prob.ps[r] 4ones(2)
55+
56+
setter = setsym_oop(sys, x)
57+
newu, newp = setter(prob, 3ones(2))
58+
@test getu(sys, x)(newu) 3ones(2)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,7 @@ if GROUP == "All" || GROUP == "Downstream"
5858
@safetestset "remake_buffer with array symbolics test" begin
5959
@time include("downstream/remake_arrayvars.jl")
6060
end
61+
@safetestset "array indexing" begin
62+
@time include("downstream/array_indexing.jl")
63+
end
6164
end

0 commit comments

Comments
 (0)