Skip to content

Commit 8dcbc2b

Browse files
committed
Remove internal get_constant helper
Fixes bug in `local_add_neg_to_sub` reported in #584
1 parent a2521ad commit 8dcbc2b

File tree

2 files changed

+59
-47
lines changed

2 files changed

+59
-47
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126126
return consts, origconsts, nonconsts
127127

128128

129-
def get_constant(v):
130-
"""
131-
132-
Returns
133-
-------
134-
object
135-
A numeric constant if v is a Constant or, well, a
136-
numeric constant. If v is a plain Variable, returns None.
137-
138-
"""
139-
if isinstance(v, TensorConstant):
140-
return v.unique_value
141-
elif isinstance(v, Variable):
142-
return None
143-
else:
144-
return v
145-
146-
147129
@register_canonicalize
148130
@register_stabilize
149131
@node_rewriter([Dot])
@@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
994976
"""
995977
Find all constants and put them together into a single constant.
996978
997-
Finds all constants in orig_num and orig_denum (using
998-
get_constant) and puts them together into a single
979+
Finds all constants in orig_num and orig_denum
980+
and puts them together into a single
999981
constant. The constant is inserted as the first element of the
1000982
numerator. If the constant is the neutral element, it is
1001983
removed from the numerator.
@@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1016998
numct, denumct = [], []
1017999

10181000
for v in orig_num:
1019-
ct = get_constant(v)
1020-
if ct is not None:
1001+
if isinstance(v, TensorConstant) and v.unique_value is not None:
10211002
# We found a constant in the numerator!
10221003
# We add it to numct
1023-
numct.append(ct)
1004+
numct.append(v.unique_value)
10241005
else:
10251006
num.append(v)
10261007
for v in orig_denum:
1027-
ct = get_constant(v)
1028-
if ct is not None:
1029-
denumct.append(ct)
1008+
if isinstance(v, TensorConstant) and v.unique_value is not None:
1009+
denumct.append(v.unique_value)
10301010
else:
10311011
denum.append(v)
10321012

@@ -1050,10 +1030,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10501030

10511031
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
10521032
# In that case we should only have one constant in `ct`.
1053-
assert len(ct) == 1
1054-
first_num_ct = get_constant(orig_num[0])
1055-
if first_num_ct is not None and ct[0].type.values_eq(
1056-
ct[0].data, first_num_ct
1033+
[var_ct] = ct
1034+
first_num_var = orig_num[0]
1035+
first_num_ct = (
1036+
first_num_var.unique_value
1037+
if isinstance(first_num_var, TensorConstant)
1038+
else None
1039+
)
1040+
if first_num_ct is not None and var_ct.type.values_eq(
1041+
var_ct.data, first_num_ct
10571042
):
10581043
# This is an important trick :( if it so happens that:
10591044
# * there's exactly one constant on the numerator and none on
@@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node):
18401825
return [new_out]
18411826

18421827
# Check if it is a negative constant
1843-
const = get_constant(second)
1844-
if const is not None and const < 0:
1845-
new_out = sub(first, np.abs(const))
1828+
if (
1829+
isinstance(second, TensorConstant)
1830+
and second.unique_value is not None
1831+
and second.unique_value < 0
1832+
):
1833+
new_out = sub(first, np.abs(second.data))
18461834
return [new_out]
18471835

18481836

@@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node):
18711859
@register_specialize
18721860
@node_rewriter([true_div])
18731861
def local_div_to_reciprocal(fgraph, node):
1874-
if np.all(get_constant(node.inputs[0]) == 1.0):
1862+
if (
1863+
get_underlying_scalar_constant_value(
1864+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1865+
)
1866+
== 1.0
1867+
):
18751868
out = node.outputs[0]
18761869
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
18771870
# The ones could have forced upcasting
@@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node):
18921885
@register_canonicalize
18931886
@node_rewriter([pt_pow])
18941887
def local_pow_canonicalize(fgraph, node):
1895-
cst = get_constant(node.inputs[1])
1888+
cst = get_underlying_scalar_constant_value(
1889+
node.inputs[1], only_process_constants=True, raise_not_constant=False
1890+
)
18961891
if cst == 0:
18971892
return [alloc_like(1, node.outputs[0], fgraph)]
18981893
if cst == 1:
@@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node):
19231918
@node_rewriter([int_div, true_div])
19241919
def local_zero_div(fgraph, node):
19251920
"""0 / x -> 0"""
1926-
if get_constant(node.inputs[0]) == 0:
1921+
if (
1922+
get_underlying_scalar_constant_value(
1923+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1924+
)
1925+
== 0
1926+
):
19271927
ret = alloc_like(0, node.outputs[0], fgraph)
19281928
ret.tag.values_eq_approx = values_eq_approx_remove_nan
19291929
return [ret]
@@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node):
19361936
odtype = node.outputs[0].dtype
19371937
xsym = node.inputs[0]
19381938
ysym = node.inputs[1]
1939-
y = get_constant(ysym)
1940-
if (y is not None) and not broadcasted_by(xsym, ysym):
1939+
try:
1940+
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
1941+
except NotScalarConstantError:
1942+
return
1943+
1944+
if not broadcasted_by(xsym, ysym):
19411945
rval = None
19421946

19431947
if np.all(y == 2):
@@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node):
19711975
"""
19721976

19731977
# the idea here is that we have pow(x, y)
1978+
xsym, ysym = node.inputs
1979+
1980+
try:
1981+
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
1982+
except NotScalarConstantError:
1983+
return
1984+
19741985
odtype = node.outputs[0].dtype
1975-
xsym = node.inputs[0]
1976-
ysym = node.inputs[1]
1977-
y = get_constant(ysym)
19781986

19791987
# the next line is needed to fix a strange case that I don't
19801988
# know how to make a separate test.
@@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node):
19901998
y = y[0]
19911999
except IndexError:
19922000
pass
1993-
if (y is not None) and not broadcasted_by(xsym, ysym):
2001+
if not broadcasted_by(xsym, ysym):
19942002
rval = None
19952003
# 512 is too small for the cpu and too big for some gpu!
19962004
if abs(y) == int(abs(y)) and abs(y) <= 512:
@@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node):
20572065
nb_neg_node += 1
20582066

20592067
# remove special case arguments of 1, -1 or 0
2060-
y = get_constant(inp)
2068+
y = get_underlying_scalar_constant_value(
2069+
inp, only_process_constants=True, raise_not_constant=False
2070+
)
20612071
if y == 1.0:
20622072
nb_cst += 1
20632073
elif y == -1.0:

tests/tensor/rewriting/test_math.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4440,16 +4440,18 @@ def test_local_add_neg_to_sub(first_negative):
44404440
assert np.allclose(f(x_test, y_test), exp)
44414441

44424442

4443-
def test_local_add_neg_to_sub_const():
4443+
@pytest.mark.parametrize("const_left", (True, False))
4444+
def test_local_add_neg_to_sub_const(const_left):
44444445
x = vector("x")
4445-
const = 5.0
4446+
const = np.full((3, 2), 5.0)
4447+
out = -const + x if const_left else x + (-const)
44464448

4447-
f = function([x], x + (-const), mode=Mode("py"))
4449+
f = function([x], out, mode=Mode("py"))
44484450

44494451
nodes = [
44504452
node.op
44514453
for node in f.maker.fgraph.toposort()
4452-
if not isinstance(node.op, DimShuffle)
4454+
if not isinstance(node.op, DimShuffle | Alloc)
44534455
]
44544456
assert nodes == [pt.sub]
44554457

0 commit comments

Comments
 (0)