Skip to content

[mlir-python] Emit only dialect EnumAttr registrations. #117918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Nov 27, 2024

Problem

Currently we emit EnumAttr bindings indiscriminately (i.e., without considering -bind-dialect). This leads to two things

  1. duplicated class <OTHERDIALECT>Enums for every <DIALECT>_enums_gen.py that uses <OTHERDIALECT>Enum;
  2. similarly, duplicated @register_attribute_builder("<OTHERDIALECT>EnumAttr")

Note, the duplication is due to the fact that tablegen has no "memory" across emitted files so it has no knowledge that it has already emitted XYZEnumAttr in both _<DIALECT>_enums_gen.py and _<OTHERDIALECT>_enums_gen.py.

The first isn't an issue because even if they're duplicated, they're mechanically kept in sync.

The second is an issue because attributeBuilderMap doesn't automatically replace/override registrations (nor should it). For example, we have linalg.iterator_type and IREE has iree_gpu.iterator_type and thus if you try to use both _linalg_enums_gen.py and _iree_gpu_enums_gen.py you will get

RuntimeError: Attribute builder for 'IteratorType' is already registered with func: <function _iteratortype at 0x7f4f678cc5e0>

Solution

The only solution is to cease indiscriminate emission of registration calls. In this PR I do that, i.e., I add -bind-dialect to -gen-python-enum-bindings and filter which register_attribute_builder calls are emitted. The effect is that builtin dialect attribute builder registration calls are no longer emitted anywhere. So those have to be written by hand somewhere. Here I put them adjacent to where they were being emitted prior (i.e., in the <DIALECT>.py) . I also actually use the dialect to prefix the builder lookup key. So

@register_attribute_builder("MatchInterfaceEnum")

becomes

@register_attribute_builder("builtin.MatchInterfaceEnum")

@makslevental makslevental force-pushed the makslevental/fix-enums branch 2 times, most recently from 67d9107 to 9fce7ad Compare November 27, 2024 20:30
@makslevental makslevental changed the title [python] fix enum ambiguity [python] fix enum collision Nov 27, 2024
@makslevental makslevental force-pushed the makslevental/fix-enums branch from 9fce7ad to 5d3785e Compare November 27, 2024 22:55
Copy link

github-actions bot commented Nov 27, 2024

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the makslevental/fix-enums branch 2 times, most recently from 1e19c9c to 90c0b70 Compare November 27, 2024 23:28
@makslevental makslevental changed the title [python] fix enum collision [python] Emit only dialect EnumAttr registrations. Nov 28, 2024
@makslevental makslevental marked this pull request as ready for review November 28, 2024 00:08
@llvmbot llvmbot added backend:AMDGPU mlir:core MLIR Core Infrastructure mlir:linalg mlir:python MLIR Python bindings mlir labels Nov 28, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-backend-amdgpu

Author: Maksim Levental (makslevental)

Changes

Problem

Currently we emit EnumAttr bindings indiscriminately (i.e., without considering -bind-dialect). This leads to two things

  1. duplicated class &lt;OTHERDIALECT&gt;Enums for every &lt;DIALECT&gt;_enums_gen.py that uses &lt;OTHERDIALECT&gt;Enum;
  2. similarly, duplicated @<!-- -->register_attribute_builder("&lt;OTHERDIALECT&gt;EnumAttr")

Note, the duplication is due to the fact that tablegen has no "memory" across emitted files so it has no knowledge that it has already emitted XYZEnumAttr in both _&lt;DIALECT&gt;_enums_gen.py and _&lt;OTHERDIALECT&gt;_enums_gen.py.

The first isn't an issue because even if they're duplicated, they're mechanically kept in sync.

The second is an issue because attributeBuilderMap doesn't automatically replace/override registrations (nor should it). For example, we have linalg.iterator_type and IREE has iree_gpu.iterator_type and thus if you try to use both _linalg_enums_gen.py and _iree_gpu_enums_gen.py you will get

RuntimeError: Attribute builder for 'IteratorType' is already registered with func: &lt;function _iteratortype at 0x7f4f678cc5e0&gt;

Solution

