Skip to content

Commit 4cf00f1

Browse files
committed
Merge branch 'master' of github.com:symengine/SymEngine.jl into N_speedup
2 parents dc2c82d + d4f4c53 commit 4cf00f1

File tree

9 files changed

+269
-35
lines changed

9 files changed

+269
-35
lines changed

LICENSE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1818
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
1919
THE SOFTWARE.
2020

21+
=============================================================================
22+
23+
Some parts of src/decl.jl is from Symbolics.jl and SymPy.jl licensed under the
24+
same license with the copyrights
25+
26+
Copyright (c) <2013> <j verzani>
27+
Copyright (c) 2021: Shashi Gowda, Yingbo Ma, Chris Rackauckas, Julia Computing.

ext/SymEngineTermInterfaceExt.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module SymEngineTermInterfaceExt
22

33
import SymEngine
4-
import SymEngine: SymbolicType
54
import TermInterface
65

76

@@ -22,7 +21,27 @@ import TermInterface
2221
λ(::Val{:Csch}) = csch; λ(::Val{:Sech}) = sech; λ(::Val{:Coth}) = coth
2322
λ(::Val{:Asinh}) = asinh; λ(::Val{:Acosh}) = acosh; λ(::Val{:Atanh}) = atanh
2423
λ(::Val{:Acsch}) = acsch; λ(::Val{:Asech}) = asech; λ(::Val{:Acoth}) = acoth
25-
λ(::Val{:Gamma}) = gamma; λ(::Val{:Zeta}) = zeta; λ(::Val{:LambertW}) = lambertw
24+
λ(::Val{:ATan2}) = atan;
25+
λ(::Val{:Beta}) = SymEngine.SpecialFunctions.beta;
26+
λ(::Val{:Gamma}) = SymEngine.SpecialFunctions.gamma;
27+
λ(::Val{:PolyGamma}) = SymEngine.SpecialFunctions.polygamma;
28+
λ(::Val{:LogGamma}) = SymEngine.SpecialFunctions.loggamma;
29+
λ(::Val{:Erf}) = SymEngine.SpecialFunctions.erf;
30+
λ(::Val{:Erfc}) = SymEngine.SpecialFunctions.erfc;
31+
λ(::Val{:Zeta}) = SymEngine.SpecialFunctions.zeta;
32+
λ(::Val{:LambertW}) = SymEngine.SpecialFunctions.lambertw
33+
34+
35+
36+
const julia_operations = Vector{Any}(missing, length(SymEngine.symengine_classes))
37+
for (i,s) enumerate(SymEngine.symengine_classes)
38+
val = try
39+
λ(Val(s))
40+
catch err
41+
missing
42+
end
43+
julia_operations[i] = val
44+
end
2645

2746
#==
2847
Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression.
@@ -40,7 +59,7 @@ TermInterface.isexpr(x::SymEngine.SymbolicType) = TermInterface.iscall(x)
4059

4160
function TermInterface.operation(x::SymEngine.SymbolicType)
4261
TermInterface.iscall(x) || error("$(typeof(x)) doesn't have an operation!")
43-
return λ(x)
62+
return julia_operations[SymEngine.get_type(x) + 1]
4463
end
4564

4665
function TermInterface.arguments(x::SymEngine.SymbolicType)

src/SymEngine.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const libversion = get_libversion()
2323
include("exceptions.jl")
2424
include("types.jl")
2525
include("ctypes.jl")
26+
include("decl.jl")
2627
include("display.jl")
2728
include("mathops.jl")
2829
include("mathfuns.jl")

