From d365af12246727b33141e4dc955ff9598e6cafd2 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 25 Oct 2021 22:39:30 -0400 Subject: [PATCH 01/19] Expand connect during structural_simplify --- src/systems/abstractsystem.jl | 67 +++++++++++++++++++++++++++++++- src/systems/diffeqs/odesystem.jl | 4 +- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 869d401bc3..a86799d993 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -811,6 +811,7 @@ topological sort of the observed equations. When `simplify=true`, the `simplify` function will be applied during the tearing process. """ function structural_simplify(sys::AbstractSystem; simplify=false) + sys = expand_connects(sys) sys = initialize_system_structure(alias_elimination(sys)) check_consistency(sys) if sys isa ODESystem @@ -923,8 +924,72 @@ function promote_connect_type(T, S) error("Don't know how to connect systems of type $S and $T") end +struct Connect + syss +end + +function Base.show(io::IO, c::Connect) + syss = c.syss + if syss === nothing + print(io, "") + else + print(io, "<", join((nameof(s) for s in syss), ", "), ">") + end +end + function connect(syss...) - connect(promote_connect_type(map(get_connection_type, syss)...), syss...) + length(syss) >= 2 || error("connect takes at least two systems!") + length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") + Equation(Connect(nothing), Connect(syss)) # the RHS are connected systems +end + +function expand_connects(sys::AbstractSystem; debug=false) + sys = flatten(sys) + eqs′ = equations(sys) + eqs = Equation[] + cts = [] + for eq in eqs′ + eq.lhs isa Connect ? push!(cts, eq.rhs.syss) : push!(eqs, eq) # split connections and equations + end + + # O(n) algorithm for connection fusing + sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement + narg_connects = Vector{Any}[] + for (i, syss) in enumerate(cts) + # find intersecting connections + exclude = findfirst(s->haskey(sys2idx, nameof(s)), syss) + if exclude === nothing + push!(narg_connects, collect(syss)) + for s in syss + sys2idx[nameof(s)] = length(narg_connects) + end + else + # fuse intersecting connections + for (j, s) in enumerate(syss); j == exclude && continue + push!(narg_connects[idx], s) + end + end + end + + # validation + for syss in narg_connects + length(unique(nameof, syss)) == length(syss) || error("$(Connect(syss)) has duplicated connections") + end + + if debug + println("Connections:") + print_with_indent(x) = println(" " ^ 4, x) + foreach(print_with_indent ∘ Connect, narg_connects) + end + + # generate connections + for syss in narg_connects + T = promote_connect_type(map(get_connection_type, syss)...) + append!(eqs, connect(T, syss...)) + end + + @set! sys.eqs = eqs + return sys end ### diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index c67e263660..c4a2bf7025 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -167,7 +167,9 @@ function ODESystem(eqs, iv=nothing; kwargs...) end iv = value(iv) iv === nothing && throw(ArgumentError("Please pass in independent variables.")) + connecteqs = Equation[] for eq in eqs + eq.lhs isa Connect && (push!(connecteqs, eq); continue) collect_vars!(allstates, ps, eq.lhs, iv) collect_vars!(allstates, ps, eq.rhs, iv) if isdiffeq(eq) @@ -182,7 +184,7 @@ function ODESystem(eqs, iv=nothing; kwargs...) end algevars = setdiff(allstates, diffvars) # the orders here are very important! - return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...) + return ODESystem(Equation[diffeq; algeeq; connecteqs], iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...) end # NOTE: equality does not check cached Jacobian From 6d898881428cc6a0e42e967f1f2ddc0a9a43b88e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 25 Oct 2021 22:54:35 -0400 Subject: [PATCH 02/19] Fix bugs --- examples/rc_model.jl | 3 ++- examples/serial_inductor.jl | 3 ++- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 15 +++++++++++---- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/rc_model.jl b/examples/rc_model.jl index c7bc4b901f..88b112bee8 100644 --- a/examples/rc_model.jl +++ b/examples/rc_model.jl @@ -11,7 +11,8 @@ V = 1.0 rc_eqs = [ connect(source.p, resistor.p) connect(resistor.n, capacitor.p) - connect(capacitor.n, source.n, ground.g) + connect(capacitor.n, source.n) + connect(capacitor.n, ground.g) ] @named rc_model = ODESystem(rc_eqs, t) diff --git a/examples/serial_inductor.jl b/examples/serial_inductor.jl index 63d3215a8e..feff486a5d 100644 --- a/examples/serial_inductor.jl +++ b/examples/serial_inductor.jl @@ -10,7 +10,8 @@ eqs = [ connect(source.p, resistor.p) connect(resistor.n, inductor1.p) connect(inductor1.n, inductor2.p) - connect(source.n, inductor2.n, ground.g) + connect(source.n, inductor2.n) + connect(inductor2.n, ground.g) ] @named ll_model = ODESystem(eqs, t) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 0a785aca2e..68aab8ad9b 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -174,7 +174,7 @@ export Equation, ConstrainedEquation export Term, Sym export SymScope, LocalScope, ParentScope, GlobalScope export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure -export structural_simplify +export structural_simplify, expand_connections export DiscreteSystem, DiscreteProblem export calculate_jacobian, generate_jacobian, generate_function diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index a86799d993..0b8d915e78 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -811,7 +811,7 @@ topological sort of the observed equations. When `simplify=true`, the `simplify` function will be applied during the tearing process. """ function structural_simplify(sys::AbstractSystem; simplify=false) - sys = expand_connects(sys) + sys = expand_connections(sys) sys = initialize_system_structure(alias_elimination(sys)) check_consistency(sys) if sys isa ODESystem @@ -943,7 +943,7 @@ function connect(syss...) Equation(Connect(nothing), Connect(syss)) # the RHS are connected systems end -function expand_connects(sys::AbstractSystem; debug=false) +function expand_connections(sys::AbstractSystem; debug=false) sys = flatten(sys) eqs′ = equations(sys) eqs = Equation[] @@ -957,8 +957,15 @@ function expand_connects(sys::AbstractSystem; debug=false) narg_connects = Vector{Any}[] for (i, syss) in enumerate(cts) # find intersecting connections - exclude = findfirst(s->haskey(sys2idx, nameof(s)), syss) - if exclude === nothing + exclude = 0 # exclude the intersecting system + idx = 0 # idx of narg_connects + for (j, s) in enumerate(syss) + idx′ = get(sys2idx, nameof(s), nothing) + idx′ === nothing && continue + idx = idx′ + exclude = j + end + if exclude == 0 push!(narg_connects, collect(syss)) for s in syss sys2idx[nameof(s)] = length(narg_connects) From 9ba7fd20adc41c2c379574b0a2b6a17749626885 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 25 Oct 2021 23:14:12 -0400 Subject: [PATCH 03/19] Fix minor issue --- src/systems/abstractsystem.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 0b8d915e78..f32ce4f593 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -973,6 +973,7 @@ function expand_connections(sys::AbstractSystem; debug=false) else # fuse intersecting connections for (j, s) in enumerate(syss); j == exclude && continue + sys2idx[nameof(s)] = idx push!(narg_connects[idx], s) end end @@ -992,7 +993,8 @@ function expand_connections(sys::AbstractSystem; debug=false) # generate connections for syss in narg_connects T = promote_connect_type(map(get_connection_type, syss)...) - append!(eqs, connect(T, syss...)) + ceqs = connect(T, syss...) + ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) end @set! sys.eqs = eqs From 49b02f78ec00d6b766e176f0265b797792994dfc Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 25 Oct 2021 23:14:20 -0400 Subject: [PATCH 04/19] Add tests --- test/connectors.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/connectors.jl b/test/connectors.jl index f492c9f26c..9305a8378a 100644 --- a/test/connectors.jl +++ b/test/connectors.jl @@ -42,3 +42,35 @@ ModelingToolkit.promote_connect_rule(::Type{<:Goo}, ::Type{<:Foo}) = Foo # test conflict ModelingToolkit.promote_connect_rule(::Type{<:Goo}, ::Type{<:Foo}) = Goo @test_throws ArgumentError connect(f1, g) + +@connector Hoo(;name) = ODESystem(Equation[], t, [], [], name=name) +function ModelingToolkit.connect(::Type{<:Hoo}, ss...) + nameof.(ss) ~ 0 +end +@named hs[1:8] = Hoo() +@named sys = ODESystem([connect(hs[1], hs[2]), + connect(hs[1], hs[3])], t, [], []) +@test equations(expand_connections(sys)) == [(:hs_1, :hs_2, :hs_3) ~ 0] +@named sys = ODESystem([connect(hs[1], hs[2]), + connect(hs[2], hs[3])], t, [], []) +@test equations(expand_connections(sys)) == [(:hs_1, :hs_2, :hs_3) ~ 0] +@named sys = ODESystem([connect(hs[1], hs[2]), + connect(hs[4], hs[3])], t, [], []) +@test equations(expand_connections(sys)) == [(:hs_1, :hs_2) ~ 0, (:hs_4, :hs_3) ~ 0] +@named sys = ODESystem([connect(hs[1], hs[2]), + connect(hs[1], hs[2])], t, [], []) +@test_throws Any expand_connections(sys) +@named sys = ODESystem([connect(hs[1], hs[2]), + connect(hs[3], hs[2]), + connect(hs[1], hs[4]), + connect(hs[8], hs[4]), + connect(hs[7], hs[5]), + ], t, [], []) +@test equations(expand_connections(sys)) == [(:hs_1, :hs_2, :hs_3, :hs_4, :hs_8) ~ 0, (:hs_7, :hs_5) ~ 0] +@named sys = ODESystem([connect(hs[1], hs[2]), + connect(hs[3], hs[2]), + connect(hs[1], hs[4]), + connect(hs[8], hs[4]), + connect(hs[2], hs[8]), + ], t, [], []) +@test_throws Any expand_connections(sys) From fe6ea724ebeb7801f82da89560871e7ffb0f59d0 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Nov 2021 18:11:07 -0400 Subject: [PATCH 05/19] Connector overhaul --- src/systems/abstractsystem.jl | 105 ++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 6 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f32ce4f593..86aa991b8d 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -924,25 +924,117 @@ function promote_connect_type(T, S) error("Don't know how to connect systems of type $S and $T") end -struct Connect - syss +Base.@kwdef struct Connect + inners = nothing + outers = nothing end +Connect(syss) = Connect(inners=syss) +get_systems(c::Connect) = c.inners + function Base.show(io::IO, c::Connect) - syss = c.syss - if syss === nothing + @unpack outers, inners = c + if outers === nothing && inners === nothing print(io, "") else - print(io, "<", join((nameof(s) for s in syss), ", "), ">") + inner_str = join((string(nameof(s)) * "::inner" for s in inners), ", ") + outer_str = join((string(nameof(s)) * "::outer" for s in outers), ", ") + isempty(outer_str) || (outer_str = ", " * outer_str) + print(io, "<", inner_str, outer_str, ">") end end function connect(syss...) length(syss) >= 2 || error("connect takes at least two systems!") length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") - Equation(Connect(nothing), Connect(syss)) # the RHS are connected systems + Equation(Connect(), Connect(syss)) # the RHS are connected systems +end + +# fallback +connect(T::Type, c::Connect) = connect(T, c.outers..., c.inners...) + +function expand_connections(sys::AbstractSystem; debug=false) + subsys = get_systems(sys) + isempty(subsys) && return sys + + # post order traversal + @unpack sys.systems = map(s->expand_connections(s, debug=debug), subsys) + + + # Note that subconnectors in outer connectors are still outer connectors. + # Ref: https://specification.modelica.org/v3.4/Ch9.html see 9.1.2 + isouter = let outer_connectors=[nameof(s) for s in subsys if has_connection_type(s) && get_connection_type(s) !== nothing] + sys -> begin + s = string(nameof(sys)) + idx = findfirst(isequal('₊'), s) + parent_name = idx === nothing ? s : s[1:idx] + parent_name in isouter + end + end + + eqs′ = equations(sys) + eqs = Equation[] + cts = [] # connections + for eq in eqs′ + eq.lhs isa Connect ? push!(cts, get_systems(eq.rhs)) : push!(eqs, eq) # split connections and equations + end + + sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement + narg_connects = Vector{Connect}[] + for (i, syss) in enumerate(cts) + # find intersecting connections + exclude = 0 # exclude the intersecting system + idx = 0 # idx of narg_connects + for (j, s) in enumerate(syss) + idx′ = get(sys2idx, nameof(s), nothing) + idx′ === nothing && continue + idx = idx′ + exclude = j + end + if exclude == 0 + outers = [] + inners = [] + for s in syss + isouter(s) ? push!(outers, s) : push!(inners, s) + end + push!(narg_connects, Connect(outers=outers, inners=inners)) + for s in syss + sys2idx[nameof(s)] = length(narg_connects) + end + else + # fuse intersecting connections + for (j, s) in enumerate(syss); j == exclude && continue + sys2idx[nameof(s)] = idx + c = narg_connects[idx] + isouter(s) ? push!(c.outers, s) : push!(c.inners, s) + end + end + end + + # Bad things happen when there are more than one intersections + for c in narg_connects + @unpack outer, inner = c + len = length(outers) + length(inners) + length(unique(nameof, [outers; inners])) == len || error("$(Connect(syss)) has duplicated connections") + end + + if debug + println("Connections:") + print_with_indent(x) = println(" " ^ 4, x) + foreach(print_with_indent ∘ Connect, narg_connects) + end + + for c in narg_connects + T = promote_connect_type(map(get_connection_type, c.outers)..., map(get_connection_type, c.inners)...) + ceqs = connect(T, c) + ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) + end + + @set! sys.eqs = eqs + return sys end +#= function expand_connections(sys::AbstractSystem; debug=false) sys = flatten(sys) eqs′ = equations(sys) @@ -1000,6 +1092,7 @@ function expand_connections(sys::AbstractSystem; debug=false) @set! sys.eqs = eqs return sys end +=# ### ### Inheritance & composition From d6ea90802fadaae01e2b299220d0bad14fb3eef2 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 5 Nov 2021 18:45:06 -0400 Subject: [PATCH 06/19] Update tests --- examples/electrical_components.jl | 10 ++-- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 93 +++++-------------------------- src/systems/diffeqs/odesystem.jl | 2 +- 4 files changed, 23 insertions(+), 84 deletions(-) diff --git a/examples/electrical_components.jl b/examples/electrical_components.jl index 7e1f01380d..457fd1fd29 100644 --- a/examples/electrical_components.jl +++ b/examples/electrical_components.jl @@ -7,10 +7,12 @@ using ModelingToolkit, OrdinaryDiffEq ODESystem(Equation[], t, sts, []; name=name) end -function ModelingToolkit.connect(::Type{Pin}, ps...) - eqs = [ - 0 ~ sum(p->p.i, ps) # KCL - ] +function ModelingToolkit.connect(::Type{Pin}, c::Connection) + @unpack outers, inners = c + isum = isempty(inners) ? 0 : sum(p->p.i, inners) + osum = isempty(outers) ? 0 : sum(p->p.i, outers) + eqs = [0 ~ isum - osum] # KCL + ps = [outers; inners] # KVL for i in 1:length(ps)-1 push!(eqs, ps[i].v ~ ps[i+1].v) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 68aab8ad9b..5d438410c1 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -165,7 +165,7 @@ export SteadyStateProblem, SteadyStateProblemExpr export JumpProblem, DiscreteProblem export NonlinearSystem, OptimizationSystem export ControlSystem -export alias_elimination, flatten, connect, @connector +export alias_elimination, flatten, connect, @connector, Connection export ode_order_lowering, liouville_transform export runge_kutta_discretize export PDESystem diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 86aa991b8d..49852e1331 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -924,18 +924,18 @@ function promote_connect_type(T, S) error("Don't know how to connect systems of type $S and $T") end -Base.@kwdef struct Connect +Base.@kwdef struct Connection inners = nothing outers = nothing end -Connect(syss) = Connect(inners=syss) -get_systems(c::Connect) = c.inners +Connection(syss) = Connection(inners=syss) +get_systems(c::Connection) = c.inners -function Base.show(io::IO, c::Connect) +function Base.show(io::IO, c::Connection) @unpack outers, inners = c if outers === nothing && inners === nothing - print(io, "") + print(io, "") else inner_str = join((string(nameof(s)) * "::inner" for s in inners), ", ") outer_str = join((string(nameof(s)) * "::outer" for s in outers), ", ") @@ -947,18 +947,15 @@ end function connect(syss...) length(syss) >= 2 || error("connect takes at least two systems!") length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") - Equation(Connect(), Connect(syss)) # the RHS are connected systems + Equation(Connection(), Connection(syss)) # the RHS are connected systems end -# fallback -connect(T::Type, c::Connect) = connect(T, c.outers..., c.inners...) - function expand_connections(sys::AbstractSystem; debug=false) subsys = get_systems(sys) isempty(subsys) && return sys # post order traversal - @unpack sys.systems = map(s->expand_connections(s, debug=debug), subsys) + @set sys.systems = map(s->expand_connections(s, debug=debug), subsys) # Note that subconnectors in outer connectors are still outer connectors. @@ -968,19 +965,19 @@ function expand_connections(sys::AbstractSystem; debug=false) s = string(nameof(sys)) idx = findfirst(isequal('₊'), s) parent_name = idx === nothing ? s : s[1:idx] - parent_name in isouter + parent_name in outer_connectors end end - eqs′ = equations(sys) + eqs′ = get_eqs(sys) eqs = Equation[] cts = [] # connections for eq in eqs′ - eq.lhs isa Connect ? push!(cts, get_systems(eq.rhs)) : push!(eqs, eq) # split connections and equations + eq.lhs isa Connection ? push!(cts, get_systems(eq.rhs)) : push!(eqs, eq) # split connections and equations end sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement - narg_connects = Vector{Connect}[] + narg_connects = Connection[] for (i, syss) in enumerate(cts) # find intersecting connections exclude = 0 # exclude the intersecting system @@ -997,7 +994,7 @@ function expand_connections(sys::AbstractSystem; debug=false) for s in syss isouter(s) ? push!(outers, s) : push!(inners, s) end - push!(narg_connects, Connect(outers=outers, inners=inners)) + push!(narg_connects, Connection(outers=outers, inners=inners)) for s in syss sys2idx[nameof(s)] = length(narg_connects) end @@ -1013,15 +1010,15 @@ function expand_connections(sys::AbstractSystem; debug=false) # Bad things happen when there are more than one intersections for c in narg_connects - @unpack outer, inner = c + @unpack outers, inners = c len = length(outers) + length(inners) - length(unique(nameof, [outers; inners])) == len || error("$(Connect(syss)) has duplicated connections") + length(unique(nameof, [outers; inners])) == len || error("$(Connection(syss)) has duplicated connections") end if debug println("Connections:") print_with_indent(x) = println(" " ^ 4, x) - foreach(print_with_indent ∘ Connect, narg_connects) + foreach(print_with_indent, narg_connects) end for c in narg_connects @@ -1034,66 +1031,6 @@ function expand_connections(sys::AbstractSystem; debug=false) return sys end -#= -function expand_connections(sys::AbstractSystem; debug=false) - sys = flatten(sys) - eqs′ = equations(sys) - eqs = Equation[] - cts = [] - for eq in eqs′ - eq.lhs isa Connect ? push!(cts, eq.rhs.syss) : push!(eqs, eq) # split connections and equations - end - - # O(n) algorithm for connection fusing - sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement - narg_connects = Vector{Any}[] - for (i, syss) in enumerate(cts) - # find intersecting connections - exclude = 0 # exclude the intersecting system - idx = 0 # idx of narg_connects - for (j, s) in enumerate(syss) - idx′ = get(sys2idx, nameof(s), nothing) - idx′ === nothing && continue - idx = idx′ - exclude = j - end - if exclude == 0 - push!(narg_connects, collect(syss)) - for s in syss - sys2idx[nameof(s)] = length(narg_connects) - end - else - # fuse intersecting connections - for (j, s) in enumerate(syss); j == exclude && continue - sys2idx[nameof(s)] = idx - push!(narg_connects[idx], s) - end - end - end - - # validation - for syss in narg_connects - length(unique(nameof, syss)) == length(syss) || error("$(Connect(syss)) has duplicated connections") - end - - if debug - println("Connections:") - print_with_indent(x) = println(" " ^ 4, x) - foreach(print_with_indent ∘ Connect, narg_connects) - end - - # generate connections - for syss in narg_connects - T = promote_connect_type(map(get_connection_type, syss)...) - ceqs = connect(T, syss...) - ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) - end - - @set! sys.eqs = eqs - return sys -end -=# - ### ### Inheritance & composition ### diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index c4a2bf7025..b97f4e054f 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -169,7 +169,7 @@ function ODESystem(eqs, iv=nothing; kwargs...) iv === nothing && throw(ArgumentError("Please pass in independent variables.")) connecteqs = Equation[] for eq in eqs - eq.lhs isa Connect && (push!(connecteqs, eq); continue) + eq.lhs isa Connection && (push!(connecteqs, eq); continue) collect_vars!(allstates, ps, eq.lhs, iv) collect_vars!(allstates, ps, eq.rhs, iv) if isdiffeq(eq) From 1455768b0946a996da3b2941380ded22b3cdb74c Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 8 Nov 2021 17:06:09 -0500 Subject: [PATCH 07/19] Update tests --- Project.toml | 4 +- src/systems/abstractsystem.jl | 10 +++-- test/connectors.jl | 80 +++++++++++++++++++---------------- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 6697d7d73a..9ba32f65d3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Chris Rackauckas "] -version = "7.0.0" +version = "6.8.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -73,7 +73,7 @@ Setfield = "0.7, 0.8" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.10, 0.11, 0.12, 1.0" SymbolicUtils = "0.18" -Symbolics = "4.0.0" +Symbolics = "3, 4.0.0" UnPack = "0.1, 1.0" Unitful = "1.1" julia = "1.2" diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 49852e1331..a2df71ff41 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -932,15 +932,17 @@ end Connection(syss) = Connection(inners=syss) get_systems(c::Connection) = c.inners +const EMPTY_VEC = [] + function Base.show(io::IO, c::Connection) @unpack outers, inners = c if outers === nothing && inners === nothing print(io, "") else - inner_str = join((string(nameof(s)) * "::inner" for s in inners), ", ") - outer_str = join((string(nameof(s)) * "::outer" for s in outers), ", ") - isempty(outer_str) || (outer_str = ", " * outer_str) - print(io, "<", inner_str, outer_str, ">") + syss = Iterators.flatten((something(inners, EMPTY_VEC), something(outers, EMPTY_VEC))) + splitting_idx = length(inners) + sys_str = join((string(nameof(s)) * (i <= splitting_idx ? ("::inner") : ("::outers")) for (i, s) in enumerate(syss)), ", ") + print(io, "<", sys_str, ">") end end diff --git a/test/connectors.jl b/test/connectors.jl index 9305a8378a..360577d50a 100644 --- a/test/connectors.jl +++ b/test/connectors.jl @@ -13,7 +13,9 @@ end ODESystem(Equation[], t, [x], [p], defaults=Dict(x=>1.0, p=>1.0), name=name) end -function ModelingToolkit.connect(::Type{<:Foo}, ss...) +function ModelingToolkit.connect(::Type{<:Foo}, c::Connection) + @show c.inners + ss = c.inners n = length(ss)-1 eqs = Vector{Equation}(undef, n) for i in 1:n @@ -28,49 +30,55 @@ end @named f4 = Foo() @named g = Goo() -@test isequal(connect(f1, f2), [f1.x ~ f2.x]) -@test_throws ArgumentError connect(f1, g) +function connection_eqs(eqs, subsys) + @named sys = ODESystem(eqs, t) + @named newsys = compose(sys, subsys) + equations(expand_connections(newsys)) +end + +connection_eqs(subsys) = Base.Fix2(connection_eqs, subsys) +ceqs = connection_eqs([f1, f2, f3, f4, g]) + +@test isequal(ceqs(connect(f1, f2)), [f1.x ~ f2.x]) +@test_throws ArgumentError ceqs(connect(f1, g)) # Note that since there're overloadings, these tests are not re-runable. ModelingToolkit.promote_connect_rule(::Type{<:Foo}, ::Type{<:Goo}) = Foo -@test isequal(connect(f1, g), [f1.x ~ g.x]) -@test isequal(connect(f1, f2, g), [f1.x ~ f2.x; f2.x ~ g.x]) -@test isequal(connect(f1, f2, g, f3), [f1.x ~ f2.x; f2.x ~ g.x; g.x ~ f3.x]) -@test isequal(connect(f1, f2, g, f3, f4), [f1.x ~ f2.x; f2.x ~ g.x; g.x ~ f3.x; f3.x ~ f4.x]) +@test isequal(ceqs(connect(f1, g)), [f1.x ~ g.x]) +@test isequal(ceqs(connect(f1, f2, g)), [f1.x ~ f2.x; f2.x ~ g.x]) +@test isequal(ceqs(connect(f1, f2, g, f3)), [f1.x ~ f2.x; f2.x ~ g.x; g.x ~ f3.x]) +@test isequal(ceqs(connect(f1, f2, g, f3, f4)), [f1.x ~ f2.x; f2.x ~ g.x; g.x ~ f3.x; f3.x ~ f4.x]) ModelingToolkit.promote_connect_rule(::Type{<:Goo}, ::Type{<:Foo}) = Foo -@test isequal(connect(f1, g), [f1.x ~ g.x]) +@test isequal(ceqs(connect(f1, g)), [f1.x ~ g.x]) # test conflict ModelingToolkit.promote_connect_rule(::Type{<:Goo}, ::Type{<:Foo}) = Goo -@test_throws ArgumentError connect(f1, g) +@test_throws ArgumentError ceqs(connect(f1, g)) @connector Hoo(;name) = ODESystem(Equation[], t, [], [], name=name) -function ModelingToolkit.connect(::Type{<:Hoo}, ss...) +function ModelingToolkit.connect(::Type{<:Hoo}, c::Connection) + ss = c.inners nameof.(ss) ~ 0 end @named hs[1:8] = Hoo() -@named sys = ODESystem([connect(hs[1], hs[2]), - connect(hs[1], hs[3])], t, [], []) -@test equations(expand_connections(sys)) == [(:hs_1, :hs_2, :hs_3) ~ 0] -@named sys = ODESystem([connect(hs[1], hs[2]), - connect(hs[2], hs[3])], t, [], []) -@test equations(expand_connections(sys)) == [(:hs_1, :hs_2, :hs_3) ~ 0] -@named sys = ODESystem([connect(hs[1], hs[2]), - connect(hs[4], hs[3])], t, [], []) -@test equations(expand_connections(sys)) == [(:hs_1, :hs_2) ~ 0, (:hs_4, :hs_3) ~ 0] -@named sys = ODESystem([connect(hs[1], hs[2]), - connect(hs[1], hs[2])], t, [], []) -@test_throws Any expand_connections(sys) -@named sys = ODESystem([connect(hs[1], hs[2]), - connect(hs[3], hs[2]), - connect(hs[1], hs[4]), - connect(hs[8], hs[4]), - connect(hs[7], hs[5]), - ], t, [], []) -@test equations(expand_connections(sys)) == [(:hs_1, :hs_2, :hs_3, :hs_4, :hs_8) ~ 0, (:hs_7, :hs_5) ~ 0] -@named sys = ODESystem([connect(hs[1], hs[2]), - connect(hs[3], hs[2]), - connect(hs[1], hs[4]), - connect(hs[8], hs[4]), - connect(hs[2], hs[8]), - ], t, [], []) -@test_throws Any expand_connections(sys) +ceqs = connection_eqs(hs) + +@test ceqs([connect(hs[1], hs[2]), + connect(hs[1], hs[3])]) == [[:hs_1, :hs_2, :hs_3] ~ 0] + +@test ceqs([connect(hs[1], hs[2]), + connect(hs[2], hs[3])]) == [[:hs_1, :hs_2, :hs_3] ~ 0] + +@test ceqs([connect(hs[1], hs[2]), + connect(hs[4], hs[3])]) == [[:hs_1, :hs_2] ~ 0, [:hs_4, :hs_3] ~ 0] +@test_throws Any ceqs([connect(hs[1], hs[2]), + connect(hs[1], hs[2])]) +@test ceqs([connect(hs[1], hs[2]), + connect(hs[3], hs[2]), + connect(hs[1], hs[4]), + connect(hs[8], hs[4]), + connect(hs[7], hs[5]),]) == [[:hs_1, :hs_2, :hs_3, :hs_4, :hs_8] ~ 0, [:hs_7, :hs_5] ~ 0] +@test_throws Any ceqs([connect(hs[1], hs[2]), + connect(hs[3], hs[2]), + connect(hs[1], hs[4]), + connect(hs[8], hs[4]), + connect(hs[2], hs[8])]) From d1cd98b1064d566616da0c61ea0a84109ccd44dc Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 8 Nov 2021 23:20:28 -0500 Subject: [PATCH 08/19] Refactor --- src/systems/abstractsystem.jl | 33 ++++++++++++++++++++------------ src/systems/diffeqs/odesystem.jl | 6 +++--- src/utils.jl | 18 +++++++++++++++++ test/connectors.jl | 3 +++ 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index a2df71ff41..65c8c215ce 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -952,24 +952,31 @@ function connect(syss...) Equation(Connection(), Connection(syss)) # the RHS are connected systems end -function expand_connections(sys::AbstractSystem; debug=false) - subsys = get_systems(sys) - isempty(subsys) && return sys - - # post order traversal - @set sys.systems = map(s->expand_connections(s, debug=debug), subsys) - +isconnector(s::AbstractSystem) = has_connection_type(s) && get_connection_type(s) !== nothing +function isouterconnector(sys::AbstractSystem; check=true) + subsys = get_systems(sys) + outer_connectors = [nameof(s) for s in subsys if isconnector(sys)] # Note that subconnectors in outer connectors are still outer connectors. # Ref: https://specification.modelica.org/v3.4/Ch9.html see 9.1.2 - isouter = let outer_connectors=[nameof(s) for s in subsys if has_connection_type(s) && get_connection_type(s) !== nothing] - sys -> begin + let outer_connectors=outer_connectors, check=check + function isouter(sys)::Bool s = string(nameof(sys)) + check && (isconnector(sys) || error("$s is not a connector!")) idx = findfirst(isequal('₊'), s) parent_name = idx === nothing ? s : s[1:idx] parent_name in outer_connectors end end +end + +function expand_connections(sys::AbstractSystem; debug=false) + subsys = get_systems(sys) + isempty(subsys) && return sys + + # post order traversal + @set sys.systems = map(s->expand_connections(s, debug=debug), subsys) + isouter = isouterconnector(sys) eqs′ = get_eqs(sys) eqs = Equation[] @@ -1014,7 +1021,9 @@ function expand_connections(sys::AbstractSystem; debug=false) for c in narg_connects @unpack outers, inners = c len = length(outers) + length(inners) - length(unique(nameof, [outers; inners])) == len || error("$(Connection(syss)) has duplicated connections") + allconnectors = Iterators.flatten((outers, inners)) + dups = find_duplicates(nameof(c) for c in allconnectors) + length(dups) == 0 || error("$(Connection(syss)) has duplicated connections: $(dups).") end if debug @@ -1052,7 +1061,7 @@ function Base.hash(sys::AbstractSystem, s::UInt) end """ - $(TYPEDSIGNATURES) +$(TYPEDSIGNATURES) entend the `basesys` with `sys`, the resulting system would inherit `sys`'s name by default. @@ -1087,7 +1096,7 @@ end Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameof(sys)) = extend(sys, basesys; name=name) """ - $(SIGNATURES) +$(SIGNATURES) compose multiple systems together. The resulting system would inherit the first system's name. diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 7ea8a2c594..aec0a31b11 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -163,9 +163,9 @@ function ODESystem(eqs, iv=nothing; kwargs...) end iv = value(iv) iv === nothing && throw(ArgumentError("Please pass in independent variables.")) - connecteqs = Equation[] + compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)` for eq in eqs - eq.lhs isa Connection && (push!(connecteqs, eq); continue) + eq.lhs isa Symbolic || (push!(compressed_eqs, eq); continue) collect_vars!(allstates, ps, eq.lhs, iv) collect_vars!(allstates, ps, eq.rhs, iv) if isdiffeq(eq) @@ -180,7 +180,7 @@ function ODESystem(eqs, iv=nothing; kwargs...) end algevars = setdiff(allstates, diffvars) # the orders here are very important! - return ODESystem(Equation[diffeq; algeeq; connecteqs], iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...) + return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...) end # NOTE: equality does not check cached Jacobian diff --git a/src/utils.jl b/src/utils.jl index 9abfeffb6f..f328f24916 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -349,3 +349,21 @@ function get_postprocess_fbody(sys) end return pre_ end + +""" +$(SIGNATURES) + +find duplicates in an iterable object. +""" +function find_duplicates(xs) + appeared = Set() + duplicates = Set() + for x in xs + if x in appeared + push!(duplicates, x) + else + push!(appeared, x) + end + end + return duplicates +end diff --git a/test/connectors.jl b/test/connectors.jl index 360577d50a..90da4efafd 100644 --- a/test/connectors.jl +++ b/test/connectors.jl @@ -82,3 +82,6 @@ ceqs = connection_eqs(hs) connect(hs[1], hs[4]), connect(hs[8], hs[4]), connect(hs[2], hs[8])]) + +# Outer/inner connectors + From 24c01e528b8ccb1b9c198d16d021186ac1608546 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 8 Nov 2021 23:35:45 -0500 Subject: [PATCH 09/19] connection_type -> connector_type and move some code around --- src/ModelingToolkit.jl | 8 +- src/systems/abstractsystem.jl | 181 +----------------- src/systems/connectors.jl | 175 +++++++++++++++++ src/systems/diffeqs/odesystem.jl | 12 +- src/systems/diffeqs/sdesystem.jl | 10 +- .../discrete_system/discrete_system.jl | 10 +- src/systems/jumps/jumpsystem.jl | 10 +- src/systems/nonlinear/nonlinearsystem.jl | 10 +- src/systems/pde/pdesystem.jl | 6 +- src/variables.jl | 4 + test/discretesystem.jl | 4 +- 11 files changed, 215 insertions(+), 215 deletions(-) create mode 100644 src/systems/connectors.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index d993cc35d5..061d4179b5 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -116,6 +116,7 @@ include("utils.jl") include("domains.jl") include("systems/abstractsystem.jl") +include("systems/connectors.jl") include("systems/diffeqs/odesystem.jl") include("systems/diffeqs/sdesystem.jl") @@ -150,8 +151,6 @@ for S in subtypes(ModelingToolkit.AbstractSystem) @eval convert_system(::Type{<:$S}, sys::$S) = sys end -struct Flow end - export AbstractTimeDependentSystem, AbstractTimeIndependentSystem, AbstractMultivariateSystem export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr, convert_system export DAEFunctionExpr, DAEProblemExpr @@ -166,7 +165,8 @@ export SteadyStateProblem, SteadyStateProblemExpr export JumpProblem, DiscreteProblem export NonlinearSystem, OptimizationSystem export ControlSystem -export alias_elimination, flatten, connect, @connector, Connection +export alias_elimination, flatten +export connect, @connector, Connection, Flow, Stream export ode_order_lowering, liouville_transform export runge_kutta_discretize export PDESystem @@ -197,7 +197,7 @@ export toexpr, get_variables export simplify, substitute export build_function export modelingtoolkitize -export @variables, @parameters, Flow +export @variables, @parameters export @named, @nonamespace, @namespace, extend, compose end # module diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 65c8c215ce..1f31f38889 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -185,7 +185,7 @@ for prop in [ :domain :ivs :dvs - :connection_type + :connector_type :preface ] fname1 = Symbol(:get_, prop) @@ -863,185 +863,6 @@ function check_eqs_u0(eqs, dvs, u0; check_length=true, kwargs...) return nothing end -### -### Connectors -### - -function with_connection_type(expr) - @assert expr isa Expr && (expr.head == :function || (expr.head == :(=) && - expr.args[1] isa Expr && - expr.args[1].head == :call)) - - sig = expr.args[1] - body = expr.args[2] - - fname = sig.args[1] - args = sig.args[2:end] - - quote - struct $fname - $(gensym()) -> 1 # this removes the default constructor - end - function $fname($(args...)) - function f() - $body - end - res = f() - $isdefined(res, :connection_type) ? $Setfield.@set!(res.connection_type = $fname) : res - end - end -end - -macro connector(expr) - esc(with_connection_type(expr)) -end - -promote_connect_rule(::Type{T}, ::Type{S}) where {T, S} = Union{} -promote_connect_rule(::Type{T}, ::Type{T}) where {T} = T -promote_connect_type(t1::Type, t2::Type, ts::Type...) = promote_connect_type(promote_connect_rule(t1, t2), ts...) -@inline function promote_connect_type(::Type{T}, ::Type{S}) where {T,S} - promote_connect_result( - T, - S, - promote_connect_rule(T,S), - promote_connect_rule(S,T) - ) -end - -promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{Union{}}) where {T} = T -promote_connect_result(::Type, ::Type, ::Type{Union{}}, ::Type{S}) where {S} = S -promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{T}) where {T} = T -function promote_connect_result(::Type{T}, ::Type{S}, ::Type{P1}, ::Type{P2}) where {T,S,P1,P2} - throw(ArgumentError("connection promotion for $T and $S resulted in $P1 and $P2. " * - "Define promotion only in one direction.")) -end - -throw_connector_promotion(T, S) = throw(ArgumentError("Don't know how to connect systems of type $S and $T")) -promote_connect_result(::Type{T},::Type{S},::Type{Union{}},::Type{Union{}}) where {T,S} = throw_connector_promotion(T,S) - -promote_connect_type(::Type{T}, ::Type{T}) where {T} = T -function promote_connect_type(T, S) - error("Don't know how to connect systems of type $S and $T") -end - -Base.@kwdef struct Connection - inners = nothing - outers = nothing -end - -Connection(syss) = Connection(inners=syss) -get_systems(c::Connection) = c.inners - -const EMPTY_VEC = [] - -function Base.show(io::IO, c::Connection) - @unpack outers, inners = c - if outers === nothing && inners === nothing - print(io, "") - else - syss = Iterators.flatten((something(inners, EMPTY_VEC), something(outers, EMPTY_VEC))) - splitting_idx = length(inners) - sys_str = join((string(nameof(s)) * (i <= splitting_idx ? ("::inner") : ("::outers")) for (i, s) in enumerate(syss)), ", ") - print(io, "<", sys_str, ">") - end -end - -function connect(syss...) - length(syss) >= 2 || error("connect takes at least two systems!") - length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") - Equation(Connection(), Connection(syss)) # the RHS are connected systems -end - -isconnector(s::AbstractSystem) = has_connection_type(s) && get_connection_type(s) !== nothing - -function isouterconnector(sys::AbstractSystem; check=true) - subsys = get_systems(sys) - outer_connectors = [nameof(s) for s in subsys if isconnector(sys)] - # Note that subconnectors in outer connectors are still outer connectors. - # Ref: https://specification.modelica.org/v3.4/Ch9.html see 9.1.2 - let outer_connectors=outer_connectors, check=check - function isouter(sys)::Bool - s = string(nameof(sys)) - check && (isconnector(sys) || error("$s is not a connector!")) - idx = findfirst(isequal('₊'), s) - parent_name = idx === nothing ? s : s[1:idx] - parent_name in outer_connectors - end - end -end - -function expand_connections(sys::AbstractSystem; debug=false) - subsys = get_systems(sys) - isempty(subsys) && return sys - - # post order traversal - @set sys.systems = map(s->expand_connections(s, debug=debug), subsys) - isouter = isouterconnector(sys) - - eqs′ = get_eqs(sys) - eqs = Equation[] - cts = [] # connections - for eq in eqs′ - eq.lhs isa Connection ? push!(cts, get_systems(eq.rhs)) : push!(eqs, eq) # split connections and equations - end - - sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement - narg_connects = Connection[] - for (i, syss) in enumerate(cts) - # find intersecting connections - exclude = 0 # exclude the intersecting system - idx = 0 # idx of narg_connects - for (j, s) in enumerate(syss) - idx′ = get(sys2idx, nameof(s), nothing) - idx′ === nothing && continue - idx = idx′ - exclude = j - end - if exclude == 0 - outers = [] - inners = [] - for s in syss - isouter(s) ? push!(outers, s) : push!(inners, s) - end - push!(narg_connects, Connection(outers=outers, inners=inners)) - for s in syss - sys2idx[nameof(s)] = length(narg_connects) - end - else - # fuse intersecting connections - for (j, s) in enumerate(syss); j == exclude && continue - sys2idx[nameof(s)] = idx - c = narg_connects[idx] - isouter(s) ? push!(c.outers, s) : push!(c.inners, s) - end - end - end - - # Bad things happen when there are more than one intersections - for c in narg_connects - @unpack outers, inners = c - len = length(outers) + length(inners) - allconnectors = Iterators.flatten((outers, inners)) - dups = find_duplicates(nameof(c) for c in allconnectors) - length(dups) == 0 || error("$(Connection(syss)) has duplicated connections: $(dups).") - end - - if debug - println("Connections:") - print_with_indent(x) = println(" " ^ 4, x) - foreach(print_with_indent, narg_connects) - end - - for c in narg_connects - T = promote_connect_type(map(get_connection_type, c.outers)..., map(get_connection_type, c.inners)...) - ceqs = connect(T, c) - ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) - end - - @set! sys.eqs = eqs - return sys -end - ### ### Inheritance & composition ### diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl new file mode 100644 index 0000000000..3022b1ea1c --- /dev/null +++ b/src/systems/connectors.jl @@ -0,0 +1,175 @@ +function with_connector_type(expr) + @assert expr isa Expr && (expr.head == :function || (expr.head == :(=) && + expr.args[1] isa Expr && + expr.args[1].head == :call)) + + sig = expr.args[1] + body = expr.args[2] + + fname = sig.args[1] + args = sig.args[2:end] + + quote + function $fname($(args...)) + function f() + $body + end + res = f() + $isdefined(res, :connector_type) ? $Setfield.@set!(res.connector_type = $connector_type(res)) : res + end + end +end + +macro connector(expr) + esc(with_connector_type(expr)) +end + +function connector_type(sys::AbstractSystem) + states(sys) +end + +promote_connect_rule(::Type{T}, ::Type{S}) where {T, S} = Union{} +promote_connect_rule(::Type{T}, ::Type{T}) where {T} = T +promote_connect_type(t1::Type, t2::Type, ts::Type...) = promote_connect_type(promote_connect_rule(t1, t2), ts...) +@inline function promote_connect_type(::Type{T}, ::Type{S}) where {T,S} + promote_connect_result( + T, + S, + promote_connect_rule(T,S), + promote_connect_rule(S,T) + ) +end + +promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{Union{}}) where {T} = T +promote_connect_result(::Type, ::Type, ::Type{Union{}}, ::Type{S}) where {S} = S +promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{T}) where {T} = T +function promote_connect_result(::Type{T}, ::Type{S}, ::Type{P1}, ::Type{P2}) where {T,S,P1,P2} + throw(ArgumentError("connection promotion for $T and $S resulted in $P1 and $P2. " * + "Define promotion only in one direction.")) +end + +throw_connector_promotion(T, S) = throw(ArgumentError("Don't know how to connect systems of type $S and $T")) +promote_connect_result(::Type{T},::Type{S},::Type{Union{}},::Type{Union{}}) where {T,S} = throw_connector_promotion(T,S) + +promote_connect_type(::Type{T}, ::Type{T}) where {T} = T +function promote_connect_type(T, S) + error("Don't know how to connect systems of type $S and $T") +end + +Base.@kwdef struct Connection + inners = nothing + outers = nothing +end + +Connection(syss) = Connection(inners=syss) +get_systems(c::Connection) = c.inners + +const EMPTY_VEC = [] + +function Base.show(io::IO, c::Connection) + @unpack outers, inners = c + if outers === nothing && inners === nothing + print(io, "") + else + syss = Iterators.flatten((something(inners, EMPTY_VEC), something(outers, EMPTY_VEC))) + splitting_idx = length(inners) + sys_str = join((string(nameof(s)) * (i <= splitting_idx ? ("::inner") : ("::outers")) for (i, s) in enumerate(syss)), ", ") + print(io, "<", sys_str, ">") + end +end + +function connect(syss...) + length(syss) >= 2 || error("connect takes at least two systems!") + length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") + Equation(Connection(), Connection(syss)) # the RHS are connected systems +end + +isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing + +function isouterconnector(sys::AbstractSystem; check=true) + subsys = get_systems(sys) + outer_connectors = [nameof(s) for s in subsys if isconnector(sys)] + # Note that subconnectors in outer connectors are still outer connectors. + # Ref: https://specification.modelica.org/v3.4/Ch9.html see 9.1.2 + let outer_connectors=outer_connectors, check=check + function isouter(sys)::Bool + s = string(nameof(sys)) + check && (isconnector(sys) || error("$s is not a connector!")) + idx = findfirst(isequal('₊'), s) + parent_name = idx === nothing ? s : s[1:idx] + parent_name in outer_connectors + end + end +end + +function expand_connections(sys::AbstractSystem; debug=false) + subsys = get_systems(sys) + isempty(subsys) && return sys + + # post order traversal + @set sys.systems = map(s->expand_connections(s, debug=debug), subsys) + isouter = isouterconnector(sys) + + eqs′ = get_eqs(sys) + eqs = Equation[] + cts = [] # connections + for eq in eqs′ + eq.lhs isa Connection ? push!(cts, get_systems(eq.rhs)) : push!(eqs, eq) # split connections and equations + end + + sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement + narg_connects = Connection[] + for (i, syss) in enumerate(cts) + # find intersecting connections + exclude = 0 # exclude the intersecting system + idx = 0 # idx of narg_connects + for (j, s) in enumerate(syss) + idx′ = get(sys2idx, nameof(s), nothing) + idx′ === nothing && continue + idx = idx′ + exclude = j + end + if exclude == 0 + outers = [] + inners = [] + for s in syss + isouter(s) ? push!(outers, s) : push!(inners, s) + end + push!(narg_connects, Connection(outers=outers, inners=inners)) + for s in syss + sys2idx[nameof(s)] = length(narg_connects) + end + else + # fuse intersecting connections + for (j, s) in enumerate(syss); j == exclude && continue + sys2idx[nameof(s)] = idx + c = narg_connects[idx] + isouter(s) ? push!(c.outers, s) : push!(c.inners, s) + end + end + end + + # Bad things happen when there are more than one intersections + for c in narg_connects + @unpack outers, inners = c + len = length(outers) + length(inners) + allconnectors = Iterators.flatten((outers, inners)) + dups = find_duplicates(nameof(c) for c in allconnectors) + length(dups) == 0 || error("$(Connection(syss)) has duplicated connections: $(dups).") + end + + if debug + println("Connections:") + print_with_indent(x) = println(" " ^ 4, x) + foreach(print_with_indent, narg_connects) + end + + for c in narg_connects + T = promote_connect_type(map(get_connector_type, c.outers)..., map(get_connector_type, c.inners)...) + ceqs = connect(T, c) + ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) + end + + @set! sys.eqs = eqs + return sys +end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index aec0a31b11..e24ed7376d 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -80,22 +80,22 @@ struct ODESystem <: AbstractODESystem """ structure::Any """ - connection_type: type of the system + connector_type: type of the system """ - connection_type::Any + connector_type::Any """ preface: injuect assignment statements before the evaluation of the RHS function. """ preface::Any - function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type, preface; checks::Bool = true) + function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, preface; checks::Bool = true) if checks check_variables(dvs,iv) check_parameters(ps,iv) check_equations(deqs,iv) all_dimensionless([dvs;ps;iv]) ||check_units(deqs) end - new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type, preface) + new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, preface) end end @@ -108,7 +108,7 @@ function ODESystem( default_u0=Dict(), default_p=Dict(), defaults=_merge(Dict(default_u0), Dict(default_p)), - connection_type=nothing, + connector_type=nothing, preface=nothing, checks = true, ) @@ -140,7 +140,7 @@ function ODESystem( if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) end - ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type, preface, checks = checks) + ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connector_type, preface, checks = checks) end function ODESystem(eqs, iv=nothing; kwargs...) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 89c724c225..4861d42f0f 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -84,16 +84,16 @@ struct SDESystem <: AbstractODESystem """ type: type of the system """ - connection_type::Any + connector_type::Any - function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type; checks::Bool = true) + function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type; checks::Bool = true) if checks check_variables(dvs,iv) check_parameters(ps,iv) check_equations(deqs,iv) all_dimensionless([dvs;ps;iv]) || check_units(deqs,neqs) end - new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type) + new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type) end end @@ -105,7 +105,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; default_p=Dict(), defaults=_merge(Dict(default_u0), Dict(default_p)), name=nothing, - connection_type=nothing, + connector_type=nothing, checks = true, ) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @@ -134,7 +134,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0)) Wfact = RefValue(Matrix{Num}(undef, 0, 0)) Wfact_t = RefValue(Matrix{Num}(undef, 0, 0)) - SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type, checks = checks) + SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, checks = checks) end SDESystem(sys::ODESystem, neqs; kwargs...) = SDESystem(equations(sys), neqs, get_iv(sys), states(sys), parameters(sys); kwargs...) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 27a06f6fcb..5e119bad4a 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -54,14 +54,14 @@ struct DiscreteSystem <: AbstractTimeDependentSystem """ type: type of the system """ - connection_type::Any - function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connection_type; checks::Bool = true) + connector_type::Any + function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connector_type; checks::Bool = true) if checks check_variables(dvs, iv) check_parameters(ps, iv) all_dimensionless([dvs;ps;iv;ctrls]) ||check_units(discreteEqs) end - new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connection_type) + new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connector_type) end end @@ -79,7 +79,7 @@ function DiscreteSystem( default_u0=Dict(), default_p=Dict(), defaults=_merge(Dict(default_u0), Dict(default_p)), - connection_type=nothing, + connector_type=nothing, kwargs..., ) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @@ -103,7 +103,7 @@ function DiscreteSystem( if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) end - DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, connection_type, kwargs...) + DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, connector_type, kwargs...) end diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 27d1e427c6..331a19eb46 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -52,14 +52,14 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem """ type: type of the system """ - connection_type::Any - function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems, defaults, connection_type; checks::Bool = true) where U <: ArrayPartition + connector_type::Any + function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems, defaults, connector_type; checks::Bool = true) where U <: ArrayPartition if checks check_variables(states, iv) check_parameters(ps, iv) all_dimensionless([states;ps;iv]) || check_units(ap,iv) end - new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults, connection_type) + new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults, connector_type) end end @@ -70,7 +70,7 @@ function JumpSystem(eqs, iv, states, ps; default_p=Dict(), defaults=_merge(Dict(default_u0), Dict(default_p)), name=nothing, - connection_type=nothing, + connector_type=nothing, checks = true, kwargs...) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @@ -102,7 +102,7 @@ function JumpSystem(eqs, iv, states, ps; process_variables!(var_to_name, defaults, states) process_variables!(var_to_name, defaults, ps) - JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems, defaults, connection_type, checks = checks) + JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems, defaults, connector_type, checks = checks) end function generate_rate_function(js::JumpSystem, rate) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 8c6799a016..0e90d62f1b 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -53,12 +53,12 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem """ type: type of the system """ - connection_type::Any - function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connection_type; checks::Bool = true) + connector_type::Any + function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connector_type; checks::Bool = true) if checks all_dimensionless([states;ps]) ||check_units(eqs) end - new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connection_type) + new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connector_type) end end @@ -69,7 +69,7 @@ function NonlinearSystem(eqs, states, ps; default_p=Dict(), defaults=_merge(Dict(default_u0), Dict(default_p)), systems=NonlinearSystem[], - connection_type=nothing, + connector_type=nothing, checks = true, ) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @@ -93,7 +93,7 @@ function NonlinearSystem(eqs, states, ps; process_variables!(var_to_name, defaults, states) process_variables!(var_to_name, defaults, ps) - NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, nothing, connection_type, checks = checks) + NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, nothing, connector_type, checks = checks) end function calculate_jacobian(sys::NonlinearSystem; sparse=false, simplify=false) diff --git a/src/systems/pde/pdesystem.jl b/src/systems/pde/pdesystem.jl index 9c5a8623f8..56d168ea54 100644 --- a/src/systems/pde/pdesystem.jl +++ b/src/systems/pde/pdesystem.jl @@ -55,7 +55,7 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem """ type: type of the system """ - connection_type::Any + connector_type::Any """ name: the name of the system """ @@ -63,14 +63,14 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem @add_kwonly function PDESystem(eqs, bcs, domain, ivs, dvs, ps=SciMLBase.NullParameters(); defaults=Dict(), - connection_type = nothing, + connector_type = nothing, checks::Bool = true, name ) if checks all_dimensionless([dvs;ivs;ps]) ||check_units(eqs) end - new(eqs, bcs, domain, ivs, dvs, ps, defaults, connection_type, name) + new(eqs, bcs, domain, ivs, dvs, ps, defaults, connector_type, name) end end diff --git a/src/variables.jl b/src/variables.jl index 5d9d1789f9..01bf000f9f 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -11,6 +11,10 @@ Symbolics.option_to_metadata_type(::Val{:description}) = VariableDescriptionType Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput Symbolics.option_to_metadata_type(::Val{:output}) = VariableOutput +abstract type AbstractConnectType end +struct Flow <: AbstractConnectType end # sum to 0 +struct Stream <: AbstractConnectType end # special stream connector + function isvarkind(m, x) p = getparent(x, nothing) p === nothing || (x = p) diff --git a/test/discretesystem.jl b/test/discretesystem.jl index da227d9b94..d030be4e1f 100644 --- a/test/discretesystem.jl +++ b/test/discretesystem.jl @@ -113,7 +113,7 @@ linearized_eqs = [ ] @test all(eqs2 .== linearized_eqs) -# Test connection_type +# Test connector_type @connector function DiscreteComponent(;name) @variables v(t) i(t) DiscreteSystem(Equation[], t, [v, i], [], name=name, defaults=Dict(v=>1.0, i=>1.0)) @@ -121,4 +121,4 @@ end @named d1 = DiscreteComponent() -@test ModelingToolkit.get_connection_type(d1) == DiscreteComponent +@test ModelingToolkit.get_connector_type(d1) == DiscreteComponent From adfbb8e89a5b141dfaca385974a97d02a4583490 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 9 Nov 2021 00:31:10 -0500 Subject: [PATCH 10/19] Basically rewrite connectors --- examples/electrical_components.jl | 16 +----- src/systems/connectors.jl | 81 +++++++++++++++++++------------ src/variables.jl | 5 +- 3 files changed, 53 insertions(+), 49 deletions(-) diff --git a/examples/electrical_components.jl b/examples/electrical_components.jl index 457fd1fd29..9a4ca63f7a 100644 --- a/examples/electrical_components.jl +++ b/examples/electrical_components.jl @@ -3,24 +3,10 @@ using ModelingToolkit, OrdinaryDiffEq @parameters t @connector function Pin(;name) - sts = @variables v(t)=1.0 i(t)=1.0 + sts = @variables v(t)=1.0 i(t)=1.0 [connect = Flow] ODESystem(Equation[], t, sts, []; name=name) end -function ModelingToolkit.connect(::Type{Pin}, c::Connection) - @unpack outers, inners = c - isum = isempty(inners) ? 0 : sum(p->p.i, inners) - osum = isempty(outers) ? 0 : sum(p->p.i, outers) - eqs = [0 ~ isum - osum] # KCL - ps = [outers; inners] - # KVL - for i in 1:length(ps)-1 - push!(eqs, ps[i].v ~ ps[i+1].v) - end - - return eqs -end - function Ground(;name) @named g = Pin() eqs = [g.v ~ 0] diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 3022b1ea1c..44d81197a1 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -24,36 +24,14 @@ macro connector(expr) esc(with_connector_type(expr)) end -function connector_type(sys::AbstractSystem) - states(sys) -end +abstract type AbstractConnectorType end +struct StreamConnector <: AbstractConnectorType end +struct RegularConnector <: AbstractConnectorType end -promote_connect_rule(::Type{T}, ::Type{S}) where {T, S} = Union{} -promote_connect_rule(::Type{T}, ::Type{T}) where {T} = T -promote_connect_type(t1::Type, t2::Type, ts::Type...) = promote_connect_type(promote_connect_rule(t1, t2), ts...) -@inline function promote_connect_type(::Type{T}, ::Type{S}) where {T,S} - promote_connect_result( - T, - S, - promote_connect_rule(T,S), - promote_connect_rule(S,T) - ) -end - -promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{Union{}}) where {T} = T -promote_connect_result(::Type, ::Type, ::Type{Union{}}, ::Type{S}) where {S} = S -promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{T}) where {T} = T -function promote_connect_result(::Type{T}, ::Type{S}, ::Type{P1}, ::Type{P2}) where {T,S,P1,P2} - throw(ArgumentError("connection promotion for $T and $S resulted in $P1 and $P2. " * - "Define promotion only in one direction.")) -end - -throw_connector_promotion(T, S) = throw(ArgumentError("Don't know how to connect systems of type $S and $T")) -promote_connect_result(::Type{T},::Type{S},::Type{Union{}},::Type{Union{}}) where {T,S} = throw_connector_promotion(T,S) - -promote_connect_type(::Type{T}, ::Type{T}) where {T} = T -function promote_connect_type(T, S) - error("Don't know how to connect systems of type $S and $T") +function connector_type(sys::AbstractSystem) + sts = states(sys) + #TODO: check the criteria for stream connectors + any(s->getmetadata(s, ModelingToolkit.VariableConnectType, nothing) === Stream, sts) ? StreamConnector() : RegularConnector() end Base.@kwdef struct Connection @@ -61,6 +39,7 @@ Base.@kwdef struct Connection outers = nothing end +# everything is inner by default until we expand the connections Connection(syss) = Connection(inners=syss) get_systems(c::Connection) = c.inners @@ -78,12 +57,51 @@ function Base.show(io::IO, c::Connection) end end -function connect(syss...) +function connect(syss::AbstractSystem...) length(syss) >= 2 || error("connect takes at least two systems!") length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") Equation(Connection(), Connection(syss)) # the RHS are connected systems end +function connect(c::Connection; check=true) + @unpack inners, outers = c + + flow_eqs = Equation[] + other_eqs = Equation[] + + cnts = Iterators.flatten((inners, outers)) + fs, ss = Iterators.peel(cnts) + splitting_idx = length(inners) # anything after the splitting_idx is outer. + first_sts = get_states(fs) + first_sts_set = Set(getname.(first_sts)) + for sys in ss + current_sts = getname.(get_states(sys)) + Set(current_sts) == first_sts_set || error("$(nameof(sys)) ($current_sts) doesn't match the connection type of $(nameof(fs)) ($first_sts).") + end + + ceqs = Equation[] + for s in first_sts + name = getname(s) + isflow = getmetadata(s, VariableConnectType, Equality) === Flow + rhs = 0 # only used for flow variables + fix_val = getproperty(fs, name) # used for equality connections + for (i, c) in enumerate(cnts) + isinner = i <= splitting_idx + # https://specification.modelica.org/v3.4/Ch15.html + var = getproperty(c, name) + if isflow + rhs += isinner ? var : -var + else + i == 1 && continue # skip the first iteration + push!(ceqs, fix_val ~ getproperty(c, name)) + end + end + isflow && push!(ceqs, 0 ~ rhs) + end + + return ceqs +end + isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing function isouterconnector(sys::AbstractSystem; check=true) @@ -165,8 +183,7 @@ function expand_connections(sys::AbstractSystem; debug=false) end for c in narg_connects - T = promote_connect_type(map(get_connector_type, c.outers)..., map(get_connector_type, c.inners)...) - ceqs = connect(T, c) + ceqs = connect(c) ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) end diff --git a/src/variables.jl b/src/variables.jl index 01bf000f9f..1c27d3b098 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -12,8 +12,9 @@ Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput Symbolics.option_to_metadata_type(::Val{:output}) = VariableOutput abstract type AbstractConnectType end -struct Flow <: AbstractConnectType end # sum to 0 -struct Stream <: AbstractConnectType end # special stream connector +struct Equality <: AbstractConnectType end # Equality connection +struct Flow <: AbstractConnectType end # sum to 0 +struct Stream <: AbstractConnectType end # special stream connector function isvarkind(m, x) p = getparent(x, nothing) From 7b6666726bddfebd2c2a5c49062bb7afce8364e7 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 9 Nov 2021 02:01:44 -0500 Subject: [PATCH 11/19] Fix some typos --- src/systems/connectors.jl | 21 ++++++---- test/components.jl | 29 +++++++++++++ test/connectors.jl | 87 --------------------------------------- test/runtests.jl | 1 - 4 files changed, 42 insertions(+), 96 deletions(-) delete mode 100644 test/connectors.jl diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 44d81197a1..f4b3d62784 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -15,7 +15,7 @@ function with_connector_type(expr) $body end res = f() - $isdefined(res, :connector_type) ? $Setfield.@set!(res.connector_type = $connector_type(res)) : res + $isdefined(res, :connector_type) && $getfield(res, :connector_type) === nothing ? $Setfield.@set!(res.connector_type = $connector_type(res)) : res end end end @@ -106,7 +106,7 @@ isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) function isouterconnector(sys::AbstractSystem; check=true) subsys = get_systems(sys) - outer_connectors = [nameof(s) for s in subsys if isconnector(sys)] + outer_connectors = [nameof(s) for s in subsys if isconnector(s)] # Note that subconnectors in outer connectors are still outer connectors. # Ref: https://specification.modelica.org/v3.4/Ch9.html see 9.1.2 let outer_connectors=outer_connectors, check=check @@ -114,18 +114,20 @@ function isouterconnector(sys::AbstractSystem; check=true) s = string(nameof(sys)) check && (isconnector(sys) || error("$s is not a connector!")) idx = findfirst(isequal('₊'), s) - parent_name = idx === nothing ? s : s[1:idx] + parent_name = Symbol(idx === nothing ? s : s[1:idx]) parent_name in outer_connectors end end end +print_with_indent(n, x) = println(" " ^ n, x) + function expand_connections(sys::AbstractSystem; debug=false) subsys = get_systems(sys) isempty(subsys) && return sys # post order traversal - @set sys.systems = map(s->expand_connections(s, debug=debug), subsys) + @set! sys.systems = map(s->expand_connections(s, debug=debug), subsys) isouter = isouterconnector(sys) eqs′ = get_eqs(sys) @@ -176,15 +178,18 @@ function expand_connections(sys::AbstractSystem; debug=false) length(dups) == 0 || error("$(Connection(syss)) has duplicated connections: $(dups).") end - if debug + if debug && !isempty(narg_connects) println("Connections:") - print_with_indent(x) = println(" " ^ 4, x) - foreach(print_with_indent, narg_connects) + foreach(Base.Fix1(print_with_indent, 4), narg_connects) end for c in narg_connects ceqs = connect(c) - ceqs isa Equation ? push!(eqs, ceqs) : append!(eqs, ceqs) + if debug + println("Connection equations:") + foreach(Base.Fix1(print_with_indent, 4), ceqs) + end + append!(eqs, ceqs) end @set! sys.eqs = eqs diff --git a/test/components.jl b/test/components.jl index 8a9a8a321e..757e9adf45 100644 --- a/test/components.jl +++ b/test/components.jl @@ -78,3 +78,32 @@ end @unpack foo = goo @test ModelingToolkit.defaults(goo)[foo.a] == 3 @test ModelingToolkit.defaults(goo)[foo.b] == 300 + +# Outer/inner connections +function rc_component(;name) + R = 1 + C = 1 + @named p = Pin() + @named n = Pin() + @named resistor = Resistor(R=R) + @named capacitor = Capacitor(C=C) + eqs = [ + connect(p, resistor.p); + connect(resistor.n, capacitor.p); + connect(capacitor.n, n); + ] + @named sys = ODESystem(eqs, t) + compose(sys, [p, n, resistor, capacitor]; name=name) +end + +@named ground = Ground() +@named source = ConstantVoltage(V=1) +@named rc_comp = rc_component() +eqs = [ + connect(source.p, rc_comp.p) + connect(source.n, rc_comp.n) + connect(source.n, ground.g) + ] +@named sys′ = ODESystem(eqs, t) +@named sys = compose(sys′, [ground, source, rc_comp]) +expand_connections(sys, debug=true) diff --git a/test/connectors.jl b/test/connectors.jl deleted file mode 100644 index 90da4efafd..0000000000 --- a/test/connectors.jl +++ /dev/null @@ -1,87 +0,0 @@ -using Test, ModelingToolkit - -@parameters t - -@connector function Foo(;name) - @variables x(t) - ODESystem(Equation[], t, [x], [], defaults=Dict(x=>1.0), name=name) -end - -@connector function Goo(;name) - @variables x(t) - @parameters p - ODESystem(Equation[], t, [x], [p], defaults=Dict(x=>1.0, p=>1.0), name=name) -end - -function ModelingToolkit.connect(::Type{<:Foo}, c::Connection) - @show c.inners - ss = c.inners - n = length(ss)-1 - eqs = Vector{Equation}(undef, n) - for i in 1:n - eqs[i] = ss[i].x ~ ss[i+1].x - end - eqs -end - -@named f1 = Foo() -@named f2 = Foo() -@named f3 = Foo() -@named f4 = Foo() -@named g = Goo() - -function connection_eqs(eqs, subsys) - @named sys = ODESystem(eqs, t) - @named newsys = compose(sys, subsys) - equations(expand_connections(newsys)) -end - -connection_eqs(subsys) = Base.Fix2(connection_eqs, subsys) -ceqs = connection_eqs([f1, f2, f3, f4, g]) - -@test isequal(ceqs(connect(f1, f2)), [f1.x ~ f2.x]) -@test_throws ArgumentError ceqs(connect(f1, g)) - -# Note that since there're overloadings, these tests are not re-runable. -ModelingToolkit.promote_connect_rule(::Type{<:Foo}, ::Type{<:Goo}) = Foo -@test isequal(ceqs(connect(f1, g)), [f1.x ~ g.x]) -@test isequal(ceqs(connect(f1, f2, g)), [f1.x ~ f2.x; f2.x ~ g.x]) -@test isequal(ceqs(connect(f1, f2, g, f3)), [f1.x ~ f2.x; f2.x ~ g.x; g.x ~ f3.x]) -@test isequal(ceqs(connect(f1, f2, g, f3, f4)), [f1.x ~ f2.x; f2.x ~ g.x; g.x ~ f3.x; f3.x ~ f4.x]) -ModelingToolkit.promote_connect_rule(::Type{<:Goo}, ::Type{<:Foo}) = Foo -@test isequal(ceqs(connect(f1, g)), [f1.x ~ g.x]) -# test conflict -ModelingToolkit.promote_connect_rule(::Type{<:Goo}, ::Type{<:Foo}) = Goo -@test_throws ArgumentError ceqs(connect(f1, g)) - -@connector Hoo(;name) = ODESystem(Equation[], t, [], [], name=name) -function ModelingToolkit.connect(::Type{<:Hoo}, c::Connection) - ss = c.inners - nameof.(ss) ~ 0 -end -@named hs[1:8] = Hoo() -ceqs = connection_eqs(hs) - -@test ceqs([connect(hs[1], hs[2]), - connect(hs[1], hs[3])]) == [[:hs_1, :hs_2, :hs_3] ~ 0] - -@test ceqs([connect(hs[1], hs[2]), - connect(hs[2], hs[3])]) == [[:hs_1, :hs_2, :hs_3] ~ 0] - -@test ceqs([connect(hs[1], hs[2]), - connect(hs[4], hs[3])]) == [[:hs_1, :hs_2] ~ 0, [:hs_4, :hs_3] ~ 0] -@test_throws Any ceqs([connect(hs[1], hs[2]), - connect(hs[1], hs[2])]) -@test ceqs([connect(hs[1], hs[2]), - connect(hs[3], hs[2]), - connect(hs[1], hs[4]), - connect(hs[8], hs[4]), - connect(hs[7], hs[5]),]) == [[:hs_1, :hs_2, :hs_3, :hs_4, :hs_8] ~ 0, [:hs_7, :hs_5] ~ 0] -@test_throws Any ceqs([connect(hs[1], hs[2]), - connect(hs[3], hs[2]), - connect(hs[1], hs[4]), - connect(hs[8], hs[4]), - connect(hs[2], hs[8])]) - -# Outer/inner connectors - diff --git a/test/runtests.jl b/test/runtests.jl index 1d0e7883a8..b16524cc59 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,5 +37,4 @@ println("Last test requires gcc available in the path!") @safetestset "StructuralTransformations" begin include("structural_transformation/runtests.jl") end @testset "Serialization" begin include("serialization.jl") end @safetestset "print_tree" begin include("print_tree.jl") end -@safetestset "connectors" begin include("connectors.jl") end @safetestset "error_handling" begin include("error_handling.jl") end From 704fed370a080ec6f6142683d7f4743f58cb24c5 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 9 Nov 2021 02:11:44 -0500 Subject: [PATCH 12/19] Better tests --- test/components.jl | 228 +++++++++++++++++++++++---------------------- 1 file changed, 119 insertions(+), 109 deletions(-) diff --git a/test/components.jl b/test/components.jl index 757e9adf45..da79b67c78 100644 --- a/test/components.jl +++ b/test/components.jl @@ -1,109 +1,119 @@ -using Test -using ModelingToolkit, OrdinaryDiffEq - -include("../examples/rc_model.jl") - -sys = structural_simplify(rc_model) -@test !isempty(ModelingToolkit.defaults(sys)) -u0 = [ - capacitor.v => 0.0 - capacitor.p.i => 0.0 - resistor.v => 0.0 - ] -prob = ODEProblem(sys, u0, (0, 10.0)) -sol = solve(prob, Rodas4()) - -@test sol[resistor.p.i] == sol[capacitor.p.i] -@test sol[resistor.n.i] == -sol[capacitor.p.i] -@test sol[capacitor.n.i] == -sol[capacitor.p.i] -@test iszero(sol[ground.g.i]) -@test iszero(sol[ground.g.v]) -@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v] - -u0 = [ - capacitor.v => 0.0 - ] -prob = ODAEProblem(sys, u0, (0, 10.0)) -sol = solve(prob, Tsit5()) - -@test sol[resistor.p.i] == sol[capacitor.p.i] -@test sol[resistor.n.i] == -sol[capacitor.p.i] -@test sol[capacitor.n.i] == -sol[capacitor.p.i] -@test iszero(sol[ground.g.i]) -@test iszero(sol[ground.g.v]) -@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v] -#using Plots -#plot(sol) - -include("../examples/serial_inductor.jl") -sys = structural_simplify(ll_model) -u0 = [ - inductor1.i => 0.0 - inductor2.i => 0.0 - inductor2.v => 0.0 - ] -prob = ODEProblem(sys, u0, (0, 10.0)) -sol = solve(prob, Rodas4()) - -prob = ODAEProblem(sys, u0, (0, 10.0)) -sol = solve(prob, Tsit5()) - -@variables t x1(t) x2(t) x3(t) x4(t) -D = Differential(t) -@named sys1_inner = ODESystem([D(x1) ~ x1], t) -@named sys1_partial = compose(ODESystem([D(x2) ~ x2], t; name=:foo), sys1_inner) -@named sys1 = extend(ODESystem([D(x3) ~ x3], t; name=:foo), sys1_partial) -@named sys2 = compose(ODESystem([D(x4) ~ x4], t; name=:foo), sys1) -@test_nowarn sys2.sys1.sys1_inner.x1 # test the correct nesting - - -# compose tests -@parameters t - -function record_fun(;name) - pars = @parameters a=10 b=100 - ODESystem(Equation[], t, [], pars; name) -end - -function first_model(;name) - @named foo=record_fun() - - defs = Dict() - defs[foo.a] = 3 - defs[foo.b] = 300 - pars = @parameters x=2 y=20 - compose(ODESystem(Equation[], t, [], pars; name, defaults=defs), foo) -end -@named goo = first_model() -@unpack foo = goo -@test ModelingToolkit.defaults(goo)[foo.a] == 3 -@test ModelingToolkit.defaults(goo)[foo.b] == 300 - -# Outer/inner connections -function rc_component(;name) - R = 1 - C = 1 - @named p = Pin() - @named n = Pin() - @named resistor = Resistor(R=R) - @named capacitor = Capacitor(C=C) - eqs = [ - connect(p, resistor.p); - connect(resistor.n, capacitor.p); - connect(capacitor.n, n); - ] - @named sys = ODESystem(eqs, t) - compose(sys, [p, n, resistor, capacitor]; name=name) -end - -@named ground = Ground() -@named source = ConstantVoltage(V=1) -@named rc_comp = rc_component() -eqs = [ - connect(source.p, rc_comp.p) - connect(source.n, rc_comp.n) - connect(source.n, ground.g) - ] -@named sys′ = ODESystem(eqs, t) -@named sys = compose(sys′, [ground, source, rc_comp]) -expand_connections(sys, debug=true) +using Test +using ModelingToolkit, OrdinaryDiffEq + +include("../examples/rc_model.jl") + +sys = structural_simplify(rc_model) +@test !isempty(ModelingToolkit.defaults(sys)) +u0 = [ + capacitor.v => 0.0 + capacitor.p.i => 0.0 + resistor.v => 0.0 + ] +prob = ODEProblem(sys, u0, (0, 10.0)) +sol = solve(prob, Rodas4()) + +@test sol[resistor.p.i] == sol[capacitor.p.i] +@test sol[resistor.n.i] == -sol[capacitor.p.i] +@test sol[capacitor.n.i] == -sol[capacitor.p.i] +@test iszero(sol[ground.g.i]) +@test iszero(sol[ground.g.v]) +@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v] + +# Outer/inner connections +function rc_component(;name) + R = 1 + C = 1 + @named p = Pin() + @named n = Pin() + @named resistor = Resistor(R=R) + @named capacitor = Capacitor(C=C) + eqs = [ + connect(p, resistor.p); + connect(resistor.n, capacitor.p); + connect(capacitor.n, n); + ] + @named sys = ODESystem(eqs, t) + compose(sys, [p, n, resistor, capacitor]; name=name) +end + +@named ground = Ground() +@named source = ConstantVoltage(V=1) +@named rc_comp = rc_component() +eqs = [ + connect(source.p, rc_comp.p) + connect(source.n, rc_comp.n) + connect(source.n, ground.g) + ] +@named sys′ = ODESystem(eqs, t) +@named sys_inner_outer = compose(sys′, [ground, source, rc_comp]) +expand_connections(sys_inner_outer, debug=true) +sys_inner_outer = structural_simplify(sys_inner_outer) +@test !isempty(ModelingToolkit.defaults(sys_inner_outer)) +u0 = [ + rc_comp.capacitor.v => 0.0 + rc_comp.capacitor.p.i => 0.0 + rc_comp.resistor.v => 0.0 + ] +prob = ODEProblem(sys_inner_outer, u0, (0, 10.0)) +sol_inner_outer = solve(prob, Rodas4()) +@test sol[capacitor.v] ≈ sol_inner_outer[rc_comp.capacitor.v] + +u0 = [ + capacitor.v => 0.0 + ] +prob = ODAEProblem(sys, u0, (0, 10.0)) +sol = solve(prob, Tsit5()) + +@test sol[resistor.p.i] == sol[capacitor.p.i] +@test sol[resistor.n.i] == -sol[capacitor.p.i] +@test sol[capacitor.n.i] == -sol[capacitor.p.i] +@test iszero(sol[ground.g.i]) +@test iszero(sol[ground.g.v]) +@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v] +#using Plots +#plot(sol) + +include("../examples/serial_inductor.jl") +sys = structural_simplify(ll_model) +u0 = [ + inductor1.i => 0.0 + inductor2.i => 0.0 + inductor2.v => 0.0 + ] +prob = ODEProblem(sys, u0, (0, 10.0)) +sol = solve(prob, Rodas4()) + +prob = ODAEProblem(sys, u0, (0, 10.0)) +sol = solve(prob, Tsit5()) + +@variables t x1(t) x2(t) x3(t) x4(t) +D = Differential(t) +@named sys1_inner = ODESystem([D(x1) ~ x1], t) +@named sys1_partial = compose(ODESystem([D(x2) ~ x2], t; name=:foo), sys1_inner) +@named sys1 = extend(ODESystem([D(x3) ~ x3], t; name=:foo), sys1_partial) +@named sys2 = compose(ODESystem([D(x4) ~ x4], t; name=:foo), sys1) +@test_nowarn sys2.sys1.sys1_inner.x1 # test the correct nesting + + +# compose tests +@parameters t + +function record_fun(;name) + pars = @parameters a=10 b=100 + ODESystem(Equation[], t, [], pars; name) +end + +function first_model(;name) + @named foo=record_fun() + + defs = Dict() + defs[foo.a] = 3 + defs[foo.b] = 300 + pars = @parameters x=2 y=20 + compose(ODESystem(Equation[], t, [], pars; name, defaults=defs), foo) +end +@named goo = first_model() +@unpack foo = goo +@test ModelingToolkit.defaults(goo)[foo.a] == 3 +@test ModelingToolkit.defaults(goo)[foo.b] == 300 From 60a5430c65d0bac445a8cba2952fba78b4af8913 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 9 Nov 2021 13:31:49 -0500 Subject: [PATCH 13/19] Quick fix --- src/systems/connectors.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index f4b3d62784..00b3c1fa4a 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -52,7 +52,7 @@ function Base.show(io::IO, c::Connection) else syss = Iterators.flatten((something(inners, EMPTY_VEC), something(outers, EMPTY_VEC))) splitting_idx = length(inners) - sys_str = join((string(nameof(s)) * (i <= splitting_idx ? ("::inner") : ("::outers")) for (i, s) in enumerate(syss)), ", ") + sys_str = join((string(nameof(s)) * (i <= splitting_idx ? ("::inner") : ("::outer")) for (i, s) in enumerate(syss)), ", ") print(io, "<", sys_str, ">") end end @@ -179,19 +179,24 @@ function expand_connections(sys::AbstractSystem; debug=false) end if debug && !isempty(narg_connects) - println("Connections:") + println("============BEGIN================") + println("Connections for [$(nameof(sys))]:") foreach(Base.Fix1(print_with_indent, 4), narg_connects) end + connection_eqs = Equation[] for c in narg_connects ceqs = connect(c) - if debug - println("Connection equations:") - foreach(Base.Fix1(print_with_indent, 4), ceqs) - end + debug && append!(connection_eqs, ceqs) append!(eqs, ceqs) end + if debug && !isempty(narg_connects) + println("Connection equations:") + foreach(Base.Fix1(print_with_indent, 4), connection_eqs) + println("=============END=================") + end + @set! sys.eqs = eqs return sys end From b8aac3c920d37dd4765d88b06139dbac75300ac3 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 12 Nov 2021 18:22:34 -0500 Subject: [PATCH 14/19] WIP --- src/systems/abstractsystem.jl | 4 +- src/systems/connectors.jl | 164 +++++++++++++++++++++++++++++++--- src/utils.jl | 4 +- 3 files changed, 155 insertions(+), 17 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 1f31f38889..23c5281eef 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -248,7 +248,7 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace=false) elseif !isempty(systems) i = findfirst(x->nameof(x)==name, systems) if i !== nothing - return namespace ? rename(systems[i], renamespace(sys, name)) : systems[i] + return namespace ? renamespace(sys, systems[i]) : systems[i] end end @@ -333,6 +333,8 @@ function renamespace(sys, x) x end end + elseif x isa AbstractSystem + rename(x, renamespace(sys, nameof(x))) else Symbol(getname(sys), :₊, x) end diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 00b3c1fa4a..e44fe55c92 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -31,7 +31,15 @@ struct RegularConnector <: AbstractConnectorType end function connector_type(sys::AbstractSystem) sts = states(sys) #TODO: check the criteria for stream connectors - any(s->getmetadata(s, ModelingToolkit.VariableConnectType, nothing) === Stream, sts) ? StreamConnector() : RegularConnector() + n_stream = 0 + n_flow = 0 + for s in sts + vtype = getmetadata(s, ModelingToolkit.VariableConnectType, nothing) + vtype === Stream && (n_stream += 1) + vtype === Flow && (n_flow += 1) + end + (n_stream > 1 && n_flow > 1) && error("There are multiple flow variables in $(nameof(sys))!") + n_stream > 1 ? StreamConnector() : RegularConnector() end Base.@kwdef struct Connection @@ -57,12 +65,14 @@ function Base.show(io::IO, c::Connection) end end +# symbolic `connect` function connect(syss::AbstractSystem...) length(syss) >= 2 || error("connect takes at least two systems!") length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") Equation(Connection(), Connection(syss)) # the RHS are connected systems end +# the actual `connect`. function connect(c::Connection; check=true) @unpack inners, outers = c @@ -102,25 +112,137 @@ function connect(c::Connection; check=true) return ceqs end +instream(a) = term(instream, unwrap(a), type=symtype(a)) + isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing +isstreamconnector(s::AbstractSystem) = isconnector(s) && get_connector_type(s) === Stream + +print_with_indent(n, x) = println(" " ^ n, x) -function isouterconnector(sys::AbstractSystem; check=true) +function get_stream_connectors!(sc, sys::AbstractSystem) subsys = get_systems(sys) - outer_connectors = [nameof(s) for s in subsys if isconnector(s)] - # Note that subconnectors in outer connectors are still outer connectors. - # Ref: https://specification.modelica.org/v3.4/Ch9.html see 9.1.2 - let outer_connectors=outer_connectors, check=check - function isouter(sys)::Bool + isempty(subsys) && return nothing + for s in subsys; isstreamconnector(s) || continue + push!(sc, renamespace(sys, s)) + end + for s in subsys + get_stream_connectors!(sc, renamespace(sys, s)) + end + nothing +end + +collect_instream!(set, eq::Equation) = collect_instream!(set, eq.lhs) | collect_instream!(set, eq.rhs) + +function collect_instream!(set, expr, occurs=false) + istree(expr) || return occurs + op = operation(expr) + op === instream && (push!(set, expr); occurs = true) + for a in unsorted_arguments(expr) + occurs |= collect_instream!(set, a, occurs) + end + return occurs +end + +function split_var(var) + name = string(nameof(var)) + map(Symbol, split(name, '₊')) +end + +# inclusive means the first level of `var` is `sys` +function get_sys_var(sys::AbstractSystem, var; inclusive=true) + lvs = split_var(var) + if inclusive + sysn, lvs = Iterator.peel(lvs) + sysn === nameof(sys) || error("$(nameof(sys)) doesn't have $var!") + end + newsys = getproperty(sys, first(lvs)) + for i in 2:length(lvs)-1 + newsys = getproperty(newsys, lvs[i]) + end + newsys, lvs[end] +end + +function expand_instream(sys::AbstractSystem; debug=false) + subsys = get_systems(sys) + isempty(subsys) && return sys + + # post order traversal + @set! sys.systems = map(s->expand_connections(s, debug=debug), subsys) + + outer_sc = [] + for s in subsys + n = nameof(s) + isstreamconnector(s) && push!(outer_sc, n) + end + + # the number of stream connectors excluding the current level + inner_sc = [] + for s in subsys + get_stream_connectors!(inner_sc, renamespace(sys, s)) + end + + # error checking + # TODO: Error might never be possible anyway, because subsystem names must + # be distinct. + outer_names, dup = find_duplicates((nameof(s) for s in outer_sc), Val(true)) + isempty(dup) || error("$dup are duplicate stream connectors!") + inner_names, dup = find_duplicates((nameof(s) for s in inner_sc), Val(true)) + isempty(dup) || error("$dup are duplicate stream connectors!") + + foreach(Base.Fix1(get_stream_connectors!, inner_sc), subsys) + isouterstream = let stream_connectors=outer_sc + function isstream(sys)::Bool s = string(nameof(sys)) - check && (isconnector(sys) || error("$s is not a connector!")) - idx = findfirst(isequal('₊'), s) - parent_name = Symbol(idx === nothing ? s : s[1:idx]) - parent_name in outer_connectors + isstreamconnector(sys) || error("$s is not a stream connector!") + s in stream_connectors end end -end -print_with_indent(n, x) = println(" " ^ n, x) + eqs′ = get_eqs(sys) + eqs = Equation[] + instream_eqs = Equation[] + instream_exprs = Set() + for eq in eqs′ + if collect_instream!(instream_exprs, eq) + push!(instream_eqs, eq) + else + push!(eqs, eq) # split instreams and equations + end + end + + function check_in_stream_connectors(stream, sc) + stream = only(arguments(ex)) + stream_name = string(nameof(stream)) + connector_name = stream_name[1:something(findlast('₊', stream_name), end)] + connector_name in sc || error("$stream_name is not in any stream connector of $(nameof(sys))") + end + + # expand `instream`s + sub = Dict() + n_outers = length(outer_names) + n_inners = length(inner_names) + # https://specification.modelica.org/v3.4/Ch15.html + # Based on the above requirements, the following implementation is + # recommended: + if n_inners == 1 && n_outers == 0 + for ex in instream_exprs + stream = only(arguments(ex)) + check_in_stream_connectors(stream, inner_names) + sub[ex] = stream + end + elseif n_inners == 2 && n_outers == 0 + for ex in instream_exprs + stream = only(arguments(ex)) + check_in_stream_connectors(stream, inner_names) + + sub[ex] = stream + end + elseif n_inners == 1 && n_outers == 1 + elseif n_inners == 0 && n_outers == 2 + else + end + instream_eqs = map(Base.Fix2(substitute, sub), instream_eqs) +end function expand_connections(sys::AbstractSystem; debug=false) subsys = get_systems(sys) @@ -128,7 +250,21 @@ function expand_connections(sys::AbstractSystem; debug=false) # post order traversal @set! sys.systems = map(s->expand_connections(s, debug=debug), subsys) - isouter = isouterconnector(sys) + + outer_connectors = Symbol[] + for s in subsys + n = nameof(s) + isconnector(s) && push!(outer_connectors, n) + end + isouter = let outer_connectors=outer_connectors + function isouter(sys)::Bool + s = string(nameof(sys)) + isconnector(sys) || error("$s is not a connector!") + idx = findfirst(isequal('₊'), s) + parent_name = Symbol(idx === nothing ? s : s[1:idx]) + parent_name in outer_connectors + end + end eqs′ = get_eqs(sys) eqs = Equation[] diff --git a/src/utils.jl b/src/utils.jl index f328f24916..95ea631aa9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -355,7 +355,7 @@ $(SIGNATURES) find duplicates in an iterable object. """ -function find_duplicates(xs) +function find_duplicates(xs, ::Val{Ret}) where Ret appeared = Set() duplicates = Set() for x in xs @@ -365,5 +365,5 @@ function find_duplicates(xs) push!(appeared, x) end end - return duplicates + return Ret ? duplicates : (appeared, duplicates) end From 8d7b2d1669f62c5e3995aadb6b829d78ab66b84b Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 13 Nov 2021 03:58:58 -0500 Subject: [PATCH 15/19] Minor fix --- src/systems/diffeqs/odesystem.jl | 2 +- src/utils.jl | 4 ++-- test/discretesystem.jl | 10 ---------- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 7caed4b4b3..8648cea007 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -173,7 +173,7 @@ function ODESystem(eqs, iv=nothing; kwargs...) iv === nothing && throw(ArgumentError("Please pass in independent variables.")) compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)` for eq in eqs - eq.lhs isa Symbolic || (push!(compressed_eqs, eq); continue) + eq.lhs isa Union{Symbolic,Number} || (push!(compressed_eqs, eq); continue) collect_vars!(allstates, ps, eq.lhs, iv) collect_vars!(allstates, ps, eq.rhs, iv) if isdiffeq(eq) diff --git a/src/utils.jl b/src/utils.jl index dd6396442c..8a868889f5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -359,7 +359,7 @@ $(SIGNATURES) find duplicates in an iterable object. """ -function find_duplicates(xs, ::Val{Ret}) where Ret +function find_duplicates(xs, ::Val{Ret}=Val(false)) where Ret appeared = Set() duplicates = Set() for x in xs @@ -369,5 +369,5 @@ function find_duplicates(xs, ::Val{Ret}) where Ret push!(appeared, x) end end - return Ret ? duplicates : (appeared, duplicates) + return Ret ? (appeared, duplicates) : duplicates end diff --git a/test/discretesystem.jl b/test/discretesystem.jl index d030be4e1f..fe993a8661 100644 --- a/test/discretesystem.jl +++ b/test/discretesystem.jl @@ -112,13 +112,3 @@ linearized_eqs = [ y(t - 2.0) ~ y(t) ] @test all(eqs2 .== linearized_eqs) - -# Test connector_type -@connector function DiscreteComponent(;name) - @variables v(t) i(t) - DiscreteSystem(Equation[], t, [v, i], [], name=name, defaults=Dict(v=>1.0, i=>1.0)) -end - -@named d1 = DiscreteComponent() - -@test ModelingToolkit.get_connector_type(d1) == DiscreteComponent From fb791eb53a0b011616ba634f7403b2fe34246e8e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 15 Nov 2021 00:54:27 -0500 Subject: [PATCH 16/19] WIP --- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 2 + src/systems/connectors.jl | 149 +++++++++++++++++++++---------- src/systems/diffeqs/odesystem.jl | 10 ++- 4 files changed, 110 insertions(+), 53 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 8aac4b3321..15444ee086 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -173,7 +173,7 @@ export JumpProblem, DiscreteProblem export NonlinearSystem, OptimizationSystem export ControlSystem export alias_elimination, flatten -export connect, @connector, Connection, Flow, Stream +export connect, @connector, Connection, Flow, Stream, instream export ode_order_lowering, liouville_transform export runge_kutta_discretize export PDESystem diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 05ef2a33c4..79c81d0440 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -220,6 +220,7 @@ for prop in [ :ivs :dvs :connector_type + :connections :preface ] fname1 = Symbol(:get_, prop) @@ -355,6 +356,7 @@ GlobalScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, GlobalScope( renamespace(sys, eq::Equation) = namespace_equation(eq, sys) +renamespace(names::AbstractVector, x) = foldr(renamespace, names, init=x) function renamespace(sys, x) x = unwrap(x) if x isa Symbolic diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index e44fe55c92..4500f03fa4 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -50,6 +50,10 @@ end # everything is inner by default until we expand the connections Connection(syss) = Connection(inners=syss) get_systems(c::Connection) = c.inners +function Base.in(e::Symbol, c::Connection) + f = isequal(e) + any(f, c.inners) || any(f, c.outers) +end const EMPTY_VEC = [] @@ -137,7 +141,7 @@ function collect_instream!(set, expr, occurs=false) istree(expr) || return occurs op = operation(expr) op === instream && (push!(set, expr); occurs = true) - for a in unsorted_arguments(expr) + for a in SymbolicUtils.unsorted_arguments(expr) occurs |= collect_instream!(set, a, occurs) end return occurs @@ -162,42 +166,42 @@ function get_sys_var(sys::AbstractSystem, var; inclusive=true) newsys, lvs[end] end -function expand_instream(sys::AbstractSystem; debug=false) +function split_stream_var(var) + var_name = string(getname(var)) + @show var_name + sidx = findlast(isequal('₊'), var_name) + sidx === nothing && error("$var is not a stream variable") + connector_name = Symbol(var_name[1:prevind(var_name, sidx)]) + streamvar_name = Symbol(var_name[nextind(var_name, sidx):end]) + connector_name, streamvar_name +end + +function find_connection(connector_name, ogsys, names) + cs = get_connections(ogsys) + cs === nothing || for c in cs + @show renamespace(names, connector_name) + renamespace(names, connector_name) in c && return c + end + innersys = ogsys + for n in names + innersys = getproperty(innersys, n) + cs = get_connections(innersys) + cs === nothing || for c in cs + connector_name in c && return c + end + end + error("$connector_name cannot be found in $(nameof(ogsys)) with levels $(names)") +end + +function expand_instream(ogsys, sys::AbstractSystem=ogsys, names=[]; debug=false) subsys = get_systems(sys) isempty(subsys) && return sys # post order traversal - @set! sys.systems = map(s->expand_connections(s, debug=debug), subsys) - - outer_sc = [] - for s in subsys + @set! sys.systems = map(subsys) do s n = nameof(s) - isstreamconnector(s) && push!(outer_sc, n) - end - - # the number of stream connectors excluding the current level - inner_sc = [] - for s in subsys - get_stream_connectors!(inner_sc, renamespace(sys, s)) + expand_instream(ogsys, s, [names; n], debug=debug) end - - # error checking - # TODO: Error might never be possible anyway, because subsystem names must - # be distinct. - outer_names, dup = find_duplicates((nameof(s) for s in outer_sc), Val(true)) - isempty(dup) || error("$dup are duplicate stream connectors!") - inner_names, dup = find_duplicates((nameof(s) for s in inner_sc), Val(true)) - isempty(dup) || error("$dup are duplicate stream connectors!") - - foreach(Base.Fix1(get_stream_connectors!, inner_sc), subsys) - isouterstream = let stream_connectors=outer_sc - function isstream(sys)::Bool - s = string(nameof(sys)) - isstreamconnector(sys) || error("$s is not a stream connector!") - s in stream_connectors - end - end - eqs′ = get_eqs(sys) eqs = Equation[] instream_eqs = Equation[] @@ -209,47 +213,87 @@ function expand_instream(sys::AbstractSystem; debug=false) push!(eqs, eq) # split instreams and equations end end - - function check_in_stream_connectors(stream, sc) - stream = only(arguments(ex)) - stream_name = string(nameof(stream)) - connector_name = stream_name[1:something(findlast('₊', stream_name), end)] - connector_name in sc || error("$stream_name is not in any stream connector of $(nameof(sys))") + @show nameof(sys), names, instream_exprs + isempty(instream_eqs) && return sys + + for ex in instream_exprs + var = only(arguments(ex)) + connector_name, streamvar_name = split_stream_var(var) + + n_inners = 0 + n_outers = 0 + outer_names = Symbol[] + inner_names = Symbol[] + outer_sc = Symbol[] + inner_sc = Symbol[] + # find the connect + connect = find_connection(connector_name, ogsys, names) + @show connect end + return sys + #= + + #splitting_idx + + #n_inners + n_outers <= 0 && error("Model $(nameof(sys)) has no stream connectors, yet there are equations with `instream` functions: $(instream_eqs)") # expand `instream`s sub = Dict() - n_outers = length(outer_names) - n_inners = length(inner_names) + additional_eqs = Equation[] + seen = Set() # https://specification.modelica.org/v3.4/Ch15.html # Based on the above requirements, the following implementation is # recommended: if n_inners == 1 && n_outers == 0 for ex in instream_exprs - stream = only(arguments(ex)) - check_in_stream_connectors(stream, inner_names) - sub[ex] = stream + var = only(arguments(ex)) + connector_name, streamvar_name = split_stream_var(var) + idx = findfirst(isequal(connector_name), inner_names) + idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") + sub[ex] = var #getproperty(inner_sc[idx], streamvar_name) end elseif n_inners == 2 && n_outers == 0 for ex in instream_exprs - stream = only(arguments(ex)) - check_in_stream_connectors(stream, inner_names) - - sub[ex] = stream + var = only(arguments(ex)) + connector_name, streamvar_name = split_stream_var(var) + idx = findfirst(isequal(connector_name), inner_names) + idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") + other = idx == 1 ? 2 : 1 + sub[ex] = getproperty(inner_sc[other], streamvar_name) end elseif n_inners == 1 && n_outers == 1 + for ex in instream_exprs + var = only(arguments(ex)) # m_1.c.h_outflow + connector_name, streamvar_name = split_stream_var(var) + idx = findfirst(isequal(connector_name), inner_names) + idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") + outerinstream = getproperty(only(outer_sc), streamvar_name) # c_1.h_outflow + sub[ex] = outerinstream + if var in seen + push!(additional_eqs, outerinstream ~ var) + push!(seen, var) + end + end elseif n_inners == 0 && n_outers == 2 + push!(additional_eqs, outerinstream ~ var) else end instream_eqs = map(Base.Fix2(substitute, sub), instream_eqs) + =# end function expand_connections(sys::AbstractSystem; debug=false) + sys = collect_connections(sys; debug=debug) + sys = expand_instream(sys; debug=debug) + return sys +end + +function collect_connections(sys::AbstractSystem; debug=false) subsys = get_systems(sys) isempty(subsys) && return sys # post order traversal - @set! sys.systems = map(s->expand_connections(s, debug=debug), subsys) + @set! sys.systems = map(s->collect_connections(s, debug=debug), subsys) outer_connectors = Symbol[] for s in subsys @@ -273,6 +317,9 @@ function expand_connections(sys::AbstractSystem; debug=false) eq.lhs isa Connection ? push!(cts, get_systems(eq.rhs)) : push!(eqs, eq) # split connections and equations end + # if there are no connections, we are done + isempty(cts) && return sys + sys2idx = Dict{Symbol,Int}() # system (name) to n-th connect statement narg_connects = Connection[] for (i, syss) in enumerate(cts) @@ -305,6 +352,10 @@ function expand_connections(sys::AbstractSystem; debug=false) end end + isempty(narg_connects) && error("Unreachable reached. Please file an issue.") + + @set! sys.connections = narg_connects + # Bad things happen when there are more than one intersections for c in narg_connects @unpack outers, inners = c @@ -314,7 +365,7 @@ function expand_connections(sys::AbstractSystem; debug=false) length(dups) == 0 || error("$(Connection(syss)) has duplicated connections: $(dups).") end - if debug && !isempty(narg_connects) + if debug println("============BEGIN================") println("Connections for [$(nameof(sys))]:") foreach(Base.Fix1(print_with_indent, 4), narg_connects) @@ -327,7 +378,7 @@ function expand_connections(sys::AbstractSystem; debug=false) append!(eqs, ceqs) end - if debug && !isempty(narg_connects) + if debug println("Connection equations:") foreach(Base.Fix1(print_with_indent, 4), connection_eqs) println("=============END=================") diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 8648cea007..e573230cae 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -84,6 +84,10 @@ struct ODESystem <: AbstractODESystem """ connector_type::Any """ + connections: connections in a system + """ + connections::Any + """ preface: inject assignment statements before the evaluation of the RHS function. """ preface::Any @@ -93,7 +97,7 @@ struct ODESystem <: AbstractODESystem """ continuous_events::Vector{SymbolicContinuousCallback} - function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, preface, events; checks::Bool = true) + function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, connections, preface, events; checks::Bool = true) if checks check_variables(dvs,iv) check_parameters(ps,iv) @@ -101,7 +105,7 @@ struct ODESystem <: AbstractODESystem check_equations(equations(events),iv) all_dimensionless([dvs;ps;iv]) || check_units(deqs) end - new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, preface, events) + new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, connections, preface, events) end end @@ -148,7 +152,7 @@ function ODESystem( throw(ArgumentError("System names must be unique.")) end cont_callbacks = SymbolicContinuousCallbacks(continuous_events) - ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connector_type, preface, cont_callbacks, checks = checks) + ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connector_type, nothing, preface, cont_callbacks, checks = checks) end function ODESystem(eqs, iv=nothing; kwargs...) From ec9a51ff64dc5495ca49d255f9a1c6205ba1f3ae Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 15 Nov 2021 01:21:20 -0500 Subject: [PATCH 17/19] Implement draft inner/outer criterion for `instream` --- src/systems/connectors.jl | 137 ++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 66 deletions(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 4500f03fa4..ce8082145e 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -51,8 +51,7 @@ end Connection(syss) = Connection(inners=syss) get_systems(c::Connection) = c.inners function Base.in(e::Symbol, c::Connection) - f = isequal(e) - any(f, c.inners) || any(f, c.outers) + any(k->nameof(k) === e, c.inners) || any(k->nameof(k) === e, c.outers) end const EMPTY_VEC = [] @@ -166,11 +165,10 @@ function get_sys_var(sys::AbstractSystem, var; inclusive=true) newsys, lvs[end] end -function split_stream_var(var) +function split_sys_var(var) var_name = string(getname(var)) - @show var_name sidx = findlast(isequal('₊'), var_name) - sidx === nothing && error("$var is not a stream variable") + sidx === nothing && error("$var is not a namespaced variable") connector_name = Symbol(var_name[1:prevind(var_name, sidx)]) streamvar_name = Symbol(var_name[nextind(var_name, sidx):end]) connector_name, streamvar_name @@ -179,15 +177,15 @@ end function find_connection(connector_name, ogsys, names) cs = get_connections(ogsys) cs === nothing || for c in cs - @show renamespace(names, connector_name) - renamespace(names, connector_name) in c && return c + renamespace(names, connector_name) in c && return ogsys, c end innersys = ogsys - for n in names + for (i, n) in enumerate(names) innersys = getproperty(innersys, n) cs = get_connections(innersys) cs === nothing || for c in cs - connector_name in c && return c + nn = @view names[i+1:end] + renamespace(nn, connector_name) in c && return innersys, c end end error("$connector_name cannot be found in $(nameof(ogsys)) with levels $(names)") @@ -218,68 +216,75 @@ function expand_instream(ogsys, sys::AbstractSystem=ogsys, names=[]; debug=false for ex in instream_exprs var = only(arguments(ex)) - connector_name, streamvar_name = split_stream_var(var) - - n_inners = 0 - n_outers = 0 - outer_names = Symbol[] - inner_names = Symbol[] - outer_sc = Symbol[] - inner_sc = Symbol[] + connector_name, streamvar_name = split_sys_var(var) + + outer_sc = [] + inner_sc = [] # find the connect - connect = find_connection(connector_name, ogsys, names) - @show connect - end - return sys - #= - - #splitting_idx - - #n_inners + n_outers <= 0 && error("Model $(nameof(sys)) has no stream connectors, yet there are equations with `instream` functions: $(instream_eqs)") - - # expand `instream`s - sub = Dict() - additional_eqs = Equation[] - seen = Set() - # https://specification.modelica.org/v3.4/Ch15.html - # Based on the above requirements, the following implementation is - # recommended: - if n_inners == 1 && n_outers == 0 - for ex in instream_exprs - var = only(arguments(ex)) - connector_name, streamvar_name = split_stream_var(var) - idx = findfirst(isequal(connector_name), inner_names) - idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") - sub[ex] = var #getproperty(inner_sc[idx], streamvar_name) - end - elseif n_inners == 2 && n_outers == 0 - for ex in instream_exprs - var = only(arguments(ex)) - connector_name, streamvar_name = split_stream_var(var) - idx = findfirst(isequal(connector_name), inner_names) - idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") - other = idx == 1 ? 2 : 1 - sub[ex] = getproperty(inner_sc[other], streamvar_name) + parentsys, connect = find_connection(connector_name, ogsys, names) + if nameof(parentsys) != nameof(sys) + # everything is a inner connector w.r.t. `sys` + for s in Iterators.flatten((connect.inners, connect.outers)) + push!(inner_sc, s) + end + else + for s in Iterators.flatten((connect.inners, connect.outers)) + if connector_name == split_var(nameof(s))[1] + push!(inner_sc, s) + else + push!(outer_sc, s) + end + end end - elseif n_inners == 1 && n_outers == 1 - for ex in instream_exprs - var = only(arguments(ex)) # m_1.c.h_outflow - connector_name, streamvar_name = split_stream_var(var) - idx = findfirst(isequal(connector_name), inner_names) - idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") - outerinstream = getproperty(only(outer_sc), streamvar_name) # c_1.h_outflow - sub[ex] = outerinstream - if var in seen - push!(additional_eqs, outerinstream ~ var) - push!(seen, var) + + n_inners = length(outer_sc) + n_outers = length(inner_sc) + @show n_inners n_outers + + # expand `instream`s + sub = Dict() + additional_eqs = Equation[] + seen = Set() + # https://specification.modelica.org/v3.4/Ch15.html + # Based on the above requirements, the following implementation is + # recommended: + if n_inners == 1 && n_outers == 0 + for ex in instream_exprs + var = only(arguments(ex)) + connector_name, streamvar_name = split_stream_var(var) + idx = findfirst(isequal(connector_name), inner_names) + idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") + sub[ex] = var #getproperty(inner_sc[idx], streamvar_name) + end + elseif n_inners == 2 && n_outers == 0 + for ex in instream_exprs + var = only(arguments(ex)) + connector_name, streamvar_name = split_stream_var(var) + idx = findfirst(isequal(connector_name), inner_names) + idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") + other = idx == 1 ? 2 : 1 + sub[ex] = getproperty(inner_sc[other], streamvar_name) + end + elseif n_inners == 1 && n_outers == 1 + for ex in instream_exprs + var = only(arguments(ex)) # m_1.c.h_outflow + connector_name, streamvar_name = split_stream_var(var) + idx = findfirst(isequal(connector_name), inner_names) + idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") + outerinstream = getproperty(only(outer_sc), streamvar_name) # c_1.h_outflow + sub[ex] = outerinstream + if var in seen + push!(additional_eqs, outerinstream ~ var) + push!(seen, var) + end end + elseif n_inners == 0 && n_outers == 2 + push!(additional_eqs, outerinstream ~ var) + else end - elseif n_inners == 0 && n_outers == 2 - push!(additional_eqs, outerinstream ~ var) - else end instream_eqs = map(Base.Fix2(substitute, sub), instream_eqs) - =# + return sys end function expand_connections(sys::AbstractSystem; debug=false) @@ -305,7 +310,7 @@ function collect_connections(sys::AbstractSystem; debug=false) s = string(nameof(sys)) isconnector(sys) || error("$s is not a connector!") idx = findfirst(isequal('₊'), s) - parent_name = Symbol(idx === nothing ? s : s[1:idx]) + parent_name = Symbol(idx === nothing ? s : s[1:prevind(s, idx)]) parent_name in outer_connectors end end From 2f8aa2a32fccfa83b0d004a90671017586f3ed41 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 15 Nov 2021 16:26:39 -0500 Subject: [PATCH 18/19] Draft of various connections --- src/systems/connectors.jl | 150 +++++++++++++++++++++++++++++--------- 1 file changed, 115 insertions(+), 35 deletions(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index ce8082145e..f893d6dc84 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -29,7 +29,7 @@ struct StreamConnector <: AbstractConnectorType end struct RegularConnector <: AbstractConnectorType end function connector_type(sys::AbstractSystem) - sts = states(sys) + sts = get_states(sys) #TODO: check the criteria for stream connectors n_stream = 0 n_flow = 0 @@ -191,6 +191,15 @@ function find_connection(connector_name, ogsys, names) error("$connector_name cannot be found in $(nameof(ogsys)) with levels $(names)") end +function flowvar(sys::AbstractSystem) + sts = get_states(sys) + for s in sts + vtype = getmetadata(s, ModelingToolkit.VariableConnectType, nothing) + vtype === Flow && return s + end + error("There in no flow variable in $(nameof(sys))") +end + function expand_instream(ogsys, sys::AbstractSystem=ogsys, names=[]; debug=false) subsys = get_systems(sys) isempty(subsys) && return sys @@ -211,9 +220,12 @@ function expand_instream(ogsys, sys::AbstractSystem=ogsys, names=[]; debug=false push!(eqs, eq) # split instreams and equations end end - @show nameof(sys), names, instream_exprs + #@show nameof(sys), names, instream_exprs isempty(instream_eqs) && return sys + sub = Dict() + seen = Set() + additional_eqs = Equation[] for ex in instream_exprs var = only(arguments(ex)) connector_name, streamvar_name = split_sys_var(var) @@ -222,13 +234,16 @@ function expand_instream(ogsys, sys::AbstractSystem=ogsys, names=[]; debug=false inner_sc = [] # find the connect parentsys, connect = find_connection(connector_name, ogsys, names) + connectors = Iterators.flatten((connect.inners, connect.outers)) + # stream variable + sv = getproperty(first(connectors), streamvar_name; namespace=false) if nameof(parentsys) != nameof(sys) # everything is a inner connector w.r.t. `sys` - for s in Iterators.flatten((connect.inners, connect.outers)) + for s in connectors push!(inner_sc, s) end else - for s in Iterators.flatten((connect.inners, connect.outers)) + for s in connectors if connector_name == split_var(nameof(s))[1] push!(inner_sc, s) else @@ -239,52 +254,117 @@ function expand_instream(ogsys, sys::AbstractSystem=ogsys, names=[]; debug=false n_inners = length(outer_sc) n_outers = length(inner_sc) - @show n_inners n_outers + outer_names = (nameof(s) for s in outer_sc) + inner_names = (nameof(s) for s in inner_sc) + if debug + println("Expanding: $ex") + isempty(inner_names) || println("Inner connectors: $(collect(inner_names))") + isempty(outer_names) || println("Outer connectors: $(collect(outer_names))") + end # expand `instream`s - sub = Dict() - additional_eqs = Equation[] - seen = Set() # https://specification.modelica.org/v3.4/Ch15.html # Based on the above requirements, the following implementation is # recommended: if n_inners == 1 && n_outers == 0 - for ex in instream_exprs - var = only(arguments(ex)) - connector_name, streamvar_name = split_stream_var(var) - idx = findfirst(isequal(connector_name), inner_names) - idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") - sub[ex] = var #getproperty(inner_sc[idx], streamvar_name) - end + connector_name === only(inner_names) || error("$stream_name is not in any stream connector of $(nameof(ogsys))") + sub[ex] = var elseif n_inners == 2 && n_outers == 0 - for ex in instream_exprs - var = only(arguments(ex)) - connector_name, streamvar_name = split_stream_var(var) - idx = findfirst(isequal(connector_name), inner_names) - idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") - other = idx == 1 ? 2 : 1 - sub[ex] = getproperty(inner_sc[other], streamvar_name) - end + connector_name in inner_names || error("$stream_name is not in any stream connector of $(nameof(ogsys))") + idx = findfirst(c->nameof(c) === connector_name, inner_sc) + other = idx == 1 ? 2 : 1 + sub[ex] = states(inner_sc[other], sv) elseif n_inners == 1 && n_outers == 1 - for ex in instream_exprs - var = only(arguments(ex)) # m_1.c.h_outflow - connector_name, streamvar_name = split_stream_var(var) - idx = findfirst(isequal(connector_name), inner_names) - idx === nothing || error("$stream_name is not in any stream connector of $(nameof(sys))") - outerinstream = getproperty(only(outer_sc), streamvar_name) # c_1.h_outflow + isinner = connector_name === only(inner_names) + isouter = connector_name === only(outer_names) + (isinner || isouter) || error("$stream_name is not in any stream connector of $(nameof(ogsys))") + if isinner + outerinstream = states(only(outer_sc), sv) # c_1.h_outflow sub[ex] = outerinstream - if var in seen - push!(additional_eqs, outerinstream ~ var) - push!(seen, var) - end + end + if var in seen + push!(additional_eqs, outerinstream ~ var) + push!(seen, var) end elseif n_inners == 0 && n_outers == 2 - push!(additional_eqs, outerinstream ~ var) + # we don't expand `instream` in this case. + if var in seen + v1 = states(outer_sc[1], sv) + v2 = states(outer_sc[2], sv) + push!(additional_eqs, v1 ~ instream(v2)) + push!(additional_eqs, v2 ~ instream(v1)) + push!(seen, var) + end else + fv = flowvar(first(connectors)) + idx = findfirst(c->nameof(c) === connector_name, inner_sc) + if idx !== nothing + si = sum(s->max(states(s, fv), 0), outer_sc) + for j in 1:n_inners; j == i && continue + f = states(inner_sc[j], fv) + si += max(-f, 0) + end + + num = 0 + den = 0 + for j in 1:n_inners; j == i && continue + f = states(inner_sc[j], fv) + tmp = positivemax(-f, si) + den += tmp + num += tmp * states(inner_sc[j], sv) + end + for k in 1:n_outers + f = states(outer_sc[k], fv) + tmp = positivemax(f, si) + den += tmp + num += tmp * instream(states(outer_sc[k], sv)) + end + sub[ex] = num / den + end + + if var in seen + for q in 1:n_outers + sq += sum(s->max(-states(s, fv), 0), inner_sc) + for k in 1:n_outers; k == q && continue + f = states(outer_sc[j], fv) + si += max(f, 0) + end + + num = 0 + den = 0 + for j in 1:n_inners + f = states(inner_sc[j], fv) + tmp = positivemax(-f, sq) + den += tmp + num += tmp * states(inner_sc[j], sv) + end + for k in 1:n_outers; k == q && continue + f = states(outer_sc[k], fv) + tmp = positivemax(f, sq) + den += tmp + num += tmp * instream(states(outer_sc[k], sv)) + end + push!(additional_eqs, states(outer_sc[q], sv) ~ num / den) + end + push!(seen, var) + end end end instream_eqs = map(Base.Fix2(substitute, sub), instream_eqs) - return sys + if debug + println("Expanded equations:") + for eq in instream_eqs + print_with_indent(4, eq) + end + if !isempty(additional_eqs) + println("Additional equations:") + for eq in additional_eqs + print_with_indent(4, eq) + end + end + end + @set! sys.eqs = [eqs; instream_eqs; additional_eqs] + return flatten(sys) end function expand_connections(sys::AbstractSystem; debug=false) From fdb853532b3bca4006a266723ce5608651c94321 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 15 Nov 2021 17:25:25 -0500 Subject: [PATCH 19/19] Impose arity constraint in the `connect` signature Co-authored-by: Fredrik Bagge Carlson --- src/systems/connectors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index f893d6dc84..193027a247 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -69,8 +69,8 @@ function Base.show(io::IO, c::Connection) end # symbolic `connect` -function connect(syss::AbstractSystem...) - length(syss) >= 2 || error("connect takes at least two systems!") +function connect(sys1::AbstractSystem, sys2::AbstractSystem, syss::AbstractSystem...) + syss = (sys1, sys2, syss...) length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") Equation(Connection(), Connection(syss)) # the RHS are connected systems end