The only solution is to cease generation of indiscriminate emission of registration calls. In this PR I do that, i.e., I add -bind-dialect to -gen-python-enum-bindings and filter which register_attribute_builder calls are emitted. The effect is that builtin dialect attribute builder registration calls are no longer emitted anywhere. So those have to be written by hand somewhere. Here I put them adjacent (in the &lt;DIALECT&gt;.py) to where they were being emitted prior. I also actually use the dialect to prefix the builder lookup key. So

@<!-- -->register_attribute_builder("MatchInterfaceEnum")

becomes

@<!-- -->register_attribute_builder("builtin.MatchInterfaceEnum")

Patch is 45.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117918.diff

26 Files Affected:

  • (modified) mlir/cmake/modules/AddMLIRPython.cmake (+2-2)
  • (modified) mlir/python/CMakeLists.txt (+1-2)
  • (modified) mlir/python/mlir/dialects/_ods_common.py (+1)
  • (modified) mlir/python/mlir/dialects/amdgpu.py (+16)
  • (modified) mlir/python/mlir/dialects/arith.py (+35)
  • (modified) mlir/python/mlir/dialects/bufferization.py (+6)
  • (modified) mlir/python/mlir/dialects/gpu/init.py (+56)
  • (modified) mlir/python/mlir/dialects/index.py (+6)
  • (modified) mlir/python/mlir/dialects/linalg/init.py (+25)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+1)
  • (modified) mlir/python/mlir/dialects/llvm.py (+106-1)
  • (modified) mlir/python/mlir/dialects/nvgpu.py (+26)
  • (modified) mlir/python/mlir/dialects/nvvm.py (+76)
  • (modified) mlir/python/mlir/dialects/sparse_tensor.py (+16)
  • (modified) mlir/python/mlir/dialects/transform/init.py (+10)
  • (modified) mlir/python/mlir/dialects/transform/extras/init.py (+1)
  • (modified) mlir/python/mlir/dialects/transform/structured.py (+10)
  • (modified) mlir/python/mlir/dialects/transform/vector.py (+21)
  • (modified) mlir/python/mlir/dialects/vector.py (+16)
  • (modified) mlir/python/mlir/ir.py (+54-54)
  • (modified) mlir/test/mlir-tblgen/enums-python-bindings.td (+2-14)
  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+2-2)
  • (modified) mlir/test/python/dialects/index_dialect.py (+1-1)
  • (modified) mlir/test/python/dialects/transform_structured_ext.py (+1-1)
  • (modified) mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (+16-7)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+6-3)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7b91f43e2d57fd..d06fc927ea44d4 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
       list(APPEND _sources ${enum_filename})
     endif()
 
@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
       list(APPEND _sources ${enum_filename})
     endif()
 
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bb..9949743b9bf09c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/AffineOps.td
   SOURCES
     dialects/affine.py
-  DIALECT_NAME affine
-  GEN_ENUM_BINDINGS)
+  DIALECT_NAME affine)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d6..22fc588e180b1d 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
         else op
     )
 
+
 ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
 ResultValueT = _Union[ResultValueTypeTuple]
 VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481cc..9b8beaa5571a26 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -2,5 +2,21 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._amdgpu_ops_gen import *
 from ._amdgpu_enum_gen import *
