4242from mlir .extras .util import mlir_type_to_np_dtype
4343
4444
45+ INDENT = 0
46+ # OUTPUT_BUF = io.StringIO()
47+ OUTPUT_BUF = sys .stdout
48+ ATTR_ALIASES = {}
49+
50+
4551def normalize_ssa (ssa : str | Value ):
4652 if isinstance (ssa , Value ):
47- ssa = ssa .get_name ()
53+ ssa = ssa .get_name (use_name_loc_as_prefix = True )
4854 if ssa [1 ].isnumeric ():
4955 ssa = ssa .replace ("%" , "v" )
5056 else :
@@ -74,6 +80,8 @@ def np_array_from_shape_type(shape, dtype, splat_value=None):
7480
7581
7682def map_attr (attr ):
83+ if attr in ATTR_ALIASES :
84+ return ATTR_ALIASES [attr ]
7785 attr = attr .maybe_downcast ()
7886 if isinstance (attr , (IntegerAttr , BoolAttr , FloatAttr )):
7987 return attr .value
@@ -130,11 +138,6 @@ def map_type(type):
130138 return f"Type.parse('{ type } ')"
131139
132140
133- indent = 0
134- # OUTPUT_BUF = io.StringIO()
135- OUTPUT_BUF = sys .stdout
136-
137-
138141def get_init_args (opview ):
139142 klass = opview .__class__
140143 while not klass .__base__ is OpView :
@@ -168,7 +171,7 @@ def underscore(word: str) -> str:
168171
169172
170173def print_opview (opview , name = None ):
171- print (" " * indent , file = OUTPUT_BUF , end = "" )
174+ print (" " * INDENT , file = OUTPUT_BUF , end = "" )
172175 if len (opview .results ):
173176 print (
174177 ", " .join ([normalize_ssa (r ) for r in opview .results ]),
@@ -249,15 +252,15 @@ def print_opview(opview, name=None):
249252 else :
250253 owner = f"{ op_idx_owner_name } "
251254 print (
252- " " * indent
255+ " " * INDENT
253256 + f"{ owner } .attributes['OpIdx'] = amdgpu.OpIdxAttr.get({ attrs ['OpIdx' ].value } )" ,
254257 file = OUTPUT_BUF ,
255258 )
256259
257260
258261def print_func_op (func_op : func .FuncOp ):
259262 # op.print(print_generic_op_form=True)
260- print (" " * indent , file = OUTPUT_BUF , end = "" )
263+ print (" " * INDENT , file = OUTPUT_BUF , end = "" )
261264 print ("@func.func(" , file = OUTPUT_BUF , end = "" )
262265 if len (func_op .attributes ):
263266 attrs = []
@@ -283,7 +286,7 @@ def print_func_op(func_op: func.FuncOp):
283286
284287
285288def print_arith_constant (constop : arith .ConstantOp ):
286- print (" " * indent , file = OUTPUT_BUF , end = "" )
289+ print (" " * INDENT , file = OUTPUT_BUF , end = "" )
287290 print (
288291 f"{ normalize_ssa (constop .result )} = arith.constant({ map_attr (constop .value )} , { map_type (constop .result .type )} )" ,
289292 file = OUTPUT_BUF ,
@@ -305,7 +308,7 @@ def print_scf_for(for_op: scf.ForOp):
305308 )
306309 init_args = [normalize_ssa (a ) for a in for_op .initArgs ]
307310 print (
308- (" " * indent )
311+ (" " * INDENT )
309312 + f"for { opers_str } in scf.for_({ start } , { stop } , { step } , iter_args=[{ ', ' .join (init_args )} ]):" ,
310313 file = OUTPUT_BUF ,
311314 )
@@ -315,12 +318,12 @@ def print_scf_if(if_op: scf.IfOp):
315318 assert len (if_op .results ) == 1
316319 res = if_op .results [0 ]
317320 res_name = normalize_ssa (res )
318- global indent
321+ global INDENT
319322
320323 def print_yield_as_return (yield_op : scf .YieldOp ):
321324 opers = [normalize_ssa (a ) for a in yield_op .operands ]
322325 print (
323- (" " * indent ) + f"return { ', ' .join (opers )} " ,
326+ (" " * INDENT ) + f"return { ', ' .join (opers )} " ,
324327 file = OUTPUT_BUF ,
325328 )
326329
@@ -332,17 +335,17 @@ def print_yield_as_return(yield_op: scf.YieldOp):
332335 def { res_name } ():\
333336 """
334337 ),
335- " " * indent ,
338+ " " * INDENT ,
336339 ),
337340 file = OUTPUT_BUF ,
338341 )
339- indent += 1
342+ INDENT += 1
340343 for bodyop in if_op .thenRegion .blocks [0 ].operations :
341344 if isinstance (bodyop , scf .YieldOp ):
342345 print_yield_as_return (bodyop )
343346 else :
344347 bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
345- indent -= 1
348+ INDENT -= 1
346349 print (
347350 textwrap .indent (
348351 textwrap .dedent (
@@ -351,17 +354,17 @@ def {res_name}():\
351354 def { res_name } _else():\
352355 """ ,
353356 ),
354- " " * indent ,
357+ " " * INDENT ,
355358 ),
356359 file = OUTPUT_BUF ,
357360 )
358- indent += 1
361+ INDENT += 1
359362 for bodyop in if_op .elseRegion .blocks [0 ].operations :
360363 if isinstance (bodyop , scf .YieldOp ):
361364 print_yield_as_return (bodyop )
362365 else :
363366 bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
364- indent -= 1
367+ INDENT -= 1
365368
366369
367370def generic_print_walk_callback (op ):
@@ -392,16 +395,26 @@ def generic_print_walk_callback(op):
392395 print_opview (opview )
393396
394397 if len (op .regions ):
395- global indent
396- indent += 1
398+ global INDENT
399+ INDENT += 1
397400 for bodyop in op .regions [0 ].blocks [0 ].operations :
398401 bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
399- indent -= 1
402+ INDENT -= 1
400403 return WalkResult .SKIP
401404
402405 return WalkResult .ADVANCE
403406
404407
408+ def print_attr_alias (attr_line : str ):
409+ print (attr_line )
410+ alias_name , attr_str = attr_line .split (" = " , maxsplit = 1 )
411+ assert alias_name .startswith ("#" )
412+ alias_name = alias_name [1 :]
413+ attr = Attribute .parse (attr_str )
414+ print (f"{ alias_name } = { map_attr (attr )} " , file = OUTPUT_BUF )
415+ ATTR_ALIASES [attr ] = alias_name
416+
417+
405418def main () -> None :
406419 parser = argparse .ArgumentParser ()
407420 parser .add_argument ("input_file" , type = Path )
0 commit comments