diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73139a4d58..fea9ad4031 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: sphinx-lint args: ["."] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.3 + rev: v0.9.7 hooks: - id: ruff args: ["--fix", "--output-format=full"] diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 37acfc8e86..fc19bddc5a 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -847,7 +847,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # for each input: # same as range(ndim), but with 'x' at all broadcastable positions orders = [ - [s == 1 and "x" or i for i, s in enumerate(input.type.shape)] + [(s == 1 and "x") or i for i, s in enumerate(input.type.shape)] for input in inputs ] @@ -1671,8 +1671,10 @@ def construct(symbol): scalar_op = getattr(scalar, symbolname) rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) + # Set the docstring to be just the original function's docstring + # without appending the generic Elemwise docstring if getattr(symbol, "__doc__"): - rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__ + rval.__doc__ = symbol.__doc__ # for the meaning of this see the ./epydoc script # it makes epydoc display rval as if it were a function, not an object diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 2aa6ad2381..a39fb22a26 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -622,7 +622,36 @@ def ge(a, b): @scalar_elemwise def eq(a, b): - """a == b""" + """Element-wise equality comparison. + + Computes a tensor of 0s and 1s where 1 indicates the corresponding + elements of a and b are equal. + + Parameters + ---------- + a, b : tensor_like + Input tensors of same shape. + + Returns + ------- + tensor + Boolean tensor of same shape as inputs, with 1s where corresponding + elements are equal, 0s otherwise. + + Examples + -------- + >>> import pytensor.tensor as pt + >>> import numpy as np + >>> x = pt.vector() + >>> y = pt.vector() + >>> f = pt.function([x, y], pt.eq(x, y)) + >>> f([1, 2, 3], [1, 4, 3]) + array([ True, False, True]) + + Notes + ----- + This function supports the Python syntax `a == b` when used with PyTensor tensors. + """ @scalar_elemwise @@ -850,7 +879,29 @@ def abs(a): @scalar_elemwise def exp(a): - """e^`a`""" + """Exponential function (e^a). + + Computes the element-wise exponential of a tensor. + + Parameters + ---------- + a : tensor_like + Input tensor + + Returns + ------- + tensor + Output tensor with the exponential of each element in `a`. + + Examples + -------- + >>> import pytensor.tensor as pt + >>> import numpy as np + >>> x = pt.vector() + >>> f = pt.function([x], pt.exp(x)) + >>> f([0.0, 1.0, 2.0]) + array([1. , 2.7182817 , 7.389056 ], dtype=float32) + """ @scalar_elemwise @@ -875,7 +926,34 @@ def reciprocal(a): @scalar_elemwise def log(a): - """base e logarithm of a""" + """Natural logarithm (base e). + + Computes the element-wise natural logarithm of a tensor. + + Parameters + ---------- + a : tensor_like + Input tensor. Should contain only positive values. + + Returns + ------- + tensor + Output tensor with the natural logarithm of each element in `a`. + + Examples + -------- + >>> import pytensor.tensor as pt + >>> import numpy as np + >>> x = pt.vector() + >>> f = pt.function([x], pt.log(x)) + >>> f([1.0, 2.7182817, 7.389056]) + array([0. , 1. , 1.9999998], dtype=float32) + + Notes + ----- + For negative or zero values, this function will output NaN. + Consider using log1p(x-1) for values close to 1 to avoid numerical precision issues. + """ @scalar_elemwise @@ -1956,8 +2034,7 @@ def _tensordot_as_dot(a, b, axes, dot, batched): if not np.isscalar(axes) and len(axes) != 2: raise ValueError( - "Axes should be an integer or a " - f"list/tuple of len 2 ({axes} was provided)" + f"Axes should be an integer or a list/tuple of len 2 ({axes} was provided)" ) # if 'axes' is a number of axes to multiply and sum over (trailing axes @@ -2934,150 +3011,150 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): not_equal = neq __all__ = [ - "max_and_argmax", - "max", - "matmul", - "argmax", - "min", - "argmin", - "smallest", - "largest", - "lt", - "less", - "gt", - "greater", - "le", - "less_equal", - "ge", - "greater_equal", - "eq", - "equal", - "neq", - "not_equal", - "isnan", - "isinf", - "isposinf", - "isneginf", + "abs", + "add", + "all", "allclose", - "isclose", "and_", + "angle", + "any", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctan2", + "arctanh", + "argmax", + "argmin", + "betainc", + "betaincinv", "bitwise_and", - "or_", + "bitwise_not", "bitwise_or", - "xor", "bitwise_xor", - "invert", - "bitwise_not", - "abs", - "exp", - "exp2", - "expm1", - "neg", - "reciprocal", - "log", - "log2", - "log10", - "log1p", - "sgn", - "sign", "ceil", - "floor", - "trunc", - "iround", - "round", - "round_half_to_even", - "round_half_away_from_zero", - "sqr", - "square", - "cov", - "sqrt", - "deg2rad", - "rad2deg", + "ceil_intdiv", + "chi2sf", + "clip", + "complex", + "complex_from_polar", + "conj", + "conjugate", "cos", - "arccos", - "sin", - "arcsin", - "tan", - "arctan", - "arctan2", "cosh", - "arccosh", - "sinh", - "arcsinh", - "tanh", - "arctanh", + "cov", + "deg2rad", + "dense_dot", + "digamma", + "divmod", + "dot", + "eq", + "equal", "erf", "erfc", + "erfcinv", "erfcx", "erfinv", - "erfcinv", - "owens_t", + "exp", + "exp2", + "expit", + "expm1", + "floor", + "floor_div", "gamma", - "gammaln", - "psi", - "digamma", - "tri_gamma", - "polygamma", - "chi2sf", "gammainc", "gammaincc", - "gammau", - "gammal", - "gammaincinv", "gammainccinv", - "j0", - "j1", - "jv", + "gammaincinv", + "gammal", + "gammaln", + "gammau", + "ge", + "greater", + "greater_equal", + "gt", + "hyp2f1", "i0", "i1", + "imag", + "int_div", + "invert", + "iround", + "isclose", + "isinf", + "isnan", + "isneginf", + "isposinf", "iv", "ive", + "j0", + "j1", + "jv", "kv", "kve", - "sigmoid", - "expit", - "softplus", - "log1pexp", + "largest", + "le", + "less", + "less_equal", + "log", "log1mexp", - "betainc", - "betaincinv", - "real", - "imag", - "angle", - "complex", - "conj", - "conjugate", - "complex_from_polar", - "sum", - "prod", + "log1p", + "log1pexp", + "log2", + "log10", + "logaddexp", + "logsumexp", + "lt", + "matmul", + "max", + "max_and_argmax", + "maximum", "mean", "median", - "var", - "std", - "std", - "maximum", + "min", "minimum", - "divmod", - "add", - "sub", - "mul", - "true_div", - "int_div", - "floor_div", - "ceil_intdiv", "mod", - "pow", - "clip", - "dot", - "dense_dot", - "tensordot", + "mul", + "nan_to_num", + "neg", + "neq", + "not_equal", + "or_", "outer", - "any", - "all", - "ptp", + "owens_t", + "polygamma", + "pow", "power", - "logaddexp", - "logsumexp", - "hyp2f1", - "nan_to_num", + "prod", + "psi", + "ptp", + "rad2deg", + "real", + "reciprocal", + "round", + "round_half_away_from_zero", + "round_half_to_even", + "sgn", + "sigmoid", + "sign", + "sin", + "sinh", + "smallest", + "softplus", + "sqr", + "sqrt", + "square", + "std", + "std", + "sub", + "sum", + "tan", + "tanh", + "tensordot", + "tri_gamma", + "true_div", + "trunc", + "var", + "xor", ]