@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126
126
return consts , origconsts , nonconsts
127
127
128
128
129
- def get_constant (v ):
130
- """
131
-
132
- Returns
133
- -------
134
- object
135
- A numeric constant if v is a Constant or, well, a
136
- numeric constant. If v is a plain Variable, returns None.
137
-
138
- """
139
- if isinstance (v , TensorConstant ):
140
- return v .unique_value
141
- elif isinstance (v , Variable ):
142
- return None
143
- else :
144
- return v
145
-
146
-
147
129
@register_canonicalize
148
130
@register_stabilize
149
131
@node_rewriter ([Dot ])
@@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
994
976
"""
995
977
Find all constants and put them together into a single constant.
996
978
997
- Finds all constants in orig_num and orig_denum (using
998
- get_constant) and puts them together into a single
979
+ Finds all constants in orig_num and orig_denum
980
+ and puts them together into a single
999
981
constant. The constant is inserted as the first element of the
1000
982
numerator. If the constant is the neutral element, it is
1001
983
removed from the numerator.
@@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1016
998
numct , denumct = [], []
1017
999
1018
1000
for v in orig_num :
1019
- ct = get_constant (v )
1020
- if ct is not None :
1001
+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
1021
1002
# We found a constant in the numerator!
1022
1003
# We add it to numct
1023
- numct .append (ct )
1004
+ numct .append (v . unique_value )
1024
1005
else :
1025
1006
num .append (v )
1026
1007
for v in orig_denum :
1027
- ct = get_constant (v )
1028
- if ct is not None :
1029
- denumct .append (ct )
1008
+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
1009
+ denumct .append (v .unique_value )
1030
1010
else :
1031
1011
denum .append (v )
1032
1012
@@ -1050,10 +1030,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1050
1030
1051
1031
if orig_num and len (numct ) == 1 and len (denumct ) == 0 and ct :
1052
1032
# In that case we should only have one constant in `ct`.
1053
- assert len (ct ) == 1
1054
- first_num_ct = get_constant (orig_num [0 ])
1055
- if first_num_ct is not None and ct [0 ].type .values_eq (
1056
- ct [0 ].data , first_num_ct
1033
+ [var_ct ] = ct
1034
+ first_num_var = orig_num [0 ]
1035
+ first_num_ct = (
1036
+ first_num_var .unique_value
1037
+ if isinstance (first_num_var , TensorConstant )
1038
+ else None
1039
+ )
1040
+ if first_num_ct is not None and var_ct .type .values_eq (
1041
+ var_ct .data , first_num_ct
1057
1042
):
1058
1043
# This is an important trick :( if it so happens that:
1059
1044
# * there's exactly one constant on the numerator and none on
@@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node):
1840
1825
return [new_out ]
1841
1826
1842
1827
# Check if it is a negative constant
1843
- const = get_constant (second )
1844
- if const is not None and const < 0 :
1845
- new_out = sub (first , np .abs (const ))
1828
+ if (
1829
+ isinstance (second , TensorConstant )
1830
+ and second .unique_value is not None
1831
+ and second .unique_value < 0
1832
+ ):
1833
+ new_out = sub (first , np .abs (second .data ))
1846
1834
return [new_out ]
1847
1835
1848
1836
@@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node):
1871
1859
@register_specialize
1872
1860
@node_rewriter ([true_div ])
1873
1861
def local_div_to_reciprocal (fgraph , node ):
1874
- if np .all (get_constant (node .inputs [0 ]) == 1.0 ):
1862
+ if (
1863
+ get_underlying_scalar_constant_value (
1864
+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1865
+ )
1866
+ == 1.0
1867
+ ):
1875
1868
out = node .outputs [0 ]
1876
1869
new_out = reciprocal (local_mul_canonizer .merge_num_denum (node .inputs [1 :], []))
1877
1870
# The ones could have forced upcasting
@@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node):
1892
1885
@register_canonicalize
1893
1886
@node_rewriter ([pt_pow ])
1894
1887
def local_pow_canonicalize (fgraph , node ):
1895
- cst = get_constant (node .inputs [1 ])
1888
+ cst = get_underlying_scalar_constant_value (
1889
+ node .inputs [1 ], only_process_constants = True , raise_not_constant = False
1890
+ )
1896
1891
if cst == 0 :
1897
1892
return [alloc_like (1 , node .outputs [0 ], fgraph )]
1898
1893
if cst == 1 :
@@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node):
1923
1918
@node_rewriter ([int_div , true_div ])
1924
1919
def local_zero_div (fgraph , node ):
1925
1920
"""0 / x -> 0"""
1926
- if get_constant (node .inputs [0 ]) == 0 :
1921
+ if (
1922
+ get_underlying_scalar_constant_value (
1923
+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1924
+ )
1925
+ == 0
1926
+ ):
1927
1927
ret = alloc_like (0 , node .outputs [0 ], fgraph )
1928
1928
ret .tag .values_eq_approx = values_eq_approx_remove_nan
1929
1929
return [ret ]
@@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node):
1936
1936
odtype = node .outputs [0 ].dtype
1937
1937
xsym = node .inputs [0 ]
1938
1938
ysym = node .inputs [1 ]
1939
- y = get_constant (ysym )
1940
- if (y is not None ) and not broadcasted_by (xsym , ysym ):
1939
+ try :
1940
+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1941
+ except NotScalarConstantError :
1942
+ return
1943
+
1944
+ if not broadcasted_by (xsym , ysym ):
1941
1945
rval = None
1942
1946
1943
1947
if np .all (y == 2 ):
@@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node):
1971
1975
"""
1972
1976
1973
1977
# the idea here is that we have pow(x, y)
1978
+ xsym , ysym = node .inputs
1979
+
1980
+ try :
1981
+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1982
+ except NotScalarConstantError :
1983
+ return
1984
+
1974
1985
odtype = node .outputs [0 ].dtype
1975
- xsym = node .inputs [0 ]
1976
- ysym = node .inputs [1 ]
1977
- y = get_constant (ysym )
1978
1986
1979
1987
# the next line is needed to fix a strange case that I don't
1980
1988
# know how to make a separate test.
@@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node):
1990
1998
y = y [0 ]
1991
1999
except IndexError :
1992
2000
pass
1993
- if ( y is not None ) and not broadcasted_by (xsym , ysym ):
2001
+ if not broadcasted_by (xsym , ysym ):
1994
2002
rval = None
1995
2003
# 512 is too small for the cpu and too big for some gpu!
1996
2004
if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
@@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node):
2057
2065
nb_neg_node += 1
2058
2066
2059
2067
# remove special case arguments of 1, -1 or 0
2060
- y = get_constant (inp )
2068
+ y = get_underlying_scalar_constant_value (
2069
+ inp , only_process_constants = True , raise_not_constant = False
2070
+ )
2061
2071
if y == 1.0 :
2062
2072
nb_cst += 1
2063
2073
elif y == - 1.0 :
0 commit comments