Skip to content

Commit 5d3785e

Browse files
committed
[python] fix enum ambiguity
1 parent 9d55e86 commit 5d3785e

26 files changed

+573
-137
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
318318
set(LLVM_TARGET_DEFINITIONS ${td_file})
319319
endif()
320320
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
321-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
321+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
322322
list(APPEND _sources ${enum_filename})
323323
endif()
324324

@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
390390
set(LLVM_TARGET_DEFINITIONS ${td_file})
391391
endif()
392392
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
393-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
393+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
394394
list(APPEND _sources ${enum_filename})
395395
endif()
396396

mlir/python/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
6363
TD_FILE dialects/AffineOps.td
6464
SOURCES
6565
dialects/affine.py
66-
DIALECT_NAME affine
67-
GEN_ENUM_BINDINGS)
66+
DIALECT_NAME affine)
6867

6968
declare_mlir_dialect_python_bindings(
7069
ADD_TO_PARENT MLIRPythonSources.Dialects

mlir/python/mlir/dialects/_ods_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
143143
else op
144144
)
145145

146+
146147
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
147148
ResultValueT = _Union[ResultValueTypeTuple]
148149
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/amdgpu.py

+16
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,21 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._amdgpu_ops_gen import *
67
from ._amdgpu_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.AMDGPU_DPPPerm")
11+
def _amdgpu_dppperm(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
13+
14+
15+
@register_attribute_builder("builtin.AMDGPU_MFMAPermB")
16+
def _amdgpu_mfmapermb(x, context):
17+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
18+
19+
20+
@register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
21+
def _amdgpu_schedbarrieropopt(x, context):
22+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/arith.py

+35
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,38 @@ def constant(
108108
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
109109
) -> Value:
110110
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
111+
112+
113+
@register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
114+
def _arith_cmpfpredicateattr(x, context):
115+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
116+
117+
118+
@register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
119+
def _arith_cmpipredicateattr(x, context):
120+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
121+
122+
123+
@register_attribute_builder("builtin.Arith_DenormalMode")
124+
def _arith_denormalmode(x, context):
125+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
126+
127+
128+
@register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
129+
def _arith_integeroverflowflags(x, context):
130+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
131+
132+
133+
@register_attribute_builder("builtin.Arith_RoundingModeAttr")
134+
def _arith_roundingmodeattr(x, context):
135+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
136+
137+
138+
@register_attribute_builder("builtin.AtomicRMWKindAttr")
139+
def _atomicrmwkindattr(x, context):
140+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
141+
142+
143+
@register_attribute_builder("builtin.FastMathFlags")
144+
def _fastmathflags(x, context):
145+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/bufferization.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._bufferization_ops_gen import *
67
from ._bufferization_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.LayoutMapOption")
11+
def _layoutmapoption(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/gpu/__init__.py

+56
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,62 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ...ir import IntegerAttr, IntegerType, register_attribute_builder
56
from .._gpu_ops_gen import *
67
from .._gpu_enum_gen import *
78
from ..._mlir_libs._mlirDialectsGPU import *
9+
10+
11+
@register_attribute_builder("builtin.GPU_AddressSpaceEnum")
12+
def _gpu_addressspaceenum(x, context):
13+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
14+
15+
16+
@register_attribute_builder("builtin.GPU_AllReduceOperation")
17+
def _gpu_allreduceoperation(x, context):
18+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
19+
20+
21+
@register_attribute_builder("builtin.GPU_CompilationTargetEnum")
22+
def _gpu_compilationtargetenum(x, context):
23+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
24+
25+
26+
@register_attribute_builder("builtin.GPU_Dimension")
27+
def _gpu_dimension(x, context):
28+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
29+
30+
31+
@register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
32+
def _gpu_prune2to4spmatflag(x, context):
33+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
34+
35+
36+
@register_attribute_builder("builtin.GPU_ShuffleMode")
37+
def _gpu_shufflemode(x, context):
38+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
39+
40+
41+
@register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
42+
def _gpu_spgemmworkestimationorcomputekind(x, context):
43+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
44+
45+
46+
@register_attribute_builder("builtin.GPU_TransposeMode")
47+
def _gpu_transposemode(x, context):
48+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
49+
50+
51+
@register_attribute_builder("builtin.MMAElementWise")
52+
def _mmaelementwise(x, context):
53+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
54+
55+
56+
@register_attribute_builder("builtin.MappingIdEnum")
57+
def _mappingidenum(x, context):
58+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
59+
60+
61+
@register_attribute_builder("builtin.ProcessorEnum")
62+
def _processorenum(x, context):
63+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))

