diff --git a/doc/source/general.rst b/doc/source/general.rst index fb2ec9ac..797bb67d 100644 --- a/doc/source/general.rst +++ b/doc/source/general.rst @@ -55,6 +55,28 @@ The special method ``ctx.cleanup()`` frees up internal caches used by MPFR, FLINT and Arb. The user does normally not have to worry about this. +The context object ``flint.ctx`` can be controlled locally to increase the +working precision using python context managers:: + + >>> arb(2).sqrt() + [1.41421356237309 +/- 5.15e-15] + >>> with ctx.extraprec(15): + ... arb(2).sqrt() + ... + [1.414213562373095049 +/- 2.10e-19] + +In the same manner, it is possible to exactly set the working precision, +or to update it in terms of digits:: + + >>> with ctx.extradps(15): + ... arb(2).sqrt() + ... + [1.41421356237309504880168872421 +/- 6.27e-31] + >>> with ctx.workprec(15): + ... arb(2).sqrt() + ... + [1.414 +/- 2.46e-4] + Types and methods ----------------- diff --git a/src/flint/flint_base/flint_context.pyx b/src/flint/flint_base/flint_context.pyx index c4c02df6..324510cb 100644 --- a/src/flint/flint_base/flint_context.pyx +++ b/src/flint/flint_base/flint_context.pyx @@ -6,6 +6,8 @@ from flint.flintlib.types.flint cimport ( ) from flint.utils.conversion cimport prec_to_dps, dps_to_prec +from functools import wraps + cdef class FlintContext: def __init__(self): self.default() @@ -57,6 +59,88 @@ cdef class FlintContext: assert num >= 1 and num <= 64 flint_set_num_threads(num) + def extraprec(self, n): + """ + Adds n bits of precision to the current flint context. + + >>> from flint import arb, ctx + >>> with ctx.extraprec(5): x = arb(2).sqrt().str() + >>> x + '[1.414213562373095 +/- 5.53e-17]' + + This function also works as a wrapper: + + >>> from flint import arb, ctx + >>> @ctx.extraprec(10) + ... def f(x): + ... return x.sqrt().str() + >>> f(arb(2)) + '[1.41421356237309505 +/- 1.46e-18]' + """ + return self.workprec(n + self.prec) + + def extradps(self, n): + """ + Adds n digits of precision to the current flint context. + + >>> from flint import arb, ctx + >>> with ctx.extradps(5): x = arb(2).sqrt().str() + >>> x + '[1.4142135623730950488 +/- 2.76e-21]' + + This function also works as a wrapper: + + >>> from flint import arb, ctx + >>> @ctx.extradps(10) + ... def f(x): + ... return x.sqrt().str() + >>> f(arb(2)) + '[1.414213562373095048801689 +/- 3.13e-25]' + """ + return self.workdps(n + self.dps) + + def workprec(self, n): + """ + Sets the working precision for the current flint context, + using a python context manager. + + >>> from flint import arb, ctx + >>> with ctx.workprec(5): x = arb(2).sqrt().str() + >>> x + '[1e+0 +/- 0.438]' + + This function also works as a wrapper: + + >>> from flint import arb, ctx + >>> @ctx.workprec(24) + ... def f(x): + ... return x.sqrt().str() + >>> f(arb(2)) + '[1.41421 +/- 3.66e-6]' + """ + return PrecisionManager(self, eprec=n) + + def workdps(self, n): + """ + Sets the working precision in digits for the current + flint context, using a python context manager. + + >>> from flint import arb, ctx + >>> with ctx.workdps(5): x = arb(2).sqrt().str() + >>> x + '[1.4142 +/- 1.51e-5]' + + This function also works as a wrapper: + + >>> from flint import arb, ctx + >>> @ctx.workdps(10) + ... def f(x): + ... return x.sqrt().str() + >>> f(arb(2)) + '[1.414213562 +/- 3.85e-10]' + """ + return PrecisionManager(self, edps=n) + def __repr__(self): return "pretty = %-8s # pretty-print repr() output\n" \ "unicode = %-8s # use unicode characters in output\n" \ @@ -69,4 +153,51 @@ cdef class FlintContext: def cleanup(self): flint_cleanup() + +cdef class PrecisionManager: + cdef FlintContext ctx + cdef int eprec + cdef int edps + cdef int _oldprec + + def __init__(self, ctx, eprec=-1, edps=-1): + if eprec != -1 and edps != -1: + raise ValueError("two different precisions requested") + + self.ctx = ctx + + self.eprec = eprec + self.edps = edps + + def __call__(self, func): + @wraps(func) + def wrapped(*args, **kwargs): + _oldprec = self.ctx.prec + + try: + if self.eprec != -1: + self.ctx.prec = self.eprec + + if self.edps != -1: + self.ctx.dps = self.edps + + return func(*args, **kwargs) + finally: + self.ctx.prec = _oldprec + + return wrapped + + def __enter__(self): + self._oldprec = self.ctx.prec + + if self.eprec != -1: + self.ctx.prec = self.eprec + + if self.edps != -1: + self.ctx.dps = self.edps + + def __exit__(self, type, value, traceback): + self.ctx.prec = self._oldprec + + cdef FlintContext thectx = FlintContext()