diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 0719870205..5b45fa40f4 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -1550,6 +1550,7 @@ def __init__( tracks=(), get_nodes=None, values_eq_approx=None, + allow_cast=True, ): """ @@ -1572,6 +1573,10 @@ def __init__( If you provide `tracks`, you must provide this parameter. It must be a function that takes the tracked node and returns a list of nodes on which we will try this rewrite. + values_eq_approx + TODO + allow_cast + Automatically cast the output of the rewrite whenever new and old types differ Notes ----- @@ -1586,6 +1591,7 @@ def __init__( self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) self.values_eq_approx = values_eq_approx + self.allow_cast = allow_cast if isinstance(in_pattern, list | tuple): self.op = self.in_pattern[0] elif isinstance(in_pattern, dict): @@ -1630,6 +1636,10 @@ def transform(self, fgraph, node, get_nodes=True): if node.op != self.op: return False + if len(node.outputs) != 1: + # PatternNodeRewriter doesn't support replacing multi-output nodes + return False + s = unify(self.in_pattern, node.out) if s is False: @@ -1652,19 +1662,20 @@ def transform(self, fgraph, node, get_nodes=True): ): return False - if ret.owner: + [old_out] = node.outputs + if not old_out.type.is_super(ret.type): + # Type doesn't match if not ( - len(node.outputs) == len(ret.owner.outputs) - and all( - o.type.is_super(new_o.type) - for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True) - ) + self.allow_cast + and isinstance(old_out.type, pytensor.tensor.TensorType) + and isinstance(ret.type, pytensor.tensor.TensorType) ): return False - else: - # ret is just an input variable - assert len(node.outputs) == 1 - if not node.outputs[0].type.is_super(ret.type): + + # Try to cast tensors + ret = ret.astype(old_out.type.dtype) + if not old_out.type.is_super(ret.type): + # Still doesn't match return False return [ret] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index d126502bde..07dca66ad3 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2385,7 +2385,7 @@ def local_log1p(fgraph, node): log_arg.owner.inputs, only_process_constants=True ) # scalar_inputs are potentially dimshuffled and fill'd scalars - if scalars and np.allclose(np.sum(scalars), 1): + if scalars and isclose(np.sum(scalars), 1): if nonconsts: ninp = variadic_add(*nonconsts) if ninp.dtype != log_arg.type.dtype: @@ -2990,6 +2990,21 @@ def check_input(inputs): return [ret] +def isclose(x, ref, rtol=0, atol=0, num_ulps=10): + """ + + Returns + ------- + bool + True iff x is a constant close to ref (by default 10 ULPs). + + """ + x = np.asarray(x) + if np.issubdtype(x.dtype, np.floating): + atol = atol + num_ulps * np.abs(np.spacing(x.dtype.type(ref))) + return np.allclose(x, ref, rtol=rtol, atol=atol) + + def _skip_mul_1(r): if r.owner and r.owner.op == mul: not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] @@ -3008,7 +3023,7 @@ def _is_1(expr): """ try: v = get_underlying_scalar_constant_value(expr) - return np.isclose(v, 1) + return isclose(v, 1) except NotScalarConstantError: return False @@ -3069,7 +3084,7 @@ def is_1pexp(t, only_process_constants=True): scal_sum = scalars[0] for s in scalars[1:]: scal_sum = scal_sum + s - if np.allclose(scal_sum, 1): + if isclose(scal_sum, 1): return False, maybe_exp.owner.inputs[0] return None @@ -3169,7 +3184,7 @@ def is_neg(var): for idx, mul_input in enumerate(var_node.inputs): try: constant = get_underlying_scalar_constant_value(mul_input) - is_minus_1 = np.isclose(constant, -1) + is_minus_1 = isclose(constant, -1) except NotScalarConstantError: is_minus_1 = False if is_minus_1: @@ -3577,7 +3592,7 @@ def local_reciprocal_1_plus_exp(fgraph, node): # scalar_inputs are potentially dimshuffled and fill'd scalars if len(nonconsts) == 1: if nonconsts[0].owner and nonconsts[0].owner.op == exp: - if scalars_ and np.allclose(np.sum(scalars_), 1): + if scalars_ and isclose(np.sum(scalars_), 1): out = [ alloc_like( sigmoid(neg(nonconsts[0].owner.inputs[0])), diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index 3d36a2234a..d0cb94f9fb 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -41,6 +41,7 @@ op_y, op_z, ) +from tests.unittest_tools import assert_equal_computations class AssertNoChanges(Feature): @@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern): assert e.type.is_super(fg.outputs[0].type) -def test_patternsub_different_output_lengths(): - # Test that PatternNodeRewriter won't replace nodes with different numbers of outputs - ps = PatternNodeRewriter( - (op1, "x"), - ("x"), +def test_patternsub_multi_output_nodes(): + # Test that PatternNodeRewriter won't attempt to replace multi-output nodes + multiple_op_ps = PatternNodeRewriter( + (op_multiple_outputs, "x"), + "x", name="ps", ) - rewriter = in2out(ps) + + single_op_ps = PatternNodeRewriter( + (op_y, "x"), + "x", + name="ps", + ) + + rewriter = in2out(multiple_op_ps, single_op_ps) x = MyVariable("x") e1, e2 = op_multiple_outputs(x) - o = op1(e1) + o1, o2 = op_y(e1), op_y(e2) + + fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False) + rewriter.rewrite(fgraph) + # This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched + assert_equal_computations(fgraph.outputs, [e2, e1]) - fgraph = FunctionGraph(inputs=[x], outputs=[o]) + fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False) rewriter.rewrite(fgraph) - assert fgraph.outputs[0].owner.op == op1 + # Having a variable that comes out of a multi-output node should be fine + assert_equal_computations(fgraph.outputs, [e2, e1]) class TestSequentialNodeRewriter: diff --git a/tests/graph/utils.py b/tests/graph/utils.py index 86b52a7ed1..2e14fc79a4 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -107,6 +107,9 @@ def make_node(self, *inputs): class MyOpMultipleOutputs(MyOp): + def __init__(self, name, dmap=None, x=None): + super().__init__(name=name, dmap=dmap, x=x, n_outs=2) + def make_node(self, input): outputs = [input.type(), input.type()] return Apply(self, [input], outputs) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3699a3fcff..15cf31a7ff 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -49,6 +49,7 @@ bitwise_and, bitwise_or, bitwise_xor, + cast, conj, cosh, deg2rad, @@ -123,6 +124,7 @@ dvector, fmatrices, fmatrix, + fscalar, ftensor4, fvector, imatrices, @@ -4114,25 +4116,36 @@ def test_exp_over_1_plus_exp(self): def test_local_1msigmoid(self): m = self.get_mode(excluding=["fusion", "inplace"]) - x = fmatrix() + x = fscalar() + xd = dscalar() # Test `exp_over_1_plus_exp` f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m) # FIXME: PatternNodeRewriter does not copy stack trace # (see https://github.com/Theano/Theano/issues/4581) # assert check_stack_trace(f, ops_to_check=[neg, sigmoid]) - assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] + assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)]) # Test `inv_1_plus_exp` f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m) # assert check_stack_trace(f, ops_to_check=[neg, sigmoid]) - assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] + assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)]) # Test float constant - f = pytensor.function( - [x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m - ) - assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] + for out, expected in [ + (np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)), + (np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")), + (np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)), + (np.array(1.0, "float64") - sigmoid(xd), sigmoid(-xd)), + (np.sum(1 / np.array([2, 3, 6], "float32")) - sigmoid(x), sigmoid(-x)), + (np.sum(1 / np.array([2, 3, 6], "float64")) - sigmoid(xd), sigmoid(-xd)), + (np.float32(1 - 9e-6) - sigmoid(x), np.float32(1 - 9e-6) - sigmoid(x)), + (np.float64(1 - 1e-9) - sigmoid(xd), np.float64(1 - 1e-9) - sigmoid(xd)), + ]: + rewritten = rewrite_graph( + out, include=["canonicalize", "specialize", "stabilize"] + ) + utt.assert_equal_computations([rewritten], [expected], original=out) def test_local_sigm_times_exp(self): """ @@ -4280,7 +4293,8 @@ def test_log1msigm_to_softplus(self): f(np.random.random((54, 11)).astype(config.floatX)) # Test close to 1 - out = log(1.000001 - sigmoid(x)) + x_dtype = np.dtype(x.dtype).type + out = log(np.nextafter(x_dtype(1), x_dtype(2)) - sigmoid(x)) f = pytensor.function([x], out, mode=self.m) topo = f.maker.fgraph.toposort() assert len(topo) == 2 diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index 1bdfc01410..fee2fea3b4 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -11,6 +11,7 @@ from pytensor.compile.debugmode import str_diagnostic from pytensor.configdefaults import config from pytensor.gradient import verify_grad as orig_verify_grad +from pytensor.graph.basic import equal_computations from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.math import _allclose from pytensor.tensor.math import add as pt_add @@ -279,6 +280,41 @@ def assert_allclose(expected, value, rtol=None, atol=None): raise WrongValue(expected, value, rtol, atol) +def assert_equal_computations(rewritten, expected, *args, original=None, **kwargs): + """ + Assert that `rewritten` computes the same as `expected`. + + Parameters + ---------- + rewritten + The expression after the rewrite pass. + expected + The reference expression to compare against. + *args, **kwargs + Extra arguments forwarded to equal_computations. + original : optional + If given, will be printed in the error message. + """ + __tracebackhide__ = True # Hide traceback for py.test + + ok = equal_computations(rewritten, expected, *args, **kwargs) + + if not ok: + parts = [] + + def _dprint(expr): + return pytensor.dprint(expr, print_type=True, file="str") + + if original is not None: + parts.append(f"\nOriginal:\n{_dprint(original)}") + parts.append(f"\nRewritten:\n{_dprint(rewritten)}") + parts.append(f"\nExpected:\n{_dprint(expected)}") + + raise AssertionError("equal_computations failed\n" + "".join(parts)) + + return True + + class AttemptManyTimes: """Decorator for unit tests that forces a unit test to be attempted multiple times. The test needs to pass a certain number of times for it to