@@ -1859,6 +1859,7 @@ def mkconsts(values, dtype=np.int64):
18591859 const_zero_float , const_neg_one_float = mkconsts ([[0 ], [- 1 ]], np .float32 )
18601860 const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
18611861 mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
1862+ const_zero_scalar , const_one_scalar , const_neg_one_scalar = mkconsts ([0 , 1 , - 1 ])
18621863
18631864 m_shape = ctx .make_node ('Shape' , [node .input [0 ]]).output [0 ]
18641865 xlen = ctx .make_node ('Gather' , [m_shape , const_neg_one ]).output [0 ]
@@ -1882,24 +1883,24 @@ def mkconsts(values, dtype=np.int64):
18821883 input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
18831884 k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]).output [0 ]
18841885 k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]).output [0 ]
1886+ k0_scalar = ctx .make_node ('Squeeze' , [k0 ]).output [0 ]
18851887 k1_scalar = ctx .make_node ('Squeeze' , [k1 ]).output [0 ]
18861888 m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
18871889
18881890 # starting indexes for super diagonals
1889- xstart_0 = ctx .make_node ('Cast' , [k0 ], attr = {'to' : TensorProto .FLOAT })
1891+ xstart_0 = ctx .make_node ('Cast' , [k0_scalar ], attr = {'to' : TensorProto .FLOAT })
18901892 xstart_1 = ctx .make_node ('Max' , [const_zero_float , xstart_0 .output [0 ]])
18911893 xstart_2 = ctx .make_node ('Cast' , [xstart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1892- xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one ])
1893- xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one ])
1894+ xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one_scalar ])
1895+ xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one_scalar ])
18941896 xstart = ctx .make_node ('Reshape' , [xstart_4 .output [0 ], const_t ])
18951897
18961898 # starting indexes for sub diagonals
1897- ystart_0 = ctx .make_node ('Cast' , [k1 ], attr = {'to' : TensorProto .FLOAT })
1899+ ystart_0 = ctx .make_node ('Cast' , [k1_scalar ], attr = {'to' : TensorProto .FLOAT })
18981900 ystart_1 = ctx .make_node ('Min' , [const_neg_one_float , ystart_0 .output [0 ]])
18991901 ystart_2 = ctx .make_node ('Cast' , [ystart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1900- ystart_2_scalar = ctx .make_node ('Squeeze' , [ystart_2 .output [0 ]])
1901- ystart_3 = ctx .make_node ('Add' , [k0 , const_neg_one ])
1902- ystart_4 = ctx .make_node ('Range' , [ystart_2_scalar .output [0 ], ystart_3 .output [0 ], const_neg_one ])
1902+ ystart_3 = ctx .make_node ('Add' , [k0_scalar , const_neg_one_scalar ])
1903+ ystart_4 = ctx .make_node ('Range' , [ystart_2 .output [0 ], ystart_3 .output [0 ], const_neg_one_scalar ])
19031904 ystart = ctx .make_node ('Reshape' , [ystart_4 .output [0 ], const_t ])
19041905
19051906 xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], xlenp ])
@@ -1920,7 +1921,7 @@ def mkconsts(values, dtype=np.int64):
19201921 maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
19211922 maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
19221923
1923- diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
1924+ diagdistances_0 = ctx .make_node ('Range' , [const_zero_scalar , maxsize_scalar .output [0 ], const_one_scalar ])
19241925 diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], stride ])
19251926
19261927 def right_align (sizes , indices , starts , maxval ):
@@ -1976,7 +1977,7 @@ def compute_out_shape(k0_k1_same=False):
19761977 if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
19771978 if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
19781979
1979- shapes = [ - 1 ] * m_rank
1980+ shapes = ctx . get_shape ( node . output [ 0 ])
19801981 dtypes = node .output_dtypes
19811982 ctx .remove_node (node .name )
19821983 ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
0 commit comments