Skip to content

Commit 86f74c6

Browse files
Merge pull request #90 from oscarbenjamin/pr_generic_tests
Add generic tests for polynomial types
2 parents f84c5b9 + b28deef commit 86f74c6

File tree

6 files changed

+371
-59
lines changed

6 files changed

+371
-59
lines changed

src/flint/test/test.py

Lines changed: 205 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -742,9 +742,9 @@ def test_fmpq():
742742
assert raises(lambda: Q("1.0"), ValueError)
743743
assert raises(lambda: Q("1.5"), ValueError)
744744
assert raises(lambda: Q("1/2/3"), ValueError)
745-
assert raises(lambda: Q([]), ValueError)
746-
assert raises(lambda: Q(1, []), ValueError)
747-
assert raises(lambda: Q([], 1), ValueError)
745+
assert raises(lambda: Q([]), TypeError)
746+
assert raises(lambda: Q(1, []), TypeError)
747+
assert raises(lambda: Q([], 1), TypeError)
748748
assert bool(Q(0)) == False
749749
assert bool(Q(1)) == True
750750
assert Q(1,3) + Q(2,3) == 1
@@ -1049,9 +1049,8 @@ def test_fmpq_mat():
10491049
assert raises(lambda: Q(None), TypeError)
10501050
assert Q([[1,2,3],[4,5,6]]) == Q(2,3,[1,2,3,4,5,6])
10511051
assert raises(lambda: Q(2,3,[1,2,3,4,5]), ValueError)
1052-
# XXX: Should be TypeError not ValueError:
1053-
assert raises(lambda: Q([[1,2,3],[4,[],6]]), ValueError)
1054-
assert raises(lambda: Q(2,3,[1,2,3,4,[],6]), ValueError)
1052+
assert raises(lambda: Q([[1,2,3],[4,[],6]]), TypeError)
1053+
assert raises(lambda: Q(2,3,[1,2,3,4,[],6]), TypeError)
10551054
assert raises(lambda: Q(2,3,[1,2],[3,4]), ValueError)
10561055
assert bool(Q([[1]])) is True
10571056
assert bool(Q([[0]])) is False
@@ -1815,6 +1814,204 @@ def test_fmpz_mod_dlog():
18151814
assert g**x == a
18161815

18171816

