Skip to content

Commit dbbf8bb

Browse files
Merge pull request #93 from oscarbenjamin/pr_pow3
Fix pow(int, int, fmpz)
2 parents 98c2883 + 0990591 commit dbbf8bb

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

src/flint/test/test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,16 @@ def test_fmpz():
136136
(2, 2, 3, 1),
137137
(2, -1, 5, 3),
138138
(2, 0, 5, 1),
139+
(2, 5, 1000, 32),
139140
]
140141
for a, b, c, ab_mod_c in pow_mod_examples:
141142
assert pow(a, b, c) == ab_mod_c
142143
assert pow(flint.fmpz(a), b, c) == ab_mod_c
143144
assert pow(a, flint.fmpz(b), c) == ab_mod_c
145+
assert pow(a, b, flint.fmpz(c)) == ab_mod_c
144146
assert pow(flint.fmpz(a), flint.fmpz(b), c) == ab_mod_c
147+
assert pow(flint.fmpz(a), b, flint.fmpz(c)) == ab_mod_c
148+
assert pow(a, flint.fmpz(b), flint.fmpz(c)) == ab_mod_c
145149
assert pow(flint.fmpz(a), flint.fmpz(b), flint.fmpz(c)) == ab_mod_c
146150

147151
assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError)

src/flint/types/fmpz.pyx

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -360,53 +360,58 @@ cdef class fmpz(flint_scalar):
360360
return u
361361

362362
def __pow__(s, t, m):
363+
cdef fmpz_struct sval[1]
363364
cdef fmpz_struct tval[1]
364365
cdef fmpz_struct mval[1]
366+
cdef int stype = FMPZ_UNKNOWN
365367
cdef int ttype = FMPZ_UNKNOWN
366368
cdef int mtype = FMPZ_UNKNOWN
367369
cdef int success
368370
u = NotImplemented
369-
ttype = fmpz_set_any_ref(tval, t)
370-
if ttype == FMPZ_UNKNOWN:
371-
return NotImplemented
372371

373-
if m is None:
374-
# fmpz_pow_fmpz throws if x is negative
375-
if fmpz_sgn(tval) == -1:
376-
if ttype == FMPZ_TMP: fmpz_clear(tval)
377-
raise ValueError("negative exponent")
372+
try:
373+
stype = fmpz_set_any_ref(sval, s)
374+
if stype == FMPZ_UNKNOWN:
375+
return NotImplemented
376+
ttype = fmpz_set_any_ref(tval, t)
377+
if ttype == FMPZ_UNKNOWN:
378+
return NotImplemented
379+
if m is None:
380+
# fmpz_pow_fmpz throws if x is negative
381+
if fmpz_sgn(tval) == -1:
382+
raise ValueError("negative exponent")
378383

379-
u = fmpz.__new__(fmpz)
380-
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)
384+
u = fmpz.__new__(fmpz)
385+
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)
381386

382-
if not success:
383-
if ttype == FMPZ_TMP: fmpz_clear(tval)
384-
raise OverflowError("fmpz_pow_fmpz: exponent too large")
385-
else:
386-
# Modular exponentiation
387-
mtype = fmpz_set_any_ref(mval, m)
388-
if mtype != FMPZ_UNKNOWN:
387+
if not success:
388+
raise OverflowError("fmpz_pow_fmpz: exponent too large")
389+
390+
return u
391+
else:
392+
# Modular exponentiation
393+
mtype = fmpz_set_any_ref(mval, m)
394+
if mtype == FMPZ_UNKNOWN:
395+
return NotImplemented
389396

390397
if fmpz_is_zero(mval):
391-
if ttype == FMPZ_TMP: fmpz_clear(tval)
392-
if mtype == FMPZ_TMP: fmpz_clear(mval)
393398
raise ValueError("pow(): modulus cannot be zero")
394399

395400
# The Flint docs say that fmpz_powm will throw if m is zero
396401
# but it also throws if m is negative. Python generally allows
397402
# e.g. pow(2, 2, -3) == (2^2) % (-3) == -2. We could implement
398403
# that here as well but it is not clear how useful it is.
399404
if fmpz_sgn(mval) == -1:
400-
if ttype == FMPZ_TMP: fmpz_clear(tval)
401-
if mtype == FMPZ_TMP: fmpz_clear(mval)
402-
raise ValueError("pow(): negative modulua not supported")
405+
raise ValueError("pow(): negative modulus not supported")
403406

404407
u = fmpz.__new__(fmpz)
405-
fmpz_powm((<fmpz>u).val, (<fmpz>s).val, tval, mval)
408+
fmpz_powm((<fmpz>u).val, sval, tval, mval)
406409

407-
if ttype == FMPZ_TMP: fmpz_clear(tval)
408-
if mtype == FMPZ_TMP: fmpz_clear(mval)
409-
return u
410+
return u
411+
finally:
412+
if stype == FMPZ_TMP: fmpz_clear(sval)
413+
if ttype == FMPZ_TMP: fmpz_clear(tval)
414+
if mtype == FMPZ_TMP: fmpz_clear(mval)
410415

411416
def __rpow__(s, t, m):
412417
t = any_as_fmpz(t)

0 commit comments

Comments
 (0)