|
| 1 | +# !!! Note: |
| 2 | +# Many thanks to `@matthieubulte` for this contribution to `SymPy`. |
| 3 | + |
| 4 | +# The map_subscripts function is stolen from Symbolics.jl |
| 5 | +const IndexMap = Dict{Char,Char}( |
| 6 | + '-' => '₋', |
| 7 | + '0' => '₀', |
| 8 | + '1' => '₁', |
| 9 | + '2' => '₂', |
| 10 | + '3' => '₃', |
| 11 | + '4' => '₄', |
| 12 | + '5' => '₅', |
| 13 | + '6' => '₆', |
| 14 | + '7' => '₇', |
| 15 | + '8' => '₈', |
| 16 | + '9' => '₉') |
| 17 | + |
| 18 | +function map_subscripts(indices) |
| 19 | + str = string(indices) |
| 20 | + join(IndexMap[c] for c in str) |
| 21 | +end |
| 22 | + |
| 23 | +# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later. |
| 24 | +abstract type VarDecl end |
| 25 | + |
| 26 | +struct SymDecl <: VarDecl |
| 27 | + sym :: Symbol |
| 28 | +end |
| 29 | + |
| 30 | +struct NamedDecl <: VarDecl |
| 31 | + name :: String |
| 32 | + rest :: VarDecl |
| 33 | +end |
| 34 | + |
| 35 | +struct FunctionDecl <: VarDecl |
| 36 | + rest :: VarDecl |
| 37 | +end |
| 38 | + |
| 39 | +struct TensorDecl <: VarDecl |
| 40 | + ranges :: Vector{AbstractRange} |
| 41 | + rest :: VarDecl |
| 42 | +end |
| 43 | + |
| 44 | +struct AssumptionsDecl <: VarDecl |
| 45 | + assumptions :: Vector{Symbol} |
| 46 | + rest :: VarDecl |
| 47 | +end |
| 48 | + |
| 49 | +# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol |
| 50 | +function gendecl(x::VarDecl) |
| 51 | + asstokw(a) = Expr(:kw, esc(a), true) |
| 52 | + val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...))) |
| 53 | + :($(esc(sym(x))) = $(genreshape(val, x))) |
| 54 | +end |
| 55 | + |
| 56 | +# Transform an expression in a Decl struct |
| 57 | +function parsedecl(expr) |
| 58 | + # @vars x |
| 59 | + if isa(expr, Symbol) |
| 60 | + return SymDecl(expr) |
| 61 | + |
| 62 | + # @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...) |
| 63 | + #= no assumptions in SymEngine |
| 64 | + elseif isa(expr, Expr) && expr.head == :(::) |
| 65 | + symexpr, assumptions = expr.args |
| 66 | + assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args |
| 67 | + return AssumptionsDecl(assumptions, parsedecl(symexpr)) |
| 68 | + =# |
| 69 | + |
| 70 | + # @vars x=>"name" |
| 71 | + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>) |
| 72 | + length(expr.args) == 3 || parseerror() |
| 73 | + isa(expr.args[3], String) || parseerror() |
| 74 | + |
| 75 | + expr, strname = expr.args[2:end] |
| 76 | + return NamedDecl(strname, parsedecl(expr)) |
| 77 | + |
| 78 | + # @vars x() |
| 79 | + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>) |
| 80 | + length(expr.args) == 1 || parseerror() |
| 81 | + return FunctionDecl(parsedecl(expr.args[1])) |
| 82 | + |
| 83 | + # @vars x[1:5, 3:9] |
| 84 | + elseif isa(expr, Expr) && expr.head == :ref |
| 85 | + length(expr.args) > 1 || parseerror() |
| 86 | + ranges = map(parserange, expr.args[2:end]) |
| 87 | + return TensorDecl(ranges, parsedecl(expr.args[1])) |
| 88 | + else |
| 89 | + parseerror() |
| 90 | + end |
| 91 | +end |
| 92 | + |
| 93 | +function parserange(expr) |
| 94 | + range = eval(expr) |
| 95 | + isa(range, AbstractRange) || parseerror() |
| 96 | + range |
| 97 | +end |
| 98 | + |
| 99 | +sym(x::SymDecl) = x.sym |
| 100 | +sym(x::NamedDecl) = sym(x.rest) |
| 101 | +sym(x::FunctionDecl) = sym(x.rest) |
| 102 | +sym(x::TensorDecl) = sym(x.rest) |
| 103 | +sym(x::AssumptionsDecl) = sym(x.rest) |
| 104 | + |
| 105 | +ctor(::SymDecl) = :symbols |
| 106 | +ctor(x::NamedDecl) = ctor(x.rest) |
| 107 | +ctor(::FunctionDecl) = :SymFunction |
| 108 | +ctor(x::TensorDecl) = ctor(x.rest) |
| 109 | +ctor(x::AssumptionsDecl) = ctor(x.rest) |
| 110 | + |
| 111 | +assumptions(::SymDecl) = [] |
| 112 | +assumptions(x::NamedDecl) = assumptions(x.rest) |
| 113 | +assumptions(x::FunctionDecl) = assumptions(x.rest) |
| 114 | +assumptions(x::TensorDecl) = assumptions(x.rest) |
| 115 | +assumptions(x::AssumptionsDecl) = x.assumptions |
| 116 | + |
| 117 | +# Reshape is not used by most nodes, but TensorNodes require the output to be given |
| 118 | +# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should |
| 119 | +# have size(x) = (3, 5) |
| 120 | +genreshape(expr, ::SymDecl) = expr |
| 121 | +genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest) |
| 122 | +genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest) |
| 123 | +genreshape(expr, x::TensorDecl) = let |
| 124 | + shape = tuple(length.(x.ranges)...) |
| 125 | + :(reshape(collect($(expr)), $(shape))) |
| 126 | +end |
| 127 | +genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest) |
| 128 | + |
| 129 | +# To find out the name, we need to traverse in both directions to make sure that each node can get |
| 130 | +# information from parents and children about possible name. |
| 131 | +# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl |
| 132 | +# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or |
| 133 | +# based on the SymDecl leaf. |
| 134 | +name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym)) |
| 135 | +name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name) |
| 136 | +name(x::FunctionDecl, parentname) = name(x.rest, parentname) |
| 137 | +name(x::AssumptionsDecl, parentname) = name(x.rest, parentname) |
| 138 | +name(x::TensorDecl, parentname) = let |
| 139 | + basename = name(x.rest, parentname) |
| 140 | + # we need to double reverse the indices to make sure that we traverse them in the natural order |
| 141 | + namestensor = map(Iterators.product(x.ranges...)) do ind |
| 142 | + sub = join(map(map_subscripts, ind), "_") |
| 143 | + string(basename, sub) |
| 144 | + end |
| 145 | + join(namestensor[:], ", ") |
| 146 | +end |
| 147 | + |
| 148 | +function parseerror() |
| 149 | + error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.") |
| 150 | +end |
0 commit comments