1817+
def _all_polys():
1818+
return [
1819+
# (poly_type, scalar_type, is_field)
1820+
(flint.fmpz_poly, flint.fmpz, False),
1821+
(flint.fmpq_poly, flint.fmpq, True),
1822+
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True),
1823+
]
1824+
1825+
1826+
def test_polys():
1827+
for P, S, is_field in _all_polys():
1828+
1829+
assert P([S(1)]) == P([1]) == P(P([1])) == P(1)
1830+
1831+
assert raises(lambda: P([None]), TypeError)
1832+
assert raises(lambda: P(object()), TypeError)
1833+
assert raises(lambda: P(None), TypeError)
1834+
assert raises(lambda: P(None, None), TypeError)
1835+
assert raises(lambda: P([1,2], None), TypeError)
1836+
assert raises(lambda: P(1, None), TypeError)
1837+
1838+
assert len(P([])) == P([]).length() == 0
1839+
assert len(P([1])) == P([1]).length() == 1
1840+
assert len(P([1,2])) == P([1,2]).length() == 2
1841+
assert len(P([1,2,3])) == P([1,2,3]).length() == 3
1842+
1843+
assert P([]).degree() == -1
1844+
assert P([1]).degree() == 0
1845+
assert P([1,2]).degree() == 1
1846+
assert P([1,2,3]).degree() == 2
1847+
1848+
assert (P([1]) == P([1])) is True
1849+
assert (P([1]) != P([1])) is False
1850+
assert (P([1]) == P([2])) is False
1851+
assert (P([1]) != P([2])) is True
1852+
1853+
assert (P([1]) == None) is False
1854+
assert (P([1]) != None) is True
1855+
assert (None == P([1])) is False
1856+
assert (None != P([1])) is True
1857+
1858+
assert raises(lambda: P([1]) < P([1]), TypeError)
1859+
assert raises(lambda: P([1]) <= P([1]), TypeError)
1860+
assert raises(lambda: P([1]) > P([1]), TypeError)
1861+
assert raises(lambda: P([1]) >= P([1]), TypeError)
1862+
assert raises(lambda: P([1]) < None, TypeError)
1863+
assert raises(lambda: P([1]) <= None, TypeError)
1864+
assert raises(lambda: P([1]) > None, TypeError)
1865+
assert raises(lambda: P([1]) >= None, TypeError)
1866+
assert raises(lambda: None < P([1]), TypeError)
1867+
assert raises(lambda: None <= P([1]), TypeError)
1868+
assert raises(lambda: None > P([1]), TypeError)
1869+
assert raises(lambda: None >= P([1]), TypeError)
1870+
1871+
assert P([1, 2, 3])[1] == S(2)
1872+
assert P([1, 2, 3])[-1] == S(0)
1873+
assert P([1, 2, 3])[3] == S(0)
1874+
1875+
p = P([1, 2, 3])
1876+
p[1] = S(4)
1877+
assert p == P([1, 4, 3])
1878+
1879+
def setbad(obj, i, val):
1880+
obj[i] = val
1881+
1882+
assert raises(lambda: setbad(p, 2, None), TypeError)
1883+
assert raises(lambda: setbad(p, -1, 1), ValueError)
1884+
1885+
for v in [], [1], [1, 2]:
1886+
if P == flint.fmpz_poly:
1887+
assert P(v).repr() == f'fmpz_poly({v!r})'
1888+
elif P == flint.fmpq_poly:
1889+
assert P(v).repr() == f'fmpq_poly({v!r})'
1890+
else:
1891+
assert P(v).repr() == f'nmod_poly({v!r}, 17)'
1892+
1893+
assert repr(P([])) == '0'
1894+
assert repr(P([1])) == '1'
1895+
assert repr(P([1, 2])) == '2*x + 1'
1896+
assert repr(P([1, 2, 3])) == '3*x^2 + 2*x + 1'
1897+
1898+
p = P([1, 2, 3])
1899+
assert p(0) == p(S(0)) == S(1) == 1
1900+
assert p(1) == p(S(1)) == S(6) == 6
1901+
assert p(p) == P([6, 16, 36, 36, 27])
1902+
assert raises(lambda: p(None), TypeError)
1903+
1904+
assert bool(P([])) is False
1905+
assert bool(P([1])) is True
1906+
1907+
assert +P([1, 2, 3]) == P([1, 2, 3])
1908+
assert -P([1, 2, 3]) == P([-1, -2, -3])
1909+
1910+
assert P([1, 2, 3]) + P([4, 5, 6]) == P([5, 7, 9])
1911+
1912+
for T in [int, S, flint.fmpz]:
1913+
assert P([1, 2, 3]) + T(1) == P([2, 2, 3])
1914+
assert T(1) + P([1, 2, 3]) == P([2, 2, 3])
1915+
1916+
assert raises(lambda: P([1, 2, 3]) + None, TypeError)
1917+
assert raises(lambda: None + P([1, 2, 3]), TypeError)
1918+
1919+
assert P([1, 2, 3]) - P([4, 5, 6]) == P([-3, -3, -3])
1920+
1921+
for T in [int, S, flint.fmpz]:
1922+
assert P([1, 2, 3]) - T(1) == P([0, 2, 3])
1923+
assert T(1) - P([1, 2, 3]) == P([0, -2, -3])
1924+
1925+
assert raises(lambda: P([1, 2, 3]) - None, TypeError)
1926+
assert raises(lambda: None - P([1, 2, 3]), TypeError)
1927+
1928+
assert P([1, 2, 3]) * P([4, 5, 6]) == P([4, 13, 28, 27, 18])
1929+
1930+
for T in [int, S, flint.fmpz]:
1931+
assert P([1, 2, 3]) * T(2) == P([2, 4, 6])
1932+
assert T(2) * P([1, 2, 3]) == P([2, 4, 6])
1933+
1934+
assert raises(lambda: P([1, 2, 3]) * None, TypeError)
1935+
assert raises(lambda: None * P([1, 2, 3]), TypeError)
1936+
1937+
assert P([1, 2, 1]) // P([1, 1]) == P([1, 1])
1938+
assert P([1, 2, 1]) % P([1, 1]) == P([0])
1939+
assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0]))
1940+
1941+
if is_field:
1942+
assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2])
1943+
assert P([1, 1]) % 2 == P([0])
1944+
else:
1945+
assert P([1, 1]) // 2 == P([0, 0])
1946+
assert P([1, 1]) % 2 == P([1, 1])
1947+
1948+
assert 1 // P([1, 1]) == P([0])
1949+
assert 1 % P([1, 1]) == P([1])
1950+
assert divmod(1, P([1, 1])) == (P([0]), P([1]))
1951+
1952+
assert raises(lambda: P([1, 2, 1]) // None, TypeError)
1953+
assert raises(lambda: P([1, 2, 1]) % None, TypeError)
1954+
assert raises(lambda: divmod(P([1, 2, 1]), None), TypeError)
1955+
1956+
assert raises(lambda: None // P([1, 1]), TypeError)
1957+
assert raises(lambda: None % P([1, 1]), TypeError)
1958+
assert raises(lambda: divmod(None, P([1, 1])), TypeError)
1959+
1960+
assert raises(lambda: P([1, 2, 1]) // 0, ZeroDivisionError)
1961+
assert raises(lambda: P([1, 2, 1]) % 0, ZeroDivisionError)
1962+
assert raises(lambda: divmod(P([1, 2, 1]), 0), ZeroDivisionError)
1963+
1964+
assert raises(lambda: P([1, 2, 1]) // P([0]), ZeroDivisionError)
1965+
assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError)
1966+
assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError)
1967+
1968+
if is_field:
1969+
assert P([2, 2]) / 2 == P([1, 1])
1970+
assert P([1, 2]) / 2 == P([S(1)/2, 1])
1971+
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
1972+
else:
1973+
assert raises(lambda: P([2, 2]) / 2, TypeError)
1974+
1975+
assert raises(lambda: 1 / P([1, 1]), TypeError)
1976+
assert raises(lambda: P([1, 2, 1]) / P([1, 1]), TypeError)
1977+
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), TypeError)
1978+
1979+
assert P([1, 1]) ** 0 == P([1])
1980+
assert P([1, 1]) ** 1 == P([1, 1])
1981+
assert P([1, 1]) ** 2 == P([1, 2, 1])
1982+
assert raises(lambda: P([1, 1]) ** -1, ValueError)
1983+
assert raises(lambda: P([1, 1]) ** None, TypeError)
1984+
# XXX: Not sure what this should do in general:
1985+
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)
1986+
1987+
assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
1988+
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
1989+
1990+
if is_field:
1991+
p1 = P([1, 0, 1])
1992+
p2 = P([2, 1])
1993+
g, s, t = P([1]), P([1])/5, P([2, -1])/5
1994+
assert p1.xgcd(p2) == (g, s, t)
1995+
assert raises(lambda: p1.xgcd(None), TypeError)
1996+
1997+
assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)])
1998+
1999+
assert P([1, 2, 1]).sqrt() == P([1, 1])
2000+
assert P([1, 2, 2]).sqrt() is None
2001+
if P == flint.fmpq_poly:
2002+
assert P([1, 2, 1], 3).sqrt() is None
2003+
assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2)
2004+
2005+
assert P([]).deflation() == (P([]), 1)
2006+
assert P([1, 2]).deflation() == (P([1, 2]), 1)
2007+
assert P([1, 0, 2]).deflation() == (P([1, 2]), 2)
2008+
2009+
assert P([1, 2, 1]).derivative() == P([2, 2])
2010+
2011+
if is_field:
2012+
assert P([1, 2, 1]).integral() == P([0, 1, 1, S(1)/3])
2013+
2014+
18182015

18192016
all_tests = [
18202017
test_pyflint,
@@ -1835,5 +2032,6 @@ def test_fmpz_mod_dlog():
18352032
test_nmod_mat,
18362033
test_arb,
18372034
test_fmpz_mod,
1838-
test_fmpz_mod_dlog
2035+
test_fmpz_mod_dlog,
2036+
test_polys,
18392037
]

src/flint/types/fmpq.pyx

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ cdef class fmpq(flint_scalar):
7171
def __dealloc__(self):
7272
fmpq_clear(self.val)
7373

74-
def __init__(self, p=None, q=None):
75-
cdef long x
76-
if q is None:
77-
if p is None:
78-
return # zero
79-
elif typecheck(p, fmpq):
74+
def __init__(self, *args):
75+
if not args:
76+
return # zero
77+
elif len(args) == 2:
78+
p, q = args
79+
elif len(args) == 1:
80+
p = args[0]
81+
if typecheck(p, fmpq):
8082
fmpq_set(self.val, (<fmpq>p).val)
8183
return
8284
elif typecheck(p, str):
@@ -90,17 +92,21 @@ cdef class fmpq(flint_scalar):
9092
else:
9193
p = any_as_fmpq(p)
9294
if p is NotImplemented:
93-
raise ValueError("cannot create fmpq from object of type %s" % type(p))
95+
raise TypeError("cannot create fmpq from object of type %s" % type(p))
9496
fmpq_set(self.val, (<fmpq>p).val)
9597
return
98+
else:
99+
raise TypeError("fmpq() takes at most 2 arguments (%d given)" % len(args))
100+
96101
p = any_as_fmpz(p)
97102
if p is NotImplemented:
98-
raise ValueError("cannot create fmpq from object of type %s" % type(p))
103+
raise TypeError("cannot create fmpq from object of type %s" % type(p))
99104
q = any_as_fmpz(q)
100105
if q is NotImplemented:
101-
raise ValueError("cannot create fmpq from object of type %s" % type(q))
106+
raise TypeError("cannot create fmpq from object of type %s" % type(q))
102107
if fmpz_is_zero((<fmpz>q).val):
103108
raise ZeroDivisionError("cannot create rational number with zero denominator")
109+
104110
fmpz_set(fmpq_numref(self.val), (<fmpz>p).val)
105111
fmpz_set(fmpq_denref(self.val), (<fmpz>q).val)
106112
fmpq_canonicalise(self.val)

