@@ -1835,94 +1835,98 @@ class MatrixDiagPartV2V3:
18351835 @classmethod
18361836 def version_11 (cls , ctx , node , ** kwargs ):
18371837
1838- def mkconst (npval , desc ):
1839- name = utils .make_name (node .name ) + f'_{ desc } '
1840- return ctx .make_const (name , npval ).output [0 ]
1838+ def mkconsts (values , dtype = np .int64 ):
1839+ ret = []
1840+ for value in values :
1841+ name = utils .make_name (node .name + '_const' )
1842+ ret .append (ctx .make_const (name , np .array (value , dtype = dtype )).output [0 ])
1843+ return ret
18411844
18421845 # assemble MatrixDiagPart V2&V3
18431846 m = node .input [0 ]
18441847 m_shape = ctx .get_shape (m )
1845- utils .make_sure (- 1 not in m_shape , 'At least one dim is unknown %s' , str (m_shape ))
1846-
1847- xlen = m_shape [- 1 ]
1848- ylen = m_shape [- 2 ]
1849- xlenp = xlen + 1
1850- pads = np .zeros (2 * len (m_shape ), dtype = np .int64 )
1848+ m_rank = len (m_shape )
1849+ pads = np .zeros (2 * m_rank , dtype = np .int64 )
18511850 pads [- 2 :] = [1 , 1 ]
1851+ utils .make_sure (m_rank > 1 , 'Input data should be at least 2D %s' , str (m_shape ))
18521852
18531853 align = 'LEFT_LEFT'
18541854 if node .op .op_type == 'MatrixDiagPartV3' :
18551855 align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
18561856 xalign , yalign = align .split ('_' )
18571857
18581858 # consts
1859- const_zero = mkconst (np .array ([0 ], np .int64 ), 'const_zero_dtype' )
1860- const_zero_float = mkconst (np .array ([0 ], np .float32 ), 'const_zero_dtype_f' )
1861- const_one = mkconst (np .array ([1 ], np .int64 ), 'const_one_dtype' )
1862- const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
1863- const_neg_one_float = mkconst (np .array ([- 1 ]).astype (np .float32 ), 'const_neg_one_f' )
1864- const_pad_vals = mkconst (pads , 'pads' )
1865- const_t = mkconst (np .array ([- 1 , 1 ], np .int64 ), 'const_t' )
1866- const_xlen = mkconst (np .array ([xlen ], np .int64 ), 'const_xlen' )
1867- const_ylen = mkconst (np .array ([ylen ], np .int64 ), 'const_ylen' )
1868- const_xlenp = mkconst (np .array ([xlenp ], np .int64 ), 'const_xlenp' )
1869- const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
1870- const_minxy_float = mkconst (np .array ([min (xlen , ylen )], np .float32 ), 'const_minxy_f' )
1871- const_xmax = mkconst (np .array ([xlen * xlenp + xlenp - 1 ], np .int64 ), 'const_xmax' )
1872- const_ymax = mkconst (np .array ([xlenp * ylen - 1 ], np .int64 ), 'const_ymax' )
1873- const_ymax_float = mkconst (np .array ([xlenp * ylen - 1 ], np .float32 ), 'const_ymax_f' )
1874- const_partial_shape = mkconst (np .asarray (m_shape [:- 2 ], np .int64 ), 'partial_shape' )
1875- const_m2_shape = mkconst (np .asarray (m_shape [:- 2 ] + [- 1 ], np .int64 ), 'm2_shape' )
1876- const_gather_shape = mkconst (np .asarray (m_shape [:- 2 ] + [1 ], np .int64 ), 'gather_shape' )
1859+ const_zero_float , const_neg_one_float = mkconsts ([[0 ], [- 1 ]], np .float32 )
1860+ const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
1861+ mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
1862+ const_zero_scalar , const_one_scalar , const_neg_one_scalar = mkconsts ([0 , 1 , - 1 ])
1863+
1864+ m_shape = ctx .make_node ('Shape' , [node .input [0 ]]).output [0 ]
1865+ xlen = ctx .make_node ('Gather' , [m_shape , const_neg_one ]).output [0 ]
1866+ ylen = ctx .make_node ('Gather' , [m_shape , const_neg_two ]).output [0 ]
1867+ xlenp = ctx .make_node ('Add' , [xlen , const_one ]).output [0 ]
1868+ stride = ctx .make_node ('Add' , [xlenp , const_one ]).output [0 ]
1869+ minxy_0 = ctx .make_node ('Concat' , [xlen , ylen ], attr = {'axis' : 0 }).output [0 ]
1870+ minxy = ctx .make_node ('ReduceMin' , [minxy_0 ]).output [0 ]
1871+ minxy_float = ctx .make_node ('Cast' , [minxy ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1872+ xmax_0 = ctx .make_node ('Mul' , [xlen , xlenp ]).output [0 ]
1873+ xmax_1 = ctx .make_node ('Add' , [xmax_0 , xlenp ]).output [0 ]
1874+ xmax = ctx .make_node ('Add' , [xmax_1 , const_neg_one ]).output [0 ]
1875+ ymax_0 = ctx .make_node ('Mul' , [xlenp , ylen ]).output [0 ]
1876+ ymax = ctx .make_node ('Add' , [ymax_0 , const_neg_one ]).output [0 ]
1877+ ymax_float = ctx .make_node ('Cast' , [ymax ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1878+ partial_shape = ctx .make_node ('Slice' , [m_shape , const_zero , const_neg_two ]).output [0 ]
1879+ m2_shape = ctx .make_node ('Concat' , [partial_shape , const_neg_one ], attr = {'axis' : 0 }).output [0 ]
1880+ gather_shape = ctx .make_node ('Concat' , [partial_shape , const_one ], attr = {'axis' : 0 }).output [0 ]
18771881
18781882 # get k0, k1 values. diags to be extracted
18791883 input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
1880- k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]])
1881- k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]])
1882- k1_scalar = ctx .make_node ('Squeeze' , [k1 .output [0 ]])
1884+ k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]).output [0 ]
1885+ k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]).output [0 ]
1886+ k0_scalar = ctx .make_node ('Squeeze' , [k0 ]).output [0 ]
1887+ k1_scalar = ctx .make_node ('Squeeze' , [k1 ]).output [0 ]
18831888 m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
18841889
18851890 # starting indexes for super diagonals
1886- xstart_0 = ctx .make_node ('Cast' , [k0 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1891+ xstart_0 = ctx .make_node ('Cast' , [k0_scalar ], attr = {'to' : TensorProto .FLOAT })
18871892 xstart_1 = ctx .make_node ('Max' , [const_zero_float , xstart_0 .output [0 ]])
18881893 xstart_2 = ctx .make_node ('Cast' , [xstart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1889- xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one ])
1890- xstart_4 = ctx .make_node ('Range' , [k1_scalar . output [ 0 ] , xstart_3 .output [0 ], const_neg_one ])
1894+ xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one_scalar ])
1895+ xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one_scalar ])
18911896 xstart = ctx .make_node ('Reshape' , [xstart_4 .output [0 ], const_t ])
18921897
18931898 # starting indexes for sub diagonals
1894- ystart_0 = ctx .make_node ('Cast' , [k1 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1899+ ystart_0 = ctx .make_node ('Cast' , [k1_scalar ], attr = {'to' : TensorProto .FLOAT })
18951900 ystart_1 = ctx .make_node ('Min' , [const_neg_one_float , ystart_0 .output [0 ]])
18961901 ystart_2 = ctx .make_node ('Cast' , [ystart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1897- ystart_2_scalar = ctx .make_node ('Squeeze' , [ystart_2 .output [0 ]])
1898- ystart_3 = ctx .make_node ('Add' , [k0 .output [0 ], const_neg_one ])
1899- ystart_4 = ctx .make_node ('Range' , [ystart_2_scalar .output [0 ], ystart_3 .output [0 ], const_neg_one ])
1902+ ystart_3 = ctx .make_node ('Add' , [k0_scalar , const_neg_one_scalar ])
1903+ ystart_4 = ctx .make_node ('Range' , [ystart_2 .output [0 ], ystart_3 .output [0 ], const_neg_one_scalar ])
19001904 ystart = ctx .make_node ('Reshape' , [ystart_4 .output [0 ], const_t ])
19011905
1902- xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], const_xlenp ])
1903- xmax = ctx .make_node ('Sub' , [const_xmax , xmax_0 .output [0 ]])
1906+ xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], xlenp ])
1907+ xmax = ctx .make_node ('Sub' , [xmax , xmax_0 .output [0 ]])
19041908 xmax_float = ctx .make_node ('Cast' , [xmax .output [0 ]], attr = {'to' : TensorProto .FLOAT })
19051909
19061910 # lengths of super/sub diags to extract
1907- xsize_0 = ctx .make_node ('Sub' , [const_xlen , xstart .output [0 ]])
1911+ xsize_0 = ctx .make_node ('Sub' , [xlen , xstart .output [0 ]])
19081912 xsize_1 = ctx .make_node ('Cast' , [xsize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1909- xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], const_minxy_float ])
1913+ xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], minxy_float ])
19101914 xsize = ctx .make_node ('Cast' , [xsize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1911- ysize_0 = ctx .make_node ('Add' , [const_ylen , ystart .output [0 ]])
1915+ ysize_0 = ctx .make_node ('Add' , [ylen , ystart .output [0 ]])
19121916 ysize_1 = ctx .make_node ('Cast' , [ysize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1913- ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], const_minxy_float ])
1917+ ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], minxy_float ])
19141918 ysize = ctx .make_node ('Cast' , [ysize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
19151919 diagsize = ctx .make_node ('Concat' , [xsize .output [0 ], ysize .output [0 ]], attr = {'axis' : 0 })
19161920 maxsize = ctx .make_node ('ReduceMax' , [diagsize .output [0 ]], attr = {'keep_dims' : 0 })
19171921 maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
19181922 maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
19191923
1920- diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
1921- diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], const_stride ])
1924+ diagdistances_0 = ctx .make_node ('Range' , [const_zero_scalar , maxsize_scalar .output [0 ], const_one_scalar ])
1925+ diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], stride ])
19221926
19231927 def right_align (sizes , indices , starts , maxval ):
19241928 op1 = ctx .make_node ('Sub' , [maxsize .output [0 ], sizes .output [0 ]])
1925- op2 = ctx .make_node ('Mul' , [op1 .output [0 ], const_stride ])
1929+ op2 = ctx .make_node ('Mul' , [op1 .output [0 ], stride ])
19261930 op3 = ctx .make_node ('Sub' , [indices .output [0 ], op2 .output [0 ]])
19271931 op4 = ctx .make_node ('Less' , [op3 .output [0 ], starts .output [0 ]])
19281932 op5 = ctx .make_node ('Where' , [op4 .output [0 ], maxval , op3 .output [0 ]])
@@ -1932,48 +1936,48 @@ def right_align(sizes, indices, starts, maxval):
19321936 xdiags_0 = ctx .make_node ('Add' , [xstart .output [0 ], diagdistances .output [0 ]])
19331937 xdiags_1 = ctx .make_node ('Cast' , [xdiags_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
19341938 if xalign == 'RIGHT' :
1935- xdiags = right_align (xsize , xdiags_0 , xstart , const_ymax )
1939+ xdiags = right_align (xsize , xdiags_0 , xstart , ymax )
19361940 else :
19371941 xdiags_2 = ctx .make_node ('Min' , [xdiags_1 .output [0 ], xmax_float .output [0 ]])
19381942 xdiags = ctx .make_node ('Cast' , [xdiags_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
19391943
19401944 ydiags_0_ = ctx .make_node ('Abs' , [ystart .output [0 ]])
1941- ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], const_xlenp ])
1945+ ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], xlenp ])
19421946 ydiags_2 = ctx .make_node ('Add' , [ydiags_1 .output [0 ], diagdistances .output [0 ]])
19431947 ydiags_3 = ctx .make_node ('Cast' , [ydiags_2 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
19441948 if yalign == 'RIGHT' :
1945- ydiags = right_align (ysize , ydiags_2 , ydiags_1 , const_ymax )
1949+ ydiags = right_align (ysize , ydiags_2 , ydiags_1 , ymax )
19461950 else :
1947- ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], const_ymax_float ])
1951+ ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], ymax_float ])
19481952 ydiags = ctx .make_node ('Cast' , [ydiags_4 .output [0 ]], attr = {'to' : TensorProto .INT64 })
19491953
19501954 # flatten last dimension of matrix
1951- m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], const_m2_shape ])
1955+ m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], m2_shape ])
19521956
19531957 diags_0 = ctx .make_node ('Concat' , [xdiags .output [0 ], ydiags .output [0 ]], attr = {'axis' : 0 })
19541958 diags_1 = ctx .make_node ('Reshape' , [diags_0 .output [0 ], const_neg_one ])
1955- diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
1959+ diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], gather_shape ])
19561960 diags = ctx .make_node ('GatherElements' , [m2 .output [0 ], diags_2 .output [0 ]], attr = {'axis' : - 1 })
19571961
19581962 def compute_out_shape (k0_k1_same = False ):
19591963 g = ctx .create_new_graph_with_same_config ()
19601964 g .parent_graph = ctx
19611965 if k0_k1_same :
1962- dims = [const_partial_shape , maxsize_0 .output [0 ]]
1966+ dims = [partial_shape , maxsize_0 .output [0 ]]
19631967 else :
1964- dims = [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]]
1968+ dims = [partial_shape , const_neg_one , maxsize_0 .output [0 ]]
19651969 outshape = g .make_node ('Concat' , dims , attr = {'axis' : 0 })
19661970 g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
19671971 return g
19681972
19691973 # if k0=k1, rank of output matrix is 1 less than usual
19701974 # hence, need 'If' to compute right output matrix shape
1971- k0_k1_same = ctx .make_node ('Equal' , [k1 . output [ 0 ] , k0 . output [ 0 ] ])
1975+ k0_k1_same = ctx .make_node ('Equal' , [k1 , k0 ])
19721976 if_node = ctx .make_node ('If' , [k0_k1_same .output [0 ]])
19731977 if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
19741978 if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
19751979
1976- shapes = [ - 1 ] * len ( m_shape )
1980+ shapes = ctx . get_shape ( node . output [ 0 ] )
19771981 dtypes = node .output_dtypes
19781982 ctx .remove_node (node .name )
19791983 ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
0 commit comments