mlir/python/mlir/dialects/index.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._index_ops_gen import *
67
from ._index_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.IndexCmpPredicate")
11+
def _indexcmppredicate(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/linalg/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,28 @@ def broadcast(
102102
)
103103
fill_builtin_region(op.operation)
104104
return op
105+
106+
107+
@register_attribute_builder("builtin.BinaryFn")
108+
def _binaryfn(x, context):
109+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
110+
111+
112+
@register_attribute_builder("builtin.IteratorType")
113+
def _iteratortype(x, context):
114+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
115+
116+
117+
@register_attribute_builder("builtin.TernaryFn")
118+
def _ternaryfn(x, context):
119+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
120+
121+
122+
@register_attribute_builder("builtin.TypeFn")
123+
def _typefn(x, context):
124+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
125+
126+
127+
@register_attribute_builder("builtin.UnaryFn")
128+
def _unaryfn(x, context):
129+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

+66-50
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
888888
- TypeFn.cast_signed(U, IZp)
889889
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
890890

891+
891892
@linalg_structured_op
892893
def conv_2d_nchw_fchw(
893894
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
@@ -1082,16 +1083,19 @@ def conv_3d_ndhwc_dhwcf(
10821083
"""
10831084
implements(ConvolutionOpInterface)
10841085
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1085-
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
1086-
U,
1087-
I[
1088-
D.n,
1089-
D.od * S.SD + D.kd * S.DD,
1090-
D.oh * S.SH + D.kh * S.DH,
1091-
D.ow * S.SW + D.kw * S.DW,
1092-
D.c,
1093-
],
1094-
) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
1086+
O[D.n, D.od, D.oh, D.ow, D.f] += (
1087+
TypeFn.cast_signed(
1088+
U,
1089+
I[
1090+
D.n,
1091+
D.od * S.SD + D.kd * S.DD,
1092+
D.oh * S.SH + D.kh * S.DH,
1093+
D.ow * S.SW + D.kw * S.DW,
1094+
D.c,
1095+
],
1096+
)
1097+
* TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
1098+
)
10951099

10961100

10971101
@linalg_structured_op
@@ -1159,16 +1163,19 @@ def conv_3d_ncdhw_fcdhw(
11591163
"""
11601164
implements(ConvolutionOpInterface)
11611165
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1162-
O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
1163-
U,
1164-
I[
1165-
D.n,
1166-
D.c,
1167-
D.od * S.SD + D.kd * S.DD,
1168-
D.oh * S.SH + D.kh * S.DH,
1169-
D.ow * S.SW + D.kw * S.DW,
1170-
],
1171-
) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
1166+
O[D.n, D.f, D.od, D.oh, D.ow] += (
1167+
TypeFn.cast_signed(
1168+
U,
1169+
I[
1170+
D.n,
1171+
D.c,
1172+
D.od * S.SD + D.kd * S.DD,
1173+
D.oh * S.SH + D.kh * S.DH,
1174+
D.ow * S.SW + D.kw * S.DW,
1175+
],
1176+
)
1177+
* TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
1178+
)
11721179

11731180

11741181
@linalg_structured_op
@@ -1368,16 +1375,19 @@ def depthwise_conv_3d_ndhwc_dhwc(
13681375
"""
13691376
implements(ConvolutionOpInterface)
13701377
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
1371-
O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
1372-
U,
1373-
I[
1374-
D.n,
1375-
D.od * S.SD + D.kd * S.DD,
1376-
D.oh * S.SH + D.kh * S.DH,
1377-
D.ow * S.SW + D.kw * S.DW,
1378-
D.ic,
1379-
],
1380-
) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
1378+
O[D.n, D.od, D.oh, D.ow, D.ic] += (
1379+
TypeFn.cast_signed(
1380+
U,
1381+
I[
1382+
D.n,
1383+
D.od * S.SD + D.kd * S.DD,
1384+
D.oh * S.SH + D.kh * S.DH,
1385+
D.ow * S.SW + D.kw * S.DW,
1386+
D.ic,
1387+
],
1388+
)
1389+
* TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
1390+
)
13811391

13821392

13831393
@linalg_structured_op
@@ -1403,16 +1413,19 @@ def depthwise_conv_3d_ncdhw_cdhw(
14031413
"""
14041414
implements(ConvolutionOpInterface)
14051415
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
1406-
O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
1407-
U,
1408-
I[
1409-
D.n,
1410-
D.ic,
1411-
D.od * S.SD + D.kd * S.DD,
1412-
D.oh * S.SH + D.kh * S.DH,
1413-
D.ow * S.SW + D.kw * S.DW,
1414-
],
1415-
) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
1416+
O[D.n, D.ic, D.od, D.oh, D.ow] += (
1417+
TypeFn.cast_signed(
1418+
U,
1419+
I[
1420+
D.n,
1421+
D.ic,
1422+
D.od * S.SD + D.kd * S.DD,
1423+
D.oh * S.SH + D.kh * S.DH,
1424+
D.ow * S.SW + D.kw * S.DW,
1425+
],
1426+
)
1427+
* TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
1428+
)
14161429

14171430

14181431
@linalg_structured_op
@@ -1437,16 +1450,19 @@ def depthwise_conv_3d_ndhwc_dhwcm(
14371450
"""
14381451
implements(ConvolutionOpInterface)
14391452
domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
1440-
O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
1441-
U,
1442-
I[
1443-
D.n,
1444-
D.od * S.SD + D.kd * S.DD,
1445-
D.oh * S.SH + D.kh * S.DH,
1446-
D.ow * S.SW + D.kw * S.DW,
1447-
D.ic,
1448-
],
1449-
) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
1453+
O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += (
1454+
TypeFn.cast_signed(
1455+
U,
1456+
I[
1457+
D.n,
1458+
D.od * S.SD + D.kd * S.DD,
1459+
D.oh * S.SH + D.kh * S.DH,
1460+
D.ow * S.SW + D.kw * S.DW,
1461+
D.ic,
1462+
],
1463+
)
1464+
* TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
1465+
)
14501466

14511467

14521468
@linalg_structured_op

0 commit comments

Comments
 (0)