Skip to content

Commit c5174eb

Browse files
committed
fix nanobind differences (and others)
1 parent 8984cf8 commit c5174eb

File tree

10 files changed

+195
-61
lines changed

10 files changed

+195
-61
lines changed

examples/mwe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def pats():
131131
.finalize_memref_to_llvm()
132132
# Convert Func to LLVM (always needed).
133133
.convert_func_to_llvm()
134+
.convert_arith_to_llvm()
135+
.convert_cf_to_llvm()
134136
# Convert Index to LLVM (always needed).
135137
.convert_index_to_llvm()
136138
# Convert remaining unrealized_casts (always needed).

examples/vectorization_e2e.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@
424424
" .finalize_memref_to_llvm()\n",
425425
" # Convert Func to LLVM (always needed).\n",
426426
" .convert_func_to_llvm()\n",
427+
" .convert_arith_to_llvm()\n",
428+
" .convert_cf_to_llvm()\n",
427429
" # Convert Index to LLVM (always needed).\n",
428430
" .convert_index_to_llvm()\n",
429431
" # Convert remaining unrealized_casts (always needed).\n",

mlir/extras/dialects/ext/arith.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Tuple, Union
88

99
from bytecode import ConcreteBytecode
10+
from einspect.structs import PyTypeObject
1011

1112
from ...ast.canonicalize import StrictTransformer, Canonicalizer, BytecodePatcher
1213
from ...ast.util import ast_call
@@ -138,7 +139,13 @@ def index_cast(
138139
)
139140

140141

141-
class ArithValueMeta(type(Value)):
142+
nb_meta_cls = type(Value)
143+
144+
_Py_TPFLAGS_BASETYPE = 1 << 10
145+
PyTypeObject.from_object(nb_meta_cls).tp_flags |= _Py_TPFLAGS_BASETYPE
146+
147+
148+
class ArithValueMeta(nb_meta_cls):
142149
"""Metaclass that orchestrates the Python object protocol
143150
(i.e., calling __new__ and __init__) for Indexing dialect extension values
144151
(created using `mlir_value_subclass`).

mlir/extras/dialects/ext/memref.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
2-
from typing import Sequence, Union
2+
from functools import cached_property, reduce
3+
from typing import Sequence, Union, Tuple
34

45
import numpy as np
56

@@ -129,7 +130,40 @@ def store(
129130

130131

131132
@register_value_caster(MemRefType.static_typeid)
132-
class MemRef(Value, ShapedValue):
133+
class MemRef(Value):
134+
@cached_property
135+
def literal_value(self) -> np.ndarray:
136+
if not self.is_constant:
137+
raise ValueError("Can't build literal from non-constant value")
138+
return np.array(DenseElementsAttr(self.owner.opview.value), copy=False)
139+
140+
@cached_property
141+
def _shaped_type(self) -> ShapedType:
142+
return ShapedType(self.type)
143+
144+
def has_static_shape(self) -> bool:
145+
return self._shaped_type.has_static_shape
146+
147+
def has_rank(self) -> bool:
148+
return self._shaped_type.has_rank
149+
150+
@cached_property
151+
def rank(self) -> int:
152+
return self._shaped_type.rank
153+
154+
@cached_property
155+
def shape(self) -> Tuple[int, ...]:
156+
return tuple(self._shaped_type.shape)
157+
158+
@cached_property
159+
def n_elements(self) -> int:
160+
assert self.has_static_shape()
161+
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1)
162+
163+
@cached_property
164+
def dtype(self) -> Type:
165+
return self._shaped_type.element_type
166+
133167
def __str__(self):
134168
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"
135169

mlir/extras/dialects/ext/tensor.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
from dataclasses import dataclass
3+
from functools import cached_property, reduce
34
from typing import Any, List, Optional, Tuple, Union, Sequence
45

56
# noinspection PyUnresolvedReferences
@@ -20,7 +21,7 @@
2021
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
2122
from ....dialects.tensor import *
2223
from ....dialects.transform.structured import _get_int_array_array_attr
23-
from ....ir import RankedTensorType, ShapedType, Type, Value
24+
from ....ir import RankedTensorType, ShapedType, Type, Value, DenseElementsAttr
2425

2526
S = ShapedType.get_dynamic_size()
2627

@@ -109,7 +110,40 @@ def insert_slice(
109110

110111
# TODO(max): unify vector/memref/tensor
111112
@register_value_caster(RankedTensorType.static_typeid)
112-
class Tensor(ShapedValue, ArithValue):
113+
class Tensor(ArithValue):
114+
@cached_property
115+
def literal_value(self) -> np.ndarray:
116+
if not self.is_constant:
117+
raise ValueError("Can't build literal from non-constant value")
118+
return np.array(DenseElementsAttr(self.owner.opview.value), copy=False)
119+
120+
@cached_property
121+
def _shaped_type(self) -> ShapedType:
122+
return ShapedType(self.type)
123+
124+
def has_static_shape(self) -> bool:
125+
return self._shaped_type.has_static_shape
126+
127+
def has_rank(self) -> bool:
128+
return self._shaped_type.has_rank
129+
130+
@cached_property
131+
def rank(self) -> int:
132+
return self._shaped_type.rank
133+
134+
@cached_property
135+
def shape(self) -> Tuple[int, ...]:
136+
return tuple(self._shaped_type.shape)
137+
138+
@cached_property
139+
def n_elements(self) -> int:
140+
assert self.has_static_shape()
141+
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1)
142+
143+
@cached_property
144+
def dtype(self) -> Type:
145+
return self._shaped_type.element_type
146+
113147
def __getitem__(self, idx: tuple) -> "Tensor":
114148
loc = get_user_code_loc()
115149

mlir/extras/dialects/ext/vector.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import inspect
2-
from typing import List
2+
from functools import cached_property, reduce
3+
from typing import List, Tuple, Type
4+
5+
import numpy as np
36

47
from ._shaped_value import ShapedValue
58
from .arith import ArithValue, FastMathFlags, constant, Scalar
@@ -10,11 +13,45 @@
1013
# noinspection PyUnresolvedReferences
1114
from ....dialects.vector import *
1215
from ....extras import types as T
13-
from ....ir import AffineMap, VectorType, Value
16+
from ....ir import AffineMap, VectorType, Value, DenseElementsAttr, ShapedType
1417

1518

1619
@register_value_caster(VectorType.static_typeid)
17-
class Vector(ShapedValue, ArithValue):
20+
class Vector(ArithValue):
21+
22+
@cached_property
23+
def literal_value(self) -> np.ndarray:
24+
if not self.is_constant:
25+
raise ValueError("Can't build literal from non-constant value")
26+
return np.array(DenseElementsAttr(self.owner.opview.value), copy=False)
27+
28+
@cached_property
29+
def _shaped_type(self) -> ShapedType:
30+
return ShapedType(self.type)
31+
32+
def has_static_shape(self) -> bool:
33+
return self._shaped_type.has_static_shape
34+
35+
def has_rank(self) -> bool:
36+
return self._shaped_type.has_rank
37+
38+
@cached_property
39+
def rank(self) -> int:
40+
return self._shaped_type.rank
41+
42+
@cached_property
43+
def shape(self) -> Tuple[int, ...]:
44+
return tuple(self._shaped_type.shape)
45+
46+
@cached_property
47+
def n_elements(self) -> int:
48+
assert self.has_static_shape()
49+
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1)
50+
51+
@cached_property
52+
def dtype(self) -> Type:
53+
return self._shaped_type.element_type
54+
1855
def __getitem__(self, idx: tuple) -> "Vector":
1956
loc = get_user_code_loc()
2057

@@ -105,7 +142,7 @@ def transfer_read(
105142
if isinstance(padding, int):
106143
padding = constant(padding, type=source.type.element_type)
107144
if in_bounds is None:
108-
in_bounds = [None] * len(permutation_map.results)
145+
raise ValueError("in_bounds cannot be None")
109146

110147
return _transfer_read(
111148
vector=vector_t,

mlir/extras/runtime/passes.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ def affine_expand_index_ops(self):
252252
self.add_pass("affine-expand-index-ops")
253253
return self
254254

255+
def affine_expand_index_ops_as_affine(self):
256+
"""Lower affine operations operating on indices into affine.apply operations"""
257+
self.add_pass("affine-expand-index-ops-as-affine")
258+
return self
259+
255260
def affine_loop_coalescing(self):
256261
"""Coalesce nested loops with independent bounds into a single loop"""
257262
self.add_pass("affine-loop-coalescing")
@@ -1363,10 +1368,6 @@ def convert_func_to_llvm(
13631368
returns are updated accordingly. Block argument types are updated to use
13641369
LLVM IR types.
13651370
1366-
Note that until https://github.com/llvm/llvm-project/issues/70982 is resolved,
1367-
this pass includes patterns that lower `arith` and `cf` to LLVM. This is legacy
1368-
code due to when they were all converted in the same pass.
1369-
13701371
Args:
13711372
use-bare-ptr-memref-call-conv: Replace FuncOp's MemRef arguments with bare pointers to the MemRef element types
13721373
index-bitwidth: Bitwidth of the index type, 0 to use size of machine word
@@ -1398,12 +1399,12 @@ def convert_gpu_launch_to_vulkan_launch(self):
13981399
self.add_pass("convert-gpu-launch-to-vulkan-launch")
13991400
return self
14001401

1401-
def convert_gpu_to_llvm_spv(self, index_bitwidth: int = None):
1402+
def convert_gpu_to_llvm_spv(self, use_64bit_index: bool = None):
14021403
"""Generate LLVM operations to be ingested by a SPIR-V backend for gpu operations
14031404
Args:
1404-
index-bitwidth: Bitwidth of the index type, 0 to use size of machine word
1405+
use-64bit-index: Use 64-bit integers to convert index types
14051406
"""
1406-
self.add_pass("convert-gpu-to-llvm-spv", index_bitwidth=index_bitwidth)
1407+
self.add_pass("convert-gpu-to-llvm-spv", use_64bit_index=use_64bit_index)
14071408
return self
14081409

14091410
def convert_gpu_to_nvvm(
@@ -1597,6 +1598,20 @@ def convert_memref_to_spirv(
15971598
)
15981599
return self
15991600

1601+
def convert_mesh_to_mpi(self):
1602+
"""Convert Mesh dialect to MPI dialect.
1603+
1604+
This pass converts communication operations from the Mesh dialect to the
1605+
MPI dialect.
1606+
If it finds a global named "static_mpi_rank" it will use that splat value
1607+
instead of calling MPI_Comm_rank. This allows optimizations like constant
1608+
shape propagation and fusion because shard/partition sizes depend on the
1609+
rank.
1610+
1611+
"""
1612+
self.add_pass("convert-mesh-to-mpi")
1613+
return self
1614+
16001615
def convert_nvgpu_to_nvvm(self):
16011616
"""Convert NVGPU dialect to NVVM dialect
16021617
@@ -1715,17 +1730,26 @@ def convert_tensor_to_spirv(self, emulate_lt_32_bit_scalar_types: bool = None):
17151730
)
17161731
return self
17171732

