Skip to content

Commit 246f7a8

Browse files
committed
use upstream types
1 parent ecaa7c4 commit 246f7a8

22 files changed

+332
-483
lines changed

examples/mlir_python_extras.ipynb

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range\n",
6060
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
6161
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n",
62+
"from mlir.ir import StridedLayoutAttr\n",
6263
"\n",
6364
"# you need this to register the memref value caster\n",
6465
"# noinspection PyUnresolvedReferences\n",
@@ -91,15 +92,15 @@
9192
"outputs": [],
9293
"source": [
9394
"K = 10\n",
94-
"memref_i64 = T.memref(K, K, T.i64)\n",
95+
"memref_i64 = T.memref(K, K, T.i64())\n",
9596
"\n",
9697
"@func(emit=True)\n",
9798
"@canonicalize(using=scf)\n",
9899
"def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):\n",
99100
" one = constant(1)\n",
100101
" two = constant(2)\n",
101102
" if one > two:\n",
102-
" C[0, 0] = constant(3, T.i64)\n",
103+
" C[0, 0] = constant(3, T.i64())\n",
103104
" else:\n",
104105
" for i in range(0, K):\n",
105106
" for j in range(0, K):\n",
@@ -447,8 +448,9 @@
447448
"D = 32\n",
448449
"\n",
449450
"F = K // D\n",
450-
"ranked_memref_kxk_f32 = T.memref(K, K, T.f32)\n",
451-
"ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))\n",
451+
"ranked_memref_kxk_f32 = T.memref(K, K, T.f32())\n",
452+
"layout = StridedLayoutAttr.get(S, (K, 1))\n",
453+
"ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n",
452454
"\n",
453455
"@func(emit=True)\n",
454456
"@canonicalize(using=scf)\n",
@@ -784,8 +786,9 @@
784786
"ctx_man = mlir_mod_ctx()\n",
785787
"ctx = ctx_man.__enter__()\n",
786788
"\n",
787-
"ranked_memref_kxk_f32 = T.memref(K, K, T.f32)\n",
788-
"ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))\n",
789+
"ranked_memref_kxk_f32 = T.memref(K, K, T.f32())\n",
790+
"layout = StridedLayoutAttr.get(S, (K, 1))\n",
791+
"ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n",
789792
"\n",
790793
"from mlir.extras.dialects.ext import linalg\n",
791794
"\n",

examples/mwe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
def setting_memref(ctx: MLIRContext, backend: LLVMJITBackend):
1717
K = 10
18-
memref_i64 = T.memref(K, K, T.i64)
18+
memref_i64 = T.memref(K, K, T.i64())
1919

2020
@func
2121
@canonicalize(using=scf)

mlir/extras/dialects/ext/arith.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@
4343
FloatAttr,
4444
)
4545

46-
from ...util import get_user_code_loc
46+
from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype
4747
from ...._mlir_libs._mlir import register_value_caster
48-
from ...types import infer_mlir_type, mlir_type_to_np_dtype
4948

5049

