From a52f8bcd85ac4a1f338c39f5d94a55290d6f54fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Mon, 9 Jan 2023 11:16:21 +0100 Subject: [PATCH 1/8] pytensor-54: Rewrite products of exponents as exponent of sum. Rewrite e^x*e^y to e^(x+y), e^x/e^y to e^(x-y). --- pytensor/tensor/rewriting/math.py | 29 +++++++++++++++++ tests/tensor/rewriting/test_math.py | 50 +++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 590625445f..323a3d738d 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -423,6 +423,35 @@ def local_sumsqr2dot(fgraph, node): return [new_out] +@register_canonicalize +@register_specialize +@node_rewriter([Elemwise]) +def local_mulexp2expadd(fgraph, node): + """ + This rewrite detects e^x * e^y and converts it to e^(x+y). + Similarly, e^x / e^y becomes e^(x-y). + """ + if ( + isinstance(node.op, Elemwise) + and isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Elemwise) + and isinstance(node.inputs[0].owner.op.scalar_op, aes.Exp) + and node.inputs[1].owner + and isinstance(node.inputs[1].owner.op, Elemwise) + and isinstance(node.inputs[1].owner.op.scalar_op, aes.Exp) + ): + input1 = node.inputs[0].owner.inputs[0] + input2 = node.inputs[1].owner.inputs[0] + if isinstance(node.op.scalar_op, aes.Mul): + new_out = exp(input1 + input2) + else: # TrueDiv + new_out = exp(input1 - input2) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + return [new_out] + + @register_stabilize @register_specialize @register_canonicalize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3f508af79b..e85e77ff21 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4014,6 +4014,56 @@ def test_local_sumsqr2dot(): ) +def test_local_mulexp2expadd(): + # e^x * e^y = e^(x+y) + # test simple scalars first + x = scalar("x") + y = scalar("y") + expx = exp(x) + expy = exp(y) + expx_expy = expx * expy + f = function([x, y], expx_expy) + utt.assert_allclose(f(3, 4), np.exp(3 + 4)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Add) for n in inner_graph) + + # expect same for matrices as well + mx = matrix("mx") + my = matrix("my") + f = function([mx, my], exp(mx) * exp(my)) + M1 = np.array([[1.0, 2.0], [3.0, 4.0]]) + M2 = np.array([[5.0, 6.0], [7.0, 8.0]]) + utt.assert_allclose(f(M1, M2), np.exp(M1 + M2)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Add) for n in inner_graph) + + # checking whether further rewrites can proceed after this one as one would expect + # e^x * e^(-x) = e^(x-x) = e^0 = 1 + f = function([x], expx * exp(neg(x))) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].inputs[0], TensorConstant) + utt.assert_allclose(f(42), 1) + + # e^x / e^y = e^(x-y) + expx_div_expy = expx / expy + f = function([x, y], expx_div_expy) + utt.assert_allclose(f(5, 3), np.exp(5 - 3)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Sub) for n in inner_graph) + + # e^x / e^x = e^(x-x) = e^0 = 1 + f = function([x], expx / expx) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].inputs[0], TensorConstant) + utt.assert_allclose(f(42), 1) + + def test_local_expm1(): x = matrix("x") u = scalar("u") From 28fdc862cbe14f1ce615d42889e908028a66dae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Tue, 10 Jan 2023 15:53:19 +0100 Subject: [PATCH 2/8] pytensor-54: Handle properly the scenarios where a Mul node has more than two factors with some of which may not be an exp --- pytensor/tensor/rewriting/math.py | 54 +++++++++++++-------- tests/tensor/rewriting/test_math.py | 73 ++++++++++++++++++++++------- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 323a3d738d..f6a0e9fc86 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -423,33 +423,47 @@ def local_sumsqr2dot(fgraph, node): return [new_out] -@register_canonicalize @register_specialize -@node_rewriter([Elemwise]) +@node_rewriter([mul, true_div]) def local_mulexp2expadd(fgraph, node): """ This rewrite detects e^x * e^y and converts it to e^(x+y). Similarly, e^x / e^y becomes e^(x-y). """ - if ( - isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Elemwise) - and isinstance(node.inputs[0].owner.op.scalar_op, aes.Exp) - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Elemwise) - and isinstance(node.inputs[1].owner.op.scalar_op, aes.Exp) + if isinstance(node.op, Elemwise) and isinstance( + node.op.scalar_op, (aes.Mul, aes.TrueDiv) ): - input1 = node.inputs[0].owner.inputs[0] - input2 = node.inputs[1].owner.inputs[0] - if isinstance(node.op.scalar_op, aes.Mul): - new_out = exp(input1 + input2) - else: # TrueDiv - new_out = exp(input1 - input2) - if new_out.dtype != node.outputs[0].dtype: - new_out = cast(new_out, dtype=node.outputs[0].dtype) - return [new_out] + exps = [ + n.owner.inputs[0] + for n in node.inputs + if n.owner + and hasattr(n.owner.op, "scalar_op") + and isinstance(n.owner.op.scalar_op, aes.Exp) + ] + # Can only do any rewrite if there are at least two exp-s + if len(exps) >= 2: + # Mul -> add; TrueDiv -> sub + orig_op, new_op = mul, add + if isinstance(node.op.scalar_op, aes.TrueDiv): + orig_op, new_op = true_div, sub + new_out = exp(new_op(*exps)) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + # The original Mul may have more than two factors, some of which may not be exp nodes. + # If so, we keep multiplying them with the new exp(sum) node. + # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w + rest = [ + n + for n in node.inputs + if not n.owner + or not hasattr(n.owner.op, "scalar_op") + or not isinstance(n.owner.op.scalar_op, aes.Exp) + ] + if len(rest) > 0: + new_out = orig_op(new_out, *rest) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + return [new_out] @register_stabilize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index e85e77ff21..c41549b12b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4015,19 +4015,68 @@ def test_local_sumsqr2dot(): def test_local_mulexp2expadd(): - # e^x * e^y = e^(x+y) - # test simple scalars first x = scalar("x") y = scalar("y") + z = scalar("z") + w = scalar("w") expx = exp(x) expy = exp(y) - expx_expy = expx * expy - f = function([x, y], expx_expy) - utt.assert_allclose(f(3, 4), np.exp(3 + 4)) + expz = exp(z) + expw = exp(w) + + # e^x * e^y * e^z * e^w = e^(x+y+z+w) + op = expx * expy * expz * expw + f = function([x, y, z, w], op) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Add) for n in inner_graph) + assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + + # e^x * e^y * e^z / e^w = e^(x+y+z-w) + op = expx * expy * expz / expw + f = function([x, y, z, w], op) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Add) for n in inner_graph) + assert any(isinstance(n.op, aes.Sub) for n in inner_graph) + assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph) + + # e^x * e^y / e^z * e^w = e^(x+y-z+w) + op = expx * expy / expz * expw + f = function([x, y, z, w], op) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Add) for n in inner_graph) + assert any(isinstance(n.op, aes.Sub) for n in inner_graph) + assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph) + + # e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z) + op = expx / expy / expz + f = function([x, y, z], op) + utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Sub) for n in inner_graph) + assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph) + + # e^x * y * e^z * w = e^(x+z) * y * w + op = expx * y * expz * w + f = function([x, y, z, w], op) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6) graph = f.maker.fgraph.toposort() assert isinstance(graph[0].op, Elemwise) inner_graph = graph[0].op.scalar_op.fgraph.toposort() assert any(isinstance(n.op, aes.Add) for n in inner_graph) + assert any(isinstance(n.op, aes.Mul) for n in inner_graph) # expect same for matrices as well mx = matrix("mx") @@ -4040,28 +4089,20 @@ def test_local_mulexp2expadd(): assert isinstance(graph[0].op, Elemwise) inner_graph = graph[0].op.scalar_op.fgraph.toposort() assert any(isinstance(n.op, aes.Add) for n in inner_graph) + assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) # checking whether further rewrites can proceed after this one as one would expect # e^x * e^(-x) = e^(x-x) = e^0 = 1 f = function([x], expx * exp(neg(x))) - graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].inputs[0], TensorConstant) utt.assert_allclose(f(42), 1) - - # e^x / e^y = e^(x-y) - expx_div_expy = expx / expy - f = function([x, y], expx_div_expy) - utt.assert_allclose(f(5, 3), np.exp(5 - 3)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Sub) for n in inner_graph) + assert isinstance(graph[0].inputs[0], TensorConstant) # e^x / e^x = e^(x-x) = e^0 = 1 f = function([x], expx / expx) + utt.assert_allclose(f(42), 1) graph = f.maker.fgraph.toposort() assert isinstance(graph[0].inputs[0], TensorConstant) - utt.assert_allclose(f(42), 1) def test_local_expm1(): From 8466acd2d30b716423f97136f1cac52a22c2d3b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Thu, 12 Jan 2023 15:16:56 +0100 Subject: [PATCH 3/8] pytensor-54: Rewrite a^x * a^y to a^(x+y) --- pytensor/tensor/rewriting/math.py | 59 +++++++++++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 58 ++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index f6a0e9fc86..5e1641dccd 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -466,6 +466,65 @@ def local_mulexp2expadd(fgraph, node): return [new_out] +@register_specialize +@node_rewriter([mul, true_div]) +def local_mulpow2powadd(fgraph, node): + """ + This rewrite detects a^x * a^y and converts it to a^(x+y). + Similarly, a^x / a^y becomes a^(x-y). + """ + if isinstance(node.op, Elemwise) and isinstance( + node.op.scalar_op, (aes.Mul, aes.TrueDiv) + ): + from collections import defaultdict + + # search for pow-s and group them by their bases + pow_nodes = defaultdict(list) + rest = [] + for n in node.inputs: + if ( + n.owner + and hasattr(n.owner.op, "scalar_op") + and isinstance(n.owner.op.scalar_op, aes.Pow) + ): + base_node = n.owner.inputs[0] + # exponent is at n.owner.inputs[1], but we need to store the full node + # in case this particular power node remains alone and can't be rewritten + pow_nodes[base_node].append(n) + else: + rest.append(n) + + # Can only do any rewrite if there are at least two pow-s with the same base + can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2] + if len(can_rewrite) >= 1: + # Mul -> add; TrueDiv -> sub + orig_op, new_op = mul, add + if isinstance(node.op.scalar_op, aes.TrueDiv): + orig_op, new_op = true_div, sub + pow_factors = [] + # Rewrite pow-s having the same base for each different base + # E.g.: a^x * a^y --> a^(x+y) + for base in can_rewrite: + exponents = [n.owner.inputs[1] for n in pow_nodes[base]] + new_node = base ** new_op(*exponents) + if new_node.dtype != node.outputs[0].dtype: + new_node = cast(new_node, dtype=node.outputs[0].dtype) + pow_factors.append(new_node) + # Don't forget about those sole pow-s that couldn't be rewriten + sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite] + # Combine the rewritten pow-s and other, non-pow factors of the original Mul + # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v + if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0: + new_out = orig_op(*pow_factors, *sole_pows, *rest) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + else: + # if all factors of the original mul were pows-s with the same base, + # we can get rid of the mul completely. + new_out = pow_factors[0] + return [new_out] + + @register_stabilize @register_specialize @register_canonicalize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index c41549b12b..fc0620cea9 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4105,6 +4105,64 @@ def test_local_mulexp2expadd(): assert isinstance(graph[0].inputs[0], TensorConstant) +def test_local_mulpow2powadd(): + x = scalar("x") + y = scalar("y") + z = scalar("z") + w = scalar("w") + v = scalar("v") + u = scalar("u") + t = scalar("t") + s = scalar("s") + a = scalar("a") + b = scalar("b") + c = scalar("c") + + # 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w) + op = 2**x * 2**y * 2**z * 2**w + f = function([x, y, z, w], op) + utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert any(isinstance(n.op, aes.Add) for n in inner_graph) + assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + + # 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s + op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t + f = function([x, y, z, w, v, u, t, s, a, b, c], op) + utt.assert_allclose( + f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5), + 2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11, + ) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert len([True for n in inner_graph if isinstance(n.op, aes.Add)]) == 3 + assert len([True for n in inner_graph if isinstance(n.op, aes.Pow)]) == 4 + assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + + # (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w) + op = 2**x / 2**y * (a**z / a**w) + f = function([x, y, z, w, a], op) + utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert len([True for n in inner_graph if isinstance(n.op, aes.Sub)]) == 2 + assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + + # a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w) + op = a**x * a**y * exp(z) * exp(w) + f = function([x, y, z, w, a], op) + utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6)) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].op, Elemwise) + inner_graph = graph[0].op.scalar_op.fgraph.toposort() + assert len([True for n in inner_graph if isinstance(n.op, aes.Add)]) == 2 + assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + + def test_local_expm1(): x = matrix("x") u = scalar("u") From 799722b8ae1bf55df37de30fbef73af582ccf637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Thu, 12 Jan 2023 20:08:53 +0100 Subject: [PATCH 4/8] pytensor-54: Rename functions according to naming conventions. Removed a redundant check. Moved import statement to top of file. --- pytensor/tensor/rewriting/math.py | 15 +++++---------- tests/tensor/rewriting/test_math.py | 4 ++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 5e1641dccd..928ec74f1b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2,6 +2,7 @@ import itertools import operator +from collections import defaultdict from functools import partial, reduce import numpy as np @@ -425,14 +426,12 @@ def local_sumsqr2dot(fgraph, node): @register_specialize @node_rewriter([mul, true_div]) -def local_mulexp2expadd(fgraph, node): +def local_mul_exp_to_exp_add(fgraph, node): """ This rewrite detects e^x * e^y and converts it to e^(x+y). Similarly, e^x / e^y becomes e^(x-y). """ - if isinstance(node.op, Elemwise) and isinstance( - node.op.scalar_op, (aes.Mul, aes.TrueDiv) - ): + if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)): exps = [ n.owner.inputs[0] for n in node.inputs @@ -468,16 +467,12 @@ def local_mulexp2expadd(fgraph, node): @register_specialize @node_rewriter([mul, true_div]) -def local_mulpow2powadd(fgraph, node): +def local_mul_pow_to_pow_add(fgraph, node): """ This rewrite detects a^x * a^y and converts it to a^(x+y). Similarly, a^x / a^y becomes a^(x-y). """ - if isinstance(node.op, Elemwise) and isinstance( - node.op.scalar_op, (aes.Mul, aes.TrueDiv) - ): - from collections import defaultdict - + if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)): # search for pow-s and group them by their bases pow_nodes = defaultdict(list) rest = [] diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index fc0620cea9..da1cf8ec7e 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4014,7 +4014,7 @@ def test_local_sumsqr2dot(): ) -def test_local_mulexp2expadd(): +def test_local_mul_exp_to_exp_add(): x = scalar("x") y = scalar("y") z = scalar("z") @@ -4105,7 +4105,7 @@ def test_local_mulexp2expadd(): assert isinstance(graph[0].inputs[0], TensorConstant) -def test_local_mulpow2powadd(): +def test_local_mul_pow_to_pow_add(): x = scalar("x") y = scalar("y") z = scalar("z") From 426f0c0277cb1f0fe8d394e712df4fb9af31fb6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Thu, 12 Jan 2023 21:40:06 +0100 Subject: [PATCH 5/8] pytensor-54: Removed yet another redundant check --- pytensor/tensor/rewriting/math.py | 146 +++++++++++++++--------------- 1 file changed, 72 insertions(+), 74 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 928ec74f1b..7cf495b56a 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -431,38 +431,37 @@ def local_mul_exp_to_exp_add(fgraph, node): This rewrite detects e^x * e^y and converts it to e^(x+y). Similarly, e^x / e^y becomes e^(x-y). """ - if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)): - exps = [ - n.owner.inputs[0] + exps = [ + n.owner.inputs[0] + for n in node.inputs + if n.owner + and hasattr(n.owner.op, "scalar_op") + and isinstance(n.owner.op.scalar_op, aes.Exp) + ] + # Can only do any rewrite if there are at least two exp-s + if len(exps) >= 2: + # Mul -> add; TrueDiv -> sub + orig_op, new_op = mul, add + if isinstance(node.op.scalar_op, aes.TrueDiv): + orig_op, new_op = true_div, sub + new_out = exp(new_op(*exps)) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + # The original Mul may have more than two factors, some of which may not be exp nodes. + # If so, we keep multiplying them with the new exp(sum) node. + # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w + rest = [ + n for n in node.inputs - if n.owner - and hasattr(n.owner.op, "scalar_op") - and isinstance(n.owner.op.scalar_op, aes.Exp) + if not n.owner + or not hasattr(n.owner.op, "scalar_op") + or not isinstance(n.owner.op.scalar_op, aes.Exp) ] - # Can only do any rewrite if there are at least two exp-s - if len(exps) >= 2: - # Mul -> add; TrueDiv -> sub - orig_op, new_op = mul, add - if isinstance(node.op.scalar_op, aes.TrueDiv): - orig_op, new_op = true_div, sub - new_out = exp(new_op(*exps)) + if len(rest) > 0: + new_out = orig_op(new_out, *rest) if new_out.dtype != node.outputs[0].dtype: new_out = cast(new_out, dtype=node.outputs[0].dtype) - # The original Mul may have more than two factors, some of which may not be exp nodes. - # If so, we keep multiplying them with the new exp(sum) node. - # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w - rest = [ - n - for n in node.inputs - if not n.owner - or not hasattr(n.owner.op, "scalar_op") - or not isinstance(n.owner.op.scalar_op, aes.Exp) - ] - if len(rest) > 0: - new_out = orig_op(new_out, *rest) - if new_out.dtype != node.outputs[0].dtype: - new_out = cast(new_out, dtype=node.outputs[0].dtype) - return [new_out] + return [new_out] @register_specialize @@ -472,52 +471,51 @@ def local_mul_pow_to_pow_add(fgraph, node): This rewrite detects a^x * a^y and converts it to a^(x+y). Similarly, a^x / a^y becomes a^(x-y). """ - if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)): - # search for pow-s and group them by their bases - pow_nodes = defaultdict(list) - rest = [] - for n in node.inputs: - if ( - n.owner - and hasattr(n.owner.op, "scalar_op") - and isinstance(n.owner.op.scalar_op, aes.Pow) - ): - base_node = n.owner.inputs[0] - # exponent is at n.owner.inputs[1], but we need to store the full node - # in case this particular power node remains alone and can't be rewritten - pow_nodes[base_node].append(n) - else: - rest.append(n) - - # Can only do any rewrite if there are at least two pow-s with the same base - can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2] - if len(can_rewrite) >= 1: - # Mul -> add; TrueDiv -> sub - orig_op, new_op = mul, add - if isinstance(node.op.scalar_op, aes.TrueDiv): - orig_op, new_op = true_div, sub - pow_factors = [] - # Rewrite pow-s having the same base for each different base - # E.g.: a^x * a^y --> a^(x+y) - for base in can_rewrite: - exponents = [n.owner.inputs[1] for n in pow_nodes[base]] - new_node = base ** new_op(*exponents) - if new_node.dtype != node.outputs[0].dtype: - new_node = cast(new_node, dtype=node.outputs[0].dtype) - pow_factors.append(new_node) - # Don't forget about those sole pow-s that couldn't be rewriten - sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite] - # Combine the rewritten pow-s and other, non-pow factors of the original Mul - # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v - if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0: - new_out = orig_op(*pow_factors, *sole_pows, *rest) - if new_out.dtype != node.outputs[0].dtype: - new_out = cast(new_out, dtype=node.outputs[0].dtype) - else: - # if all factors of the original mul were pows-s with the same base, - # we can get rid of the mul completely. - new_out = pow_factors[0] - return [new_out] + # search for pow-s and group them by their bases + pow_nodes = defaultdict(list) + rest = [] + for n in node.inputs: + if ( + n.owner + and hasattr(n.owner.op, "scalar_op") + and isinstance(n.owner.op.scalar_op, aes.Pow) + ): + base_node = n.owner.inputs[0] + # exponent is at n.owner.inputs[1], but we need to store the full node + # in case this particular power node remains alone and can't be rewritten + pow_nodes[base_node].append(n) + else: + rest.append(n) + + # Can only do any rewrite if there are at least two pow-s with the same base + can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2] + if len(can_rewrite) >= 1: + # Mul -> add; TrueDiv -> sub + orig_op, new_op = mul, add + if isinstance(node.op.scalar_op, aes.TrueDiv): + orig_op, new_op = true_div, sub + pow_factors = [] + # Rewrite pow-s having the same base for each different base + # E.g.: a^x * a^y --> a^(x+y) + for base in can_rewrite: + exponents = [n.owner.inputs[1] for n in pow_nodes[base]] + new_node = base ** new_op(*exponents) + if new_node.dtype != node.outputs[0].dtype: + new_node = cast(new_node, dtype=node.outputs[0].dtype) + pow_factors.append(new_node) + # Don't forget about those sole pow-s that couldn't be rewriten + sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite] + # Combine the rewritten pow-s and other, non-pow factors of the original Mul + # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v + if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0: + new_out = orig_op(*pow_factors, *sole_pows, *rest) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + else: + # if all factors of the original mul were pows-s with the same base, + # we can get rid of the mul completely. + new_out = pow_factors[0] + return [new_out] @register_stabilize From 5ac8d92e4967f5e547c44639bfe952eb4a151800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Wed, 1 Feb 2023 12:45:39 +0100 Subject: [PATCH 6/8] pytensor-54: Exempt fusion rewrites in the test cases of local_mul_exp_to_exp_add and local_mul_pow_to_pow_add, so that the checks in the test cases also work in FAST_COMPILE mode --- tests/tensor/rewriting/test_math.py | 114 ++++++++++++++-------------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index da1cf8ec7e..ef99e0dabd 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4015,6 +4015,11 @@ def test_local_sumsqr2dot(): def test_local_mul_exp_to_exp_add(): + # Default and FAST_RUN modes put a Composite op into the final graph, + # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs, + # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites + mode = get_default_mode().excluding("fusion") + x = scalar("x") y = scalar("y") z = scalar("z") @@ -4026,86 +4031,85 @@ def test_local_mul_exp_to_exp_add(): # e^x * e^y * e^z * e^w = e^(x+y+z+w) op = expx * expy * expz * expw - f = function([x, y, z, w], op) + f = function([x, y, z, w], op, mode) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Add) for n in inner_graph) - assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) # e^x * e^y * e^z / e^w = e^(x+y+z-w) op = expx * expy * expz / expw - f = function([x, y, z, w], op) + f = function([x, y, z, w], op, mode) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Add) for n in inner_graph) - assert any(isinstance(n.op, aes.Sub) for n in inner_graph) - assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) - assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph) # e^x * e^y / e^z * e^w = e^(x+y-z+w) op = expx * expy / expz * expw - f = function([x, y, z, w], op) + f = function([x, y, z, w], op, mode) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Add) for n in inner_graph) - assert any(isinstance(n.op, aes.Sub) for n in inner_graph) - assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) - assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph) # e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z) op = expx / expy / expz - f = function([x, y, z], op) + f = function([x, y, z], op, mode) utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Sub) for n in inner_graph) - assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph) # e^x * y * e^z * w = e^(x+z) * y * w op = expx * y * expz * w - f = function([x, y, z, w], op) + f = function([x, y, z, w], op, mode) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Add) for n in inner_graph) - assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) # expect same for matrices as well mx = matrix("mx") my = matrix("my") - f = function([mx, my], exp(mx) * exp(my)) + f = function([mx, my], exp(mx) * exp(my), mode) M1 = np.array([[1.0, 2.0], [3.0, 4.0]]) M2 = np.array([[5.0, 6.0], [7.0, 8.0]]) utt.assert_allclose(f(M1, M2), np.exp(M1 + M2)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Add) for n in inner_graph) - assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) # checking whether further rewrites can proceed after this one as one would expect # e^x * e^(-x) = e^(x-x) = e^0 = 1 - f = function([x], expx * exp(neg(x))) + f = function([x], expx * exp(neg(x)), mode) utt.assert_allclose(f(42), 1) graph = f.maker.fgraph.toposort() assert isinstance(graph[0].inputs[0], TensorConstant) # e^x / e^x = e^(x-x) = e^0 = 1 - f = function([x], expx / expx) + f = function([x], expx / expx, mode) utt.assert_allclose(f(42), 1) graph = f.maker.fgraph.toposort() assert isinstance(graph[0].inputs[0], TensorConstant) def test_local_mul_pow_to_pow_add(): + # Default and FAST_RUN modes put a Composite op into the final graph, + # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs, + # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites + mode = get_default_mode().excluding("fusion") + x = scalar("x") y = scalar("y") z = scalar("z") @@ -4120,47 +4124,43 @@ def test_local_mul_pow_to_pow_add(): # 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w) op = 2**x * 2**y * 2**z * 2**w - f = function([x, y, z, w], op) + f = function([x, y, z, w], op, mode) utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert any(isinstance(n.op, aes.Add) for n in inner_graph) - assert not any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) # 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t - f = function([x, y, z, w, v, u, t, s, a, b, c], op) + f = function([x, y, z, w, v, u, t, s, a, b, c], op, mode) utt.assert_allclose( f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5), 2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11, ) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert len([True for n in inner_graph if isinstance(n.op, aes.Add)]) == 3 - assert len([True for n in inner_graph if isinstance(n.op, aes.Pow)]) == 4 - assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 3 + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Pow)]) == 4 + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) # (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w) op = 2**x / 2**y * (a**z / a**w) - f = function([x, y, z, w, a], op) + f = function([x, y, z, w, a], op, mode) utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert len([True for n in inner_graph if isinstance(n.op, aes.Sub)]) == 2 - assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Sub)]) == 2 + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) # a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w) op = a**x * a**y * exp(z) * exp(w) - f = function([x, y, z, w, a], op) + f = function([x, y, z, w, a], op, mode) utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6)) graph = f.maker.fgraph.toposort() - assert isinstance(graph[0].op, Elemwise) - inner_graph = graph[0].op.scalar_op.fgraph.toposort() - assert len([True for n in inner_graph if isinstance(n.op, aes.Add)]) == 2 - assert any(isinstance(n.op, aes.Mul) for n in inner_graph) + assert all(isinstance(n.op, Elemwise) for n in graph) + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 2 + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) def test_local_expm1(): From db41813e1ebe878a59076fa2be702ca6bfa2fa1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Wed, 1 Feb 2023 13:50:49 +0100 Subject: [PATCH 7/8] pytensor-54: Attempt to fix test failing in float32 mode due to implicit downcast of testing constants --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index ef99e0dabd..11a2eb540c 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4081,7 +4081,7 @@ def test_local_mul_exp_to_exp_add(): # expect same for matrices as well mx = matrix("mx") my = matrix("my") - f = function([mx, my], exp(mx) * exp(my), mode) + f = function([mx, my], exp(mx) * exp(my), mode, allow_input_downcast=True) M1 = np.array([[1.0, 2.0], [3.0, 4.0]]) M2 = np.array([[5.0, 6.0], [7.0, 8.0]]) utt.assert_allclose(f(M1, M2), np.exp(M1 + M2)) From 28d1ca63b90ce07b8edbf7fb41662ef65c41b3df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C5=91k=C3=A9s?= Date: Wed, 1 Feb 2023 22:58:17 +0100 Subject: [PATCH 8/8] pytensor-54: Explicitly include the tested rewrites in the tests, so that the FAST_COMPILE test runs also work properly --- tests/tensor/rewriting/test_math.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 11a2eb540c..74a6624077 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4018,7 +4018,7 @@ def test_local_mul_exp_to_exp_add(): # Default and FAST_RUN modes put a Composite op into the final graph, # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs, # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites - mode = get_default_mode().excluding("fusion") + mode = get_default_mode().excluding("fusion").including("local_mul_exp_to_exp_add") x = scalar("x") y = scalar("y") @@ -4032,6 +4032,7 @@ def test_local_mul_exp_to_exp_add(): # e^x * e^y * e^z * e^w = e^(x+y+z+w) op = expx * expy * expz * expw f = function([x, y, z, w], op, mode) + pytensor.dprint(f) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6)) graph = f.maker.fgraph.toposort() assert all(isinstance(n.op, Elemwise) for n in graph) @@ -4108,7 +4109,12 @@ def test_local_mul_pow_to_pow_add(): # Default and FAST_RUN modes put a Composite op into the final graph, # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs, # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites - mode = get_default_mode().excluding("fusion") + mode = ( + get_default_mode() + .excluding("fusion") + .including("local_mul_exp_to_exp_add") + .including("local_mul_pow_to_pow_add") + ) x = scalar("x") y = scalar("y")