src/decl.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# !!! Note:
2+
# Many thanks to `@matthieubulte` for this contribution to `SymPy`.
3+
4+
# The map_subscripts function is stolen from Symbolics.jl
5+
const IndexMap = Dict{Char,Char}(
6+
'-' => '',
7+
'0' => '',
8+
'1' => '',
9+
'2' => '',
10+
'3' => '',
11+
'4' => '',
12+
'5' => '',
13+
'6' => '',
14+
'7' => '',
15+
'8' => '',
16+
'9' => '')
17+
18+
function map_subscripts(indices)
19+
str = string(indices)
20+
join(IndexMap[c] for c in str)
21+
end
22+
23+
# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later.
24+
abstract type VarDecl end
25+
26+
struct SymDecl <: VarDecl
27+
sym :: Symbol
28+
end
29+
30+
struct NamedDecl <: VarDecl
31+
name :: String
32+
rest :: VarDecl
33+
end
34+
35+
struct FunctionDecl <: VarDecl
36+
rest :: VarDecl
37+
end
38+
39+
struct TensorDecl <: VarDecl
40+
ranges :: Vector{AbstractRange}
41+
rest :: VarDecl
42+
end
43+
44+
struct AssumptionsDecl <: VarDecl
45+
assumptions :: Vector{Symbol}
46+
rest :: VarDecl
47+
end
48+
49+
# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol
50+
function gendecl(x::VarDecl)
51+
asstokw(a) = Expr(:kw, esc(a), true)
52+
val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...)))
53+
:($(esc(sym(x))) = $(genreshape(val, x)))
54+
end
55+
56+
# Transform an expression in a Decl struct
57+
function parsedecl(expr)
58+
# @vars x
59+
if isa(expr, Symbol)
60+
return SymDecl(expr)
61+
62+
# @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...)
63+
#= no assumptions in SymEngine
64+
elseif isa(expr, Expr) && expr.head == :(::)
65+
symexpr, assumptions = expr.args
66+
assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args
67+
return AssumptionsDecl(assumptions, parsedecl(symexpr))
68+
=#
69+
70+
# @vars x=>"name"
71+
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>)
72+
length(expr.args) == 3 || parseerror()
73+
isa(expr.args[3], String) || parseerror()
74+
75+
expr, strname = expr.args[2:end]
76+
return NamedDecl(strname, parsedecl(expr))
77+
78+
# @vars x()
79+
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>)
80+
length(expr.args) == 1 || parseerror()
81+
return FunctionDecl(parsedecl(expr.args[1]))
82+
83+
# @vars x[1:5, 3:9]
84+
elseif isa(expr, Expr) && expr.head == :ref
85+
length(expr.args) > 1 || parseerror()
86+
ranges = map(parserange, expr.args[2:end])
87+
return TensorDecl(ranges, parsedecl(expr.args[1]))
88+
else
89+
parseerror()
90+
end
91+
end
92+
93+
function parserange(expr)
94+
range = eval(expr)
95+
isa(range, AbstractRange) || parseerror()
96+
range
97+
end
98+
99+
sym(x::SymDecl) = x.sym
100+
sym(x::NamedDecl) = sym(x.rest)
101+
sym(x::FunctionDecl) = sym(x.rest)
102+
sym(x::TensorDecl) = sym(x.rest)
103+
sym(x::AssumptionsDecl) = sym(x.rest)
104+
105+
ctor(::SymDecl) = :symbols
106+
ctor(x::NamedDecl) = ctor(x.rest)
107+
ctor(::FunctionDecl) = :SymFunction
108+
ctor(x::TensorDecl) = ctor(x.rest)
109+
ctor(x::AssumptionsDecl) = ctor(x.rest)
110+
111+
assumptions(::SymDecl) = []
112+
assumptions(x::NamedDecl) = assumptions(x.rest)
113+
assumptions(x::FunctionDecl) = assumptions(x.rest)
114+
assumptions(x::TensorDecl) = assumptions(x.rest)
115+
assumptions(x::AssumptionsDecl) = x.assumptions
116+
117+
# Reshape is not used by most nodes, but TensorNodes require the output to be given
118+
# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should
119+
# have size(x) = (3, 5)
120+
genreshape(expr, ::SymDecl) = expr
121+
genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest)
122+
genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest)
123+
genreshape(expr, x::TensorDecl) = let
124+
shape = tuple(length.(x.ranges)...)
125+
:(reshape(collect($(expr)), $(shape)))
126+
end
127+
genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest)
128+
129+
# To find out the name, we need to traverse in both directions to make sure that each node can get
130+
# information from parents and children about possible name.
131+
# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl
132+
# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or
133+
# based on the SymDecl leaf.
134+
name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym))
135+
name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name)
136+
name(x::FunctionDecl, parentname) = name(x.rest, parentname)
137+
name(x::AssumptionsDecl, parentname) = name(x.rest, parentname)
138+
name(x::TensorDecl, parentname) = let
139+
basename = name(x.rest, parentname)
140+
# we need to double reverse the indices to make sure that we traverse them in the natural order
141+
namestensor = map(Iterators.product(x.ranges...)) do ind
142+
sub = join(map(map_subscripts, ind), "_")
143+
string(basename, sub)
144+
end
145+
join(namestensor[:], ", ")
146+
end
147+
148+
function parseerror()
149+
error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.")
150+
end

