File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed
Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -33,16 +33,22 @@ def rewrite_random_normal(g, ops):
3333 match_results = list (matcher .match_ops (ops ))
3434 for match in match_results :
3535 output = match .get_op ('output' )
36- if output .type == 'Add' :
36+ input2 = match .get_op ('input2' )
37+ is_output = False
38+ for output_name in g .outputs :
39+ # input2 and output can not be output node.
40+ if input2 .name in output_name or output .name in output_name :
41+ is_output = True
42+ break
43+ if is_output :
44+ continue
45+ if output .type == 'Add' and input2 .type == 'Mul' :
3746 # pattern 1
3847 mean = output .inputs [1 ].get_tensor_value ()
48+ scale = input2 .inputs [1 ].get_tensor_value ()
3949 else :
4050 # pattern 2
4151 mean = 0.0
42- input2 = match .get_op ('input2' )
43- if input2 .type == 'Mul' :
44- scale = input2 .inputs [1 ].get_tensor_value ()
45- else :
4652 scale = 1.0
4753 dtype = g .get_dtype (output .output [0 ])
4854 op_name = utils .make_name ("RandomNormal" )
You can’t perform that action at this time.
0 commit comments