+
+
+@register_attribute_builder("builtin.AMDGPU_DPPPerm")
+def _amdgpu_dppperm(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AMDGPU_MFMAPermB")
+def _amdgpu_mfmapermb(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
+def _amdgpu_schedbarrieropopt(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce665..32ba832260d129 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -108,3 +108,38 @@ def constant(
     result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+@register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
+def _arith_cmpfpredicateattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
+def _arith_cmpipredicateattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_DenormalMode")
+def _arith_denormalmode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
+def _arith_integeroverflowflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Arith_RoundingModeAttr")
+def _arith_roundingmodeattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AtomicRMWKindAttr")
+def _atomicrmwkindattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FastMathFlags")
+def _fastmathflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff7..6ad76c729ed2dc 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -2,5 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._bufferization_ops_gen import *
 from ._bufferization_enum_gen import *
+
+
+@register_attribute_builder("builtin.LayoutMapOption")
+def _layoutmapoption(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca85..e0bb07c5dad8be 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -2,6 +2,62 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
 from .._gpu_ops_gen import *
 from .._gpu_enum_gen import *
 from ..._mlir_libs._mlirDialectsGPU import *
+
+
+@register_attribute_builder("builtin.GPU_AddressSpaceEnum")
+def _gpu_addressspaceenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_AllReduceOperation")
+def _gpu_allreduceoperation(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_CompilationTargetEnum")
+def _gpu_compilationtargetenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_Dimension")
+def _gpu_dimension(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
+def _gpu_prune2to4spmatflag(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_ShuffleMode")
+def _gpu_shufflemode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
+def _gpu_spgemmworkestimationorcomputekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.GPU_TransposeMode")
+def _gpu_transposemode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAElementWise")
+def _mmaelementwise(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MappingIdEnum")
+def _mappingidenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ProcessorEnum")
+def _processorenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py
index 73708c7d71a8c8..f00c397965c97c 100644
--- a/mlir/python/mlir/dialects/index.py
+++ b/mlir/python/mlir/dialects/index.py
@@ -2,5 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._index_ops_gen import *
 from ._index_enum_gen import *
+
+
+@register_attribute_builder("builtin.IndexCmpPredicate")
+def _indexcmppredicate(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..4fe9cc40ee910a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -102,3 +102,28 @@ def broadcast(
     )
     fill_builtin_region(op.operation)
     return op
+
+
+@register_attribute_builder("builtin.BinaryFn")
+def _binaryfn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.IteratorType")
+def _iteratortype(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TernaryFn")
+def _ternaryfn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TypeFn")
+def _typefn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.UnaryFn")
+def _unaryfn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..f87b25e8416023 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
         - TypeFn.cast_signed(U, IZp)
     ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
 
+
 @linalg_structured_op
 def conv_2d_nchw_fchw(
     I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 941a584966dcde..456df39ae650ff 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,7 +5,7 @@
 from ._llvm_ops_gen import *
 from ._llvm_enum_gen import *
 from .._mlir_libs._mlirDialectsLLVM import *
-from ..ir import Value
+from ..ir import Value, IntegerAttr, IntegerType, register_attribute_builder
 from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
 
 
@@ -13,3 +13,108 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value:
     return _get_op_result_or_op_results(
         ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
     )
+
+
+@register_attribute_builder("builtin.AsmATTOrIntel")
+def _asmattorintel(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AtomicBinOp")
+def _atomicbinop(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.AtomicOrdering")
+def _atomicordering(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.CConvEnum")
+def _cconvenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Comdat")
+def _comdat(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.DIFlags")
+def _diflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.DISubprogramFlags")
+def _disubprogramflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FCmpPredicate")
+def _fcmppredicate(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FPExceptionBehaviorAttr")
+def _fpexceptionbehaviorattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FastmathFlags")
+def _fastmathflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.FramePointerKindEnum")
+def _framepointerkindenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ICmpPredicate")
+def _icmppredicate(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.IntegerOverflowFlags")
+def _integeroverflowflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.LLVM_DIEmissionKind")
+def _llvm_diemissionkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.LLVM_DINameTableKind")
+def _llvm_dinametablekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.LinkageEnum")
+def _linkageenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ModRefInfoEnum")
+def _modrefinfoenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.RoundingModeAttr")
+def _roundingmodeattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TailCallKindEnum")
+def _tailcallkindenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.UnnamedAddr")
+def _unnamedaddr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+@register_attribute_builder("builtin.Visibility")
+def _visibility(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py
index d6a54f2772f40d..eea132adb0484e 100644
--- a/mlir/python/mlir/dialects/nvgpu.py
+++ b/mlir/python/mlir/dialects/nvgpu.py
@@ -2,6 +2,32 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._nvgpu_ops_gen import *
 from ._nvgpu_enum_gen import *
 from .._mlir_libs._mlirDialectsNVGPU import *
+
+
+@register_attribute_builder("builtin.RcpRoundingMode")
+def _rcproundingmode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapInterleaveKind")
+def _tensormapinterleavekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapL2PromoKind")
+def _tensormapl2promokind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapOOBKind")
+def _tensormapoobkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.TensorMapSwizzleKind")
+def _tensormapswizzlekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py
index 9477de39c9ead7..21bf24cb73fdab 100644
--- a/mlir/python/mlir/dialects/nvvm.py
+++ b/mlir/python/mlir/dialects/nvvm.py
@@ -2,5 +2,81 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._nvvm_ops_gen import *
 from ._nvvm_enum_gen import *
+
+
+@register_attribute_builder("builtin.LoadCacheModifierKind")
+def _loadcachemodifierkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAB1Op")
+def _mmab1op(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAFrag")
+def _mmafrag(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMAIntOverflow")
+def _mmaintoverflow(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMALayout")
+def _mmalayout(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MMATypes")
+def _mmatypes(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.MemScopeKind")
+def _memscopekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ProxyKind")
+def _proxykind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ReduxKind")
+def _reduxkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.SetMaxRegisterAction")
+def _setmaxregisteraction(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.SharedSpace")
+def _sharedspace(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.ShflKind")
+def _shflkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.WGMMAScaleIn")
+def _wgmmascalein(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.WGMMAScaleOut")
+def _wgmmascaleout(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_builder("builtin.WGMMATypes")
+def _wgmmatypes(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 209ecc95fa8fc8..8f1b83f9d514fd 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -2,7 +2,23 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._sparse_tensor_ops_gen import *
 from ._sparse_tensor_enum_gen import *
 from .._mlir_libs._mlirDialectsSparseTensor import *
 from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
+
+
+@register_attribute_builder("builtin.SparseTensorCrdTransDirectionEnum")
+def _sparsetensorcrdtransdirectionenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+@register_attribute_...
[truncated]

@makslevental makslevental changed the title [python] Emit only dialect EnumAttr registrations. [mlir-python] Emit only dialect EnumAttr registrations. Nov 28, 2024
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
DIALECT_NAME affine
GEN_ENUM_BINDINGS)
DIALECT_NAME affine)
Copy link
Contributor Author

@makslevental makslevental Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_affine_enums_gen.py had nothing but arith enums and registrations so best to just turn it off to keep from confusing people.

@jpienaar
Copy link
Member

The other is also that these should be prefixed with dialect. Problem 2 is resolved by requiring folks to prefix with namespace as is done most places.


@register_attribute_builder("builtin.AMDGPU_DPPPerm")
def _amdgpu_dppperm(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add a 'reflection' function to the tablegened enum classes to get the underlying type from the enum definition

@makslevental
Copy link
Contributor Author

makslevental commented Nov 28, 2024

The other is also that these should be prefixed with dialect. Problem 2 is resolved by requiring folks to prefix with namespace as is done most places.

Yes but it's upstream where the problem lies - there are many attributes that are in builtin and aren't prefixed. Such as all these in linalg and FastMathFlags in arith and the particular IteratorType offender here.

It actually wouldn't solve the problem anyway because of how the generation logic currently works:

https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp#L137

static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
  for (const Record *it :
       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
    EnumAttr enumAttr(*it);
    emitAttributeBuilder(enumAttr, os);
  }
  for (const Record *it :
       records.getAllDerivedDefinitionsIfDefined("EnumAttr"))
      emitDialectEnumAttributeBuilder(
          attr.getName(),
          formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
}

https://github.com/llvm/llvm-project/blob/main/mlir/test/mlir-tblgen/enums-python-bindings.td#L99-L105

def TestBitEnum
    : I32BitEnumAttr<"TestBitEnum", "", [
        I32BitEnumAttrCaseBit<"User", 0, "user">,
        I32BitEnumAttrCaseBit<"Group", 1, "group">,
        I32BitEnumAttrCaseBit<"Other", 2, "other">,
      ]> {
  let genSpecializedAttr = 0;
  let separator = " | ";
}

def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;

// CHECK: @register_attribute_builder("TestBitEnum")
// CHECK: def _testbitenum(x, context):
// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

// CHECK: @register_attribute_builder("TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)

So the "builtin" will always be emitted every single time it's used as the enumInfo of an EnumAttr (because again, tablegen has no knowledge that it's already emitted the registration in some other file).

So the real root cause here is simply the global attributeBuilderMap vs. "local" tablegen.

This PR actually does rectify this original flaw in the design (that first @register_attribute_builder("TestBitEnum") is no longer emitted).

@jpienaar
Copy link
Member

The other is also that these should be prefixed with dialect. Problem 2 is resolved by requiring folks to prefix with namespace as is done most places.

Yes but it's upstream where the problem lies - there are many attributes that are in builtin and aren't prefixed. Such as all these in linalg and FastMathFlags in arith and the particular IteratorType offender here.

Yes I know Linalg is a offender here: I had made a lint check and started changing these a while ago. It's not that many that needs changes. The one where it was a bit more work is updating a generator generator here which relief more on names (it required introducing same behavior as we do for op definitions though, which is a win). Built-in being special is known, it has no prefix, that is known and also why we keep it small. The FastMathFlags one I easily addressed, but it ends up a bit confusing as it results in looking like a dialect dependency as it is intentionally reused elsewhere, but it doesn't affect the C++ side or dependencies, so in practice wasn't too much hard.

It actually wouldn't solve the problem anyway because of how the generation logic currently works:

https://github.com/llvm/llvm-project/blob/main/mlir/test/mlir-tblgen/enums-python-bindings.td#L99-L105

def TestBitEnum
    : I32BitEnumAttr<"TestBitEnum", "", [
        I32BitEnumAttrCaseBit<"User", 0, "user">,
        I32BitEnumAttrCaseBit<"Group", 1, "group">,
        I32BitEnumAttrCaseBit<"Other", 2, "other">,
      ]> {
  let genSpecializedAttr = 0;
  let separator = " | ";
}

def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;

// CHECK: @register_attribute_builder("TestBitEnum")
// CHECK: def _testbitenum(x, context):
// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

// CHECK: @register_attribute_builder("TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)

So the "builtin" will always be emitted every single time it's used as the enumInfo of an EnumAttr (because again, tablegen has no knowledge that it's already emitted the registration in some other file).

Didnt we resolve this behavior in another instance by looking at the main file? I'm pretty sure we made a change to type and attribute emission.to address this by only considering the main file and using that as source of truth as to what should be emitted. This may end up that one has to use a different top file, but it also fits that one emits for the file given as input.

@makslevental
Copy link
Contributor Author

makslevental commented Nov 28, 2024

Didnt we resolve this behavior in another instance by looking at the main file? I'm pretty sure we made a change to type and attribute emission.to address this by only considering the main file and using that as source of truth as to what should be emitted. This may end up that one has to use a different top file, but it also fits that one emits for the file given as input.

Yes we have GEN_ENUM_BINDINGS_TD_FILE to prevent generating for unintended attributes that might be used on some op but in these cases (IteratorType) the TD file with the attribute itself includes the other attribute TD files.

As far as I know there's no way to limit which TD files are considered during generation because they behave just like headers right? So the "preprocessor" injects the headers and only then do we get the TU (or whatever). You can put a breakpoint of printf anywhere in the generation code itself and see that the RecordKeeper always has all of the records for the file and its includes.

The last comment on the original patch in phabricator highlighted the issue.

@makslevental makslevental force-pushed the makslevental/fix-enums branch from 90c0b70 to de851b0 Compare December 3, 2024 01:46
@makslevental
Copy link
Contributor Author

makslevental commented Dec 3, 2024

I have added a failing case that demonstrates the problem already exists upstream:

# mlir/test/python/dialects/affine.py
@constructAndPrintInModule
def test_double_AtomicRMWKindAttr_registration():
    from mlir.dialects import _affine_enum_gen

We should see

# RUN: at line 1
//mlir/test/python/dialects/affine.py | /cmake-build-debug/bin/FileCheck /mlir/test/python/dialects/affine.py
# executed command: //mlir/test/python/dialects/affine.py
# .---command stderr------------
# | Traceback (most recent call last):
# |   File "/mlir/test/python/dialects/affine.py", line 338, in <module>
# |     @constructAndPrintInModule
# |      ^^^^^^^^^^^^^^^^^^^^^^^^^
# |   File "/mlir/test/python/dialects/affine.py", line 16, in constructAndPrintInModule
# |     f()
# |   File "/mlir/test/python/dialects/affine.py", line 340, in test_double_AtomicRMWKindAttr_registration
# |     from mlir.dialects import _affine_enum_gen
# |   File "/cmake-build-debug/tools/mlir/python_packages/mlir_core/mlir/dialects/_affine_enum_gen.py", line 66, in <module>
# |     @register_attribute_builder("Arith_CmpFPredicateAttr")
# |      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# |   File "/cmake-build-debug/tools/mlir/python_packages/mlir_core/mlir/ir.py", line 14, in decorator_builder
# |     AttrBuilder.insert(kind, func, replace=replace)
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
# `-----------------------------
# error: command failed with exit status: 1
# executed command: /cmake-build-debug/bin/FileCheck /mlir/test/python/dialects/affine.py

Note that affine.py itself doesn't import _affine_enum_gen as a workaround/hack for this very problem:

from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect

@jpienaar
Copy link
Member

jpienaar commented Dec 6, 2024

As far as I know there's no way to limit which TD files are considered during generation because they behave just like headers right? So the "preprocessor" injects the headers and only then do we get the TU (or whatever).

Close, TD files aren't textually included as one would with preprocessor. The origin is retained by SourceMgr.

What I was thinking/suggestion above was something like https://github.com/jpienaar/llvm-project/pull/new/pyenum , so filtering on origin file and making it act on it (that change is larger than it should be as GPU dialect has an enum section ... and then has most of their enums in a different file and that just felt wrong). This is how attributes and types are handled. Unfortunately neither enum generator does that. I think it is a nice organization mechanism that makes things more regular, makes what is acted on explicit/unambigous (especially as many of the enums don't have a dialect associated trivially).

One could still do the dialect prefixing with this in the registry (although the example you showed, shows an attribute from LinalgTransformEnums.td as in builtin dialect so not sure about the heuristic :)). And I was thinking if one could do something more automatic by associating ops with the nearest Dialect definition in include path. But that heuristic could be too magical.

WDYT about making it explicit via file acted on as the other? (could also just be focused on Python side).

@makslevental
Copy link
Contributor Author

makslevental commented Dec 6, 2024

lthough the example you showed, shows an attribute from LinalgTransformEnums.td as in builtin dialect so not sure about the heuristic

which one? MatchInterfaceEnum and TransposeMatmulInput? these are both I32EnumAttrs so they are in builtin? LinalgTransformEnums.td

And I was thinking if one could do something more automatic by associating ops with the nearest Dialect definition in include path. But that heuristic could be too magical.

yea definitely too magical.

WDYT about making it explicit via file acted on as the other?

In general I think automatically inferring via file location is just as magical/error-prone. As you commented, GPU would now have attributes in the "wrong" place but according to what semantic? Only the semantic introduced by such a change (prior they were just conventions).

My change also introduces a new semantic but the semantic "every dialect gets dialect attributes generated/emitted automatically and builtin continues to be treated special (requires hand-written register_attribute_builder calls)" is more conservative than the alternative.

But I'll just say, since I'm no longer andactive user of these bindings (I'm just pitching in to help @kuhar), I won't stand in the way of something that unblocks some fix. If you think pinning to file location is fine and it works for you guys then we can land that instead (it'll work for us too).

makslevental added a commit to nod-ai/shark-ai that referenced this pull request Dec 10, 2024
Hack/patch to workaround overlapping `AttrBuilder` registrations. Needed
until iree-org/iree#19324 and
llvm/llvm-project#117918 get resolved.
IanNod pushed a commit to IanNod/SHARK-Platform that referenced this pull request Dec 17, 2024
Hack/patch to workaround overlapping `AttrBuilder` registrations. Needed
until iree-org/iree#19324 and
llvm/llvm-project#117918 get resolved.
monorimet pushed a commit to nod-ai/shark-ai that referenced this pull request Jan 8, 2025
Hack/patch to workaround overlapping `AttrBuilder` registrations. Needed
until iree-org/iree#19324 and
llvm/llvm-project#117918 get resolved.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU mlir:core MLIR Core Infrastructure mlir:linalg mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants