diff --git a/src/flint/test/test.py b/src/flint/test/test.py index c19a96d0..67263ca8 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -1279,6 +1279,7 @@ def test_nmod(): assert G(1,2) != G(0,2) assert G(0,2) != G(0,3) assert G(3,5) == G(8,5) + assert isinstance(hash(G(3, 5)), int) assert raises(lambda: G([], 3), TypeError) #assert G(3,5) == 8 # do we want this? #assert 8 == G(3,5) @@ -1300,9 +1301,14 @@ def test_nmod(): assert G(3,17) / G(2,17) == G(10,17) assert G(3,17) / 2 == G(10,17) assert 3 / G(2,17) == G(10,17) + assert G(0,3) / G(1,3) == G(0,3) assert G(3,17) * flint.fmpq(11,5) == G(10,17) assert G(3,17) / flint.fmpq(11,5) == G(6,17) + assert G(1,3) ** 2 == G(1,3) + assert G(2,3) ** flint.fmpz(2) == G(1,3) assert G(flint.fmpq(2, 3), 5) == G(4,5) + assert raises(lambda: G(2,5) ** G(2,5), TypeError) + assert raises(lambda: flint.fmpz(2) ** G(2,5), TypeError) assert raises(lambda: G(flint.fmpq(2, 3), 3), ZeroDivisionError) assert raises(lambda: G(2,5) / G(0,5), ZeroDivisionError) assert raises(lambda: G(2,5) / 0, ZeroDivisionError) @@ -1314,10 +1320,12 @@ def test_nmod(): assert raises(lambda: G(2,5) - [], TypeError) assert raises(lambda: G(2,5) * [], TypeError) assert raises(lambda: G(2,5) / [], TypeError) + assert raises(lambda: G(2,5) ** [], TypeError) assert raises(lambda: [] + G(2,5), TypeError) assert raises(lambda: [] - G(2,5), TypeError) assert raises(lambda: [] * G(2,5), TypeError) assert raises(lambda: [] / G(2,5), TypeError) + assert raises(lambda: [] ** G(2,5), TypeError) assert G(3,17).modulus() == 17 assert str(G(3,5)) == "3" assert G(3,5).repr() == "nmod(3, 5)" diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index e5e73eb6..87abfb5b 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -6,6 +6,7 @@ from flint.types.fmpz cimport fmpz from flint.types.fmpq cimport fmpq from flint.flintlib.fmpz cimport fmpz_t +from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv from flint.flintlib.nmod_vec cimport * from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui @@ -89,6 +90,9 @@ cdef class nmod(flint_scalar): return not res return NotImplemented + def __hash__(self): + return hash((int(self.val), self.modulus)) + def __nonzero__(self): return self.val != 0 @@ -178,6 +182,8 @@ cdef class nmod(flint_scalar): return NotImplemented if tval == 0: raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) + if not s: + return s # XXX: check invertibility? x = nmod_div(sval, tval, mod) if x == 0: @@ -195,3 +201,17 @@ cdef class nmod(flint_scalar): def __invert__(self): return (1 / self) # XXX: speed up + + def __pow__(self, exp): + cdef nmod r + e = any_as_fmpz(exp) + if e is NotImplemented: + return NotImplemented + r = nmod.__new__(nmod) + r.mod = self.mod + r.val = self.val + if e < 0: + r.val = nmod_inv(r.val, self.mod) + e = -e + r.val = nmod_pow_fmpz(r.val, (e).val, self.mod) + return r