Skip to content

Commit 267124c

Browse files
authored
Merge pull request #279 from jverzani/issue_225
Speed up diff; extend interface
2 parents 6519f4c + a48ef7d commit 267124c

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

src/calculus.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,64 @@ import Base: diff
55
## what is the rest of the interface. This does:
66
## diff(ex, x, n) f^(n)
77
## diff(ex, x, y, ...) f_{xy...} # also diff(ex, (x,y))
8-
## no support for diff(ex, x,n1, y,n2, ...), but can do diff(ex, (x,y), (n1, n2))
8+
## Support for diff(ex, x,n1, y,n2, ...),
9+
## but can also do diff(ex, (x,y), (n1, n2))
910

10-
function diff(b1::SymbolicType, b2::BasicType{Val{:Symbol}})
11-
a = Basic()
11+
12+
function diff!(a::Basic, b1::SymbolicType, b2::Basic)
13+
is_symbol(b2) || throw(ArgumentError("Must differentiate with respect to a symbol"))
1214
ret = ccall((:basic_diff, libsymengine), Int, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
1315
return a
1416
end
1517

16-
diff(b1::SymbolicType, b2::BasicType) =
17-
throw(ArgumentError("Second argument must be of Symbol type"))
18+
function diff(b1::SymbolicType, b2::Basic)
19+
a = Basic()
20+
diff!(a, b1, b2)
21+
a
22+
end
1823

19-
function diff(b1::SymbolicType, b2::SymbolicType, n::Integer=1)
24+
function diff(b1::SymbolicType, b2::SymbolicType, n::Integer)
2025
n < 0 && throw(DomainError("n must be non-negative integer"))
21-
n==0 && return b1
22-
n==1 && return diff(b1, BasicType(b2))
23-
n > 1 && return diff(diff(b1, BasicType(b2)), BasicType(b2), n-1)
26+
n == 0 && return b1
27+
x = Basic(b2)
28+
out = Basic()
29+
diff!(out, b1, x)
30+
for _ in (n-1):-1:1
31+
diff!(out, out, x)
32+
end
33+
out
34+
end
35+
36+
function diff(b1::SymbolicType, b2::SymbolicType, n::Integer, xs...)
37+
diff(diff(b1,b2,n), xs...)
2438
end
2539

2640
function diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType)
27-
isa(BasicType(b3), BasicType{Val{:Integer}}) ? diff(b1, b2, N(b3)) : diff(b1, (b2, b3))
41+
if isinteger(b3)
42+
n = N(b3)::Int
43+
diff(b1, b2, n)
44+
else
45+
ex = diff(b1, b2)
46+
diff(ex, b3)
47+
end
48+
end
49+
50+
function diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType, bs...)
51+
diff(diff(b1,b2,b3), bs...)
2852
end
2953

30-
diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType, b4::SymbolicType, b5...) =
31-
diff(b1, (b2,b3,b4,b5...))
54+
function diff(b1::SymbolicType)
55+
xs = free_symbols(b1)
56+
n = length(xs)
57+
n == 0 && return zero(b1)
58+
n > 1 && throw(ArgumentError("More than one variable; one must be specified"))
59+
diff(b1, only(xs))
60+
end
61+
62+
## deprecate
63+
diff(b1::SymbolicType, b2::BasicType{Val{:Symbol}}) = diff(b1, Basic(b2))
64+
diff(b1::SymbolicType, b2::BasicType) =
65+
throw(ArgumentError("Second argument must be of Symbol type"))
3266

3367
## mixed partials
3468
diff(ex::SymbolicType, bs::Tuple) = reduce((ex, x) -> diff(ex, x), bs, init=ex)

test/runtests.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,22 @@ u,v,w = x(2.1), x(1), x(0)
101101

102102
## calculus
103103
x,y = symbols("x y")
104+
@test diff(log(x)) == 1/x
105+
@test diff(log(x),x) == 1/x
106+
@test_throws ArgumentError diff(log(x), x^2)
107+
104108
n = Basic(2)
105109
ex = sin(x*y)
106-
@test diff(log(x),x) == 1/x
110+
@test_throws ArgumentError diff(ex)
107111
@test diff(ex, x) == y * cos(x*y)
108112
@test diff(ex, x, 2) == diff(diff(ex,x), x)
109113
@test diff(ex, x, n) == diff(diff(ex,x), x)
110114
@test diff(ex, x, y) == diff(diff(ex,x), y)
111-
@test diff(ex, x, y,x) == diff(diff(diff(ex,x), y), x)
115+
@test diff(ex, x, y, x) == diff(diff(diff(ex,x), y), x)
116+
@test diff(ex, x, 2, y, 3) == diff(ex, x,x,y,y,y)
117+
@test diff(ex, x, n, y, 3) == diff(ex, x,x,y,y,y)
118+
@test diff(ex, x, 2, y, x) == diff(ex, x,x,x,y)
119+
112120
@test series(sin(x), x, 0, 2) == x
113121
@test series(sin(x), x, 0, 3) == x - x^3/6
114122

0 commit comments

Comments
 (0)