Skip to content

Commit 49fca4d

Browse files
committed
Merge branch 'issue_225' of https://github.com/jverzani/SymEngine.jl into issue_225
2 parents 7031588 + eb2f13d commit 49fca4d

File tree

6 files changed

+192
-16
lines changed

6 files changed

+192
-16
lines changed

LICENSE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1818
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
1919
THE SOFTWARE.
2020

21+
=============================================================================
22+
23+
Some parts of src/decl.jl is from Symbolics.jl and SymPy.jl licensed under the
24+
same license with the copyrights
25+
26+
Copyright (c) <2013> <j verzani>
27+
Copyright (c) 2021: Shashi Gowda, Yingbo Ma, Chris Rackauckas, Julia Computing.

src/SymEngine.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const libversion = get_libversion()
2323
include("exceptions.jl")
2424
include("types.jl")
2525
include("ctypes.jl")
26+
include("decl.jl")
2627
include("display.jl")
2728
include("mathops.jl")
2829
include("mathfuns.jl")

src/decl.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# !!! Note:
2+
# Many thanks to `@matthieubulte` for this contribution to `SymPy`.
3+
4+
# The map_subscripts function is stolen from Symbolics.jl
5+
const IndexMap = Dict{Char,Char}(
6+
'-' => '',
7+
'0' => '',
8+
'1' => '',
9+
'2' => '',
10+
'3' => '',
11+
'4' => '',
12+
'5' => '',
13+
'6' => '',
14+
'7' => '',
15+
'8' => '',
16+
'9' => '')
17+
18+
function map_subscripts(indices)
19+
str = string(indices)
20+
join(IndexMap[c] for c in str)
21+
end
22+
23+
# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later.
24+
abstract type VarDecl end
25+
26+
struct SymDecl <: VarDecl
27+
sym :: Symbol
28+
end
29+
30+
struct NamedDecl <: VarDecl
31+
name :: String
32+
rest :: VarDecl
33+
end
34+
35+
struct FunctionDecl <: VarDecl
36+
rest :: VarDecl
37+
end
38+
39+
struct TensorDecl <: VarDecl
40+
ranges :: Vector{AbstractRange}
41+
rest :: VarDecl
42+
end
43+
44+
struct AssumptionsDecl <: VarDecl
45+
assumptions :: Vector{Symbol}
46+
rest :: VarDecl
47+
end
48+
49+
# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol
50+
function gendecl(x::VarDecl)
51+
asstokw(a) = Expr(:kw, esc(a), true)
52+
val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...)))
53+
:($(esc(sym(x))) = $(genreshape(val, x)))
54+
end
55+
56+
# Transform an expression in a Decl struct
57+
function parsedecl(expr)
58+
# @vars x
59+
if isa(expr, Symbol)
60+
return SymDecl(expr)
61+
62+
# @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...)
63+
#= no assumptions in SymEngine
64+
elseif isa(expr, Expr) && expr.head == :(::)
65+
symexpr, assumptions = expr.args
66+
assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args
67+
return AssumptionsDecl(assumptions, parsedecl(symexpr))
68+
=#
69+
70+
# @vars x=>"name"
71+
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>)
72+
length(expr.args) == 3 || parseerror()
73+
isa(expr.args[3], String) || parseerror()
74+
75+
expr, strname = expr.args[2:end]
76+
return NamedDecl(strname, parsedecl(expr))
77+
78+
# @vars x()
79+
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>)
80+
length(expr.args) == 1 || parseerror()
81+
return FunctionDecl(parsedecl(expr.args[1]))
82+
83+
# @vars x[1:5, 3:9]
84+
elseif isa(expr, Expr) && expr.head == :ref
85+
length(expr.args) > 1 || parseerror()
86+
ranges = map(parserange, expr.args[2:end])
87+
return TensorDecl(ranges, parsedecl(expr.args[1]))
88+
else
89+
parseerror()
90+
end
91+
end
92+
93+
function parserange(expr)
94+
range = eval(expr)
95+
isa(range, AbstractRange) || parseerror()
96+
range
97+
end
98+
99+
sym(x::SymDecl) = x.sym
100+
sym(x::NamedDecl) = sym(x.rest)
101+
sym(x::FunctionDecl) = sym(x.rest)
102+
sym(x::TensorDecl) = sym(x.rest)
103+
sym(x::AssumptionsDecl) = sym(x.rest)
104+
105+
ctor(::SymDecl) = :symbols
106+
ctor(x::NamedDecl) = ctor(x.rest)
107+
ctor(::FunctionDecl) = :SymFunction
108+
ctor(x::TensorDecl) = ctor(x.rest)
109+
ctor(x::AssumptionsDecl) = ctor(x.rest)
110+
111+
assumptions(::SymDecl) = []
112+
assumptions(x::NamedDecl) = assumptions(x.rest)
113+
assumptions(x::FunctionDecl) = assumptions(x.rest)
114+
assumptions(x::TensorDecl) = assumptions(x.rest)
115+
assumptions(x::AssumptionsDecl) = x.assumptions
116+
117+
# Reshape is not used by most nodes, but TensorNodes require the output to be given
118+
# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should
119+
# have size(x) = (3, 5)
120+
genreshape(expr, ::SymDecl) = expr
121+
genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest)
122+
genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest)
123+
genreshape(expr, x::TensorDecl) = let
124+
shape = tuple(length.(x.ranges)...)
125+
:(reshape(collect($(expr)), $(shape)))
126+
end
127+
genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest)
128+
129+
# To find out the name, we need to traverse in both directions to make sure that each node can get
130+
# information from parents and children about possible name.
131+
# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl
132+
# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or
133+
# based on the SymDecl leaf.
134+
name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym))
135+
name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name)
136+
name(x::FunctionDecl, parentname) = name(x.rest, parentname)
137+
name(x::AssumptionsDecl, parentname) = name(x.rest, parentname)
138+
name(x::TensorDecl, parentname) = let
139+
basename = name(x.rest, parentname)
140+
# we need to double reverse the indices to make sure that we traverse them in the natural order
141+
namestensor = map(Iterators.product(x.ranges...)) do ind
142+
sub = join(map(map_subscripts, ind), "_")
143+
string(basename, sub)
144+
end
145+
join(namestensor[:], ", ")
146+
end
147+
148+
function parseerror()
149+
error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.")
150+
end

