@@ -30,30 +30,31 @@ class REWRITER_RESULT(Enum):
3030
3131
3232# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
33- xc_pattern = OpTypePattern ('Split' , inputs = [
34- OpTypePattern ("Const" ), # axis for split
35- OpTypePattern ("BiasAdd" , name = "bias_add" , inputs = [
36- OpTypePattern ("MatMul" , inputs = [
37- OpTypePattern ("ConcatV2|Concat" , name = "xh" ),
33+
34+ xc_pattern = \
35+ OpTypePattern ('Split' , inputs = [
36+ OpTypePattern ("Const" ), # axis for split
37+ OpTypePattern ("BiasAdd" , name = "bias_add" , inputs = [
38+ OpTypePattern ("MatMul" , inputs = [
39+ OpTypePattern ("ConcatV2|Concat" , name = "xh" ),
40+ OpTypePattern ("Enter" , inputs = [
41+ OpTypePattern ("*" , name = "cell_kernel" ),
42+ ]),
43+ ]),
3844 OpTypePattern ("Enter" , inputs = [
39- OpTypePattern ("*" , name = "cell_kernel " ),
45+ OpTypePattern ("*" , name = "cell_bias " ),
4046 ]),
4147 ]),
42- OpTypePattern ("Enter" , inputs = [
43- OpTypePattern ("*" , name = "cell_bias" ),
44- ]),
45- ]),
46- ])
47-
48+ ])
4849
4950lstmcell_pattern = \
5051 OpTypePattern ('Mul' , name = 'ht' , inputs = [
5152 OpTypePattern ("Sigmoid" , name = "ot" , inputs = [xc_pattern ]),
5253 OpTypePattern ('Tanh' , inputs = [
53- OpTypePattern ("Add" , name = "ct" , inputs = [
54+ OpTypePattern ("Add|AddV2 " , name = "ct" , inputs = [
5455 OpTypePattern ("Mul" , name = "ct_identity_consumer" , inputs = [
5556 OpTypePattern ("Sigmoid" , name = "ft" , inputs = [
56- OpTypePattern ("Add" , inputs = [
57+ OpTypePattern ("Add|AddV2 " , inputs = [
5758 xc_pattern ,
5859 OpTypePattern ("*" , name = "ft_bias" ),
5960 ]),
@@ -68,6 +69,39 @@ class REWRITER_RESULT(Enum):
6869 ]),
6970 ])
7071
72+ xc_pattern_optimized = \
73+ OpTypePattern ('Split' , inputs = [
74+ OpTypePattern ("Const" ),
75+ OpTypePattern ("Identity" , inputs = [
76+ OpTypePattern ("MatMul" , inputs = [
77+ OpTypePattern ("ConcatV2|Concat" , name = "xh" ),
78+ OpTypePattern ("Const" , name = "cell_kernel" ),
79+ ]),
80+ ]),
81+ ])
82+
83+ lstmcell_pattern_optimized = \
84+ OpTypePattern ('Mul' , name = 'ht' , inputs = [
85+ OpTypePattern ("Sigmoid" , name = "ot" , inputs = [xc_pattern_optimized ]),
86+ OpTypePattern ('Tanh' , inputs = [
87+ OpTypePattern ("Add|AddV2" , name = "ct" , inputs = [
88+ OpTypePattern ("Mul" , name = "ct_identity_consumer" , inputs = [
89+ OpTypePattern ("Sigmoid" , name = "ft" , inputs = [
90+ OpTypePattern ("Add|AddV2" , inputs = [
91+ xc_pattern_optimized ,
92+ OpTypePattern ("*" , name = "ft_bias" ),
93+ ]),
94+ ]),
95+ OpTypePattern ("*" ),
96+ ]),
97+ OpTypePattern ("Mul" , inputs = [
98+ OpTypePattern ("Sigmoid" , name = "it" , inputs = [xc_pattern_optimized ]),
99+ OpTypePattern ("Tanh" , name = "gt" , inputs = [xc_pattern_optimized ]),
100+ ]),
101+ ]),
102+ ]),
103+ ])
104+
71105# input sequence: top to down, left to right
72106# split into update gate and reset gate
73107gru_split_pattern = \
@@ -237,7 +271,7 @@ class RNNUnitType(Enum):
237271
238272
239273rnn_cell_patterns = {
240- RNNUnitType .LSTMCell : [lstmcell_pattern ],
274+ RNNUnitType .LSTMCell : [lstmcell_pattern , lstmcell_pattern_optimized ],
241275 RNNUnitType .LSTMBlockCell : [lstmblockcell_pattern ],
242276 RNNUnitType .GRUCell : [grucell_pattern ],
243277 RNNUnitType .GRUBlockCell : [grublockcell_pattern0 , grublockcell_pattern1 ],
0 commit comments