5150
def constant(

mlir/extras/dialects/ext/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def __init__(
227227
loc=loc,
228228
ip=ip,
229229
)
230-
self.regions[0].blocks.append(*[T.index for _ in range(12)])
230+
self.regions[0].blocks.append(*[T.index() for _ in range(12)])
231231

232232

233233
def launch_(

mlir/extras/dialects/ext/llvm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ....ir import Type
2+
3+
4+
def llvm_ptr_t():
5+
return Type.parse("!llvm.ptr")

mlir/extras/dialects/ext/scf.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,17 @@
4646
IndexType,
4747
_denseI64ArrayAttr,
4848
Attribute,
49+
OpaqueType,
4950
)
5051

5152
logger = logging.getLogger(__name__)
5253

54+
opaque = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
55+
56+
57+
def placeholder_opaque_t():
58+
return opaque("scf", "placeholder")
59+
5360

5461
def _for(
5562
start,
@@ -372,7 +379,7 @@ def yield_(*args):
372379
unpacked_args = list(unpacked_args[0])
373380

374381
for i, r in enumerate(results):
375-
if r.type == T.placeholder_opaque():
382+
if r.type == placeholder_opaque_t():
376383
r.set_type(unpacked_args[i].type)
377384

378385
if len(results) > 1:
@@ -582,7 +589,7 @@ def visit_If(self, updated_node: ast.If) -> ast.With | list[ast.With, ast.With]:
582589
and isinstance(last_statement.targets[0], ast.Tuple)
583590
else 0,
584591
)
585-
results = [ast_call(T.placeholder_opaque.__name__) for _ in range(num_results)]
592+
results = [ast_call(placeholder_opaque_t.__name__) for _ in range(num_results)]
586593
results = ast.fix_missing_locations(
587594
ast.copy_location(ast.Tuple(results, ctx=ast.Load()), test)
588595
)
@@ -620,7 +627,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
620627
f.__globals__[yield_.__name__] = yield_
621628
f.__globals__[if_ctx_manager.__name__] = if_ctx_manager
622629
f.__globals__[else_ctx_manager.__name__] = else_ctx_manager
623-
f.__globals__[T.placeholder_opaque.__name__] = T.placeholder_opaque
630+
f.__globals__[placeholder_opaque_t.__name__] = placeholder_opaque_t
624631
return code
625632

626633

mlir/extras/dialects/ext/transform.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
SequenceOp,
1414
FailurePropagationMode,
1515
YieldOp,
16+
AnyOpType,
17+
OperationType,
1618
)
1719
from ....dialects.transform.loop import LoopUnrollOp
1820
from ....dialects.transform import GetParentOp
@@ -23,6 +25,18 @@
2325
StringAttr,
2426
)
2527
from ....dialects._ods_common import get_op_result_or_op_results
28+
from ....dialects import pdl
29+
30+
31+
pdl_operation_t = lambda: pdl.OperationType.get()
32+
33+
34+
def transform_any_op_t():
35+
return AnyOpType.get()
36+
37+
38+
def transform_op_t(name):
39+
return OperationType.get(name)
2640

2741

2842
def sequence_(
@@ -40,7 +54,7 @@ def sequence_(
4054
if results_ is None:
4155
results_ = []
4256
if target is None:
43-
target = T.pdl_operation
57+
target = pdl_operation_t()
4458
# this is a misnomer - it's not about targeting a particular op
4559
# but about picking which transform sequence runs using
4660
# transform_dialect_interpreter(debug_transform_root_tag="")
@@ -52,7 +66,7 @@ def sequence_(
5266
failure_propagation_mode = FailurePropagationMode.Propagate
5367

5468
if isinstance(target, str):
55-
target = T.transform_op(target)
69+
target = transform_op_t(target)
5670

5771
seq_op = SequenceOp(
5872
failure_propagation_mode,
@@ -85,7 +99,7 @@ def get_parent(
8599

86100
return get_op_result_or_op_results(
87101
GetParentOp(
88-
T.pdl_operation,
102+
pdl_operation_t(),
89103
target,
90104
isolated_from_above=isolated_from_above,
91105
op_name=op_name,
@@ -119,7 +133,7 @@ def match(
119133
loc = get_user_code_loc()
120134
return get_op_result_or_op_results(
121135
MatchOp(
122-
T.transform_any_op,
136+
transform_any_op_t(),
123137
target,
124138
ops=ops,
125139
interface=interface,

mlir/extras/runtime/refbackend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .. import types as T
2626
from ...dialects.memref import cast
2727
from ..runtime.passes import Pipeline, run_pipeline
28-
from ..types import (
28+
from ..util import (
2929
memref_type_to_np_dtype,
3030
mlir_type_to_ctype,
3131
np_dtype_to_mlir_type,

0 commit comments

Comments
 (0)