src/subs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d
3333
subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2])
3434
subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...)
3535
subs(ex::T, d::Pair...) where {T <: SymbolicType} = subs(ex, [(p.first, p.second) for p in d]...)
36-
36+
subs(ex::SymbolicType) = ex
3737

3838
## Allow an expression to be called, as with ex(2). When there is more than one symbol, one can rely on order of `free_symbols` or
3939
## be explicit by passing in pairs : `ex(x=>1, y=>2)` or a dict `ex(Dict(x=>1, y=>2))`.

src/types.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,30 +145,44 @@ end
145145
## Follow, somewhat, the python names: symbols to construct symbols, @vars
146146

147147
"""
148-
Macro to define 1 or more variables in the main workspace.
148+
@vars x y[1:5] z()
149149
150-
Symbolic values are defined with `_symbol`. This is a convenience
150+
Macro to define 1 or more variables or symbolic function
151151
152152
Example
153153
```
154154
@vars x y z
155+
@vars x[1:4]
156+
@vars u(), x
155157
```
158+
156159
"""
157-
macro vars(x...)
158-
q=Expr(:block)
159-
if length(x) == 1 && isa(x[1],Expr)
160-
@assert x[1].head === :tuple "@syms expected a list of symbols"
161-
x = x[1].args
160+
macro vars(xs...)
161+
# If the user separates declaration with commas, the top-level expression is a tuple
162+
if length(xs) == 1 && isa(xs[1], Expr) && xs[1].head == :tuple
163+
_gensyms(xs[1].args...)
164+
elseif length(xs) > 0
165+
_gensyms(xs...)
162166
end
163-
for s in x
164-
@assert isa(s,Symbol) "@syms expected a list of symbols"
165-
push!(q.args, Expr(:(=), esc(s), Expr(:call, :(SymEngine._symbol), Expr(:quote, s))))
167+
end
168+
169+
function _gensyms(xs...)
170+
asstokw(a) = Expr(:kw, esc(a), true)
171+
172+
# Each declaration is parsed and generates a declaration using `symbols`
173+
symdefs = map(xs) do expr
174+
decl = parsedecl(expr)
175+
symname = sym(decl)
176+
symname, gendecl(decl)
166177
end
167-
push!(q.args, Expr(:tuple, map(esc, x)...))
168-
q
178+
syms, defs = collect(zip(symdefs...))
179+
180+
# The macro returns a tuple of Symbols that were declared
181+
Expr(:block, defs..., :(tuple($(map(esc,syms)...))))
169182
end
170183

171184

185+
172186
## We also have a wrapper type that can be used to control dispatch
173187
## pros: wrapping adds overhead, so if possible best to use Basic
174188
## cons: have to write methods meth(x::Basic, ...) = meth(BasicType(x),...)

test/runtests.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ end
2121
@test_throws UndefVarError isdefined(w)
2222
@test_throws Exception show(Basic())
2323

24+
# test @vars constructions
25+
@vars a, b[0:4], c(), d=>"D"
26+
@test length(b) == 5
27+
@test isa(c, SymFunction)
28+
@test repr(d) == "D"
29+
2430
a = x^2 + x/2 - x*y*5
2531
b = diff(a, x)
2632
@test b == 2*x + 1//2 - 5*y
@@ -63,9 +69,6 @@ c = Basic(-5)
6369
@test abs(c) == 5
6470
@test abs(c) != 4
6571

66-
# test show
67-
a = x^2 + x/2 - x*y*5
68-
b = diff(a, x)
6972
repr("text/plain", a) == (1/2)*x - 5*x*y + x^2
7073
repr("text/plain", b) == 1/2 + 2*x - 5*y
7174

@@ -160,6 +163,7 @@ for val in samples
160163
@test subs(ex, x => val) == val^2 + y^2
161164
@test subs(ex, SymEngine.CMapBasicBasic(Dict(x=>val))) == val^2 + y^2
162165
@test subs(ex, Dict(x=>val)) == val^2 + y^2
166+
@test subs(ex) == ex
163167
end
164168
# This probably results in a number of redundant tests (operator order).
165169
for val1 in samples, val2 in samples

0 commit comments

Comments
 (0)