src/mathfuns.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,42 @@ for (meth, libnm, modu) in [
5151
(:acsch,:acsch,:Base),
5252
(:atanh,:atanh,:Base),
5353
(:acoth,:acoth,:Base),
54-
(:gamma,:gamma,:SpecialFunctions),
5554
(:log,:log,:Base),
5655
(:sqrt,:sqrt,:Base),
5756
(:exp,:exp,:Base),
5857
(:sign, :sign, :Base),
59-
(:eta,:dirichlet_eta,:SpecialFunctions),
60-
(:zeta,:zeta,:SpecialFunctions),
58+
(:ceil, :ceiling, :Base),
59+
(:floor, :floor, :Base)
6160
]
6261
eval(:(import $modu.$meth))
6362
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
6463
end
64+
65+
for (meth, libnm, modu) in [
66+
(:gamma,:gamma,:SpecialFunctions),
67+
(:loggamma,:loggamma,:SpecialFunctions),
68+
(:eta,:dirichlet_eta,:SpecialFunctions),
69+
(:zeta,:zeta,:SpecialFunctions),
70+
(:erf, :erf, :SpecialFunctions),
71+
(:erfc, :erfc, :SpecialFunctions)
72+
]
73+
eval(:(import $modu.$meth))
74+
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
75+
end
76+
77+
for (meth, libnm, modu) in [
78+
(:beta, :beta, :SpecialFunctions),
79+
(:polygamma, :polygamma, :SpecialFunctions),
80+
(:loggamma,:loggamma,:SpecialFunctions),
81+
]
82+
eval(:(import $modu.$meth))
83+
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
84+
end
85+
6586
Base.abs2(x::SymEngine.Basic) = abs(x)^2
6687

88+
89+
6790
if get_symbol(:basic_atan2) != C_NULL
6891
import Base.atan
6992
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)

src/numerics.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,6 @@ end
274274
trunc(x::Basic, args...) = Basic(trunc(N(x), args...))
275275
trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,args...))
276276

277-
ceil(x::Basic) = Basic(ceil(N(x)))
278-
ceil(::Type{T},x::Basic) where {T <: Integer} = convert(T, ceil(x))
279-
280-
floor(x::Basic) = Basic(floor(N(x)))
281-
floor(::Type{T},x::Basic) where {T <: Integer} = convert(T, floor(x))
282-
283277
round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...))
284278
round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...))
285279

src/subs.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d
3333
subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2])
3434
subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...)
3535
subs(ex::T, d::Pair...) where {T <: SymbolicType} = subs(ex, [(p.first, p.second) for p in d]...)
36-
36+
subs(ex::SymbolicType) = ex
3737

3838
## Allow an expression to be called, as with ex(2). When there is more than one symbol, one can rely on order of `free_symbols` or
3939
## be explicit by passing in pairs : `ex(x=>1, y=>2)` or a dict `ex(Dict(x=>1, y=>2))`.
@@ -62,26 +62,30 @@ fn_map = Dict(
6262

6363
map_fn(key, fn_map) = haskey(fn_map, key) ? fn_map[key] : Symbol(lowercase(string(key)))
6464

65+
const julia_classes = map_fn.(symengine_classes, (fn_map,))
66+
get_julia_class(x::Basic) = julia_classes[get_type(x) + 1]
67+
Base.nameof(ex::Basic) = Symbol(toString(ex))
68+
6569
function _convert(::Type{Expr}, ex::Basic)
6670
fn = get_symengine_class(ex)
6771

6872
if fn == :Symbol
69-
return Symbol(toString(ex))
73+
return nameof(ex)
7074
elseif (fn in number_types) || (fn == :Constant)
7175
return N(ex)
7276
end
7377

7478
as = get_args(ex)
75-
76-
Expr(:call, map_fn(fn, fn_map), [_convert(Expr,a) for a in as]...)
79+
fn′ = get_julia_class(ex)
80+
Expr(:call, fn′, [_convert(Expr,a) for a in as]...)
7781
end
7882

7983

8084
function convert(::Type{Expr}, ex::Basic)
8185
fn = get_symengine_class(ex)
8286

8387
if fn == :Symbol
84-
return Expr(:call, :*, Symbol(toString(ex)), 1)
88+
return Expr(:call, :*, nameof(ex), 1)
8589
elseif (fn in number_types) || (fn == :Constant)
8690
return Expr(:call, :*, N(ex), 1)
8791
end

0 commit comments

Comments
 (0)