1718-
def convert_to_llvm(self, filter_dialects: List[str] = None):
1733+
def convert_to_llvm(self, filter_dialects: List[str] = None, dynamic: bool = None):
17191734
"""Convert to LLVM via dialect interfaces found in the input IR
17201735
17211736
This is a generic pass to convert to LLVM, it uses the
17221737
`ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
17231738
the injection of conversion patterns.
17241739
1740+
If `dynamic` is set to `true`, the pass will look for
1741+
`ConvertToLLVMAttrInterface` attributes and use them to further configure
1742+
the conversion process. This option also uses the `DataLayoutAnalysis`
1743+
analysis to configure the type converter. Enabling this option incurs in
1744+
extra overhead.
1745+
17251746
Args:
17261747
filter-dialects: Test conversion patterns of only the specified dialects
1748+
dynamic: Use op conversion attributes to configure the conversion
17271749
"""
1728-
self.add_pass("convert-to-llvm", filter_dialects=filter_dialects)
1750+
self.add_pass(
1751+
"convert-to-llvm", filter_dialects=filter_dialects, dynamic=dynamic
1752+
)
17291753
return self
17301754

17311755
def convert_to_spirv(
@@ -2082,23 +2106,6 @@ def finalize_memref_to_llvm(
20822106
)
20832107
return self
20842108

2085-
def finalizing_bufferize(self):
2086-
"""Finalize a partial bufferization
2087-
2088-
A bufferize pass that finalizes a partial bufferization by removing
2089-
remaining `bufferization.to_tensor` and `bufferization.to_buffer` operations.
2090-
2091-
The removal of those operations is only possible if the operations only
2092-
exist in pairs, i.e., all uses of `bufferization.to_tensor` operations are
2093-
`bufferization.to_buffer` operations.
2094-
2095-
This pass will fail if not all operations can be removed or if any operation
2096-
with tensor typed operands remains.
2097-
2098-
"""
2099-
self.add_pass("finalizing-bufferize")
2100-
return self
2101-
21022109
def fold_memref_alias_ops(self):
21032110
"""Fold memref alias ops into consumer load/store ops
21042111
@@ -2201,6 +2208,7 @@ def gpu_module_to_binary(
22012208
l: List[str] = None,
22022209
opts: str = None,
22032210
format: str = None,
2211+
section: str = None,
22042212
):
22052213
"""Transforms a GPU module into a GPU binary.
22062214
@@ -2219,9 +2227,15 @@ def gpu_module_to_binary(
22192227
l: Extra files to link to.
22202228
opts: Command line options to pass to the tools.
22212229
format: The target representation of the compilation process.
2230+
section: ELF section where binary is to be located.
22222231
"""
22232232
self.add_pass(
2224-
"gpu-module-to-binary", toolkit=toolkit, l=l, opts=opts, format=format
2233+
"gpu-module-to-binary",
2234+
toolkit=toolkit,
2235+
l=l,
2236+
opts=opts,
2237+
format=format,
2238+
section=section,
22252239
)
22262240
return self
22272241

@@ -2893,6 +2907,7 @@ def one_shot_bufferize(
28932907
no_analysis_func_filter: List[str] = None,
28942908
function_boundary_type_conversion: str = None,
28952909
must_infer_memory_space: bool = None,
2910+
use_encoding_for_memory_space: bool = None,
28962911
test_analysis_only: bool = None,
28972912
print_conflicts: bool = None,
28982913
unknown_type_conversion: str = None,
@@ -3017,6 +3032,7 @@ def one_shot_bufferize(
30173032
no-analysis-func-filter: Skip analysis of functions with these symbol names.Set copyBeforeWrite to true when bufferizing them.
30183033
function-boundary-type-conversion: Controls layout maps when bufferizing function signatures.
30193034
must-infer-memory-space: The memory space of an memref types must always be inferred. If unset, a default memory space of 0 is used otherwise.
3035+
use-encoding-for-memory-space: Use the Tensor encoding attribute for the memory space. Exclusive to the 'must-infer-memory-space' option
30203036
test-analysis-only: Test only: Only run inplaceability analysis and annotate IR
30213037
print-conflicts: Test only: Annotate IR with RaW conflicts. Requires test-analysis-only.
30223038
unknown-type-conversion: Controls layout maps for non-inferrable memref types.
@@ -3036,6 +3052,7 @@ def one_shot_bufferize(
30363052
no_analysis_func_filter=no_analysis_func_filter,
30373053
function_boundary_type_conversion=function_boundary_type_conversion,
30383054
must_infer_memory_space=must_infer_memory_space,
3055+
use_encoding_for_memory_space=use_encoding_for_memory_space,
30393056
test_analysis_only=test_analysis_only,
30403057
print_conflicts=print_conflicts,
30413058
unknown_type_conversion=unknown_type_conversion,

tests/test_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_simple_parfor(ctx: MLIRContext, backend: LLVMJITBackend):
6767
.convert_arith_to_llvm()
6868
.finalize_memref_to_llvm()
6969
.convert_func_to_llvm()
70+
.convert_cf_to_llvm()
7071
.reconcile_unrealized_casts(),
7172
generate_kernel_wrapper=True,
7273
generate_return_consumer=True,

0 commit comments

Comments
 (0)