4242from  mlir .extras .util  import  mlir_type_to_np_dtype 
4343
4444
45+ INDENT  =  0 
46+ # OUTPUT_BUF = io.StringIO() 
47+ OUTPUT_BUF  =  sys .stdout 
48+ ATTR_ALIASES  =  {}
49+ 
50+ 
4551def  normalize_ssa (ssa : str  |  Value ):
4652    if  isinstance (ssa , Value ):
47-         ssa  =  ssa .get_name ()
53+         ssa  =  ssa .get_name (use_name_loc_as_prefix = True )
4854    if  ssa [1 ].isnumeric ():
4955        ssa  =  ssa .replace ("%" , "v" )
5056    else :
@@ -74,6 +80,8 @@ def np_array_from_shape_type(shape, dtype, splat_value=None):
7480
7581
7682def  map_attr (attr ):
83+     if  attr  in  ATTR_ALIASES :
84+         return  ATTR_ALIASES [attr ]
7785    attr  =  attr .maybe_downcast ()
7886    if  isinstance (attr , (IntegerAttr , BoolAttr , FloatAttr )):
7987        return  attr .value 
@@ -130,11 +138,6 @@ def map_type(type):
130138    return  f"Type.parse('{ type }  ')" 
131139
132140
133- indent  =  0 
134- # OUTPUT_BUF = io.StringIO() 
135- OUTPUT_BUF  =  sys .stdout 
136- 
137- 
138141def  get_init_args (opview ):
139142    klass  =  opview .__class__ 
140143    while  not  klass .__base__  is  OpView :
@@ -168,7 +171,7 @@ def underscore(word: str) -> str:
168171
169172
170173def  print_opview (opview , name = None ):
171-     print ("    "  *  indent , file = OUTPUT_BUF , end = "" )
174+     print ("    "  *  INDENT , file = OUTPUT_BUF , end = "" )
172175    if  len (opview .results ):
173176        print (
174177            ", " .join ([normalize_ssa (r ) for  r  in  opview .results ]),
@@ -249,15 +252,15 @@ def print_opview(opview, name=None):
249252        else :
250253            owner  =  f"{ op_idx_owner_name }  " 
251254        print (
252-             "    "  *  indent 
255+             "    "  *  INDENT 
253256            +  f"{ owner }  .attributes['OpIdx'] = amdgpu.OpIdxAttr.get({ attrs ['OpIdx' ].value }  )" ,
254257            file = OUTPUT_BUF ,
255258        )
256259
257260
258261def  print_func_op (func_op : func .FuncOp ):
259262    # op.print(print_generic_op_form=True) 
260-     print ("    "  *  indent , file = OUTPUT_BUF , end = "" )
263+     print ("    "  *  INDENT , file = OUTPUT_BUF , end = "" )
261264    print ("@func.func(" , file = OUTPUT_BUF , end = "" )
262265    if  len (func_op .attributes ):
263266        attrs  =  []
@@ -283,7 +286,7 @@ def print_func_op(func_op: func.FuncOp):
283286
284287
285288def  print_arith_constant (constop : arith .ConstantOp ):
286-     print ("    "  *  indent , file = OUTPUT_BUF , end = "" )
289+     print ("    "  *  INDENT , file = OUTPUT_BUF , end = "" )
287290    print (
288291        f"{ normalize_ssa (constop .result )}   = arith.constant({ map_attr (constop .value )}  , { map_type (constop .result .type )}  )" ,
289292        file = OUTPUT_BUF ,
@@ -305,7 +308,7 @@ def print_scf_for(for_op: scf.ForOp):
305308    )
306309    init_args  =  [normalize_ssa (a ) for  a  in  for_op .initArgs ]
307310    print (
308-         ("    "  *  indent )
311+         ("    "  *  INDENT )
309312        +  f"for { opers_str }   in scf.for_({ start }  , { stop }  , { step }  , iter_args=[{ ', ' .join (init_args )}  ]):" ,
310313        file = OUTPUT_BUF ,
311314    )
@@ -315,12 +318,12 @@ def print_scf_if(if_op: scf.IfOp):
315318    assert  len (if_op .results ) ==  1 
316319    res  =  if_op .results [0 ]
317320    res_name  =  normalize_ssa (res )
318-     global  indent 
321+     global  INDENT 
319322
320323    def  print_yield_as_return (yield_op : scf .YieldOp ):
321324        opers  =  [normalize_ssa (a ) for  a  in  yield_op .operands ]
322325        print (
323-             ("    "  *  indent ) +  f"return { ', ' .join (opers )}  " ,
326+             ("    "  *  INDENT ) +  f"return { ', ' .join (opers )}  " ,
324327            file = OUTPUT_BUF ,
325328        )
326329
@@ -332,17 +335,17 @@ def print_yield_as_return(yield_op: scf.YieldOp):
332335                    def { res_name }  ():\  
333336                 """
334337            ),
335-             "    "  *  indent ,
338+             "    "  *  INDENT ,
336339        ),
337340        file = OUTPUT_BUF ,
338341    )
339-     indent  +=  1 
342+     INDENT  +=  1 
340343    for  bodyop  in  if_op .thenRegion .blocks [0 ].operations :
341344        if  isinstance (bodyop , scf .YieldOp ):
342345            print_yield_as_return (bodyop )
343346        else :
344347            bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
345-     indent  -=  1 
348+     INDENT  -=  1 
346349    print (
347350        textwrap .indent (
348351            textwrap .dedent (
@@ -351,17 +354,17 @@ def {res_name}():\
351354                    def { res_name }  _else():\  
352355                 """ ,
353356            ),
354-             "    "  *  indent ,
357+             "    "  *  INDENT ,
355358        ),
356359        file = OUTPUT_BUF ,
357360    )
358-     indent  +=  1 
361+     INDENT  +=  1 
359362    for  bodyop  in  if_op .elseRegion .blocks [0 ].operations :
360363        if  isinstance (bodyop , scf .YieldOp ):
361364            print_yield_as_return (bodyop )
362365        else :
363366            bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
364-     indent  -=  1 
367+     INDENT  -=  1 
365368
366369
367370def  generic_print_walk_callback (op ):
@@ -392,16 +395,26 @@ def generic_print_walk_callback(op):
392395        print_opview (opview )
393396
394397    if  len (op .regions ):
395-         global  indent 
396-         indent  +=  1 
398+         global  INDENT 
399+         INDENT  +=  1 
397400        for  bodyop  in  op .regions [0 ].blocks [0 ].operations :
398401            bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
399-         indent  -=  1 
402+         INDENT  -=  1 
400403        return  WalkResult .SKIP 
401404
402405    return  WalkResult .ADVANCE 
403406
404407
408+ def  print_attr_alias (attr_line : str ):
409+     print (attr_line )
410+     alias_name , attr_str  =  attr_line .split (" = " , maxsplit = 1 )
411+     assert  alias_name .startswith ("#" )
412+     alias_name  =  alias_name [1 :]
413+     attr  =  Attribute .parse (attr_str )
414+     print (f"{ alias_name }   = { map_attr (attr )}  " , file = OUTPUT_BUF )
415+     ATTR_ALIASES [attr ] =  alias_name 
416+ 
417+ 
405418def  main () ->  None :
406419    parser  =  argparse .ArgumentParser ()
407420    parser .add_argument ("input_file" , type = Path )
0 commit comments