src/flint/types/fmpq_poly.pyx

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,26 @@ cdef class fmpq_poly(flint_poly):
7878
def __dealloc__(self):
7979
fmpq_poly_clear(self.val)
8080

81-
def __init__(self, p=None, q=None):
82-
if p is not None:
83-
if typecheck(p, fmpq_poly):
84-
fmpq_poly_set(self.val, (<fmpq_poly>p).val)
85-
elif typecheck(p, fmpz_poly):
86-
fmpq_poly_set_fmpz_poly(self.val, (<fmpz_poly>p).val)
87-
elif isinstance(p, list):
88-
fmpq_poly_set_list(self.val, p)
89-
else:
90-
raise TypeError("cannot create fmpq_poly from input of type %s", type(p))
91-
if q is not None:
92-
q = any_as_fmpz(q)
81+
def __init__(self, *args):
82+
if len(args) == 0:
83+
return
84+
elif len(args) > 2:
85+
raise TypeError("fmpq_poly() takes 0, 1 or 2 arguments (%d given)" % len(args))
86+
87+
p = args[0]
88+
if typecheck(p, fmpq_poly):
89+
fmpq_poly_set(self.val, (<fmpq_poly>p).val)
90+
elif typecheck(p, fmpz_poly):
91+
fmpq_poly_set_fmpz_poly(self.val, (<fmpz_poly>p).val)
92+
elif isinstance(p, list):
93+
fmpq_poly_set_list(self.val, p)
94+
elif (v := any_as_fmpq(p)) is not NotImplemented:
95+
fmpq_poly_set_fmpq(self.val, (<fmpq>v).val)
96+
else:
97+
raise TypeError("cannot create fmpq_poly from input of type %s", type(p))
98+
99+
if len(args) == 2:
100+
q = any_as_fmpz(args[1])
93101
if q is NotImplemented:
94102
raise TypeError("denominator must be an integer, got %s", type(q))
95103
if fmpz_is_zero((<fmpz>q).val):
@@ -326,12 +334,14 @@ cdef class fmpq_poly(flint_poly):
326334
return t
327335
return t._divmod_(s)
328336

329-
def __pow__(fmpq_poly self, ulong exp, mod):
337+
def __pow__(fmpq_poly self, exp, mod):
330338
cdef fmpq_poly res
331339
if mod is not None:
332340
raise NotImplementedError("fmpz_poly modular exponentiation")
341+
if exp < 0:
342+
raise ValueError("fmpq_poly negative exponent")
333343
res = fmpq_poly.__new__(fmpq_poly)
334-
fmpq_poly_pow(res.val, self.val, exp)
344+
fmpq_poly_pow(res.val, self.val, <ulong>exp)
335345
return res
336346

337347
def gcd(self, other):
@@ -384,6 +394,32 @@ cdef class fmpq_poly(flint_poly):
384394
fac[i] = (base, exp)
385395
return c / self.denom(), fac
386396

397+
def sqrt(self):
398+
"""
399+
Return the exact square root of this polynomial or ``None``.
400+
401+
>>> p = fmpq_poly([1,2,1],4)
402+
>>> p
403+
1/4*x^2 + 1/2*x + 1/4
404+
>>> p.sqrt()
405+
1/2*x + 1/2
406+
407+
"""
408+
d = self.denom()
409+
n = self.numer()
410+
d, r = d.sqrtrem()
411+
if r != 0:
412+
return None
413+
n = n.sqrt()
414+
if n is None:
415+
return None
416+
return fmpq_poly(n, d)
417+
418+
def deflation(self):
419+
num, n = self.numer().deflation()
420+
num = fmpq_poly(num, self.denom())
421+
return num, n
422+
387423
def complex_roots(self, **kwargs):
388424
"""
389425
Computes the complex roots of this polynomial. See

0 commit comments

Comments
 (0)