diff --git a/docs/src/basics/MTKLanguage.md b/docs/src/basics/MTKLanguage.md index 537e1b0349..06250f90b4 100644 --- a/docs/src/basics/MTKLanguage.md +++ b/docs/src/basics/MTKLanguage.md @@ -147,9 +147,9 @@ julia> ModelingToolkit.getdefault(model_c1.v) 2.0 ``` -#### `@extend` begin block +#### `@extend` statement -Partial systems can be extended in a higher system in two ways: +One or more partial systems can be extended in a higher system with `@extend` statements. This can be done in two ways: - `@extend PartialSystem(var1 = value1)` @@ -313,7 +313,8 @@ end - `:components`: The list of sub-components in the form of [[name, sub_component_name],...]. - `:constants`: Dictionary of constants mapped to its metadata. - `:defaults`: Dictionary of variables and default values specified in the `@defaults`. - - `:extend`: The list of extended unknowns, name given to the base system, and name of the base system. + - `:extend`: The list of extended unknowns, parameters and components, name given to the base system, and name of the base system. + When multiple extend statements are present, latter two are returned as lists. - `:structural_parameters`: Dictionary of structural parameters mapped to their metadata. - `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For parameter arrays, length is added to the metadata as `:size`. diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 21a3600749..e715dc6eea 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -918,7 +918,7 @@ Mark a system as completed. A completed system is a system which is done being defined/modified and is ready for structural analysis or other transformations. This allows for analyses and optimizations to be performed which require knowing the global structure of the system. - + One property to note is that if a system is complete, the system will no longer namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`. """ @@ -1933,7 +1933,7 @@ function Base.show( end end limited = nrows < nsubs - limited && print(io, "\n ⋮") # too many to print + limited && print(io, "\n ⋮") # too many to print # Print equations eqs = equations(sys) @@ -3043,10 +3043,19 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; return T(args...; kwargs...) end +function extend(sys, basesys::Vector{T}) where {T <: AbstractSystem} + foldl(extend, basesys, init = sys) +end + function Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; kwargs...) extend(sys, basesys; kwargs...) end +function Base.:(&)( + sys::AbstractSystem, basesys::Vector{T}; kwargs...) where {T <: AbstractSystem} + extend(sys, basesys; kwargs...) +end + """ $(SIGNATURES) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index e8955b5b84..5b79eaa91b 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -50,7 +50,7 @@ function _model_macro(mod, name, expr, isconnector) :structural_parameters => Dict{Symbol, Dict}() ) comps = Union{Symbol, Expr}[] - ext = Ref{Any}(nothing) + ext = [] eqs = Expr[] icon = Ref{Union{String, URI}}() ps, sps, vs, = [], [], [] @@ -115,10 +115,10 @@ function _model_macro(mod, name, expr, isconnector) sys = :($ODESystem($(flatten_equations)(equations), $iv, variables, parameters; name, systems, gui_metadata = $gui_metadata, defaults)) - if ext[] === nothing + if length(ext) == 0 push!(exprs.args, :(var"#___sys___" = $sys)) else - push!(exprs.args, :(var"#___sys___" = $extend($sys, $(ext[])))) + push!(exprs.args, :(var"#___sys___" = $extend($sys, [$(ext...)]))) end isconnector && push!(exprs.args, @@ -240,7 +240,7 @@ function unit_handled_variable_value(meta, varname) end # This function parses various variable/parameter definitions. -# +# # The comments indicate the syntax matched by a block; either when parsed directly # when it is called recursively for parsing a part of an expression. # These variable definitions are part of test suite in `test/model_parsing.jl` @@ -286,7 +286,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types; # `(l2(t)[1:N, 1:M] = 2), [description = "l is more than 1D, with arbitrary length"]` # `(l3(t)[1:3] = 3), [description = "l2 is 1D"]` # `(l4(t)[1:N] = 4), [description = "l2 is 1D, with arbitrary length"]` - # + # # Condition 2 parses: # `(l5(t)[1:3]::Int = 5), [description = "l3 is 1D and has a type"]` # `(l6(t)[1:N]::Int = 6), [description = "l3 is 1D and has a type, with arbitrary length"]` @@ -373,7 +373,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types; # Condition 1 is recursively called by: # `par5[1:3]::BigFloat` # `par6(t)[1:3]::BigFloat` - # + # # Condition 2 parses: # `b2(t)[1:2]` # `a2[1:2]` @@ -791,11 +791,17 @@ function _parse_extend!(ext, a, b, dict, expr, kwargs, vars, implicit_arglist) end end - ext[] = a + push!(ext, a) push!(b.args, Expr(:kw, :name, Meta.quot(a))) push!(expr.args, :($a = $b)) - dict[:extend] = [Symbol.(vars.args), a, b.args[1]] + if !haskey(dict, :extend) + dict[:extend] = [Symbol.(vars.args), a, b.args[1]] + else + push!(dict[:extend][1], Symbol.(vars.args)...) + dict[:extend][2] = vcat(dict[:extend][2], a) + dict[:extend][3] = vcat(dict[:extend][3], b.args[1]) + end push!(expr.args, :(@unpack $vars = $a)) end diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 9cdd8712d4..fd223800e3 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -945,6 +945,15 @@ end end end +@mtkmodel MidModelB begin + @parameters begin + b + end + @components begin + inmodel_b = InnerModel() + end +end + @mtkmodel OuterModel begin @extend MidModel() @equations begin @@ -958,3 +967,15 @@ end @named out = OuterModel() @test OuterModel.structure[:extend][1] == [:inmodel] end + +@mtkmodel MultipleExtend begin + @extend MidModel() + @extend MidModelB() +end + +@testset "Multiple extend statements" begin + @named multiple_extend = MultipleExtend() + @test collect(nameof.(multiple_extend.systems)) == [:inmodel_b, :inmodel] + @test MultipleExtend.structure[:extend][1] == [:inmodel, :b, :inmodel_b] + @test tosymbol.(parameters(multiple_extend)) == [:b, :inmodel_b₊p, :inmodel₊p] +end