@@ -33,6 +33,7 @@ def _alloc(
3333 sizes : Sequence [Union [int , Value ]],
3434 element_type : Type ,
3535 memory_space = None ,
36+ alignment = None ,
3637 loc = None ,
3738 ip = None ,
3839):
@@ -52,21 +53,56 @@ def _alloc(
5253
5354 symbol_operands = []
5455 return get_op_result_or_op_results (
55- op_ctor (result_type , dynamic_sizes , symbol_operands , loc = loc , ip = ip )
56+ op_ctor (
57+ result_type ,
58+ dynamic_sizes ,
59+ symbol_operands ,
60+ alignment = alignment ,
61+ loc = loc ,
62+ ip = ip ,
63+ )
5664 )
5765
5866
59- def alloc (sizes : Union [int , Value ], element_type : Type = None , memory_space = None ):
60- loc = get_user_code_loc ()
67+ def alloc (
68+ sizes : Union [int , Value ],
69+ element_type : Type = None ,
70+ memory_space = None ,
71+ alignment = None ,
72+ loc = None ,
73+ ip = None ,
74+ ):
75+ if loc is None :
76+ loc = get_user_code_loc ()
6177 return _alloc (
62- AllocOp , sizes , element_type , memory_space = memory_space , loc = loc , ip = None
78+ AllocOp ,
79+ sizes ,
80+ element_type ,
81+ memory_space = memory_space ,
82+ alignment = alignment ,
83+ loc = loc ,
84+ ip = ip ,
6385 )
6486
6587
66- def alloca (sizes : Union [int , Value ], element_type : Type = None , memory_space = None ):
67- loc = get_user_code_loc ()
88+ def alloca (
89+ sizes : Union [int , Value ],
90+ element_type : Type = None ,
91+ memory_space = None ,
92+ alignment = None ,
93+ loc = None ,
94+ ip = None ,
95+ ):
96+ if loc is None :
97+ loc = get_user_code_loc ()
6898 return _alloc (
69- AllocaOp , sizes , element_type , memory_space = memory_space , loc = loc , ip = None
99+ AllocaOp ,
100+ sizes ,
101+ element_type ,
102+ memory_space = memory_space ,
103+ alignment = alignment ,
104+ loc = loc ,
105+ ip = ip ,
70106 )
71107
72108
@@ -113,8 +149,9 @@ def __getitem__(self, idx: tuple) -> "MemRef":
113149 if idx is None :
114150 return expand_shape (self , (0 ,), loc = loc )
115151
116- idx = list ((idx ,) if isinstance (idx , (int , slice )) else idx )
152+ idx = list ((idx ,) if isinstance (idx , (int , Scalar , slice )) else idx )
117153 for i , d in enumerate (idx ):
154+ # TODO(max): rethink this since subview and etc probably take constant attributes?
118155 if isinstance (d , int ):
119156 idx [i ] = constant (d , index = True , loc = loc )
120157
@@ -123,7 +160,7 @@ def __getitem__(self, idx: tuple) -> "MemRef":
123160 else :
124161 return _subview (self , tuple (idx ), loc = loc )
125162
126- def __setitem__ (self , idx , source ):
163+ def __setitem__ (self , idx , val ):
127164 loc = get_user_code_loc ()
128165
129166 if not self .has_rank ():
@@ -135,12 +172,10 @@ def __setitem__(self, idx, source):
135172 idx [i ] = constant (d , index = True , loc = loc )
136173
137174 if all (isinstance (d , Scalar ) for d in idx ) and len (idx ) == len (self .shape ):
138- assert isinstance (
139- source , Scalar
140- ), "coordinate insert requires scalar element"
141- store (source , self , idx , loc = loc )
175+ assert isinstance (val , Scalar ), "coordinate insert requires scalar element"
176+ store (val , self , idx , loc = loc )
142177 else :
143- _copy_to_subview (self , source , tuple (idx ), loc = loc )
178+ _copy_to_subview (self , val , tuple (idx ), loc = loc )
144179
145180
146181def expand_shape (
0 commit comments