-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir-python] Emit only dialect EnumAttr
registrations.
#117918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
67d9107
to
9fce7ad
Compare
9fce7ad
to
5d3785e
Compare
✅ With the latest revision this PR passed the Python code formatter. |
1e19c9c
to
90c0b70
Compare
EnumAttr
registrations.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-backend-amdgpu Author: Maksim Levental (makslevental) ChangesProblemCurrently we emit
Note, the duplication is due to the fact that tablegen has no "memory" across emitted files so it has no knowledge that it has already emitted The first isn't an issue because even if they're duplicated, they're mechanically kept in sync. The second is an issue because
SolutionThe only solution is to cease generation of indiscriminate emission of registration calls. In this PR I do that, i.e., I add
becomes
Patch is 45.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117918.diff 26 Files Affected:
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7b91f43e2d57fd..d06fc927ea44d4 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bb..9949743b9bf09c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
- DIALECT_NAME affine
- GEN_ENUM_BINDINGS)
+ DIALECT_NAME affine)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d6..22fc588e180b1d 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
else op
)
+
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481cc..9b8beaa5571a26 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -2,5 +2,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
+
+
+@register_attribute_builder("builtin.AMDGPU_DPPPerm")
+def _amdgpu_dppperm(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AMDGPU_MFMAPermB")
+def _amdgpu_mfmapermb(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
+def _amdgpu_schedbarrieropopt(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce665..32ba832260d129 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -108,3 +108,38 @@ def constant(
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+@register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
+def _arith_cmpfpredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
+def _arith_cmpipredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_DenormalMode")
+def _arith_denormalmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
+def _arith_integeroverflowflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_RoundingModeAttr")
+def _arith_roundingmodeattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AtomicRMWKindAttr")
+def _atomicrmwkindattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FastMathFlags")
+def _fastmathflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff7..6ad76c729ed2dc 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -2,5 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._bufferization_ops_gen import *
from ._bufferization_enum_gen import *
+
+
+@register_attribute_builder("builtin.LayoutMapOption")
+def _layoutmapoption(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca85..e0bb07c5dad8be 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -2,6 +2,62 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
from .._gpu_ops_gen import *
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
+
+
+@register_attribute_builder("builtin.GPU_AddressSpaceEnum")
+def _gpu_addressspaceenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_AllReduceOperation")
+def _gpu_allreduceoperation(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_CompilationTargetEnum")
+def _gpu_compilationtargetenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_Dimension")
+def _gpu_dimension(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
+def _gpu_prune2to4spmatflag(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_ShuffleMode")
+def _gpu_shufflemode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
+def _gpu_spgemmworkestimationorcomputekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_TransposeMode")
+def _gpu_transposemode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAElementWise")
+def _mmaelementwise(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MappingIdEnum")
+def _mappingidenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ProcessorEnum")
+def _processorenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py
index 73708c7d71a8c8..f00c397965c97c 100644
--- a/mlir/python/mlir/dialects/index.py
+++ b/mlir/python/mlir/dialects/index.py
@@ -2,5 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._index_ops_gen import *
from ._index_enum_gen import *
+
+
+@register_attribute_builder("builtin.IndexCmpPredicate")
+def _indexcmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..4fe9cc40ee910a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -102,3 +102,28 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+@register_attribute_builder("builtin.BinaryFn")
+def _binaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.IteratorType")
+def _iteratortype(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TernaryFn")
+def _ternaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TypeFn")
+def _typefn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.UnaryFn")
+def _unaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..f87b25e8416023 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
- TypeFn.cast_signed(U, IZp)
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
+
@linalg_structured_op
def conv_2d_nchw_fchw(
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 941a584966dcde..456df39ae650ff 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,7 +5,7 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
-from ..ir import Value
+from ..ir import Value, IntegerAttr, IntegerType, register_attribute_builder
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
@@ -13,3 +13,108 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value:
return _get_op_result_or_op_results(
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
)
+
+
+@register_attribute_builder("builtin.AsmATTOrIntel")
+def _asmattorintel(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AtomicBinOp")
+def _atomicbinop(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AtomicOrdering")
+def _atomicordering(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.CConvEnum")
+def _cconvenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Comdat")
+def _comdat(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.DIFlags")
+def _diflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.DISubprogramFlags")
+def _disubprogramflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FCmpPredicate")
+def _fcmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FPExceptionBehaviorAttr")
+def _fpexceptionbehaviorattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FastmathFlags")
+def _fastmathflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FramePointerKindEnum")
+def _framepointerkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ICmpPredicate")
+def _icmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.IntegerOverflowFlags")
+def _integeroverflowflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.LLVM_DIEmissionKind")
+def _llvm_diemissionkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.LLVM_DINameTableKind")
+def _llvm_dinametablekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.LinkageEnum")
+def _linkageenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ModRefInfoEnum")
+def _modrefinfoenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.RoundingModeAttr")
+def _roundingmodeattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TailCallKindEnum")
+def _tailcallkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.UnnamedAddr")
+def _unnamedaddr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Visibility")
+def _visibility(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py
index d6a54f2772f40d..eea132adb0484e 100644
--- a/mlir/python/mlir/dialects/nvgpu.py
+++ b/mlir/python/mlir/dialects/nvgpu.py
@@ -2,6 +2,32 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._nvgpu_ops_gen import *
from ._nvgpu_enum_gen import *
from .._mlir_libs._mlirDialectsNVGPU import *
+
+
+@register_attribute_builder("builtin.RcpRoundingMode")
+def _rcproundingmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapInterleaveKind")
+def _tensormapinterleavekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapL2PromoKind")
+def _tensormapl2promokind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapOOBKind")
+def _tensormapoobkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapSwizzleKind")
+def _tensormapswizzlekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py
index 9477de39c9ead7..21bf24cb73fdab 100644
--- a/mlir/python/mlir/dialects/nvvm.py
+++ b/mlir/python/mlir/dialects/nvvm.py
@@ -2,5 +2,81 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._nvvm_ops_gen import *
from ._nvvm_enum_gen import *
+
+
+@register_attribute_builder("builtin.LoadCacheModifierKind")
+def _loadcachemodifierkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAB1Op")
+def _mmab1op(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAFrag")
+def _mmafrag(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAIntOverflow")
+def _mmaintoverflow(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMALayout")
+def _mmalayout(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMATypes")
+def _mmatypes(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MemScopeKind")
+def _memscopekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ProxyKind")
+def _proxykind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ReduxKind")
+def _reduxkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.SetMaxRegisterAction")
+def _setmaxregisteraction(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.SharedSpace")
+def _sharedspace(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ShflKind")
+def _shflkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.WGMMAScaleIn")
+def _wgmmascalein(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.WGMMAScaleOut")
+def _wgmmascaleout(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.WGMMATypes")
+def _wgmmatypes(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 209ecc95fa8fc8..8f1b83f9d514fd 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -2,7 +2,23 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._sparse_tensor_ops_gen import *
from ._sparse_tensor_enum_gen import *
from .._mlir_libs._mlirDialectsSparseTensor import *
from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
+
+
+@register_attribute_builder("builtin.SparseTensorCrdTransDirectionEnum")
+def _sparsetensorcrdtransdirectionenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_...
[truncated]
|
EnumAttr
registrations.EnumAttr
registrations.
mlir/python/CMakeLists.txt
Outdated
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings( | |||
TD_FILE dialects/AffineOps.td | |||
SOURCES | |||
dialects/affine.py | |||
DIALECT_NAME affine | |||
GEN_ENUM_BINDINGS) | |||
DIALECT_NAME affine) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_affine_enums_gen.py
had nothing but arith
enums and registrations so best to just turn it off to keep from confusing people.
The other is also that these should be prefixed with dialect. Problem 2 is resolved by requiring folks to prefix with namespace as is done most places. |
mlir/python/mlir/dialects/amdgpu.py
Outdated
|
||
@register_attribute_builder("builtin.AMDGPU_DPPPerm") | ||
def _amdgpu_dppperm(x, context): | ||
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to add a 'reflection' function to the tablegened enum classes to get the underlying type from the enum definition
Yes but it's upstream where the problem lies - there are many attributes that are in builtin and aren't prefixed. Such as all these in linalg and FastMathFlags in arith and the particular It actually wouldn't solve the problem anyway because of how the generation logic currently works: static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitAttributeBuilder(enumAttr, os);
}
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttr"))
emitDialectEnumAttributeBuilder(
attr.getName(),
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
}
So the "builtin" will always be emitted every single time it's used as the So the real root cause here is simply the global This PR actually does rectify this original flaw in the design (that first |
Yes I know Linalg is a offender here: I had made a lint check and started changing these a while ago. It's not that many that needs changes. The one where it was a bit more work is updating a generator generator here which relief more on names (it required introducing same behavior as we do for op definitions though, which is a win). Built-in being special is known, it has no prefix, that is known and also why we keep it small. The FastMathFlags one I easily addressed, but it ends up a bit confusing as it results in looking like a dialect dependency as it is intentionally reused elsewhere, but it doesn't affect the C++ side or dependencies, so in practice wasn't too much hard.
Didnt we resolve this behavior in another instance by looking at the main file? I'm pretty sure we made a change to type and attribute emission.to address this by only considering the main file and using that as source of truth as to what should be emitted. This may end up that one has to use a different top file, but it also fits that one emits for the file given as input. |
Yes we have As far as I know there's no way to limit which TD files are considered during generation because they behave just like headers right? So the "preprocessor" injects the headers and only then do we get the TU (or whatever). You can put a breakpoint of The last comment on the original patch in phabricator highlighted the issue. |
90c0b70
to
de851b0
Compare
I have added a failing case that demonstrates the problem already exists upstream: # mlir/test/python/dialects/affine.py
@constructAndPrintInModule
def test_double_AtomicRMWKindAttr_registration():
from mlir.dialects import _affine_enum_gen We should see
Note that llvm-project/mlir/python/mlir/dialects/affine.py Lines 5 to 6 in d50fbe4
|
Close, TD files aren't textually included as one would with preprocessor. The origin is retained by SourceMgr. What I was thinking/suggestion above was something like https://github.com/jpienaar/llvm-project/pull/new/pyenum , so filtering on origin file and making it act on it (that change is larger than it should be as GPU dialect has an enum section ... and then has most of their enums in a different file and that just felt wrong). This is how attributes and types are handled. Unfortunately neither enum generator does that. I think it is a nice organization mechanism that makes things more regular, makes what is acted on explicit/unambigous (especially as many of the enums don't have a dialect associated trivially). One could still do the dialect prefixing with this in the registry (although the example you showed, shows an attribute from LinalgTransformEnums.td as in builtin dialect so not sure about the heuristic :)). And I was thinking if one could do something more automatic by associating ops with the nearest Dialect definition in include path. But that heuristic could be too magical. WDYT about making it explicit via file acted on as the other? (could also just be focused on Python side). |
which one?
yea definitely too magical.
In general I think automatically inferring via file location is just as magical/error-prone. As you commented, GPU would now have attributes in the "wrong" place but according to what semantic? Only the semantic introduced by such a change (prior they were just conventions). My change also introduces a new semantic but the semantic "every dialect gets dialect attributes generated/emitted automatically and But I'll just say, since I'm no longer andactive user of these bindings (I'm just pitching in to help @kuhar), I won't stand in the way of something that unblocks some fix. If you think pinning to file location is fine and it works for you guys then we can land that instead (it'll work for us too). |
Hack/patch to workaround overlapping `AttrBuilder` registrations. Needed until iree-org/iree#19324 and llvm/llvm-project#117918 get resolved.
Hack/patch to workaround overlapping `AttrBuilder` registrations. Needed until iree-org/iree#19324 and llvm/llvm-project#117918 get resolved.
Hack/patch to workaround overlapping `AttrBuilder` registrations. Needed until iree-org/iree#19324 and llvm/llvm-project#117918 get resolved.
Problem
Currently we emit
EnumAttr
bindings indiscriminately (i.e., without considering-bind-dialect
). This leads to two thingsclass <OTHERDIALECT>Enum
s for every<DIALECT>_enums_gen.py
that uses<OTHERDIALECT>Enum
;@register_attribute_builder("<OTHERDIALECT>EnumAttr")
Note, the duplication is due to the fact that tablegen has no "memory" across emitted files so it has no knowledge that it has already emitted
XYZEnumAttr
in both_<DIALECT>_enums_gen.py
and_<OTHERDIALECT>_enums_gen.py
.The first isn't an issue because even if they're duplicated, they're mechanically kept in sync.
The second is an issue because
attributeBuilderMap
doesn't automatically replace/override registrations (nor should it). For example, we havelinalg.iterator_type
and IREE hasiree_gpu.iterator_type
and thus if you try to use both_linalg_enums_gen.py
and_iree_gpu_enums_gen.py
you will getSolution
The only solution is to cease indiscriminate emission of registration calls. In this PR I do that, i.e., I add
-bind-dialect
to-gen-python-enum-bindings
and filter whichregister_attribute_builder
calls are emitted. The effect is thatbuiltin
dialect attribute builder registration calls are no longer emitted anywhere. So those have to be written by hand somewhere. Here I put them adjacent to where they were being emitted prior (i.e., in the<DIALECT>.py
) . I also actually use the dialect to prefix the builder lookup key. Sobecomes