Skip to content

Commit 90c0b70

Browse files
committed
[python] fix enum ambiguity
1 parent 9d55e86 commit 90c0b70

26 files changed

+513
-87
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
318318
set(LLVM_TARGET_DEFINITIONS ${td_file})
319319
endif()
320320
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
321-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
321+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
322322
list(APPEND _sources ${enum_filename})
323323
endif()
324324

@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
390390
set(LLVM_TARGET_DEFINITIONS ${td_file})
391391
endif()
392392
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
393-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
393+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
394394
list(APPEND _sources ${enum_filename})
395395
endif()
396396

mlir/python/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
6363
TD_FILE dialects/AffineOps.td
6464
SOURCES
6565
dialects/affine.py
66-
DIALECT_NAME affine
67-
GEN_ENUM_BINDINGS)
66+
DIALECT_NAME affine)
6867

6968
declare_mlir_dialect_python_bindings(
7069
ADD_TO_PARENT MLIRPythonSources.Dialects

mlir/python/mlir/dialects/_ods_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
143143
else op
144144
)
145145

146+
146147
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
147148
ResultValueT = _Union[ResultValueTypeTuple]
148149
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/amdgpu.py

+16
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,21 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._amdgpu_ops_gen import *
67
from ._amdgpu_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.AMDGPU_DPPPerm")
11+
def _amdgpu_dppperm(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
13+
14+
15+
@register_attribute_builder("builtin.AMDGPU_MFMAPermB")
16+
def _amdgpu_mfmapermb(x, context):
17+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
18+
19+
20+
@register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
21+
def _amdgpu_schedbarrieropopt(x, context):
22+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/arith.py

+35
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,38 @@ def constant(
108108
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
109109
) -> Value:
110110
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
111+
112+
113+
@register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
114+
def _arith_cmpfpredicateattr(x, context):
115+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
116+
117+
118+
@register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
119+
def _arith_cmpipredicateattr(x, context):
120+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
121+
122+
123+
@register_attribute_builder("builtin.Arith_DenormalMode")
124+
def _arith_denormalmode(x, context):
125+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
126+
127+
128+
@register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
129+
def _arith_integeroverflowflags(x, context):
130+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
131+
132+
133+
@register_attribute_builder("builtin.Arith_RoundingModeAttr")
134+
def _arith_roundingmodeattr(x, context):
135+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
136+
137+
138+
@register_attribute_builder("builtin.AtomicRMWKindAttr")
139+
def _atomicrmwkindattr(x, context):
140+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
141+
142+
143+
@register_attribute_builder("builtin.FastMathFlags")
144+
def _fastmathflags(x, context):
145+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/bufferization.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._bufferization_ops_gen import *
67
from ._bufferization_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.LayoutMapOption")
11+
def _layoutmapoption(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/gpu/__init__.py

+56
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,62 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ...ir import IntegerAttr, IntegerType, register_attribute_builder
56
from .._gpu_ops_gen import *
67
from .._gpu_enum_gen import *
78
from ..._mlir_libs._mlirDialectsGPU import *
9+
10+
11+
@register_attribute_builder("builtin.GPU_AddressSpaceEnum")
12+
def _gpu_addressspaceenum(x, context):
13+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
14+
15+
16+
@register_attribute_builder("builtin.GPU_AllReduceOperation")
17+
def _gpu_allreduceoperation(x, context):
18+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
19+
20+
21+
@register_attribute_builder("builtin.GPU_CompilationTargetEnum")
22+
def _gpu_compilationtargetenum(x, context):
23+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
24+
25+
26+
@register_attribute_builder("builtin.GPU_Dimension")
27+
def _gpu_dimension(x, context):
28+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
29+
30+
31+
@register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
32+
def _gpu_prune2to4spmatflag(x, context):
33+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
34+
35+
36+
@register_attribute_builder("builtin.GPU_ShuffleMode")
37+
def _gpu_shufflemode(x, context):
38+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
39+
40+
41+
@register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
42+
def _gpu_spgemmworkestimationorcomputekind(x, context):
43+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
44+
45+
46+
@register_attribute_builder("builtin.GPU_TransposeMode")
47+
def _gpu_transposemode(x, context):
48+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
49+
50+
51+
@register_attribute_builder("builtin.MMAElementWise")
52+
def _mmaelementwise(x, context):
53+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
54+
55+
56+
@register_attribute_builder("builtin.MappingIdEnum")
57+
def _mappingidenum(x, context):
58+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
59+
60+
61+
@register_attribute_builder("builtin.ProcessorEnum")
62+
def _processorenum(x, context):
63+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))

mlir/python/mlir/dialects/index.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._index_ops_gen import *
67
from ._index_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.IndexCmpPredicate")
11+
def _indexcmppredicate(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/linalg/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,28 @@ def broadcast(
102102
)
103103
fill_builtin_region(op.operation)
104104
return op
105+
106+
107+
@register_attribute_builder("builtin.BinaryFn")
108+
def _binaryfn(x, context):
109+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
110+
111+
112+
@register_attribute_builder("builtin.IteratorType")
113+
def _iteratortype(x, context):
114+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
115+
116+
117+
@register_attribute_builder("builtin.TernaryFn")
118+
def _ternaryfn(x, context):
119+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
120+
121+
122+
@register_attribute_builder("builtin.TypeFn")
123+
def _typefn(x, context):
124+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
125+
126+
127+
@register_attribute_builder("builtin.UnaryFn")
128+
def _unaryfn(x, context):
129+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
888888
- TypeFn.cast_signed(U, IZp)
889889
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
890890

891+
891892
@linalg_structured_op
892893
def conv_2d_nchw_fchw(
893894
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),

mlir/python/mlir/dialects/llvm.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,116 @@
55
from ._llvm_ops_gen import *
66
from ._llvm_enum_gen import *
77
from .._mlir_libs._mlirDialectsLLVM import *
8-
from ..ir import Value
8+
from ..ir import Value, IntegerAttr, IntegerType, register_attribute_builder
99
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
1010

1111

1212
def mlir_constant(value, *, loc=None, ip=None) -> Value:
1313
return _get_op_result_or_op_results(
1414
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
1515
)
16+
17+
18+
@register_attribute_builder("builtin.AsmATTOrIntel")
19+
def _asmattorintel(x, context):
20+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
21+
22+
23+
@register_attribute_builder("builtin.AtomicBinOp")
24+
def _atomicbinop(x, context):
25+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
26+
27+
28+
@register_attribute_builder("builtin.AtomicOrdering")
29+
def _atomicordering(x, context):
30+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
31+
32+
33+
@register_attribute_builder("builtin.CConvEnum")
34+
def _cconvenum(x, context):
35+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
36+
37+
38+
@register_attribute_builder("builtin.Comdat")
39+
def _comdat(x, context):
40+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
41+
42+
43+
@register_attribute_builder("builtin.DIFlags")
44+
def _diflags(x, context):
45+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
46+
47+
48+
@register_attribute_builder("builtin.DISubprogramFlags")
49+
def _disubprogramflags(x, context):
50+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
51+
52+
53+
@register_attribute_builder("builtin.FCmpPredicate")
54+
def _fcmppredicate(x, context):
55+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
56+
57+
58+
@register_attribute_builder("builtin.FPExceptionBehaviorAttr")
59+
def _fpexceptionbehaviorattr(x, context):
60+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
61+
62+
63+
@register_attribute_builder("builtin.FastmathFlags")
64+
def _fastmathflags(x, context):
65+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
66+
67+
68+
@register_attribute_builder("builtin.FramePointerKindEnum")
69+
def _framepointerkindenum(x, context):
70+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
71+
72+
73+
@register_attribute_builder("builtin.ICmpPredicate")
74+
def _icmppredicate(x, context):
75+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
76+
77+
78+
@register_attribute_builder("builtin.IntegerOverflowFlags")
79+
def _integeroverflowflags(x, context):
80+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
81+
82+
83+
@register_attribute_builder("builtin.LLVM_DIEmissionKind")
84+
def _llvm_diemissionkind(x, context):
85+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
86+
87+
88+
@register_attribute_builder("builtin.LLVM_DINameTableKind")
89+
def _llvm_dinametablekind(x, context):
90+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
91+
92+
93+
@register_attribute_builder("builtin.LinkageEnum")
94+
def _linkageenum(x, context):
95+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
96+
97+
98+
@register_attribute_builder("builtin.ModRefInfoEnum")
99+
def _modrefinfoenum(x, context):
100+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
101+
102+
103+
@register_attribute_builder("builtin.RoundingModeAttr")
104+
def _roundingmodeattr(x, context):
105+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
106+
107+
108+
@register_attribute_builder("builtin.TailCallKindEnum")
109+
def _tailcallkindenum(x, context):
110+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
111+
112+
113+
@register_attribute_builder("builtin.UnnamedAddr")
114+
def _unnamedaddr(x, context):
115+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
116+
117+
118+
@register_attribute_builder("builtin.Visibility")
119+
def _visibility(x, context):
120+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))

mlir/python/mlir/dialects/nvgpu.py

+26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,32 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._nvgpu_ops_gen import *
67
from ._nvgpu_enum_gen import *
78
from .._mlir_libs._mlirDialectsNVGPU import *
9+
10+
11+
@register_attribute_builder("builtin.RcpRoundingMode")
12+
def _rcproundingmode(x, context):
13+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
14+
15+
16+
@register_attribute_builder("builtin.TensorMapInterleaveKind")
17+
def _tensormapinterleavekind(x, context):
18+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
19+
20+
21+
@register_attribute_builder("builtin.TensorMapL2PromoKind")
22+
def _tensormapl2promokind(x, context):
23+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
24+
25+
26+
@register_attribute_builder("builtin.TensorMapOOBKind")
27+
def _tensormapoobkind(x, context):
28+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
29+
30+
31+
@register_attribute_builder("builtin.TensorMapSwizzleKind")
32+
def _tensormapswizzlekind(x, context):
33+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

0 commit comments

Comments
 (0)