From e7d248d8e8d501ffaa229bec0aabca8021d515b4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jul 2024 13:13:56 +0200 Subject: [PATCH 1/4] Simplify makeKeepdDims --- pytensor/tensor/math.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index d7c69135ae..461c4bce3c 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -297,32 +297,9 @@ def makeKeepDims(x, y, axis): """ x = as_tensor_variable(x) - y = as_tensor_variable(y) - if axis is None: axis = list(range(x.type.ndim)) - elif isinstance(axis, int | np.integer): - axis = [axis] - elif isinstance(axis, np.ndarray) and axis.ndim == 0: - axis = [int(axis)] - else: - axis = [int(a) for a in axis] - newaxis = [] - for a in axis: - if not isinstance(a, int): - raise ValueError("keepdims option can be used only with constant axis") - if a < 0: - a += x.type.ndim - newaxis.append(a) - i = 0 - new_dims = [] - for j, _ in enumerate(x.type.broadcastable): - if j in newaxis: - new_dims.append("x") - else: - new_dims.append(i) - i += 1 - return DimShuffle(y.type.broadcastable, new_dims)(y) + return expand_dims(y, axis) def check_and_normalize_axes(x, axis): From d00ce0088e55c8e71c3b264c6fb003b99646874c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jul 2024 13:04:50 +0200 Subject: [PATCH 2/4] Move argmax helper close to class definition --- pytensor/tensor/math.py | 44 ++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 461c4bce3c..97de67fa0d 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -277,6 +277,28 @@ def grad(self, inp, grads): return [x.zeros_like()] +def argmax(x, axis=None, keepdims=False): + """ + Returns indices of maximum elements obtained by iterating over given axis. + + When axis is None (the default value), the argmax is performed + over the flattened tensor. + + Parameters + ---------- + keepdims : bool + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, the result + will broadcast correctly against the original tensor. + + """ + argout = max_and_argmax(x, axis)[1] + + if keepdims: + argout = makeKeepDims(x, argout, axis) + return argout + + @_vectorize_node.register(Argmax) def vectorize_argmax_node(op, node, batch_x): core_ndim = node.inputs[0].type.ndim @@ -549,28 +571,6 @@ def max(x, axis=None, keepdims=False): return out -def argmax(x, axis=None, keepdims=False): - """ - Returns indices of maximum elements obtained by iterating over given axis. - - When axis is None (the default value), the argmax is performed - over the flattened tensor. - - Parameters - ---------- - keepdims : bool - If this is set to True, the axes which are reduced are left in - the result as dimensions with size one. With this option, the result - will broadcast correctly against the original tensor. - - """ - argout = max_and_argmax(x, axis)[1] - - if keepdims: - argout = makeKeepDims(x, argout, axis) - return argout - - def min(x, axis=None, keepdims=False): """ Returns minimum elements obtained by iterating over given axis. From 385a7277a45d034772f6d268db8a221dd4b0f29e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jul 2024 12:44:40 +0200 Subject: [PATCH 3/4] Cleanup Max and Argmax --- pytensor/tensor/elemwise.py | 33 +++---- pytensor/tensor/math.py | 169 +++++++++--------------------------- pytensor/tensor/utils.py | 18 ++++ tests/tensor/test_math.py | 58 ++++++------- tests/test_rop.py | 4 +- 5 files changed, 100 insertions(+), 182 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 2fdc8e7fd5..d40a5b9d43 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -30,7 +30,11 @@ float_dtypes, lvector, ) -from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string +from pytensor.tensor.utils import ( + broadcast_static_dim_lengths, + import_func_from_string, + normalize_reduce_axis, +) from pytensor.tensor.variable import TensorVariable from pytensor.utils import uniq @@ -1371,7 +1375,6 @@ def _acc_dtype(self, idtype): def make_node(self, input): input = as_tensor_variable(input) - inp_dims = input.type.ndim inp_dtype = input.type.dtype # We need to redefine make_node so that, if self.dtype is None, @@ -1383,29 +1386,19 @@ def make_node(self, input): assert dtype is not None assert acc_dtype is not None - axis = self.axis + axis = normalize_reduce_axis(self.axis, ndim=input.type.ndim) - # scalar inputs are treated as 1D regarding axis in this `Op` - if axis is not None: - try: - axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims)) - except np.AxisError: - raise np.AxisError(axis, ndim=inp_dims) + if axis != self.axis or dtype != self.dtype or acc_dtype != self.acc_dtype: + op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype) + else: + op = self + if axis is None: + out_shape = () + else: out_shape = tuple( s for i, s in enumerate(input.type.shape) if i not in axis ) - else: - out_shape = () - - if ( - (axis is not None and any(a < 0 for a in axis)) - or dtype != self.dtype - or acc_dtype != self.acc_dtype - ): - op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype) - else: - op = self output = TensorType(dtype=dtype, shape=out_shape)() diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 97de67fa0d..b55adb0312 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -8,7 +8,6 @@ from pytensor import config, printing from pytensor import scalar as ps -from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node @@ -26,9 +25,9 @@ cast, concatenate, constant, + expand_dims, stack, switch, - zeros_like, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( @@ -45,14 +44,11 @@ continuous_dtypes, discrete_dtypes, int_dtypes, - integer_dtypes, tensor, uint_dtypes, ) -from pytensor.tensor.type_other import NoneConst -from pytensor.tensor.utils import as_list +from pytensor.tensor.utils import as_list, normalize_reduce_axis from pytensor.tensor.variable import ( - TensorConstant, TensorVariable, _tensor_py_operators, ) @@ -157,7 +153,7 @@ class Argmax(COp): def __init__(self, axis): if axis is not None: - axis = tuple(axis) + axis = tuple(sorted(axis)) self.axis = axis def get_params(self, node): @@ -168,7 +164,7 @@ def get_params(self, node): c_axis = np.int64(-1) return self.params_type.get_params(c_axis=c_axis) - def make_node(self, x, axis=None): + def make_node(self, x): x = as_tensor_variable(x) if self.axis is None: all_axes = list(range(x.ndim)) @@ -198,7 +194,9 @@ def perform(self, node, inp, outs): # Work around keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") # Not-reduced axes in front - transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) + transposed_x = np.transpose( + x, np.concatenate((keep_axes, np.asarray(axes, dtype="int64"))) + ) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) @@ -214,7 +212,7 @@ def c_code(self, node, name, inp, out, sub): if self.axis is None: axis_code = "axis = NPY_MAXDIMS;" else: - if len(self.axis) > 1: + if len(self.axis) != 1: raise NotImplementedError() # params is only used here for now axis_code = """ @@ -253,7 +251,7 @@ def c_code(self, node, name, inp, out, sub): return ret % locals() def c_code_cache_version(self): - return (1,) + return (2,) def infer_shape(self, fgraph, node, shapes): (ishape,) = shapes @@ -277,7 +275,7 @@ def grad(self, inp, grads): return [x.zeros_like()] -def argmax(x, axis=None, keepdims=False): +def argmax(x: TensorLike, axis=None, keepdims: bool = False): """ Returns indices of maximum elements obtained by iterating over given axis. @@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False): Parameters ---------- + x: TensorLike + Array on which to compute argmax + axis: + Axis along which to compute argmax. Unlike numpy multiple partial axis are supported. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original tensor. + Returns + ------- + TensorVariable + TensorVariable representing the argmax operation + """ - argout = max_and_argmax(x, axis)[1] + x = as_tensor_variable(x) + axis = normalize_reduce_axis(axis, ndim=x.type.ndim) + out = Argmax(axis)(x) if keepdims: - argout = makeKeepDims(x, argout, axis) - return argout + out = makeKeepDims(x, out, axis) + + return out @_vectorize_node.register(Argmax) @@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis): return expand_dims(y, axis) -def check_and_normalize_axes(x, axis): - """Check axes, normalize and convert them to a Python list of integers. - - Parameters - ---------- - x: TensorVariable - axis: int, tuple or list of integers - - Returns - ------- - axis: list of integers - Return an empty list if argument is None. - - """ - x = as_tensor_variable(x) - if axis is None: - axis = [] - elif isinstance(axis, int | np.integer) or ( - isinstance(axis, np.ndarray) and axis.ndim == 0 - ): - axis = [int(axis)] - elif isinstance(axis, tuple | list | np.ndarray): - axis = [int(i) for i in axis] - elif isinstance(axis, Variable): - if NoneConst.equals(axis): - axis = [] - elif not isinstance(axis, TensorConstant): - raise TypeError(f"Computation needs a constant axis. Got {axis}") - else: - assert axis.dtype in integer_dtypes - if isinstance(axis.data, int | np.integer) or ( - isinstance(axis.data, np.ndarray) and axis.data.ndim == 0 - ): - axis = [int(axis.data)] - elif isinstance(axis.data, list | np.ndarray): - axis = [int(i) for i in axis.data] - else: - raise TypeError( - f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}" - ) - if len(axis) > 0: - for i in range(len(axis)): - if axis[i] < 0: - axis[i] += x.type.ndim - if axis[i] < 0 or axis[i] >= x.type.ndim: - raise ValueError( - f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}" - ) - axis = list(set(axis)) - axis.sort() - return axis - - def max_and_argmax(a, axis=None, keepdims=False): """ Returns maximum elements and their indices obtained by iterating over @@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False): """ # Check axis and convert it to a Python list of integers. # Axis will be used as an op param of Max and Argmax. - a = as_tensor_variable(a) - - is_axis_empty = False - if axis == (): - is_axis_empty = True - - axis = check_and_normalize_axes(a, axis) - - if len(axis) == 0 and not is_axis_empty: - axis = None - - out = Max(axis)(a) - - if not is_axis_empty: - argout = Argmax(axis)(a) - else: - argout = zeros_like(a, dtype="int64") - - if keepdims: - out = makeKeepDims(a, out, axis) - argout = makeKeepDims(a, argout, axis) - return [out, argout] + return [ + max(a, axis=axis, keepdims=keepdims), + argmax(a, axis=axis, keepdims=keepdims), + ] class FixedOpCAReduce(CAReduce): @@ -465,7 +404,7 @@ def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) return type(self)(axis=axis) - def grad(self, inp, grads): + def L_op(self, inputs, outputs, grads): # The strict sense mathematical gradient of the maximum function is # not calculated here for it is not defined at every point where some # coordinates are identical. However, since the latter set has null @@ -479,53 +418,27 @@ def grad(self, inp, grads): # g_max has one less dimension than x, so you need to complete # g_max to x's shape when axis=0 the broadcasting mechanism # does it automatically - x = inp[0] - if self.axis is None: - self.axis = tuple(range(x.ndim)) - axis = as_tensor_variable(self.axis) - (g_max,) = grads - - g_max_disconnected = isinstance(g_max.type, DisconnectedType) + [x] = inputs + [out] = outputs + [g_out] = grads - # if the op is totally disconnected, so are its inputs - if g_max_disconnected: - return [DisconnectedType()()] - - # if NoneConst.equals(axis): - if axis is None: - axis_ = list(range(x.ndim)) - else: - axis_ = axis - xmax = max(x, axis_) - - # Raise the g_max and xmax to the same number of dim as the input. - pattern = [] - out_dim = 0 - if NoneConst.equals(axis): - # We are taking the max/argmax over all dimensions. - axis = None - for i in range(x.ndim): - if axis is None or i in axis.data: - pattern.append("x") - else: - pattern.append(out_dim) - out_dim += 1 - g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max) - xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax) + axis = tuple(range(x.ndim)) if self.axis is None else self.axis + out_pad = expand_dims(out, axis) + g_out_pad = expand_dims(g_out, axis) # Set the grad to the correct position. - g_x = eq(xmax_pad, x) * g_max_pad + g_x = eq(out_pad, x) * g_out_pad return (g_x,) def R_op(self, inputs, eval_points): if eval_points[0] is None: return [None, None] if len(self.axis) != 1: - raise ValueError("R_op supported for arg_max only for one axis!") + raise ValueError("R_op supported for max only for one axis!") if self.axis[0] > 1: - raise ValueError("R_op supported for arg_max only when axis is 0 or 1") + raise ValueError("R_op supported for max only when axis is 0 or 1") if inputs[0].ndim != 2: - raise ValueError("R_op supported for arg_max only when input is a matrix") + raise ValueError("R_op supported for max only when input is a matrix") max_pos = Argmax(self.axis).make_node(*inputs).outputs # print(eval_points[0].eval()) if self.axis[0] == 0: @@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False): We return an error as numpy when we reduce a dim with a shape of 0. """ - out = max_and_argmax(x, axis)[0] + out = Max(axis=axis)(x) if keepdims: out = makeKeepDims(x, out, axis) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index b8ae1e780b..60ae8ebed8 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,7 +1,9 @@ import re from collections.abc import Sequence +from typing import cast import numpy as np +from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.utils import hash_from_code @@ -223,3 +225,19 @@ def operand_sig(operand_ndim: int, prefix: str) -> str: operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim) ) return f"{inputs_sig}->{outputs_sig}" + + +def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: + """Normalize the axis parameter for reduce operations.""" + if axis is None: + return None + + # scalar inputs are treated as 1D regarding axis in reduce operations + if axis is not None: + try: + axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) + except np.AxisError: + raise np.AxisError(axis, ndim=ndim) + + # TODO: If axis tuple is equivalent to None, return None for more canonicalization? + return cast(tuple, axis) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index b66599e3ca..e86bd4ec17 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -154,7 +154,6 @@ vectors, zvector, ) -from pytensor.tensor.type_other import NoneConst from tests import unittest_tools as utt from tests.link.test_link import make_function from tests.tensor.utils import ( @@ -767,9 +766,10 @@ def setup_method(self): Max.debug = 0 Argmax.debug = 0 - def test_basic(self): + @pytest.mark.parametrize("empty_axis", [(), None]) + def test_empty_axis_scalar(self, empty_axis): n = as_tensor_variable(5) - v, i = eval_outputs(max_and_argmax(n, axis=())) + v, i = eval_outputs(max_and_argmax(n, axis=empty_axis)) assert v == 5.0 assert i == 0 assert i.dtype == "int64" @@ -778,6 +778,29 @@ def test_basic(self): v = eval_outputs(max_and_argmax(n)[1].shape) assert len(v) == 0 + def test_empty_axis_tensor(self): + x = np.random.normal(size=(2, 3, 5, 7)) + axis = () + + non_axis = tuple(i for i in range(x.ndim) if i not in axis) + shape_axis = tuple(x.shape[dim] for dim in axis) + shape_non_axis = tuple(x.shape[dim] for dim in non_axis) + x_transposed = x.transpose(*axis, *non_axis) + + x_axis_raveled = x_transposed.reshape( + np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int) + ) + max_x = max_and_argmax(x, axis=axis)[0].eval() + argmax_x = max_and_argmax(x, axis=axis)[1].eval() + + raveled_max = x_axis_raveled[ + argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int)) + ] + indirect_max = raveled_max.reshape(shape_non_axis) + + np.testing.assert_allclose(max_x, x.max(axis=axis)) + np.testing.assert_allclose(indirect_max, x.max(axis=axis)) + def test_basic_1(self): n = as_tensor_variable([1, 2, 3, 2, -6]) v, i = eval_outputs(max_and_argmax(n)) @@ -796,8 +819,6 @@ def test_basic_1(self): (None, None), ([0, 1], None), ([1, 0], None), - (NoneConst.clone(), None), - (constant(0), 0), ], ) def test_basic_2(self, axis, np_axis): @@ -826,8 +847,6 @@ def test_basic_2(self, axis, np_axis): (None, None), ([0, 1], None), ([1, 0], None), - (NoneConst.clone(), None), - (constant(0), 0), ], ) def test_basic_2_float16(self, axis, np_axis): @@ -986,7 +1005,7 @@ def check_grad_max(data, max_grad_data, axis=None): safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) # Test grad with multiple axes - for i in [[0, 1], [0, 0]]: + for i in [[0, 1], [0, 2, 3]]: safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data]) @@ -1043,29 +1062,6 @@ def test_vectorize(self, core_axis, batch_axis): assert isinstance(new_node.op, Argmax) assert new_node.op.axis == batch_axis - def test_max_empty_axis(self): - x = np.random.normal(size=(2, 3, 5, 7)) - axis = () - - non_axis = tuple(i for i in range(x.ndim) if i not in axis) - shape_axis = tuple(x.shape[dim] for dim in axis) - shape_non_axis = tuple(x.shape[dim] for dim in non_axis) - x_transposed = x.transpose(*axis, *non_axis) - - x_axis_raveled = x_transposed.reshape( - np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int) - ) - max_x = max_and_argmax(x, axis=axis)[0].eval() - argmax_x = max_and_argmax(x, axis=axis)[1].eval() - - raveled_max = x_axis_raveled[ - argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int)) - ] - indirect_max = raveled_max.reshape(shape_non_axis) - - np.testing.assert_allclose(max_x, x.max(axis=axis)) - np.testing.assert_allclose(indirect_max, x.max(axis=axis)) - class TestArgminArgmax: def setup_method(self): diff --git a/tests/test_rop.py b/tests/test_rop.py index d8fc78a51b..0b9fe41a1e 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -192,9 +192,7 @@ def check_rop_lop(self, y, out_shape): class TestRopLop(RopLopChecker): def test_max(self): - # If we call max directly, we will return an CAReduce object - # which doesn't have R_op implemented! - # self.check_mat_rop_lop(at_max(self.mx, axis=[0,1])[0], ()) + # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ()) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],)) From 6cb0302755fd67cbecec1bcd674aaddeb35822b8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jul 2024 15:44:41 +0200 Subject: [PATCH 4/4] Git rid of new `"output"` uses Useless since 9ba6d99fb7a8a24c5c57bd9b2f266c7ff346a0d7 --- pytensor/tensor/rewriting/linalg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 6e34c27d43..1e7d16a612 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -509,8 +509,6 @@ def svd_uv_merge(fgraph, node): # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. # First, iterate to see if there is an SVD Op that can be reused. for cl, _ in fgraph.clients[x]: - if cl == "output": - continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): if not cl.op.core_op.compute_uv: return { @@ -529,8 +527,6 @@ def svd_uv_merge(fgraph, node): # We want rewrite if there is another one with compute_uv=True. # For this case, just reuse the `s` from the one with compute_uv=True. for cl, _ in fgraph.clients[x]: - if cl == "output": - continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): if cl.op.core_op.compute_uv and ( len(fgraph.clients[cl.outputs[0]]) > 0