Skip to content

[MLIR][Transform] apply_registered_pass op's options as a dict #143159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 11, 2025

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Jun 6, 2025

Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the addToPipeline-pass API.

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

In particular, use similar syntax for providing options as in the (pretty-)printed IR.


Full diff: https://github.com/llvm/llvm-project/pull/143159.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/transform/init.py (+35)
  • (modified) mlir/test/python/dialects/transform.py (+36)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+    def __init__(
+        self,
+        result: Type,
+        pass_name: Union[str, StringAttr],
+        target: Value,
+        *,
+        options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+        loc=None,
+        ip=None,
+    ):
+        static_options = []
+        dynamic_options = []
+        for opt in options:
+            if isinstance(opt, str):
+                static_options.append(StringAttr.get(opt))
+            elif isinstance(opt, StringAttr):
+                static_options.append(opt)
+            elif isinstance(opt, Value):
+                static_options.append(UnitAttr.get())
+                dynamic_options.append(_get_op_result_or_value(opt))
+            else:
+                raise TypeError(f"Unsupported option type: {type(opt)}")
+        super().__init__(
+            result,
+            pass_name,
+            dynamic_options,
+            target=_get_op_result_or_value(target),
+            options=static_options,
+            loc=loc,
+            ip=ip,
+        )
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
     # CHECK: %[[FIRST:.+]] = pdl_match
     # CHECK: %[[SECOND:.+]] = pdl_match
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+@run
+def testApplyRegisteredPassOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+        )
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+        )
+        max_iter = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+        )
+        max_rewrites = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+        )
+        transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod,
+            options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+    # CHECK: transform.sequence
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
+    # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+    # CHECK-SAME:    with options = "top-down=false" %[[MAX_ITER]]
+    # CHECK-SAME:   "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

Copy link

github-actions bot commented Jun 6, 2025

✅ With the latest revision this PR passed the Python code formatter.

@rolfmorel
Copy link
Contributor Author

Context: #142683

transform.AnyOpType.get(),
"canonicalize",
mod,
options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather works toward a dictionary here that would make it Python-friendly, but I see the actual ops allows for "max-iterations=10" style of parameter... Though even for the op itself, it may be wise separating the pass parameter name (which is a literal) from the value it takes (which may be a constant/value or also a literal).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have this "widget":

def add_pass(self, pass_name, **kwargs):
    kwargs = {
        k.replace("_", "-"): int(v) if isinstance(v, bool) else v
        for k, v in kwargs.items()
        if v is not None
    }
    if kwargs:
        args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())

string interpolation of python values does the right thing for the kinds of args seen in passes (ints, strings, lists, etc) except for bools True/False, which is handled by int(True)/int(False) -> 0/1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a major overhaul. Most notably values can now be passed via params (without needing to be strings nor needing to include key=).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overhaul addresses this point.

Will now merge with the current mechanism for allowing a reference to a param from the options dict. Can iterate in-tree on that design if someone has a better suggestion.

@rolfmorel rolfmorel changed the title [MLIR][Transform] friendlier Python-bindings apply_registered_pass op [MLIR][Transform] apply_registered_pass op's options as a dict Jun 7, 2025
@rolfmorel
Copy link
Contributor Author

rolfmorel commented Jun 7, 2025

Hi @ftynse, @makslevental, partially based on your comments, I decided to go whole hog on this and improve the op on the C++ side as well. That is, the options are now provided in a dictionary and it is possible to pass the values for the options via params.

I re-purposed this PR (prev. was solely fixing up Python bindings for this op) for this more substantial change. Looking forward to your re-review!

@rolfmorel rolfmorel requested a review from rengolin June 7, 2025 20:38
Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to dict, looks neat
Minor comments

@rolfmorel rolfmorel merged commit fe7bf4b into llvm:main Jun 11, 2025
7 checks passed
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…143159)

Improve ApplyRegisteredPassOp's support for taking options by taking
them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params
(which pass in attributes at interpreter runtime). In either case, the
keys and value attributes are converted to strings and a single
options-string, in the format used on the commandline, is constructed to
pass to the `addToPipeline`-pass API.
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
…143159)

Improve ApplyRegisteredPassOp's support for taking options by taking
them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params
(which pass in attributes at interpreter runtime). In either case, the
keys and value attributes are converted to strings and a single
options-string, in the format used on the commandline, is constructed to
pass to the `addToPipeline`-pass API.
paul0403 added a commit to PennyLaneAI/catalyst that referenced this pull request Jul 28, 2025
**Context:**
Update llvm, mhlo and enzyme, 2025 Q3.
The latest pair of good versions, indicated by mhlo, is
tensorflow/mlir-hlo@1dd2e71
```
mhlo=1dd2e71331014ae0373f6bf900ce6be393357190
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
```

For Enzyme, we go to the latest release
https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186
```
enzyme=v0.0.186
```
with commit `8c1a596158f6194f10e8ffd56a1660a61c54337e`

**Description of the Change:**
Miscellaneous:
1. `GreedyRewriteConfig.stuff = blah` ->
`GreedyRewriteConfig.setStuff(blah)`
llvm/llvm-project#137122
2. llvm gep op `inbounds` attribute is subsumed under a gep sign wrap
enum flag llvm/llvm-project#137272
3. `arith::Constant[Int, Float]Op` builders now have the same argument
order as other ops (output type first, then arguments)
llvm/llvm-project#144636 (note that Enzyme also
noticed this EnzymeAD/Enzyme#2379 😆 )
4. The `lookupOrCreateFn` functions now take in a builder instead of
instantiating a new one llvm/llvm-project#136421
5. `getStridedElementPtr` now takes in `rewriter` as the first argument
(instead of the last), like all the other utils
llvm/llvm-project#138984
6. The following functions now return a `LogicalResult`, and will be
caught by warnings as errors as `-Wunused-result`:
- `func::FuncOp.[insert, erase]Argument(s)`
llvm/llvm-project#137130
- `getBackwardSlice()` llvm/llvm-project#140961

Things related to `transform.apply_registered_pass` op:
1. It now takes in a `dynamic_options`
llvm/llvm-project#142683. We don't need to use
this as all our pass options are static.
2. The options it takes in are now dictionaries instead of strings
llvm/llvm-project#143159

Bufferization:
1. `bufferization.to_memref` op is renamed to `bufferization.to_buffer`
llvm/llvm-project#137180
3. `bufferization.to_tensor` op's builder now needs the result type to
be explicit llvm/llvm-project#142986. This is
also needed by a patched mhlo pass.
4. The `getBuffer()` methods take in a new arg for `BufferizationState`
llvm/llvm-project#141019,
llvm/llvm-project#141466
5. `UnknownTypeConverterFn` in bufferization options now takes in just a
type instead of a full value
llvm/llvm-project#144658

**Related GitHub Issues:** 
[sc-95176]
[sc-95664]

---------

Co-authored-by: Mehrdad Malek <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants