1- import re
21from functools import cached_property , reduce
3- from typing import Tuple , Sequence , Optional , Union
2+ from typing import Tuple , Sequence , Union
43
5- from ....ir import Type , Value , MemRefType , ShapedType , MLIRError
6-
7- from ... import types as T
8- from ....dialects .memref import *
9- from ....dialects import memref , arith
104from .arith import Scalar , constant
115from .tensor import _indices_to_indexer , compute_result_shape_reassoc_list
6+ from ... import types as T
127from ...meta import region_op
13- from ...._mlir_libs ._mlir import register_value_caster
148from ...util import get_user_code_loc
9+ from ...._mlir_libs ._mlir import register_value_caster
10+ from ....dialects import memref , arith
1511from ....dialects ._ods_common import get_op_result_or_op_results
12+ from ....dialects .memref import *
13+ from ....ir import Type , Value , MemRefType , ShapedType
1614
1715S = ShapedType .get_dynamic_size ()
1816
@@ -70,71 +68,6 @@ def store(
7068 return get_op_result_or_op_results (StoreOp (value , mem , indices , loc = loc , ip = ip ))
7169
7270
73- def subview (
74- source : "MemRef" ,
75- offsets : Optional [Sequence [Value ]] = None ,
76- strides : Optional [Sequence [Value ]] = None ,
77- static_offsets : Optional [Sequence [int ]] = None ,
78- static_sizes : Optional [Sequence [int ]] = None ,
79- static_strides : Optional [Sequence [int ]] = None ,
80- * ,
81- loc = None ,
82- ip = None ,
83- ):
84- if loc is None :
85- loc = get_user_code_loc ()
86- if offsets is None :
87- offsets = []
88- if static_offsets is None :
89- static_offsets = []
90- if strides is None :
91- strides = []
92- if static_strides is None :
93- static_strides = []
94- assert static_sizes , f"this convenience method only handles static sizes"
95- sizes = []
96- wrong_type = T .memref (* static_sizes , source .dtype )
97- if offsets and static_offsets :
98- assert all (s == S for s in static_offsets )
99- if strides and static_strides :
100- assert all (s == S for s in static_strides )
101- val = memref .subview (
102- wrong_type ,
103- source ,
104- offsets ,
105- sizes ,
106- strides ,
107- static_offsets ,
108- static_sizes ,
109- static_strides ,
110- loc = loc ,
111- ip = ip ,
112- )
113- # dumbest hack ever - the default builder doesn't connect to inferReturnTypes
114- # but the diag message does
115- try :
116- val .owner .verify ()
117- return val
118- except MLIRError as e :
119- diag = str (e .error_diagnostics [0 ])
120- correct_type = re .findall (r"'memref<(.*)>'" , diag )
121- assert len (correct_type ) == 1
122- correct_type = Type .parse (f"memref<{ correct_type [0 ]} >" )
123- val .owner .erase ()
124- return memref .subview (
125- correct_type ,
126- source ,
127- offsets ,
128- sizes ,
129- strides ,
130- static_offsets ,
131- static_sizes ,
132- static_strides ,
133- loc = loc ,
134- ip = ip ,
135- )
136-
137-
13871@register_value_caster (MemRefType .static_typeid )
13972class MemRef (Value ):
14073 def __str__ (self ):
@@ -266,16 +199,15 @@ def _subview(
266199 if indexer .is_constant ():
267200 out = subview (
268201 out ,
269- static_offsets = indexer .static_offsets (),
270- static_sizes = indexer .static_sizes (),
271- static_strides = indexer .static_strides (),
202+ offsets = indexer .static_offsets (),
203+ sizes = indexer .static_sizes (),
204+ strides = indexer .static_strides (),
272205 loc = loc ,
273206 ip = ip ,
274207 )
275208 else :
276209 # special tile case
277210 offsets = [None ] * len (indexer .in_shape )
278- static_offsets = [None ] * len (indexer .in_shape )
279211 static_sizes = [None ] * len (indexer .in_shape )
280212 static_strides = [None ] * len (indexer .in_shape )
281213 for i , ind in enumerate (indexer .indices ):
@@ -292,15 +224,13 @@ def _subview(
292224 and ind .step .is_constant ()
293225 ):
294226 offsets [i ] = ind .start
295- static_offsets [i ] = S
296227 static_sizes [i ] = maybe_size .literal_value
297228 static_strides [i ] = (
298229 ind .step .literal_value if isinstance (ind .step , Scalar ) else ind .step
299230 )
300231 else :
301232 raise RuntimeError (f"indexing not supported { indexer .indices } " )
302233 offsets = list (filter (None , offsets ))
303- static_offsets = list (filter (None , static_offsets ))
304234 static_sizes = list (filter (None , static_sizes ))
305235 static_strides = list (filter (None , static_strides ))
306236 assert (
@@ -312,9 +242,8 @@ def _subview(
312242 out = subview (
313243 out ,
314244 offsets = offsets ,
315- static_offsets = static_offsets ,
316- static_sizes = static_sizes ,
317- static_strides = static_strides ,
245+ sizes = static_sizes ,
246+ strides = static_strides ,
318247 loc = loc ,
319248 ip = ip ,
320249 )
0 commit comments