@@ -1835,94 +1835,97 @@ class MatrixDiagPartV2V3:
18351835 @classmethod
18361836 def version_11 (cls , ctx , node , ** kwargs ):
18371837
1838- def mkconst (npval , desc ):
1839- name = utils .make_name (node .name ) + f'_{ desc } '
1840- return ctx .make_const (name , npval ).output [0 ]
1838+ def mkconsts (values , dtype = np .int64 ):
1839+ ret = []
1840+ for value in values :
1841+ name = utils .make_name (node .name + '_const' )
1842+ ret .append (ctx .make_const (name , np .array (value , dtype = dtype )).output [0 ])
1843+ return ret
18411844
18421845 # assemble MatrixDiagPart V2&V3
18431846 m = node .input [0 ]
18441847 m_shape = ctx .get_shape (m )
1845- utils .make_sure (- 1 not in m_shape , 'At least one dim is unknown %s' , str (m_shape ))
1846-
1847- xlen = m_shape [- 1 ]
1848- ylen = m_shape [- 2 ]
1849- xlenp = xlen + 1
1850- pads = np .zeros (2 * len (m_shape ), dtype = np .int64 )
1848+ m_rank = len (m_shape )
1849+ pads = np .zeros (2 * m_rank , dtype = np .int64 )
18511850 pads [- 2 :] = [1 , 1 ]
1851+ utils .make_sure (m_rank > 1 , 'Input data should be at least 2D %s' , str (m_shape ))
18521852
18531853 align = 'LEFT_LEFT'
18541854 if node .op .op_type == 'MatrixDiagPartV3' :
18551855 align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
18561856 xalign , yalign = align .split ('_' )
18571857
18581858 # consts
1859- const_zero = mkconst (np .array ([0 ], np .int64 ), 'const_zero_dtype' )
1860- const_zero_float = mkconst (np .array ([0 ], np .float32 ), 'const_zero_dtype_f' )
1861- const_one = mkconst (np .array ([1 ], np .int64 ), 'const_one_dtype' )
1862- const_neg_one = mkconst (np .array ([- 1 ]).astype (np .int64 ), 'const_neg_one' )
1863- const_neg_one_float = mkconst (np .array ([- 1 ]).astype (np .float32 ), 'const_neg_one_f' )
1864- const_pad_vals = mkconst (pads , 'pads' )
1865- const_t = mkconst (np .array ([- 1 , 1 ], np .int64 ), 'const_t' )
1866- const_xlen = mkconst (np .array ([xlen ], np .int64 ), 'const_xlen' )
1867- const_ylen = mkconst (np .array ([ylen ], np .int64 ), 'const_ylen' )
1868- const_xlenp = mkconst (np .array ([xlenp ], np .int64 ), 'const_xlenp' )
1869- const_stride = mkconst (np .array ([xlenp + 1 ], np .int64 ), 'const_stride' )
1870- const_minxy_float = mkconst (np .array ([min (xlen , ylen )], np .float32 ), 'const_minxy_f' )
1871- const_xmax = mkconst (np .array ([xlen * xlenp + xlenp - 1 ], np .int64 ), 'const_xmax' )
1872- const_ymax = mkconst (np .array ([xlenp * ylen - 1 ], np .int64 ), 'const_ymax' )
1873- const_ymax_float = mkconst (np .array ([xlenp * ylen - 1 ], np .float32 ), 'const_ymax_f' )
1874- const_partial_shape = mkconst (np .asarray (m_shape [:- 2 ], np .int64 ), 'partial_shape' )
1875- const_m2_shape = mkconst (np .asarray (m_shape [:- 2 ] + [- 1 ], np .int64 ), 'm2_shape' )
1876- const_gather_shape = mkconst (np .asarray (m_shape [:- 2 ] + [1 ], np .int64 ), 'gather_shape' )
1859+ const_zero_float , const_neg_one_float = mkconsts ([[0 ], [- 1 ]], np .float32 )
1860+ const_zero , const_one , const_neg_one , const_neg_two , const_pad_vals , const_t = \
1861+ mkconsts ([[0 ], [1 ], [- 1 ], [- 2 ], pads , [- 1 , 1 ]])
1862+
1863+ m_shape = ctx .make_node ('Shape' , [node .input [0 ]]).output [0 ]
1864+ xlen = ctx .make_node ('Gather' , [m_shape , const_neg_one ]).output [0 ]
1865+ ylen = ctx .make_node ('Gather' , [m_shape , const_neg_two ]).output [0 ]
1866+ xlenp = ctx .make_node ('Add' , [xlen , const_one ]).output [0 ]
1867+ stride = ctx .make_node ('Add' , [xlenp , const_one ]).output [0 ]
1868+ minxy_0 = ctx .make_node ('Concat' , [xlen , ylen ], attr = {'axis' : 0 }).output [0 ]
1869+ minxy = ctx .make_node ('ReduceMin' , [minxy_0 ]).output [0 ]
1870+ minxy_float = ctx .make_node ('Cast' , [minxy ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1871+ xmax_0 = ctx .make_node ('Mul' , [xlen , xlenp ]).output [0 ]
1872+ xmax_1 = ctx .make_node ('Add' , [xmax_0 , xlenp ]).output [0 ]
1873+ xmax = ctx .make_node ('Add' , [xmax_1 , const_neg_one ]).output [0 ]
1874+ ymax_0 = ctx .make_node ('Mul' , [xlenp , ylen ]).output [0 ]
1875+ ymax = ctx .make_node ('Add' , [ymax_0 , const_neg_one ]).output [0 ]
1876+ ymax_float = ctx .make_node ('Cast' , [ymax ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1877+ partial_shape = ctx .make_node ('Slice' , [m_shape , const_zero , const_neg_two ]).output [0 ]
1878+ m2_shape = ctx .make_node ('Concat' , [partial_shape , const_neg_one ], attr = {'axis' : 0 }).output [0 ]
1879+ gather_shape = ctx .make_node ('Concat' , [partial_shape , const_one ], attr = {'axis' : 0 }).output [0 ]
18771880
18781881 # get k0, k1 values. diags to be extracted
18791882 input1 = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 })
1880- k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]])
1881- k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]])
1882- k1_scalar = ctx .make_node ('Squeeze' , [k1 .output [0 ]])
1883+ k0 = ctx .make_node ('ReduceMin' , [input1 .output [0 ]]). output [ 0 ]
1884+ k1 = ctx .make_node ('ReduceMax' , [input1 .output [0 ]]). output [ 0 ]
1885+ k1_scalar = ctx .make_node ('Squeeze' , [k1 ]) .output [0 ]
18831886 m_padded = ctx .make_node ('Pad' , [m , const_pad_vals , node .input [2 ]])
18841887
18851888 # starting indexes for super diagonals
1886- xstart_0 = ctx .make_node ('Cast' , [k0 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1889+ xstart_0 = ctx .make_node ('Cast' , [k0 ], attr = {'to' : TensorProto .FLOAT })
18871890 xstart_1 = ctx .make_node ('Max' , [const_zero_float , xstart_0 .output [0 ]])
18881891 xstart_2 = ctx .make_node ('Cast' , [xstart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
18891892 xstart_3 = ctx .make_node ('Add' , [xstart_2 .output [0 ], const_neg_one ])
1890- xstart_4 = ctx .make_node ('Range' , [k1_scalar . output [ 0 ] , xstart_3 .output [0 ], const_neg_one ])
1893+ xstart_4 = ctx .make_node ('Range' , [k1_scalar , xstart_3 .output [0 ], const_neg_one ])
18911894 xstart = ctx .make_node ('Reshape' , [xstart_4 .output [0 ], const_t ])
18921895
18931896 # starting indexes for sub diagonals
1894- ystart_0 = ctx .make_node ('Cast' , [k1 . output [ 0 ] ], attr = {'to' : TensorProto .FLOAT })
1897+ ystart_0 = ctx .make_node ('Cast' , [k1 ], attr = {'to' : TensorProto .FLOAT })
18951898 ystart_1 = ctx .make_node ('Min' , [const_neg_one_float , ystart_0 .output [0 ]])
18961899 ystart_2 = ctx .make_node ('Cast' , [ystart_1 .output [0 ]], attr = {'to' : TensorProto .INT64 })
18971900 ystart_2_scalar = ctx .make_node ('Squeeze' , [ystart_2 .output [0 ]])
1898- ystart_3 = ctx .make_node ('Add' , [k0 . output [ 0 ] , const_neg_one ])
1901+ ystart_3 = ctx .make_node ('Add' , [k0 , const_neg_one ])
18991902 ystart_4 = ctx .make_node ('Range' , [ystart_2_scalar .output [0 ], ystart_3 .output [0 ], const_neg_one ])
19001903 ystart = ctx .make_node ('Reshape' , [ystart_4 .output [0 ], const_t ])
19011904
1902- xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], const_xlenp ])
1903- xmax = ctx .make_node ('Sub' , [const_xmax , xmax_0 .output [0 ]])
1905+ xmax_0 = ctx .make_node ('Mul' , [xstart .output [0 ], xlenp ])
1906+ xmax = ctx .make_node ('Sub' , [xmax , xmax_0 .output [0 ]])
19041907 xmax_float = ctx .make_node ('Cast' , [xmax .output [0 ]], attr = {'to' : TensorProto .FLOAT })
19051908
19061909 # lengths of super/sub diags to extract
1907- xsize_0 = ctx .make_node ('Sub' , [const_xlen , xstart .output [0 ]])
1910+ xsize_0 = ctx .make_node ('Sub' , [xlen , xstart .output [0 ]])
19081911 xsize_1 = ctx .make_node ('Cast' , [xsize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1909- xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], const_minxy_float ])
1912+ xsize_2 = ctx .make_node ('Min' , [xsize_1 .output [0 ], minxy_float ])
19101913 xsize = ctx .make_node ('Cast' , [xsize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
1911- ysize_0 = ctx .make_node ('Add' , [const_ylen , ystart .output [0 ]])
1914+ ysize_0 = ctx .make_node ('Add' , [ylen , ystart .output [0 ]])
19121915 ysize_1 = ctx .make_node ('Cast' , [ysize_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
1913- ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], const_minxy_float ])
1916+ ysize_2 = ctx .make_node ('Min' , [ysize_1 .output [0 ], minxy_float ])
19141917 ysize = ctx .make_node ('Cast' , [ysize_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
19151918 diagsize = ctx .make_node ('Concat' , [xsize .output [0 ], ysize .output [0 ]], attr = {'axis' : 0 })
19161919 maxsize = ctx .make_node ('ReduceMax' , [diagsize .output [0 ]], attr = {'keep_dims' : 0 })
19171920 maxsize_0 = ctx .make_node ('Reshape' , [maxsize .output [0 ], const_neg_one ])
19181921 maxsize_scalar = ctx .make_node ('Squeeze' , [maxsize .output [0 ]])
19191922
19201923 diagdistances_0 = ctx .make_node ('Range' , [const_zero , maxsize_scalar .output [0 ], const_one ])
1921- diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], const_stride ])
1924+ diagdistances = ctx .make_node ('Mul' , [diagdistances_0 .output [0 ], stride ])
19221925
19231926 def right_align (sizes , indices , starts , maxval ):
19241927 op1 = ctx .make_node ('Sub' , [maxsize .output [0 ], sizes .output [0 ]])
1925- op2 = ctx .make_node ('Mul' , [op1 .output [0 ], const_stride ])
1928+ op2 = ctx .make_node ('Mul' , [op1 .output [0 ], stride ])
19261929 op3 = ctx .make_node ('Sub' , [indices .output [0 ], op2 .output [0 ]])
19271930 op4 = ctx .make_node ('Less' , [op3 .output [0 ], starts .output [0 ]])
19281931 op5 = ctx .make_node ('Where' , [op4 .output [0 ], maxval , op3 .output [0 ]])
@@ -1932,48 +1935,48 @@ def right_align(sizes, indices, starts, maxval):
19321935 xdiags_0 = ctx .make_node ('Add' , [xstart .output [0 ], diagdistances .output [0 ]])
19331936 xdiags_1 = ctx .make_node ('Cast' , [xdiags_0 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
19341937 if xalign == 'RIGHT' :
1935- xdiags = right_align (xsize , xdiags_0 , xstart , const_ymax )
1938+ xdiags = right_align (xsize , xdiags_0 , xstart , ymax )
19361939 else :
19371940 xdiags_2 = ctx .make_node ('Min' , [xdiags_1 .output [0 ], xmax_float .output [0 ]])
19381941 xdiags = ctx .make_node ('Cast' , [xdiags_2 .output [0 ]], attr = {'to' : TensorProto .INT64 })
19391942
19401943 ydiags_0_ = ctx .make_node ('Abs' , [ystart .output [0 ]])
1941- ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], const_xlenp ])
1944+ ydiags_1 = ctx .make_node ('Mul' , [ydiags_0_ .output [0 ], xlenp ])
19421945 ydiags_2 = ctx .make_node ('Add' , [ydiags_1 .output [0 ], diagdistances .output [0 ]])
19431946 ydiags_3 = ctx .make_node ('Cast' , [ydiags_2 .output [0 ]], attr = {'to' : TensorProto .FLOAT })
19441947 if yalign == 'RIGHT' :
1945- ydiags = right_align (ysize , ydiags_2 , ydiags_1 , const_ymax )
1948+ ydiags = right_align (ysize , ydiags_2 , ydiags_1 , ymax )
19461949 else :
1947- ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], const_ymax_float ])
1950+ ydiags_4 = ctx .make_node ('Min' , [ydiags_3 .output [0 ], ymax_float ])
19481951 ydiags = ctx .make_node ('Cast' , [ydiags_4 .output [0 ]], attr = {'to' : TensorProto .INT64 })
19491952
19501953 # flatten last dimension of matrix
1951- m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], const_m2_shape ])
1954+ m2 = ctx .make_node ('Reshape' , [m_padded .output [0 ], m2_shape ])
19521955
19531956 diags_0 = ctx .make_node ('Concat' , [xdiags .output [0 ], ydiags .output [0 ]], attr = {'axis' : 0 })
19541957 diags_1 = ctx .make_node ('Reshape' , [diags_0 .output [0 ], const_neg_one ])
1955- diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], const_gather_shape ])
1958+ diags_2 = ctx .make_node ('Expand' , [diags_1 .output [0 ], gather_shape ])
19561959 diags = ctx .make_node ('GatherElements' , [m2 .output [0 ], diags_2 .output [0 ]], attr = {'axis' : - 1 })
19571960
19581961 def compute_out_shape (k0_k1_same = False ):
19591962 g = ctx .create_new_graph_with_same_config ()
19601963 g .parent_graph = ctx
19611964 if k0_k1_same :
1962- dims = [const_partial_shape , maxsize_0 .output [0 ]]
1965+ dims = [partial_shape , maxsize_0 .output [0 ]]
19631966 else :
1964- dims = [const_partial_shape , const_neg_one , maxsize_0 .output [0 ]]
1967+ dims = [partial_shape , const_neg_one , maxsize_0 .output [0 ]]
19651968 outshape = g .make_node ('Concat' , dims , attr = {'axis' : 0 })
19661969 g .add_graph_output (outshape .output [0 ], TensorProto .INT64 , [- 1 ])
19671970 return g
19681971
19691972 # if k0=k1, rank of output matrix is 1 less than usual
19701973 # hence, need 'If' to compute right output matrix shape
1971- k0_k1_same = ctx .make_node ('Equal' , [k1 . output [ 0 ] , k0 . output [ 0 ] ])
1974+ k0_k1_same = ctx .make_node ('Equal' , [k1 , k0 ])
19721975 if_node = ctx .make_node ('If' , [k0_k1_same .output [0 ]])
19731976 if_node .set_body_graph_as_attr ('then_branch' , compute_out_shape (True ))
19741977 if_node .set_body_graph_as_attr ('else_branch' , compute_out_shape (False ))
19751978
1976- shapes = [- 1 ] * len ( m_shape )
1979+ shapes = [- 1 ] * m_rank
19771980 dtypes = node .output_dtypes
19781981 ctx .remove_node (node .name )
19791982 ctx .make_node ('Reshape' , [diags .output [0 ], if_node .output [0 ]], name = node .name , outputs = node .output ,
0 commit comments