diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 24942b6b2c..3c925ac2f2 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -24,6 +24,7 @@ # pytensor code, since this code may want to log some messages. import logging import sys +import warnings from functools import singledispatch from pathlib import Path from typing import Any, NoReturn, Optional @@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v): If `v` is not some view of constant data, then raise a `NotScalarConstantError`. """ - # Is it necessary to test for presence of pytensor.sparse at runtime? - sparse = globals().get("sparse") - if sparse and isinstance(v.type, sparse.SparseTensorType): - if v.owner is not None and isinstance(v.owner.op, sparse.CSM): - data = v.owner.inputs[0] - return tensor.get_underlying_scalar_constant_value(data) - return tensor.get_underlying_scalar_constant_value(v) + warnings.warn( + "get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.", + FutureWarning, + ) + from pytensor.tensor.basic import get_underlying_scalar_constant_value + + return get_underlying_scalar_constant_value(v) # isort: off diff --git a/pytensor/gradient.py b/pytensor/gradient.py index dcf0b7427d..13ca943383 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1329,7 +1329,7 @@ def try_to_copy_if_needed(var): f" {i}. Since this input is only connected " "to integer-valued outputs, it should " "evaluate to zeros, but it evaluates to" - f"{pytensor.get_underlying_scalar_constant(term)}." + f"{pytensor.get_underlying_scalar_constant_value(term)}." ) raise ValueError(msg) @@ -2157,6 +2157,9 @@ def _is_zero(x): 'maybe' means that x is an expression that is complicated enough that we can't tell that it simplifies to 0. """ + from pytensor.tensor import get_underlying_scalar_constant_value + from pytensor.tensor.exceptions import NotScalarConstantError + if not hasattr(x, "type"): return np.all(x == 0.0) if isinstance(x.type, NullType): @@ -2166,9 +2169,9 @@ def _is_zero(x): no_constant_value = True try: - constant_value = pytensor.get_underlying_scalar_constant(x) + constant_value = get_underlying_scalar_constant_value(x) no_constant_value = False - except pytensor.tensor.exceptions.NotScalarConstantError: + except NotScalarConstantError: pass if no_constant_value: diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 9cd9870616..2956afad02 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -18,7 +18,7 @@ Split, TensorFromScalar, Tri, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i @@ -103,7 +103,7 @@ def join(axis, *tensors): def jax_funcify_Split(op: Split, node, **kwargs): _, axis, splits = node.inputs try: - constant_axis = get_underlying_scalar_constant_value(axis) + constant_axis = get_scalar_constant_value(axis) except NotScalarConstantError: constant_axis = None warnings.warn( @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( [ - get_underlying_scalar_constant_value(splits[i]) + get_scalar_constant_value(splits[i]) for i in range(get_vector_length(splits)) ] ) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 8b92e60085..dcae273aef 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -484,7 +484,7 @@ def wrap_into_list(x): n_fixed_steps = int(n_steps) else: try: - n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps) + n_fixed_steps = pt.get_scalar_constant_value(n_steps) except NotScalarConstantError: n_fixed_steps = None diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 07480c43c5..2ba282d8d6 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -54,7 +54,7 @@ from pytensor.tensor.basic import ( Alloc, AllocEmpty, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -71,7 +71,7 @@ get_slice_elements, set_subtensor, ) -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import TensorConstant list_opt_slice = [ @@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): all_ins = list(graph_inputs(op_outs)) for idx in range(op_info.n_seqs): node_inp = node.inputs[idx + 1] - if ( - isinstance(node_inp, TensorConstant) - and get_unique_constant_value(node_inp) is not None - ): + if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None: try: # This works if input is a constant that has all entries # equal @@ -668,8 +665,10 @@ def inner_sitsot_only_last_step_used( client = fgraph.clients[outer_var][0][0] if isinstance(client, Apply) and isinstance(client.op, Subtensor): lst = get_idx_list(client.inputs, client.op.idx_list) - if len(lst) == 1 and pt.extract_constant(lst[0]) == -1: - return True + return ( + len(lst) == 1 + and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1 + ) return False @@ -1344,10 +1343,17 @@ def scan_save_mem(fgraph, node): if isinstance(this_slice[0], slice) and this_slice[0].stop is None: global_nsteps = None if isinstance(cf_slice[0], slice): - stop = pt.extract_constant(cf_slice[0].stop) + stop = get_scalar_constant_value( + cf_slice[0].stop, raise_not_constant=False + ) else: - stop = pt.extract_constant(cf_slice[0]) + 1 - if stop == maxsize or stop == pt.extract_constant(length): + stop = ( + get_scalar_constant_value(cf_slice[0], raise_not_constant=False) + + 1 + ) + if stop == maxsize or stop == get_scalar_constant_value( + length, raise_not_constant=False + ): stop = None else: # there is a **gotcha** here ! Namely, scan returns an @@ -1451,9 +1457,13 @@ def scan_save_mem(fgraph, node): cf_slice = get_canonical_form_slice(this_slice[0], length) if isinstance(cf_slice[0], slice): - start = pt.extract_constant(cf_slice[0].start) + start = pt.get_scalar_constant_value( + cf_slice[0].start, raise_not_constant=False + ) else: - start = pt.extract_constant(cf_slice[0]) + start = pt.get_scalar_constant_value( + cf_slice[0], raise_not_constant=False + ) if start == 0 or store_steps[i] == 0: store_steps[i] = 0 @@ -1628,7 +1638,7 @@ def scan_save_mem(fgraph, node): # 3.6 Compose the new scan # TODO: currently we don't support scan with 0 step. So # don't create one. - if pt.extract_constant(node_ins[0]) == 0: + if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0: return False # Do not call make_node for test_value @@ -1965,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes): nsteps = node.inputs[0] try: - nsteps = int(get_underlying_scalar_constant_value(nsteps)) + nsteps = int(get_scalar_constant_value(nsteps)) except NotScalarConstantError: pass rep_nsteps = rep_node.inputs[0] try: - rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) + rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) except NotScalarConstantError: pass diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index e3aa0b96b2..c590bc804a 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -491,6 +491,10 @@ def __str__(self): def __repr__(self): return str(self) + @property + def unique_value(self): + return None + SparseTensorType.variable_type = SparseVariable SparseTensorType.constant_type = SparseConstant diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index cd874a2cc6..401642ddb9 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -19,7 +19,7 @@ import pytensor import pytensor.scalar.sharedvar -from pytensor import compile, config, printing +from pytensor import config, printing from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType, grad_undefined @@ -35,7 +35,7 @@ from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise, assert_op from pytensor.scalar import int32 -from pytensor.scalar.basic import ScalarConstant, ScalarVariable +from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.tensor import ( _as_tensor_variable, _get_vector_length, @@ -71,10 +71,10 @@ uint_dtypes, values_eq_approx_always_true, ) +from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import ( TensorConstant, TensorVariable, - get_unique_constant_value, ) @@ -268,27 +268,7 @@ def _obj_is_wrappable_as_tensor(x): ) -def get_scalar_constant_value( - v, elemwise=True, only_process_constants=False, max_recur=10 -): - """ - Checks whether 'v' is a scalar (ndim = 0). - - If 'v' is a scalar then this function fetches the underlying constant by calling - 'get_underlying_scalar_constant_value()'. - - If 'v' is not a scalar, it raises a NotScalarConstantError. - - """ - if isinstance(v, Variable | np.ndarray): - if v.ndim != 0: - raise NotScalarConstantError() - return get_underlying_scalar_constant_value( - v, elemwise, only_process_constants, max_recur - ) - - -def get_underlying_scalar_constant_value( +def _get_underlying_scalar_constant_value( orig_v, elemwise=True, only_process_constants=False, max_recur=10 ): """Return the constant scalar(0-D) value underlying variable `v`. @@ -319,6 +299,10 @@ def get_underlying_scalar_constant_value( but I'm not sure where it is. """ + from pytensor.compile.ops import DeepCopyOp, OutputGuard + from pytensor.sparse import CSM + from pytensor.tensor.subtensor import Subtensor + v = orig_v while True: if v is None: @@ -336,40 +320,28 @@ def get_underlying_scalar_constant_value( raise NotScalarConstantError() if isinstance(v, Constant): - unique_value = get_unique_constant_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data + if isinstance(v.type, TensorType) and v.unique_value is not None: + return v.unique_value - if isinstance(data, np.ndarray): - try: - return np.array(data.item(), dtype=v.dtype) - except ValueError: - raise NotScalarConstantError() + elif isinstance(v.type, ScalarType): + return v.data - from pytensor.sparse.type import SparseTensorType + elif isinstance(v.type, NoneTypeT): + return None - if isinstance(v.type, SparseTensorType): - raise NotScalarConstantError() - - return data + raise NotScalarConstantError() if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: + op = v.owner.op max_recur -= 1 if isinstance( - v.owner.op, - Alloc - | DimShuffle - | Unbroadcast - | compile.ops.OutputGuard - | compile.DeepCopyOp, + op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ): # OutputGuard is only used in debugmode but we # keep it here to avoid problems with old pickles v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, Shape_i): + elif isinstance(op, Shape_i): i = v.owner.op.i inp = v.owner.inputs[0] if isinstance(inp, Constant): @@ -383,19 +355,19 @@ def get_underlying_scalar_constant_value( # mess with the stabilization optimization and be too slow. # We put all the scalar Ops used by get_canonical_form_slice() # to allow it to determine the broadcast pattern correctly. - elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar): + elif isinstance(op, ScalarFromTensor | TensorFromScalar): v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, CheckAndRaise): + elif isinstance(op, CheckAndRaise): # check if all conditions are constant and true conds = [ - get_underlying_scalar_constant_value(c, max_recur=max_recur) + _get_underlying_scalar_constant_value(c, max_recur=max_recur) for c in v.owner.inputs[1:] ] if builtins.all(0 == c.ndim and c != 0 for c in conds): v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, ps.ScalarOp): + elif isinstance(op, ps.ScalarOp): if isinstance(v.owner.op, ps.Second): # We don't need both input to be constant for second shp, val = v.owner.inputs @@ -403,7 +375,7 @@ def get_underlying_scalar_constant_value( continue if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): const = [ - get_underlying_scalar_constant_value(i, max_recur=max_recur) + _get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -412,7 +384,7 @@ def get_underlying_scalar_constant_value( # In fast_compile, we don't enable local_fill_to_alloc, so # we need to investigate Second as Alloc. So elemwise # don't disable the check for Second. - elif isinstance(v.owner.op, Elemwise): + elif isinstance(op, Elemwise): if isinstance(v.owner.op.scalar_op, ps.Second): # We don't need both input to be constant for second shp, val = v.owner.inputs @@ -422,16 +394,13 @@ def get_underlying_scalar_constant_value( v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops ): const = [ - get_underlying_scalar_constant_value(i, max_recur=max_recur) + _get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] v.owner.op.perform(v.owner, const, ret) return np.asarray(ret[0][0].copy()) - elif ( - isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor) - and v.ndim == 0 - ): + elif isinstance(op, Subtensor) and v.ndim == 0: if isinstance(v.owner.inputs[0], TensorConstant): from pytensor.tensor.subtensor import get_constant_idx @@ -468,7 +437,7 @@ def get_underlying_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) try: @@ -502,14 +471,13 @@ def get_underlying_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) - # Python 2.4 does not support indexing with numpy.integer - # So we cast it. - idx = int(idx) ret = v.owner.inputs[0].owner.inputs[idx] - ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur) + ret = _get_underlying_scalar_constant_value( + ret, max_recur=max_recur + ) # MakeVector can cast implicitly its input in some case. return np.asarray(ret, dtype=v.type.dtype) @@ -524,7 +492,7 @@ def get_underlying_scalar_constant_value( idx_list = op.idx_list idx = idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) grandparent = leftmost_parent.owner.inputs[0] @@ -534,7 +502,9 @@ def get_underlying_scalar_constant_value( grandparent.owner.op, Unbroadcast ): ggp_shape = grandparent.owner.inputs[0].type.shape - l = [get_underlying_scalar_constant_value(s) for s in ggp_shape] + l = [ + _get_underlying_scalar_constant_value(s) for s in ggp_shape + ] gp_shape = tuple(l) if not (idx < ndim): @@ -555,10 +525,105 @@ def get_underlying_scalar_constant_value( if isinstance(grandparent, Constant): return np.asarray(np.shape(grandparent.data)[idx]) + elif isinstance(op, CSM): + data = _get_underlying_scalar_constant_value( + v.owner.inputs, elemwise=elemwise, max_recur=max_recur + ) + # Sparse variable can only be constant if zero (or I guess if homogeneously dense) + if data == 0: + return data + break raise NotScalarConstantError() +def get_underlying_scalar_constant_value( + v, + *, + elemwise=True, + only_process_constants=False, + max_recur=10, + raise_not_constant=True, +): + """Return the unique constant scalar(0-D) value underlying variable `v`. + + If `v` is the output of dimshuffles, fills, allocs, etc, + cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise + and some pattern with Subtensor, this function digs through them. + + If `v` is not some view of constant scalar data, then raise a + NotScalarConstantError. + + This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by + constant folding the inputs of `v`. + + Parameters + ---------- + v: Variable + elemwise : bool + If False, we won't try to go into elemwise. So this call is faster. + But we still investigate in Second Elemwise (as this is a substitute + for Alloc) + only_process_constants : bool + If True, we only attempt to obtain the value of `orig_v` if it's + directly constant and don't try to dig through dimshuffles, fills, + allocs, and other to figure out its value. + max_recur : int + The maximum number of recursion. + raise_not_constant: bool, default True + If True, raise a NotScalarConstantError if `v` does not have an + underlying constant scalar value. If False, return `v` as is. + + + Raises + ------ + NotScalarConstantError + `v` does not have an underlying constant scalar value. + Only rasise if raise_not_constant is True. + + """ + try: + return _get_underlying_scalar_constant_value( + v, + elemwise=elemwise, + only_process_constants=only_process_constants, + max_recur=max_recur, + ) + except NotScalarConstantError: + if raise_not_constant: + raise + return v + + +def get_scalar_constant_value( + v, + elemwise=True, + only_process_constants=False, + max_recur=10, + raise_not_constant: bool = True, +): + """ + Checks whether 'v' is a scalar (ndim = 0). + + If 'v' is a scalar then this function fetches the underlying constant by calling + 'get_underlying_scalar_constant_value()'. + + If 'v' is not a scalar, it raises a NotScalarConstantError. + + """ + if isinstance(v, TensorVariable | np.ndarray): + if v.ndim != 0: + print(v, v.ndim) + raise NotScalarConstantError("Input ndim != 0") + return get_underlying_scalar_constant_value( + v, + elemwise=elemwise, + only_process_constants=only_process_constants, + max_recur=max_recur, + raise_not_constant=raise_not_constant, + ) + + class TensorFromScalar(COp): __props__ = () @@ -1743,7 +1808,7 @@ def do_constant_folding(self, fgraph, node): @_get_vector_length.register(Alloc) def _get_vector_length_Alloc(var_inst, var): try: - return get_underlying_scalar_constant_value(var.owner.inputs[1]) + return get_scalar_constant_value(var.owner.inputs[1]) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -2015,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False): ScalarVariable, we convert it to a tensor with tensor_from_scalar. """ - try: - x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) - except NotScalarConstantError: - pass - if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable): - if x.owner and isinstance(x.owner.op, ScalarFromTensor): - x = x.owner.inputs[0] - else: - x = tensor_from_scalar(x) - return x + warnings.warn( + "extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`", + FutureWarning, + ) + return get_underlying_scalar_constant_value( + x, + elemwise=elemwise, + only_process_constants=only_process_constants, + raise_not_constant=False, + ) def transpose(x, axes=None): @@ -2444,7 +2509,7 @@ def make_node(self, axis, *tensors): if not isinstance(axis, int): try: - axis = int(get_underlying_scalar_constant_value(axis)) + axis = int(get_scalar_constant_value(axis)) except NotScalarConstantError: pass @@ -2688,7 +2753,7 @@ def infer_shape(self, fgraph, node, ishapes): def _get_vector_length_Join(op, var): axis, *arrays = var.owner.inputs try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays) except NotScalarConstantError: @@ -4081,7 +4146,7 @@ def make_node(self, a, choices): static_out_shape = () for s in out_shape: try: - s_val = pytensor.get_underlying_scalar_constant(s) + s_val = get_scalar_constant_value(s) except (NotScalarConstantError, AttributeError): s_val = None @@ -4404,7 +4469,6 @@ def ix_(*args): "split", "transpose", "matrix_transpose", - "extract_constant", "default", "tensor_copy", "transfer", diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 0addd2b5f0..d1dfe44b90 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -25,7 +25,7 @@ from pytensor.raise_op import Assert from pytensor.tensor.basic import ( as_tensor_variable, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -497,8 +497,8 @@ def check_dim(given, computed): if given is None or computed is None: return True try: - given = get_underlying_scalar_constant_value(given) - computed = get_underlying_scalar_constant_value(computed) + given = get_scalar_constant_value(given) + computed = get_scalar_constant_value(computed) return int(given) == int(computed) except NotScalarConstantError: # no answer possible, accept for now @@ -534,7 +534,7 @@ def assert_conv_shape(shape): out_shape = [] for i, n in enumerate(shape): try: - const_n = get_underlying_scalar_constant_value(n) + const_n = get_scalar_constant_value(n) if i < 2: if const_n < 0: raise ValueError( @@ -2203,9 +2203,7 @@ def __init__( if imshp_i is not None: # Components of imshp should be constant or ints try: - get_underlying_scalar_constant_value( - imshp_i, only_process_constants=True - ) + get_scalar_constant_value(imshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "imshp should be None or a tuple of constant int values" @@ -2218,9 +2216,7 @@ def __init__( if kshp_i is not None: # Components of kshp should be constant or ints try: - get_underlying_scalar_constant_value( - kshp_i, only_process_constants=True - ) + get_scalar_constant_value(kshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "kshp should be None or a tuple of constant int values" diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 9fc6683200..fedcd32ab9 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -678,7 +678,7 @@ def make_node(self, x, repeats): out_shape = [None] else: try: - const_reps = ptb.get_underlying_scalar_constant_value(repeats) + const_reps = ptb.get_scalar_constant_value(repeats) except NotScalarConstantError: const_reps = None if const_reps == 1: diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 8e67d711eb..59148fae3b 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,7 +30,7 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Constant, Variable +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -55,9 +55,8 @@ as_tensor_variable, atleast_Nd, cast, - extract_constant, fill, - get_underlying_scalar_constant_value, + get_scalar_constant_value, join, ones_like, register_infer_shape, @@ -478,7 +477,12 @@ def local_alloc_sink_dimshuffle(fgraph, node): output_shape = node.inputs[1:] num_dims_with_size_1_added_to_left = 0 for i in range(len(output_shape) - inp.ndim): - if extract_constant(output_shape[i], only_process_constants=True) == 1: + if ( + get_scalar_constant_value( + output_shape[i], only_process_constants=True, raise_not_constant=False + ) + == 1 + ): num_dims_with_size_1_added_to_left += 1 else: break @@ -538,93 +542,90 @@ def local_useless_elemwise(fgraph, node): xor(x, x) -> zeros_like(x) TODO: This implementation is painfully redundant. + TODO: Allow rewrite when useless input broadcasts output """ - if isinstance(node.op, Elemwise): - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype - - if node.op.scalar_op == ps.eq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be true - ret = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif node.op.scalar_op == ps.neq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be false - ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - - elif node.op.scalar_op == ps.mul and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - - elif node.op.scalar_op == ps.add and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - elif node.op.scalar_op == ps.identity and len(node.inputs) == 1: - return [node.inputs[0]] - - elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[1].astype(node.outputs[0].dtype)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[0].astype(node.outputs[0].dtype)] - - elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[1].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[1], dtype=dtype, opt=True)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[0].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[0], dtype=dtype, opt=True)] - - elif isinstance(node.op.scalar_op, ps.XOR) and len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + out_bcast = node.outputs[0].type.broadcastable + dtype = node.outputs[0].type.dtype + scalar_op = node.op.scalar_op + + if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + # it is the same var in the graph. That will always be true + ret = ones_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + # it is the same var in the graph. That will always be false + ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + + elif ( + isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity) + and len(node.inputs) == 1 + ): + # No need to copy over any stack trace + return [node.inputs[0]] + + elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: + if ( + isinstance(node.inputs[0], TensorConstant) + and node.inputs[1].type.broadcastable == out_bcast + ): + const_val = node.inputs[0].unique_value + if const_val is not None: + if const_val == 0: + return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[1].astype(node.outputs[0].dtype)] + + if ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[0].type.broadcastable == out_bcast + ): + const_val = node.inputs[1].unique_value + if const_val is not None: + if const_val == 0: + return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[0].astype(node.outputs[0].dtype)] + + elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: + if ( + isinstance(node.inputs[0], TensorConstant) + and node.inputs[1].type.broadcastable == out_bcast + ): + const_val = node.inputs[0].unique_value + if const_val is not None: + if const_val == 0: + return [node.inputs[1].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[1], dtype=dtype, opt=True)] + + if ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[0].type.broadcastable == out_bcast + ): + const_val = node.inputs[1].unique_value + if const_val is not None: + if const_val == 0: + return [node.inputs[0].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[0], dtype=dtype, opt=True)] @register_specialize @@ -737,7 +738,7 @@ def local_remove_useless_assert(fgraph, node): n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: - const = get_underlying_scalar_constant_value(c) + const = get_scalar_constant_value(c) if 0 != const.ndim or const == 0: # Should we raise an error here? How to be sure it @@ -832,7 +833,7 @@ def local_join_empty(fgraph, node): return new_inputs = [] try: - join_idx = get_underlying_scalar_constant_value( + join_idx = get_scalar_constant_value( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: @@ -988,13 +989,10 @@ def local_useless_switch(fgraph, node): left = node.inputs[1] right = node.inputs[2] cond_var = node.inputs[0] - cond = extract_constant(cond_var, only_process_constants=True) out_bcast = node.outputs[0].type.broadcastable - if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( - cond, np.number | np.bool_ - ): - if cond == 0: + if isinstance(cond_var, TensorConstant) and cond_var.unique_value is not None: + if cond_var.unique_value == 0: correct_out = right else: correct_out = left @@ -1014,7 +1012,7 @@ def local_useless_switch(fgraph, node): # if left is right -> left if equivalent_up_to_constant_casting(left, right): if left.type.broadcastable != out_bcast: - left, _ = broadcast_arrays(left, cond) + left, _ = broadcast_arrays(left, cond_var) out_dtype = node.outputs[0].type.dtype if left.type.dtype != out_dtype: @@ -1026,13 +1024,22 @@ def local_useless_switch(fgraph, node): # This case happens with scan. # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) if ( - cond_var.owner + node.outputs[0].type.ndim == 0 + and cond_var.owner and isinstance(cond_var.owner.op, Elemwise) and isinstance(cond_var.owner.op.scalar_op, ps.LE) and cond_var.owner.inputs[0].owner and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) - and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 - and extract_constant(left, only_process_constants=True) == 0 + and get_scalar_constant_value( + cond_var.owner.inputs[1], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + and get_scalar_constant_value( + left, only_process_constants=True, raise_not_constant=False + ) + == 0 and right == cond_var.owner.inputs[0] ): assert node.outputs[0].type.is_super(right.type) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 3d0a1ef6d1..3226f9b5a7 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -41,7 +41,7 @@ register_specialize, ) from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import TensorConstant class InplaceElemwiseOptimizer(GraphRewriter): @@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): new_inputs.append(i) else: try: - # works only for scalars cval_i = get_underlying_scalar_constant_value( i, only_process_constants=True ) @@ -1218,11 +1217,13 @@ def local_inline_composite_constants(fgraph, node): node.inputs, composite_op.fgraph.inputs, strict=True ): # Complex variables don't have a `c_literal` that can be inlined - if "complex" not in outer_inp.type.dtype: - unique_value = get_unique_constant_value(outer_inp) - if unique_value is not None: + if ( + isinstance(outer_inp, TensorConstant) + and "complex" not in outer_inp.type.dtype + ): + if outer_inp.unique_value is not None: inner_replacements[inner_inp] = ps.constant( - unique_value, dtype=inner_inp.dtype + outer_inp.unique_value, dtype=inner_inp.dtype ) continue new_outer_inputs.append(outer_inp) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 185c962e0b..aa2d279f43 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -28,7 +28,6 @@ as_tensor_variable, cast, constant, - extract_constant, get_underlying_scalar_constant_value, moveaxis, ones_like, @@ -106,7 +105,6 @@ from pytensor.tensor.variable import ( TensorConstant, TensorVariable, - get_unique_constant_value, ) @@ -128,32 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): return consts, origconsts, nonconsts -def get_constant(v): - """ - - Returns - ------- - object - A numeric constant if v is a Constant or, well, a - numeric constant. If v is a plain Variable, returns None. - - """ - if isinstance(v, Constant): - unique_value = get_unique_constant_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data - if data.ndim == 0: - return data - else: - return None - elif isinstance(v, Variable): - return None - else: - return v - - @register_canonicalize @register_stabilize @node_rewriter([Dot]) @@ -163,18 +135,16 @@ def local_0_dot_x(fgraph, node): x = node.inputs[0] y = node.inputs[1] - replace = False - try: - if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass - - try: - if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass + replace = ( + get_underlying_scalar_constant_value( + x, only_process_constants=True, raise_not_constant=False + ) + == 0 + or get_underlying_scalar_constant_value( + y, only_process_constants=True, raise_not_constant=False + ) + == 0 + ) if replace: constant_zero = constant(0, dtype=node.outputs[0].type.dtype) @@ -565,27 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node): @register_stabilize @register_specialize @register_canonicalize -@node_rewriter([sub]) +@node_rewriter([add, sub]) def local_expm1(fgraph, node): - """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``.""" - in1, in2 = node.inputs - out = node.outputs[0] + """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``.""" + if len(node.inputs) != 2: + # TODO: handle more than two inputs in add + return None - if ( - in1.owner - and isinstance(in1.owner.op, Elemwise) - and isinstance(in1.owner.op.scalar_op, ps.Exp) - and extract_constant(in2, only_process_constants=False) == 1 - ): - in11 = in1.owner.inputs[0] - new_out = expm1(in11) + if isinstance(node.op.scalar_op, ps.Sub): + exp_x, other_inp = node.inputs + if not ( + exp_x.owner + and isinstance(exp_x.owner.op, Elemwise) + and isinstance(exp_x.owner.op.scalar_op, ps.Exp) + and get_underlying_scalar_constant_value( + other_inp, raise_not_constant=False + ) + == 1 + ): + return None + else: + # Try both orders + other_inp, exp_x = node.inputs + for i in range(2): + if i == 1: + other_inp, exp_x = exp_x, other_inp + if ( + exp_x.owner + and isinstance(exp_x.owner.op, Elemwise) + and isinstance(exp_x.owner.op.scalar_op, ps.Exp) + and get_underlying_scalar_constant_value( + other_inp, raise_not_constant=False + ) + == -1 + ): + break + else: # no break + return None - if new_out.dtype != out.dtype: - new_out = cast(new_out, dtype=out.dtype) + [old_out] = node.outputs - if not out.type.is_super(new_out.type): - return - return [new_out] + [x] = exp_x.owner.inputs + if x.type.broadcastable != old_out.type.broadcastable: + x = broadcast_arrays(x, other_inp)[0] + + new_out = expm1(x) + + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, dtype=old_out.dtype) + + if not old_out.type.is_super(new_out.type): + return None + + return [new_out] @register_specialize @@ -628,7 +630,14 @@ def local_mul_switch_sink(fgraph, node): # Look for a zero as the first or second branch of the switch for branch in range(2): zero_switch_input = switch_node.inputs[1 + branch] - if not get_unique_constant_value(zero_switch_input) == 0.0: + if ( + not get_underlying_scalar_constant_value( + zero_switch_input, + only_process_constants=True, + raise_not_constant=False, + ) + == 0.0 + ): continue switch_cond = switch_node.inputs[0] @@ -685,7 +694,14 @@ def local_div_switch_sink(fgraph, node): # Look for a zero as the first or second branch of the switch for branch in range(2): zero_switch_input = switch_node.inputs[1 + branch] - if not get_unique_constant_value(zero_switch_input) == 0.0: + if ( + not get_underlying_scalar_constant_value( + zero_switch_input, + only_process_constants=True, + raise_not_constant=False, + ) + == 0.0 + ): continue switch_cond = switch_node.inputs[0] @@ -989,8 +1005,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): """ Find all constants and put them together into a single constant. - Finds all constants in orig_num and orig_denum (using - get_constant) and puts them together into a single + Finds all constants in orig_num and orig_denum + and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator. @@ -1011,17 +1027,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): numct, denumct = [], [] for v in orig_num: - ct = get_constant(v) - if ct is not None: + if isinstance(v, TensorConstant) and v.unique_value is not None: # We found a constant in the numerator! # We add it to numct - numct.append(ct) + numct.append(v.unique_value) else: num.append(v) for v in orig_denum: - ct = get_constant(v) - if ct is not None: - denumct.append(ct) + if isinstance(v, TensorConstant) and v.unique_value is not None: + denumct.append(v.unique_value) else: denum.append(v) @@ -1045,10 +1059,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: # In that case we should only have one constant in `ct`. - assert len(ct) == 1 - first_num_ct = get_constant(orig_num[0]) - if first_num_ct is not None and ct[0].type.values_eq( - ct[0].data, first_num_ct + [var_ct] = ct + first_num_var = orig_num[0] + first_num_ct = ( + first_num_var.unique_value + if isinstance(first_num_var, TensorConstant) + else None + ) + if first_num_ct is not None and var_ct.type.values_eq( + var_ct.data, first_num_ct ): # This is an important trick :( if it so happens that: # * there's exactly one constant on the numerator and none on @@ -1340,12 +1359,13 @@ def local_useless_elemwise_comparison(fgraph, node): the graph easier to read. """ + # TODO: Refactor this function. So much repeated code! + if node.op.scalar_op.nin != 2: return - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype + dtype = node.outputs[0].type.dtype + out_bcast = node.outputs[0].type.broadcastable # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) if ( @@ -1356,6 +1376,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.LE | ps.GE) @@ -1366,6 +1387,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[{minimum,maximum}](X, X) -> X if ( isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum) @@ -1381,64 +1403,72 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(node.op.scalar_op, ps.LT) and node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.GE) and node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = ones_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, ps.ScalarMaximum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - # No need to copy over stacktrace. - return [node.inputs[0]] - # Elemwise[maximum](0, X.shape[i]) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, ps.ScalarMaximum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - # No need to copy over stacktrace. - return [node.inputs[1]] - # Elemwise[minimum](X.shape[i], 0) -> 0 - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] + if isinstance(node.op.scalar_op, ps.ScalarMaximum): + for idx in range(2): + if ( + node.inputs[idx].owner + and isinstance(node.inputs[idx].owner.op, Shape_i) + and get_underlying_scalar_constant_value( + node.inputs[1 - idx], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + ): + res = node.inputs[idx] + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1 - idx])[0] + # No need to copy over stacktrace. + return [res] - # Elemwise[minimum](0, X.shape[i]) -> 0 - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - res = zeros_like(node.inputs[1], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] + # Elemwise[minimum](X.shape[i], 0) -> 0 + if isinstance(node.op.scalar_op, ps.ScalarMinimum): + for idx in range(2): + if ( + node.inputs[idx].owner + and isinstance(node.inputs[idx].owner.op, Shape_i) + and get_underlying_scalar_constant_value( + node.inputs[1 - idx], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + ): + res = zeros_like(node.inputs[idx], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1 - idx])[0] + # No need to copy over stacktrace. + return [res] # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) if ( @@ -1450,12 +1480,18 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(var.owner and var.owner.op, Shape_i) for var in node.inputs[0].owner.inputs ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.GE) @@ -1466,57 +1502,61 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(var.owner and var.owner.op, Shape_i) for var in node.inputs[0].owner.inputs ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = ones_like(node.inputs[0], dtype=dtype, opt=True) - + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] - # Elemwise[EQ](Subtensor(Shape(x)), -N) - # Elemwise[EQ](somegraph that only depend of shape, -N) - # TODO: handle the case where the -N is on either side - """ - |Elemwise{eq,no_inplace} [id B] '' - | |Subtensor{int64} [id C] '' - | | |Join [id D] '' - | | | |TensorConstant{0} [id E] - | | | |Subtensor{int64:int64:} [id F] '' - | | | | |Shape [id G] '' - """ + # Elemwise[EQ](Subtensor(Shape(x)), -N) + # Elemwise[EQ](somegraph that only depend of shape, -N) + # TODO: handle the case where the -N is on either side + """ +|Elemwise{eq,no_inplace} [id B] '' +| |Subtensor{int64} [id C] '' +| | |Join [id D] '' +| | | |TensorConstant{0} [id E] +| | | |Subtensor{int64:int64:} [id F] '' +| | | | |Shape [id G] '' + """ - def investigate(node): + def investigate_if_shape(node) -> bool: "Return True if values will be shapes, so >= 0" if isinstance(node.op, Shape | Shape_i): return True elif isinstance(node.op, Subtensor) and node.inputs[0].owner: - return investigate(node.inputs[0].owner) + return investigate_if_shape(node.inputs[0].owner) elif isinstance(node.op, Join): - return all(v.owner and investigate(v.owner) for v in node.inputs[1:]) + return all( + v.owner and investigate_if_shape(v.owner) for v in node.inputs[1:] + ) elif isinstance(node.op, MakeVector): - return all(v.owner and investigate(v.owner) for v in node.inputs) + return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs) + return False if ( isinstance(node.op.scalar_op, ps.EQ) and node.inputs[0].owner - and investigate(node.inputs[0].owner) + and investigate_if_shape(node.inputs[0].owner) + and ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[1].unique_value is not None + and node.inputs[1].unique_value < 0 + ) ): - try: - cst = get_underlying_scalar_constant_value( - node.inputs[1], only_process_constants=True - ) - - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - if cst < 0: - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - - return [res] + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] - except NotScalarConstantError: - pass return @@ -1813,12 +1853,6 @@ def local_add_neg_to_sub(fgraph, node): new_out = sub(first, pre_neg) return [new_out] - # Check if it is a negative constant - const = get_constant(second) - if const is not None and const < 0: - new_out = sub(first, np.abs(const)) - return [new_out] - @register_canonicalize @node_rewriter([mul]) @@ -1845,7 +1879,12 @@ def local_mul_zero(fgraph, node): @register_specialize @node_rewriter([true_div]) def local_div_to_reciprocal(fgraph, node): - if np.all(get_constant(node.inputs[0]) == 1.0): + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 1.0 + ): out = node.outputs[0] new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) # The ones could have forced upcasting @@ -1866,7 +1905,9 @@ def local_reciprocal_canon(fgraph, node): @register_canonicalize @node_rewriter([pt_pow]) def local_pow_canonicalize(fgraph, node): - cst = get_constant(node.inputs[1]) + cst = get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) if cst == 0: return [alloc_like(1, node.outputs[0], fgraph)] if cst == 1: @@ -1897,7 +1938,12 @@ def local_intdiv_by_one(fgraph, node): @node_rewriter([int_div, true_div]) def local_zero_div(fgraph, node): """0 / x -> 0""" - if get_constant(node.inputs[0]) == 0: + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 0 + ): ret = alloc_like(0, node.outputs[0], fgraph) ret.tag.values_eq_approx = values_eq_approx_remove_nan return [ret] @@ -1910,8 +1956,12 @@ def local_pow_specialize(fgraph, node): odtype = node.outputs[0].dtype xsym = node.inputs[0] ysym = node.inputs[1] - y = get_constant(ysym) - if (y is not None) and not broadcasted_by(xsym, ysym): + try: + y = get_underlying_scalar_constant_value(ysym, only_process_constants=True) + except NotScalarConstantError: + return + + if not broadcasted_by(xsym, ysym): rval = None if np.all(y == 2): @@ -1945,10 +1995,14 @@ def local_pow_to_nested_squaring(fgraph, node): """ # the idea here is that we have pow(x, y) + xsym, ysym = node.inputs + + try: + y = get_underlying_scalar_constant_value(ysym, only_process_constants=True) + except NotScalarConstantError: + return + odtype = node.outputs[0].dtype - xsym = node.inputs[0] - ysym = node.inputs[1] - y = get_constant(ysym) # the next line is needed to fix a strange case that I don't # know how to make a separate test. @@ -1964,7 +2018,7 @@ def local_pow_to_nested_squaring(fgraph, node): y = y[0] except IndexError: pass - if (y is not None) and not broadcasted_by(xsym, ysym): + if not broadcasted_by(xsym, ysym): rval = None # 512 is too small for the cpu and too big for some gpu! if abs(y) == int(abs(y)) and abs(y) <= 512: @@ -2031,7 +2085,9 @@ def local_mul_specialize(fgraph, node): nb_neg_node += 1 # remove special case arguments of 1, -1 or 0 - y = get_constant(inp) + y = get_underlying_scalar_constant_value( + inp, only_process_constants=True, raise_not_constant=False + ) if y == 1.0: nb_cst += 1 elif y == -1.0: @@ -2083,7 +2139,7 @@ def local_add_remove_zeros(fgraph, node): y = get_underlying_scalar_constant_value(inp) except NotScalarConstantError: y = inp - if np.all(y == 0.0): + if y == 0.0: continue new_inputs.append(inp) @@ -2181,7 +2237,7 @@ def local_abs_merge(fgraph, node): ) except NotScalarConstantError: return False - if not (const >= 0).all(): + if not const >= 0: return False inputs.append(i) else: @@ -2218,12 +2274,21 @@ def local_log1p(fgraph, node): return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] elif log_arg.owner and log_arg.owner.op == sub: - one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) + one, other = log_arg.owner.inputs + try: + one = get_underlying_scalar_constant_value(one, only_process_constants=True) + except NotScalarConstantError: + return + if one != 1: return - other = log_arg.owner.inputs[1] - if other.dtype != log_arg.dtype: + + if other.type.broadcastable != log_arg.type.broadcastable: + other = broadcast_arrays(other, one)[0] + + if other.type.dtype != log_arg.type.dtype: other = other.astype(log_arg.dtype) + return [log1p(neg(other))] @@ -2561,9 +2626,9 @@ def local_greedy_distributor(fgraph, node): register_stabilize(local_one_minus_erfc) register_specialize(local_one_minus_erfc) -# erfc(-x)-1=>erf(x) +# -1 + erfc(-x)=>erf(x) local_erf_neg_minus_one = PatternNodeRewriter( - (sub, (erfc, (neg, "x")), 1), + (add, -1, (erfc, (neg, "x"))), (erf, "x"), allow_multiple_clients=True, name="local_erf_neg_minus_one", @@ -2824,7 +2889,7 @@ def _is_1(expr): """ try: v = get_underlying_scalar_constant_value(expr) - return np.allclose(v, 1) + return np.isclose(v, 1) except NotScalarConstantError: return False @@ -2992,7 +3057,7 @@ def is_neg(var): for idx, mul_input in enumerate(var_node.inputs): try: constant = get_underlying_scalar_constant_value(mul_input) - is_minus_1 = np.allclose(constant, -1) + is_minus_1 = np.isclose(constant, -1) except NotScalarConstantError: is_minus_1 = False if is_minus_1: diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index c1284aa81d..e277772ad4 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -22,8 +22,7 @@ as_tensor_variable, cast, constant, - extract_constant, - get_underlying_scalar_constant_value, + get_scalar_constant_value, register_infer_shape, stack, ) @@ -213,7 +212,7 @@ def shape_ir(self, i, r): # Do not call make_node for test_value s = Shape_i(i)(r) try: - s = get_underlying_scalar_constant_value(s) + s = get_scalar_constant_value(s) except NotScalarConstantError: pass return s @@ -297,7 +296,7 @@ def unpack(self, s_i, var): assert len(idx) == 1 idx = idx[0] try: - i = get_underlying_scalar_constant_value(idx) + i = get_scalar_constant_value(idx) except NotScalarConstantError: pass else: @@ -354,7 +353,9 @@ def set_shape(self, r, s, override=False): not hasattr(r.type, "shape") or r.type.shape[i] != 1 or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) + or self.lscalar_one.equals( + get_scalar_constant_value(shape_vars[i], raise_not_constant=False) + ) for i in range(r.type.ndim) ) self.shape_of[r] = tuple(shape_vars) @@ -450,7 +451,11 @@ def update_shape(self, r, other_r): ) or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals( - extract_constant(merged_shape[i], only_process_constants=True) + get_scalar_constant_value( + merged_shape[i], + only_process_constants=True, + raise_not_constant=False, + ) ) for i in range(r.type.ndim) ) @@ -474,7 +479,9 @@ def set_shape_i(self, r, i, s_i): not hasattr(r.type, "shape") or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals(extract_constant(new_shape[idx])) + or self.lscalar_one.equals( + get_scalar_constant_value(new_shape[idx], raise_not_constant=False) + ) for idx in range(r.type.ndim) ) self.shape_of[r] = tuple(new_shape) @@ -847,7 +854,10 @@ def local_useless_reshape(fgraph, node): outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and len(outshp_i.owner.inputs) == 2 - and extract_constant(outshp_i.owner.inputs[1]) == dim + and get_scalar_constant_value( + outshp_i.owner.inputs[1], raise_not_constant=False + ) + == dim ): subtensor_inp = outshp_i.owner.inputs[0] if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): @@ -857,7 +867,9 @@ def local_useless_reshape(fgraph, node): continue # Match constant if input.type.shape[dim] == constant - cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) + cst_outshp_i = get_scalar_constant_value( + outshp_i, only_process_constants=True, raise_not_constant=False + ) if inp.type.shape[dim] == cst_outshp_i: shape_match[dim] = True continue @@ -872,8 +884,12 @@ def local_useless_reshape(fgraph, node): if shape_feature: inpshp_i = shape_feature.get_shape(inp, dim) if inpshp_i == outshp_i or ( - extract_constant(inpshp_i, only_process_constants=True) - == extract_constant(outshp_i, only_process_constants=True) + get_scalar_constant_value( + inpshp_i, only_process_constants=True, raise_not_constant=False + ) + == get_scalar_constant_value( + outshp_i, only_process_constants=True, raise_not_constant=False + ) ): shape_match[dim] = True continue @@ -909,11 +925,14 @@ def local_reshape_to_dimshuffle(fgraph, node): new_output_shape = [] index = 0 # index over the output of the new reshape for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust extract_constant + # Since output_shape is a symbolic vector, we trust get_scalar_constant_value # to go through however it is formed to see if its i-th element is 1. # We need only_process_constants=False for that. - dim = extract_constant( - output_shape[i], only_process_constants=False, elemwise=False + dim = get_scalar_constant_value( + output_shape[i], + only_process_constants=False, + elemwise=False, + raise_not_constant=False, ) if dim == 1: dimshuffle_new_order.append("x") diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 572d2bcab6..4b824e46cf 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -26,7 +26,7 @@ as_tensor, cast, concatenate, - extract_constant, + get_scalar_constant_value, get_underlying_scalar_constant_value, register_infer_shape, switch, @@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node): start = s.start stop = s.stop - if start is not None and extract_constant( - start, only_process_constants=True + if start is not None and get_scalar_constant_value( + start, only_process_constants=True, raise_not_constant=False ) == (0 if positive_step else -1): change_flag = True start = None @@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node): if ( stop is not None and x.type.shape[dim] is not None - and extract_constant(stop, only_process_constants=True) + and get_scalar_constant_value( + stop, only_process_constants=True, raise_not_constant=False + ) == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) ): change_flag = True @@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node): and e.stop is None and ( e.step is None - or extract_constant(e.step, only_process_constants=True) == -1 + or get_scalar_constant_value( + e.step, only_process_constants=True, raise_not_constant=False + ) + == -1 ) for e in idx_cst ): @@ -994,7 +999,7 @@ def local_useless_subtensor(fgraph, node): if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize try: - length_pos_data = get_underlying_scalar_constant_value( + length_pos_data = get_scalar_constant_value( length_pos, only_process_constants=True ) except NotScalarConstantError: @@ -1059,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): # get length of the indexed tensor along the first axis try: - length = get_underlying_scalar_constant_value( + length = get_scalar_constant_value( shape_of[node.inputs[0]][0], only_process_constants=True ) except NotScalarConstantError: @@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): and # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. - extract_constant(x, elemwise=False) != 0 + get_underlying_scalar_constant_value( + x, elemwise=False, raise_not_constant=False + ) + != 0 ): return @@ -1728,7 +1736,7 @@ def local_join_subtensors(fgraph, node): axis, tensors = node.inputs[0], node.inputs[1:] try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) except NotScalarConstantError: return @@ -1789,12 +1797,7 @@ def local_join_subtensors(fgraph, node): if step is None: continue try: - if ( - get_underlying_scalar_constant_value( - step, only_process_constants=True - ) - != 1 - ): + if get_scalar_constant_value(step, only_process_constants=True) != 1: return None except NotScalarConstantError: return None diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index a357f25672..8913d6fb4d 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -20,7 +20,7 @@ from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor -from pytensor.tensor.type_other import NoneConst +from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -401,8 +401,6 @@ class SpecifyShape(COp): _output_type_depends_on_input_value = True def make_node(self, x, *shape): - from pytensor.tensor.basic import get_underlying_scalar_constant_value - x = ptb.as_tensor_variable(x) shape = tuple( @@ -428,11 +426,9 @@ def make_node(self, x, *shape): for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)): if xts is not None: type_shape[i] = xts - else: + elif not isinstance(s.type, NoneTypeT): try: - type_s = get_underlying_scalar_constant_value(s) - if type_s is not None: - type_shape[i] = int(type_s) + type_shape[i] = int(ptb.get_scalar_constant_value(s)) except NotScalarConstantError: pass @@ -460,22 +456,13 @@ def perform(self, node, inp, out_): def infer_shape(self, fgraph, node, shapes): xshape, *_ = shapes shape = node.inputs[1:] - new_shape = [] - for dim in range(node.inputs[0].type.ndim): - s = shape[dim] - try: - s = ptb.get_underlying_scalar_constant_value(s) - # We assume that `None` shapes are always retrieved by - # `get_underlying_scalar_constant_value`, and only in that case do we default to - # the shape of the input variable - if s is None: - s = xshape[dim] - except NotScalarConstantError: - pass - new_shape.append(ptb.as_tensor_variable(s)) - - assert len(new_shape) == len(xshape) - return [new_shape] + # Use x shape if specified dim is None, otherwise the specified shape + return [ + [ + xshape[i] if isinstance(dim.type, NoneTypeT) else dim + for i, dim in enumerate(shape) + ] + ] def connection_pattern(self, node): return [[True], *[[False]] * len(node.inputs[1:])] @@ -593,7 +580,7 @@ def specify_shape( @_get_vector_length.register(SpecifyShape) # type: ignore def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int: try: - return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()) + return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item()) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -674,7 +661,7 @@ def make_node(self, x, shp): y = shp_list[index] y = ptb.as_tensor_variable(y) try: - s_val = ptb.get_underlying_scalar_constant_value(y).item() + s_val = ptb.get_scalar_constant_value(y).item() if s_val >= 0: out_shape[index] = s_val except NotScalarConstantError: diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4904259d25..325567918a 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -259,9 +259,10 @@ def make_node(self, A, b): raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices - o_dtype = scipy.linalg.solve( - np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype) - ).dtype + inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] + out_arr = [[None]] + self.perform(None, inp_arr, out_arr) + o_dtype = out_arr[0][0].dtype x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index fe4d06f152..a3a81f63bd 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -29,7 +29,7 @@ from pytensor.tensor.basic import ( ScalarFromTensor, alloc, - get_underlying_scalar_constant_value, + get_scalar_constant_value, nonzero, scalar_from_tensor, ) @@ -778,7 +778,7 @@ def conv(val): return slice(conv(val.start), conv(val.stop), conv(val.step)) else: try: - return get_underlying_scalar_constant_value( + return get_scalar_constant_value( val, only_process_constants=only_process_constants, elemwise=elemwise, @@ -855,7 +855,7 @@ def extract_const(value): if value is None: return value, True try: - value = get_underlying_scalar_constant_value(value) + value = get_scalar_constant_value(value) return value, True except NotScalarConstantError: return value, False @@ -3022,17 +3022,17 @@ def _get_vector_length_Subtensor(op, var): start = ( None if indices[0].start is None - else get_underlying_scalar_constant_value(indices[0].start) + else get_scalar_constant_value(indices[0].start) ) stop = ( None if indices[0].stop is None - else get_underlying_scalar_constant_value(indices[0].stop) + else get_scalar_constant_value(indices[0].stop) ) step = ( None if indices[0].step is None - else get_underlying_scalar_constant_value(indices[0].step) + else get_scalar_constant_value(indices[0].step) ) if start == stop: diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index ae515d7432..ac89283bb6 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -11,7 +11,10 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Constant, OptionalApplyType, Variable from pytensor.graph.utils import MetaType -from pytensor.scalar import ComplexError, IntegerDivisionError +from pytensor.scalar import ( + ComplexError, + IntegerDivisionError, +) from pytensor.tensor import _get_vector_length from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.type import TensorType @@ -1042,17 +1045,9 @@ def no_nan(self): def get_unique_constant_value(x: TensorVariable) -> Number | None: """Return the unique value of a tensor, if there is one""" - if isinstance(x, Constant): - data = x.data - - if isinstance(data, np.ndarray) and data.size > 0: - if data.size == 1: - return data.squeeze() - - flat_data = data.ravel() - if (flat_data == flat_data[0]).all(): - return flat_data[0] - + warnings.warn("get_unique_constant_value is deprecated.", FutureWarning) + if isinstance(x, TensorConstant): + return x.unique_value return None @@ -1081,6 +1076,30 @@ def __init__(self, type: _TensorTypeType, data, name=None): def signature(self): return TensorConstantSignature((self.type, self.data)) + @property + def unique_value(self) -> Number | None: + """Return the unique value of a tensor, if there is one""" + try: + return self._unique_value + except AttributeError: + data = self.data + unique_value = None + if data.size > 0: + if data.size == 1: + unique_value = data.squeeze() + else: + flat_data = data.ravel() + if (flat_data == flat_data[0]).all(): + unique_value = flat_data[0] + + if unique_value is not None: + # Don't allow the unique value to be changed + unique_value.setflags(write=False) + + self._unique_value = unique_value + + return self._unique_value + def equals(self, other): # Override Constant.equals to allow to compare with # numpy.ndarray, and python type. diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 33c61f48bc..debcf44c64 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1383,11 +1383,11 @@ def assert_eqs_const(self, f, val, op=deep_copy_op): if op == deep_copy_op: assert len(elem.inputs) == 1, elem.inputs assert isinstance(elem.inputs[0], TensorConstant), elem - assert pt.extract_constant(elem.inputs[0]) == val, val + assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val else: assert len(elem.inputs) == 2, elem.inputs assert isinstance(elem.inputs[0], TensorConstant), elem - assert pt.extract_constant(elem.inputs[0]) == val, val + assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val def assert_identity(self, f): topo = f.maker.fgraph.toposort() @@ -3806,14 +3806,9 @@ def test_local_expm1(): for n in h.maker.fgraph.toposort() ) - # This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked - expect_rewrite = config.mode != "FAST_COMPILE" - assert ( - any( - isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1) - for n in r.maker.fgraph.toposort() - ) - == expect_rewrite + assert any( + isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1) + for n in r.maker.fgraph.toposort() ) @@ -4440,23 +4435,6 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -def test_local_add_neg_to_sub_const(): - x = vector("x") - const = 5.0 - - f = function([x], x + (-const), mode=Mode("py")) - - nodes = [ - node.op - for node in f.maker.fgraph.toposort() - if not isinstance(node.op, DimShuffle) - ] - assert nodes == [pt.sub] - - x_test = np.array([3, 4], dtype=config.floatX) - assert np.allclose(f(x_test), x_test + (-const)) - - def test_log1mexp_stabilization(): mode = Mode("py").including("stabilize") diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index c3ddae4b9f..ff8751e411 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -46,7 +46,6 @@ default, diag, expand_dims, - extract_constant, eye, fill, flatnonzero, @@ -3571,12 +3570,13 @@ def test_second(self): assert get_underlying_scalar_constant_value(s) == c.data def test_copy(self): - # Make sure we do not return the internal storage of a constant, + # Make sure we do not return a writeable internal storage of a constant, # so we cannot change the value of a constant by mistake. c = constant(3) - d = extract_constant(c) - d += 1 - e = extract_constant(c) + d = get_scalar_constant_value(c) + with pytest.raises(ValueError, match="output array is read-only"): + d += 1 + e = get_scalar_constant_value(c) assert e == 3, (c, d, e) @pytest.mark.parametrize("only_process_constants", (True, False)) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 8ce40d48ef..51862562ac 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -590,7 +590,7 @@ def core_scipy_fn(A, b): A_val_copy, b_val_copy ) np.testing.assert_allclose( - out, expected_out, atol=1e-5 if config.floatX == "float32" else 0 + out, expected_out, atol=1e-4 if config.floatX == "float32" else 0 ) # Confirm input was destroyed diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index c1644e41e1..081e495127 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -19,7 +19,7 @@ from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable -from pytensor.tensor.basic import second +from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Any, Sum, exp from pytensor.tensor.math import all as pt_all @@ -807,8 +807,8 @@ def test_partial_static_shape_info(self): assert len(res_shape) == 1 assert len(res_shape[0]) == 2 - assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 - assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 + assert get_scalar_constant_value(res_shape[0][0]) == 1 + assert get_scalar_constant_value(res_shape[0][1]) == 1 def test_infer_shape_multi_output(self): class CustomElemwise(Elemwise): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 3d4b6697b8..f46d771938 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -169,7 +169,12 @@ def test_eigvalsh_grad(): ) -class TestSolveBase(utt.InferShapeTester): +class TestSolveBase: + class SolveTest(SolveBase): + def perform(self, node, inputs, outputs): + A, b = inputs + outputs[0][0] = scipy.linalg.solve(A, b) + @pytest.mark.parametrize( "A_func, b_func, error_message", [ @@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message): with pytest.raises(ValueError, match=error_message): A = A_func() b = b_func() - SolveBase(b_ndim=2)(A, b) + self.SolveTest(b_ndim=2)(A, b) def test__repr__(self): np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = SolveBase(b_ndim=2)(A, b) + y = self.SolveTest(b_ndim=2)(A, b) assert ( y.__repr__() - == "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" + == "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" ) @@ -239,8 +244,9 @@ def test_correctness(self): A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX) A_val = np.dot(A_val.transpose(), A_val) - assert np.allclose( - scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val) + np.testing.assert_allclose( + scipy.linalg.solve(A_val, b_val, assume_a="gen"), + gen_solve_func(A_val, b_val), ) A_undef = np.array( @@ -253,7 +259,7 @@ def test_correctness(self): ], dtype=config.floatX, ) - assert np.allclose( + np.testing.assert_allclose( scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val) ) @@ -450,7 +456,7 @@ def test_solve_dtype(self): fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) - assert x.dtype == x_result.dtype + assert x.dtype == x_result.dtype, (A_dtype, b_dtype) def test_cho_solve():