Skip to content

Commit 9fce7ad

Browse files
committed
[python] fix enum ambiguity
1 parent 9d55e86 commit 9fce7ad

19 files changed

+2849
-35
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
318318
set(LLVM_TARGET_DEFINITIONS ${td_file})
319319
endif()
320320
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
321-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
321+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
322322
list(APPEND _sources ${enum_filename})
323323
endif()
324324

@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
390390
set(LLVM_TARGET_DEFINITIONS ${td_file})
391391
endif()
392392
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
393-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
393+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
394394
list(APPEND _sources ${enum_filename})
395395
endif()
396396

mlir/python/CMakeLists.txt

+5-18
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
6363
TD_FILE dialects/AffineOps.td
6464
SOURCES
6565
dialects/affine.py
66-
DIALECT_NAME affine
67-
GEN_ENUM_BINDINGS)
66+
DIALECT_NAME affine)
6867

6968
declare_mlir_dialect_python_bindings(
7069
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -88,10 +87,7 @@ declare_mlir_dialect_python_bindings(
8887
TD_FILE dialects/BufferizationOps.td
8988
SOURCES
9089
dialects/bufferization.py
91-
DIALECT_NAME bufferization
92-
GEN_ENUM_BINDINGS_TD_FILE
93-
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
94-
)
90+
DIALECT_NAME bufferization)
9591

9692
declare_mlir_dialect_python_bindings(
9793
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -178,10 +174,7 @@ declare_mlir_dialect_python_bindings(
178174
SOURCES
179175
dialects/transform/__init__.py
180176
_mlir_libs/_mlir/dialects/transform/__init__.pyi
181-
DIALECT_NAME transform
182-
GEN_ENUM_BINDINGS_TD_FILE
183-
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
184-
)
177+
DIALECT_NAME transform)
185178

186179
declare_mlir_python_sources(
187180
MLIRPythonSources.Dialects.transform.extras
@@ -250,10 +243,7 @@ declare_mlir_dialect_extension_python_bindings(
250243
SOURCES
251244
dialects/transform/structured.py
252245
DIALECT_NAME transform
253-
EXTENSION_NAME structured_transform
254-
GEN_ENUM_BINDINGS_TD_FILE
255-
"../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
256-
)
246+
EXTENSION_NAME structured_transform)
257247

258248
declare_mlir_dialect_extension_python_bindings(
259249
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -280,10 +270,7 @@ declare_mlir_dialect_extension_python_bindings(
280270
SOURCES
281271
dialects/transform/vector.py
282272
DIALECT_NAME transform
283-
EXTENSION_NAME vector_transform
284-
GEN_ENUM_BINDINGS_TD_FILE
285-
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
286-
)
273+
EXTENSION_NAME vector_transform)
287274

288275
declare_mlir_dialect_python_bindings(
289276
ADD_TO_PARENT MLIRPythonSources.Dialects

mlir/python/mlir/dialects/amdgpu.py

+145
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,151 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from enum import IntEnum, IntFlag
45

6+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
57
from ._amdgpu_ops_gen import *
68
from ._amdgpu_enum_gen import *
9+
10+
11+
class DPPPerm(IntEnum):
12+
"""The possible permutations for a DPP operation"""
13+
14+
quad_perm = 0
15+
row_shl = 1
16+
row_shr = 2
17+
row_ror = 3
18+
wave_shl = 4
19+
wave_shr = 5
20+
wave_ror = 6
21+
wave_rol = 7
22+
row_mirror = 8
23+
row_half_mirror = 9
24+
row_bcast_15 = 10
25+
row_bcast_31 = 11
26+
27+
def __str__(self):
28+
if self is DPPPerm.quad_perm:
29+
return "quad_perm"
30+
if self is DPPPerm.row_shl:
31+
return "row_shl"
32+
if self is DPPPerm.row_shr:
33+
return "row_shr"
34+
if self is DPPPerm.row_ror:
35+
return "row_ror"
36+
if self is DPPPerm.wave_shl:
37+
return "wave_shl"
38+
if self is DPPPerm.wave_shr:
39+
return "wave_shr"
40+
if self is DPPPerm.wave_ror:
41+
return "wave_ror"
42+
if self is DPPPerm.wave_rol:
43+
return "wave_rol"
44+
if self is DPPPerm.row_mirror:
45+
return "row_mirror"
46+
if self is DPPPerm.row_half_mirror:
47+
return "row_half_mirror"
48+
if self is DPPPerm.row_bcast_15:
49+
return "row_bcast_15"
50+
if self is DPPPerm.row_bcast_31:
51+
return "row_bcast_31"
52+
raise ValueError("Unknown DPPPerm enum entry.")
53+
54+
55+
@register_attribute_builder("AMDGPU_DPPPerm")
56+
def _amdgpu_dppperm(x, context):
57+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
58+
59+
60+
class MFMAPermB(IntEnum):
61+
"""The possible permutations of the lanes storing B available in an MFMA"""
62+
63+
none = 0
64+
bcast_first_32 = 1
65+
bcast_second_32 = 2
66+
rotate_16_right = 3
67+
bcast_first_16 = 4
68+
bcast_second_16 = 5
69+
bcast_third_16 = 6
70+
bcast_fourth_16 = 7
71+
72+
def __str__(self):
73+
if self is MFMAPermB.none:
74+
return "none"
75+
if self is MFMAPermB.bcast_first_32:
76+
return "bcast_first_32"
77+
if self is MFMAPermB.bcast_second_32:
78+
return "bcast_second_32"
79+
if self is MFMAPermB.rotate_16_right:
80+
return "rotate_16_right"
81+
if self is MFMAPermB.bcast_first_16:
82+
return "bcast_first_16"
83+
if self is MFMAPermB.bcast_second_16:
84+
return "bcast_second_16"
85+
if self is MFMAPermB.bcast_third_16:
86+
return "bcast_third_16"
87+
if self is MFMAPermB.bcast_fourth_16:
88+
return "bcast_fourth_16"
89+
raise ValueError("Unknown MFMAPermB enum entry.")
90+
91+
92+
@register_attribute_builder("AMDGPU_MFMAPermB")
93+
def _amdgpu_mfmapermb(x, context):
94+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
95+
96+
97+
class sched_barrier_opt_enum(IntFlag):
98+
"""The possible options for scheduling barriers"""
99+
100+
none = 0
101+
non_mem_non_sideffect = 1
102+
valu = 2
103+
salu = 4
104+
mfma_wmma = 8
105+
all_vmem = 16
106+
vmem_read = 32
107+
vmem_write = 64
108+
all_ds = 128
109+
ds_read = 256
110+
ds_write = 512
111+
transcendental = 1024
112+
113+
def __iter__(self):
114+
return iter([case for case in type(self) if (self & case) is case])
115+
116+
def __len__(self):
117+
return bin(self).count("1")
118+
119+
def __str__(self):
120+
if len(self) > 1:
121+
return "|".join(map(str, self))
122+
if self is sched_barrier_opt_enum.none:
123+
return "none"
124+
if self is sched_barrier_opt_enum.non_mem_non_sideffect:
125+
return "non_mem_non_sideffect"
126+
if self is sched_barrier_opt_enum.valu:
127+
return "valu"
128+
if self is sched_barrier_opt_enum.salu:
129+
return "salu"
130+
if self is sched_barrier_opt_enum.mfma_wmma:
131+
return "mfma_wmma"
132+
if self is sched_barrier_opt_enum.all_vmem:
133+
return "all_vmem"
134+
if self is sched_barrier_opt_enum.vmem_read:
135+
return "vmem_read"
136+
if self is sched_barrier_opt_enum.vmem_write:
137+
return "vmem_write"
138+
if self is sched_barrier_opt_enum.all_ds:
139+
return "all_ds"
140+
if self is sched_barrier_opt_enum.ds_read:
141+
return "ds_read"
142+
if self is sched_barrier_opt_enum.ds_write:
143+
return "ds_write"
144+
if self is sched_barrier_opt_enum.transcendental:
145+
return "transcendental"
146+
raise ValueError("Unknown sched_barrier_opt_enum enum entry.")
147+
148+
149+
@register_attribute_builder("AMDGPU_SchedBarrierOpOpt")
150+
def _amdgpu_schedbarrieropopt(x, context):
151+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

0 commit comments

Comments
 (0)