Skip to content

Commit c0ea2bf

Browse files
committed
use multiple dispatch
1 parent 4cf00f1 commit c0ea2bf

File tree

1 file changed

+46
-67
lines changed

1 file changed

+46
-67
lines changed

src/numerics.jl

Lines changed: 46 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -63,43 +63,9 @@ N(a::Integer) = a
6363
N(a::Rational) = a
6464
N(a::Complex) = a
6565

66-
function N(b::Basic)
67-
if is_a_Integer(b)
68-
return _N_Integer(b)
69-
elseif is_a_Rational(b)
70-
return _N_Rational(b)
71-
elseif is_a_RealDouble(b)
72-
return _N_RealDouble(b)
73-
elseif is_a_RealMPFR(b)
74-
return _N_RealMPFR(b)
75-
elseif is_a_Complex(b) || is_a_ComplexDouble(b) || is_a_ComplexMPC(b)
76-
return complex(N(real(b)), N(imag(b)))
77-
elseif isnan(b)
78-
return _N_NaN(b)
79-
elseif b == oo
80-
return Inf
81-
elseif b == zoo
82-
return Complex(Inf,Inf)
83-
elseif b == PI
84-
return π
85-
elseif b == EulerGamma
86-
return γ
87-
elseif b == E
88-
return
89-
elseif b == Catalan
90-
return catalan
91-
elseif b == GoldenRatio
92-
return φ
93-
else
94-
is_constant(b) ||
95-
throw(ArgumentError("Object can have no free symbols"))
96-
out = evalf(b)
97-
imag(out) == Basic(0.0) ? N(real(out)) : N(out)
98-
end
99-
end
66+
N(b::Basic) = N(b, Val{get_symengine_class(b)}())
10067

101-
102-
function _N_Integer(b::Basic)
68+
function N(b::Basic, ::Val{:Integer})
10369
a = _convert(BigInt, b)
10470
if (a.size > 1 || a.size < -1)
10571
return a
@@ -119,12 +85,51 @@ function _N_Integer(b::Basic)
11985
end
12086

12187
# TODO: conditionally wrap rational_get_mpq from cwrapper.h
122-
_N_Rational(b::Basic) = Rational(N(numerator(b)), N(denominator(b)))
123-
_N_RealDouble(b::Basic) = _convert(Cdouble, b)
124-
_N_RealMPFR(b::Basic) = _convert(BigFloat, b)
125-
_N_NaN(b::Basic) = NaN
126-
_N_ComplexNumber(b::Basic) = complex(N(real(b)), N(imag(b)))
88+
N(b::Basic, ::Val{:Rational}) = Rational(N(numerator(b)), N(denominator(b)))
89+
N(b::Basic, ::Val{:RealDouble}) = _convert(Cdouble, b)
90+
N(b::Basic, ::Val{:RealMPFR}) = _convert(BigFloat, b)
91+
N(b::Basic, ::Val{:NaN}) = NaN
92+
N(b::Basic, ::Val{:Complex}) = complex(N(real(b)), N(imag(b)))
93+
N(b::Basic, ::Val{:ComplexDouble}) = complex(N(real(b)), N(imag(b)))
94+
N(b::Basic, ::Val{:ComplexMPC}) = complex(N(real(b)), N(imag(b)))
95+
96+
function N(b::Basic, ::Val{:Infty})
97+
if b == oo
98+
return Inf
99+
elseif b == zoo
100+
return Complex(Inf,Inf)
101+
elseif b == -oo
102+
return -Inf
103+
else
104+
throw(ArgumentError("Unknown infinity symbol"))
105+
end
106+
end
107+
108+
function N(b::Basic, ::Val{:Constant})
109+
if b == PI
110+
return π
111+
elseif b == EulerGamma
112+
return γ
113+
elseif b == E
114+
return
115+
elseif b == Catalan
116+
return catalan
117+
elseif b == GoldenRatio
118+
return φ
119+
else
120+
throw(ArgumentError("Unknown constant"))
121+
end
122+
end
123+
124+
function N(b::Basic, v)
125+
is_constant(b) ||
126+
throw(ArgumentError("Object can have no free symbols"))
127+
out = evalf(b)
128+
imag(out) == Basic(0.0) ? N(real(out)) : N(out)
129+
end
127130

131+
## deprecate N(::BasicType)
132+
N(b::BasicType{T}) where {T} = N(convert(Basic, b), T)
128133

129134
## define convert(T, x) methods leveraging N() when needed
130135
function convert(::Type{Float64}, x::Basic)
@@ -291,32 +296,6 @@ eps(::Type{BasicType{Val{:ComplexDouble}}}) = 2^-52
291296
eps(x::BasicType{Val{:RealMPFR}}) = evalf(Basic(2), prec(x), true) ^ (-prec(x)+1)
292297
eps(x::BasicType{Val{:ComplexMPFR}}) = eps(real(x))
293298

294-
295-
296-
## deprecate N(::BasicType)
297-
#N(b::Basic) = N(BasicType(b))
298-
function N(b::BasicType{Val{:Infty}})
299-
b == oo && return Inf
300-
b == -oo && return -Inf
301-
b == zoo && return Complex(Inf, Inf)
302-
end
303-
304-
## Mapping of SymEngine Constants into julia values
305-
constant_map = Dict("pi" => π, "eulergamma" => γ, "exp(1)" => e, "catalan" => catalan,
306-
"goldenratio" => φ)
307-
308-
N(b::BasicType{Val{:Constant}}) = constant_map[toString(b)]
309-
310-
function N(b::BasicType)
311-
b = convert(Basic, b)
312-
fs = free_symbols(b)
313-
if length(fs) > 0
314-
throw(ArgumentError("Object can have no free symbols"))
315-
end
316-
out = evalf(b)
317-
imag(out) == Basic(0.0) ? real(out) : out
318-
end
319-
320299
## convert from BasicType
321300
function convert(::Type{BigInt}, b::BasicType{Val{:Integer}})
322301
_convert(BigInt, Basic(b))

0 commit comments

Comments
 (0)