diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 1c1297d88b..a8e293d6e6 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -34,6 +34,7 @@ upgrade_to_float64, upgrade_to_float_no_complex, ) +from pytensor.scalar.scan import ScalarScanOp class Erf(UnaryScalarOp): @@ -751,87 +752,172 @@ def c_code(self, *args, **kwargs): gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der") -class GammaIncCDer(BinaryScalarOp): - """ - Gradient of the the regularized upper gamma function (Q) wrt to the first - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` - """ - - @staticmethod - def st_impl(k, x): - gamma_k = scipy.special.gamma(k) - digamma_k = scipy.special.digamma(k) - log_x = np.log(x) - - # asymptotic expansion http://dlmf.nist.gov/8.11#E2 - if (x >= k) and (x >= 8): - S = 0 - k_minus_one_minus_n = k - 1 - fac = k_minus_one_minus_n - dfac = 1 - xpow = x +# class GammaIncCDer(BinaryScalarOp): +# """ +# Gradient of the the regularized upper gamma function (Q) wrt to the first +# argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` +# """ +# +# @staticmethod +# def st_impl(k, x): +# gamma_k = scipy.special.gamma(k) +# digamma_k = scipy.special.digamma(k) +# log_x = np.log(x) +# +# # asymptotic expansion http://dlmf.nist.gov/8.11#E2 +# if (x >= k) and (x >= 8): +# S = 0 +# k_minus_one_minus_n = k - 1 +# fac = k_minus_one_minus_n +# dfac = 1 +# xpow = x +# delta = dfac / xpow +# +# for n in range(1, 10): +# k_minus_one_minus_n -= 1 +# S += delta +# xpow *= x +# dfac = k_minus_one_minus_n * dfac + fac +# fac *= k_minus_one_minus_n +# delta = dfac / xpow +# if np.isinf(delta): +# warnings.warn( +# "gammaincc_der did not converge", +# RuntimeWarning, +# ) +# return np.nan +# +# return ( +# scipy.special.gammaincc(k, x) * (log_x - digamma_k) +# + np.exp(-x + (k - 1) * log_x) * S / gamma_k +# ) +# +# # gradient of series expansion http://dlmf.nist.gov/8.7#E3 +# else: +# log_precision = np.log(1e-6) +# max_iters = int(1e5) +# S = 0 +# log_s = 0.0 +# s_sign = 1 +# log_delta = log_s - 2 * np.log(k) +# for n in range(1, max_iters + 1): +# S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) +# s_sign = -s_sign +# log_s += log_x - np.log(n) +# log_delta = log_s - 2 * np.log(n + k) +# +# if np.isinf(log_delta): +# warnings.warn( +# "gammaincc_der did not converge", +# RuntimeWarning, +# ) +# return np.nan +# +# if log_delta <= log_precision: +# return ( +# scipy.special.gammainc(k, x) * (digamma_k - log_x) +# + np.exp(k * log_x) * S / gamma_k +# ) +# +# warnings.warn( +# f"gammaincc_der did not converge after {n} iterations", +# RuntimeWarning, +# ) +# return np.nan +# +# def impl(self, k, x): +# return self.st_impl(k, x) +# +# def c_code(self, *args, **kwargs): +# raise NotImplementedError() +# +# +# gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") + + +class GammaIncCDerInnerScan1(ScalarScanOp): + nin = 7 + nout = 6 + n_steps = 9 + + @property + def fn(self): + def inner_fn(S, delta, xpow, k_minus_one_minus_n, dfac, fac, x): + S += delta + xpow *= x + k_minus_one_minus_n -= 1 + dfac = k_minus_one_minus_n * dfac + fac + fac *= k_minus_one_minus_n delta = dfac / xpow + return S, delta, xpow, k_minus_one_minus_n, dfac, fac - for n in range(1, 10): - k_minus_one_minus_n -= 1 - S += delta - xpow *= x - dfac = k_minus_one_minus_n * dfac + fac - fac *= k_minus_one_minus_n - delta = dfac / xpow - if np.isinf(delta): - warnings.warn( - "gammaincc_der did not converge", - RuntimeWarning, - ) - return np.nan + return inner_fn - return ( - scipy.special.gammaincc(k, x) * (log_x - digamma_k) - + np.exp(-x + (k - 1) * log_x) * S / gamma_k - ) - # gradient of series expansion http://dlmf.nist.gov/8.7#E3 - else: - log_precision = np.log(1e-6) - max_iters = int(1e5) - S = 0 - log_s = 0.0 - s_sign = 1 - log_delta = log_s - 2 * np.log(k) - for n in range(1, max_iters + 1): - S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) - s_sign = -s_sign - log_s += log_x - np.log(n) - log_delta = log_s - 2 * np.log(n + k) - - if np.isinf(log_delta): - warnings.warn( - "gammaincc_der did not converge", - RuntimeWarning, - ) - return np.nan - - if log_delta <= log_precision: - return ( - scipy.special.gammainc(k, x) * (digamma_k - log_x) - + np.exp(k * log_x) * S / gamma_k - ) +_gammaincc_der_scan1 = GammaIncCDerInnerScan1() - warnings.warn( - f"gammaincc_der did not converge after {n} iterations", - RuntimeWarning, - ) - return np.nan - def impl(self, k, x): - return self.st_impl(k, x) +class GammaIncCDerInnerScan2(ScalarScanOp): + nin = 7 + nout = 5 + n_steps = int(1e5) # maximum number of iterations + log_precision = np.log(1e-6) - def c_code(self, *args, **kwargs): - raise NotImplementedError() + @property + def fn(self): + import pytensor.tensor as pt + from pytensor.scan import until + def inner_fn(S, log_s, s_sign, log_delta, n, k, log_x): + delta = pt.exp(log_delta) + S += pt.switch(s_sign > 0, delta, -delta) + s_sign = -s_sign + log_s += log_x - pt.log(n) + log_delta = log_s - 2 * pt.log(n + k) + n += 1 + return ( + (S, log_s, s_sign, log_delta, n), + {}, + until(pt.all(log_delta < self.log_precision)), + ) -gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") + return inner_fn + + +_gammaincc_der_scan2 = GammaIncCDerInnerScan2() + + +def gammaincc_der(k, x): + gamma_k = gamma(k) + digamma_k = psi(k) + log_x = log(x) + + # asymptotic expansion http://dlmf.nist.gov/8.11#E2 + S = np.array(0.0, dtype="float64") + dfac = np.array(1.0, dtype="float64") + xpow = x + k_minus_one_minus_n = k - 1 + fac = k_minus_one_minus_n + delta = true_div(dfac, xpow) + S, *_ = _gammaincc_der_scan1(S, delta, xpow, k_minus_one_minus_n, fac, dfac, x) + res1 = ( + gammaincc(k, x) * (log_x - digamma_k) + exp(-x + (k - 1) * log_x) * S / gamma_k + ) + + # gradient of series expansion http://dlmf.nist.gov/8.7#E3 + S = np.array(0.0, dtype="float64") + log_s = np.array(0.0, dtype="float64") + s_sign = np.array(1, dtype="int8") + n = np.array(1, dtype="int64") + log_delta = log_s - 2 * log(k) + S, *_ = _gammaincc_der_scan2(S, log_s, s_sign, log_delta, n, k, log_x) + res2 = gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * S / gamma_k + + return switch( + (x >= k) & (x >= 8), + res1, + res2, + ) class GammaU(BinaryScalarOp): diff --git a/pytensor/scalar/scan.py b/pytensor/scalar/scan.py new file mode 100644 index 0000000000..5c26082566 --- /dev/null +++ b/pytensor/scalar/scan.py @@ -0,0 +1,23 @@ +from pytensor.scalar.basic import ScalarOp, same_out + + +class ScalarScanOp(ScalarOp): + """Dummy Scalar Op that encapsulates a scalar scan operation. + + This Op is never supposed to be evaluated. It can safely be converted + to an Elemwise which is rewritten into a Scan node during compilation. + + TODO: FINISH DOCSTRINGS + TODO: ABC for fn property + """ + + def __init__(self, output_types_preference=None, **kwargs): + if output_types_preference is None: + + def output_types_preference(*types): + return tuple(same_out(type)[0] for type in types[: self.nout]) + + super().__init__(output_types_preference=output_types_preference, **kwargs) + + def impl(self, *args, **kwargs): + raise RuntimeError("Scalar Scan Ops should never be evaluated!") diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 4ac1ffdd33..1d737b7d08 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -638,6 +638,9 @@ def transform(r): return DimShuffle((), ["x"] * nd)(res) new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs]) + if isinstance(new_r, (list, tuple)): + # Scalar Op with multiple outputs + new_r = new_r[r.owner.outputs.index(r)] return new_r ret = [] diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e9952a3908..4d553b100e 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -7,7 +7,7 @@ import pytensor import pytensor.scalar.basic as aes from pytensor import compile -from pytensor.compile.mode import get_target_language +from pytensor.compile.mode import get_target_language, optdb from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Constant, io_toposort from pytensor.graph.features import ReplaceValidate @@ -20,11 +20,14 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from pytensor.scalar import ScalarScanOp from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.shape import shape_padleft +from pytensor.tensor.subtensor import IncSubtensor from pytensor.tensor.var import TensorConstant @@ -1025,3 +1028,57 @@ def local_careduce_fusion(fgraph, node): "fusion", position=49, ) + + +@node_rewriter([Elemwise]) +def inline_elemwise_scan(fgraph, node): + from pytensor.scan.basic import scan + from pytensor.scan.utils import expand_empty + + scalar_op = node.op.scalar_op + + if not isinstance(scalar_op, ScalarScanOp): + return None + + # TODO: Add non-batched implementation? That should be better for scans with big difference in required n_steps + bcasted_inputs = broadcast_arrays(*node.inputs) + ret, updates = scan( + scalar_op.fn, + outputs_info=bcasted_inputs[: scalar_op.nout], + non_sequences=bcasted_inputs[scalar_op.nout :], + n_steps=scalar_op.n_steps, + sequences=None, + strict=True, + ) + if updates: + raise ValueError("Scalar scan should never return updates") + if scalar_op.nout == 1: + ret = (ret,) + + # Scan output size is given by the size of the input leading dimension, by default its n_steps + 1. + # If we only want to store the last elements we can shorten the leading dimension to 1 + scan_node = ret[0].owner.inputs[0].owner + scan_inputs = scan_node.inputs + n_steps = scan_inputs[0] + n_non_seqs = scan_node.op.info.n_non_seqs + carried_inputs = scan_inputs[1 : len(scan_inputs) - n_non_seqs :] + constant_inputs = scan_inputs[len(scan_inputs) - n_non_seqs :] + new_carried_inputs = [] + for carried_input in carried_inputs: + assert isinstance(carried_input.owner.op, IncSubtensor) + fill_value = carried_input.owner.inputs[1] + # TODO: Check for the global flag where this is controlled + new_carried_inputs.append(expand_empty(fill_value, 1)) + ret = scan_node.op.make_node(n_steps, *new_carried_inputs, *constant_inputs).outputs + + return [r[1] for r in ret] + + +# We want to run this after the scan save mem rewrite, as we already applied it here +optdb.register( + "inline_elemwise_scan", + in2out(inline_elemwise_scan), + "fast_compile", + "fast_run", + position=1.62, +) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 6d1b0ad576..1e26b2c319 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from pytensor.gradient import verify_grad scipy = pytest.importorskip("scipy") @@ -9,11 +10,11 @@ import scipy.special import scipy.stats -from pytensor import function +from pytensor import function, grad from pytensor import tensor as at from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config -from pytensor.tensor import inplace +from pytensor.tensor import inplace, vector, gammaincc from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, @@ -422,6 +423,23 @@ def test_gammainc_ddk_tabulated_values(): ) +def test_gammaincc_ddk_performance(benchmark): + rng = np.random.default_rng(1) + k = vector("k") + x = vector("x") + + out = gammaincc(k, x) + grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN") + vals = [ + # Values that hit the second branch of the gradient + np.full((1000,), 3.2), + np.full((1000,), 0.01), + ] + + verify_grad(gammaincc, vals, rng=rng) + benchmark(grad_fn, *vals) + + TestGammaUBroadcast = makeBroadcastTester( op=at.gammau, expected=expected_gammau,