Skip to content

fix(fmpz_mod): avoid using malloc/free #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/flint/types/fmpz_mod.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ from flint.flintlib.fmpz_mod cimport (
cdef class fmpz_mod_ctx:
cdef fmpz_mod_ctx_t val
cdef bint _is_prime
cdef fmpz_mod_discrete_log_pohlig_hellman_t *L
cdef bint _init_L
cdef fmpz_mod_discrete_log_pohlig_hellman_t L

cdef set_any_as_fmpz_mod(self, fmpz_t val, obj)
cdef any_as_fmpz_mod(self, obj)
cdef _precompute_dlog_prime(self)
cdef discrete_log_pohlig_hellman_run(self, fmpz_t x, fmpz_t y)

cdef class fmpz_mod(flint_scalar):
cdef fmpz_mod_ctx ctx
cdef fmpz_t val
cdef fmpz_t *x_g
cdef fmpz_t x_g
63 changes: 22 additions & 41 deletions src/flint/types/fmpz_mod.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from flint.flintlib.fmpz cimport(
fmpz_divexact,
fmpz_gcd,
fmpz_is_one,
fmpz_is_zero,
fmpz_randm
)
from flint.flintlib.fmpz cimport fmpz_mod as fmpz_type_mod
Expand Down Expand Up @@ -41,14 +42,12 @@ cdef class fmpz_mod_ctx:
cdef fmpz one = fmpz.__new__(fmpz)
fmpz_one(one.val)
fmpz_mod_ctx_init(self.val, one.val)
self.L = NULL
fmpz_mod_discrete_log_pohlig_hellman_clear(self.L)
self._is_prime = 0


def __dealloc__(self):
fmpz_mod_ctx_clear(self.val)
if self.L:
fmpz_mod_discrete_log_pohlig_hellman_clear(self.L[0])
fmpz_mod_discrete_log_pohlig_hellman_clear(self.L)

def __init__(self, mod):
# Ensure modulus is fmpz type
Expand Down Expand Up @@ -137,18 +136,15 @@ cdef class fmpz_mod_ctx:

return res

cdef _precompute_dlog_prime(self):
"""
Initalise the dlog data, all discrete logs are solved with an
internally chosen base `y`
"""
self.L = <fmpz_mod_discrete_log_pohlig_hellman_t *>libc.stdlib.malloc(
cython.sizeof(fmpz_mod_discrete_log_pohlig_hellman_struct)
)
fmpz_mod_discrete_log_pohlig_hellman_init(self.L[0])
fmpz_mod_discrete_log_pohlig_hellman_precompute_prime(
self.L[0], self.val.n
)
cdef discrete_log_pohlig_hellman_run(self, fmpz_t x, fmpz_t y):
# First, Ensure that L has performed precomputations This generates a
# base which is a primative root, and used as the base in
# fmpz_mod_discrete_log_pohlig_hellman_run
if not self._init_L:
fmpz_mod_discrete_log_pohlig_hellman_precompute_prime(self.L, self.val.n)
self._init_L = True

fmpz_mod_discrete_log_pohlig_hellman_run(x, self.L, y)

cdef set_any_as_fmpz_mod(self, fmpz_t val, obj):
# Try and convert obj to fmpz
Expand Down Expand Up @@ -235,13 +231,11 @@ cdef class fmpz_mod(flint_scalar):

def __cinit__(self):
fmpz_init(self.val)
self.x_g = NULL
fmpz_init(self.x_g)

def __dealloc__(self):
fmpz_clear(self.val)
if self.x_g:
fmpz_clear(self.x_g[0])
libc.stdlib.free(self.x_g)
fmpz_clear(self.x_g)

def __init__(self, val, ctx):
if not typecheck(ctx, fmpz_mod_ctx):
Expand Down Expand Up @@ -354,33 +348,20 @@ cdef class fmpz_mod(flint_scalar):
if a is NotImplemented:
raise TypeError(f"Cannot solve the discrete log with {type(a)} as input")

# First, Ensure that self.ctx.L has performed precomputations
# This generates a `y` which is a primative root, and used as
# the base in `fmpz_mod_discrete_log_pohlig_hellman_run`
if not self.ctx.L:
self.ctx._precompute_dlog_prime()

# Solve the discrete log for the chosen base and target
# g = y^x_g and a = y^x_a
# We want to find x such that a = g^x =>
# (y^x_a) = (y^x_g)^x => x = (x_a / x_g) mod (p-1)

# For repeated calls to discrete_log, it's more efficient to
# store x_g rather than keep computing it
if not self.x_g:
self.x_g = <fmpz_t *>libc.stdlib.malloc(
cython.sizeof(fmpz_t)
)
fmpz_mod_discrete_log_pohlig_hellman_run(
self.x_g[0], self.ctx.L[0], self.val
)
if fmpz_is_zero(self.x_g):
self.ctx.discrete_log_pohlig_hellman_run(self.x_g, self.val)

# Then we need to compute x_a which will be different for each call
cdef fmpz_t x_a
fmpz_init(x_a)
fmpz_mod_discrete_log_pohlig_hellman_run(
x_a, self.ctx.L[0], (<fmpz_mod>a).val
)
self.ctx.discrete_log_pohlig_hellman_run(x_a, (<fmpz_mod>a).val)

# If g is not a primative root, then x_g and pm1 will share
# a common factor. We can use this to compute the order of
Expand All @@ -390,14 +371,14 @@ cdef class fmpz_mod(flint_scalar):
fmpz_init(g_order)
fmpz_init(x_g)

fmpz_gcd(g, self.x_g[0], self.ctx.L[0].pm1)
fmpz_gcd(g, self.x_g, self.ctx.L.pm1)
if not fmpz_is_one(g):
fmpz_divexact(x_g, self.x_g[0], g)
fmpz_divexact(x_g, self.x_g, g)
fmpz_divexact(x_a, x_a, g)
fmpz_divexact(g_order, self.ctx.L[0].pm1, g)
fmpz_divexact(g_order, self.ctx.L.pm1, g)
else:
fmpz_set(g_order, self.ctx.L[0].pm1)
fmpz_set(x_g, self.x_g[0])
fmpz_set(g_order, self.ctx.L.pm1)
fmpz_set(x_g, self.x_g)

# Finally, compute output exponent by computing
# (x_a / x_g) mod g_order
Expand Down