55from ....ir import Type , Value , MemRefType , ShapedType , MLIRError
66
77from ... import types as T
8+ from ....dialects .memref import *
89from ....dialects import memref , arith
9- from ...dialects .ext .arith import Scalar , constant
10- from ...dialects .ext .tensor import (
11- _indices_to_indexer ,
12- compute_result_shape_reassoc_list ,
13- )
10+ from .arith import Scalar , constant
11+ from .tensor import _indices_to_indexer , compute_result_shape_reassoc_list
1412from ...meta import region_op
1513from ...._mlir_libs ._mlir import register_value_caster
1614from ...util import get_user_code_loc
@@ -39,7 +37,7 @@ def _alloc(
3937def alloc (sizes : Sequence [Union [int , Value ]], element_type : Type , * , loc = None , ip = None ):
4038 if loc is None :
4139 loc = get_user_code_loc ()
42- return _alloc (memref . AllocOp , sizes , element_type , loc = loc , ip = ip )
40+ return _alloc (AllocOp , sizes , element_type , loc = loc , ip = ip )
4341
4442
4543def alloca (
@@ -48,7 +46,7 @@ def alloca(
4846 if loc is None :
4947 loc = get_user_code_loc ()
5048 return get_op_result_or_op_results (
51- _alloc (memref . AllocaOp , sizes , element_type , loc = loc , ip = ip )
49+ _alloc (AllocaOp , sizes , element_type , loc = loc , ip = ip )
5250 )
5351
5452
@@ -59,7 +57,7 @@ def load(mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None):
5957 for idx , i in enumerate (indices ):
6058 if isinstance (i , int ):
6159 indices [idx ] = constant (i , index = True )
62- return get_op_result_or_op_results (memref . LoadOp (mem , indices , loc = loc , ip = ip ))
60+ return get_op_result_or_op_results (LoadOp (mem , indices , loc = loc , ip = ip ))
6361
6462
6563def store (
@@ -71,9 +69,7 @@ def store(
7169 for idx , i in enumerate (indices ):
7270 if isinstance (i , int ):
7371 indices [idx ] = constant (i , index = True )
74- return get_op_result_or_op_results (
75- memref .StoreOp (value , mem , indices , loc = loc , ip = ip )
76- )
72+ return get_op_result_or_op_results (StoreOp (value , mem , indices , loc = loc , ip = ip ))
7773
7874
7975def subview (
@@ -345,4 +341,4 @@ def _copy_to_subview(
345341 return memref .copy (source , dest_subview , loc = loc , ip = ip )
346342
347343
348- alloca_scope = region_op (memref . AllocaScopeOp )
344+ alloca_scope = region_op (AllocaScopeOp )
0 commit comments