-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[MLIR][Transform] apply_registered_pass op's options as a dict #143159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesIn particular, use similar syntax for providing options as in the (pretty-)printed IR. Full diff: https://github.com/llvm/llvm-project/pull/143159.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+ def __init__(
+ self,
+ result: Type,
+ pass_name: Union[str, StringAttr],
+ target: Value,
+ *,
+ options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+ loc=None,
+ ip=None,
+ ):
+ static_options = []
+ dynamic_options = []
+ for opt in options:
+ if isinstance(opt, str):
+ static_options.append(StringAttr.get(opt))
+ elif isinstance(opt, StringAttr):
+ static_options.append(opt)
+ elif isinstance(opt, Value):
+ static_options.append(UnitAttr.get())
+ dynamic_options.append(_get_op_result_or_value(opt))
+ else:
+ raise TypeError(f"Unsupported option type: {type(opt)}")
+ super().__init__(
+ result,
+ pass_name,
+ dynamic_options,
+ target=_get_op_result_or_value(target),
+ options=static_options,
+ loc=loc,
+ ip=ip,
+ )
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+@run
+def testApplyRegisteredPassOp(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+ )
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+ )
+ max_iter = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+ )
+ max_rewrites = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+ )
+ transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(),
+ "canonicalize",
+ mod,
+ options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+ # CHECK: transform.sequence
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[MAX_ITER:.+]] = transform.param.constant
+ # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # CHECK-SAME: with options = "top-down=false" %[[MAX_ITER]]
+ # CHECK-SAME: "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
|
✅ With the latest revision this PR passed the Python code formatter. |
Context: #142683 |
transform.AnyOpType.get(), | ||
"canonicalize", | ||
mod, | ||
options=("top-down=false", max_iter, "test-convergence=true", max_rewrites), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather works toward a dictionary here that would make it Python-friendly, but I see the actual ops allows for "max-iterations=10" style of parameter... Though even for the op itself, it may be wise separating the pass parameter name (which is a literal) from the value it takes (which may be a constant/value or also a literal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have this "widget":
def add_pass(self, pass_name, **kwargs):
kwargs = {
k.replace("_", "-"): int(v) if isinstance(v, bool) else v
for k, v in kwargs.items()
if v is not None
}
if kwargs:
args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())
string interpolation of python values does the right thing for the kinds of args seen in passes (ints, strings, lists, etc) except for bools True/False
, which is handled by int(True)/int(False) -> 0/1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a major overhaul. Most notably values can now be passed via params (without needing to be strings nor needing to include key=
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overhaul addresses this point.
Will now merge with the current mechanism for allowing a reference to a param from the options dict. Can iterate in-tree on that design if someone has a better suggestion.
Hi @ftynse, @makslevental, partially based on your comments, I decided to go whole hog on this and improve the op on the C++ side as well. That is, the options are now provided in a dictionary and it is possible to pass the values for the options via params. I re-purposed this PR (prev. was solely fixing up Python bindings for this op) for this more substantial change. Looking forward to your re-review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to dict, looks neat
Minor comments
In particular, use similar syntax for providing options as in the (pretty-)printed IR.
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
…143159) Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs). Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the `addToPipeline`-pass API.
…143159) Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs). Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the `addToPipeline`-pass API.
**Context:** Update llvm, mhlo and enzyme, 2025 Q3. The latest pair of good versions, indicated by mhlo, is tensorflow/mlir-hlo@1dd2e71 ``` mhlo=1dd2e71331014ae0373f6bf900ce6be393357190 llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 ``` For Enzyme, we go to the latest release https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186 ``` enzyme=v0.0.186 ``` with commit `8c1a596158f6194f10e8ffd56a1660a61c54337e` **Description of the Change:** Miscellaneous: 1. `GreedyRewriteConfig.stuff = blah` -> `GreedyRewriteConfig.setStuff(blah)` llvm/llvm-project#137122 2. llvm gep op `inbounds` attribute is subsumed under a gep sign wrap enum flag llvm/llvm-project#137272 3. `arith::Constant[Int, Float]Op` builders now have the same argument order as other ops (output type first, then arguments) llvm/llvm-project#144636 (note that Enzyme also noticed this EnzymeAD/Enzyme#2379 😆 ) 4. The `lookupOrCreateFn` functions now take in a builder instead of instantiating a new one llvm/llvm-project#136421 5. `getStridedElementPtr` now takes in `rewriter` as the first argument (instead of the last), like all the other utils llvm/llvm-project#138984 6. The following functions now return a `LogicalResult`, and will be caught by warnings as errors as `-Wunused-result`: - `func::FuncOp.[insert, erase]Argument(s)` llvm/llvm-project#137130 - `getBackwardSlice()` llvm/llvm-project#140961 Things related to `transform.apply_registered_pass` op: 1. It now takes in a `dynamic_options` llvm/llvm-project#142683. We don't need to use this as all our pass options are static. 2. The options it takes in are now dictionaries instead of strings llvm/llvm-project#143159 Bufferization: 1. `bufferization.to_memref` op is renamed to `bufferization.to_buffer` llvm/llvm-project#137180 3. `bufferization.to_tensor` op's builder now needs the result type to be explicit llvm/llvm-project#142986. This is also needed by a patched mhlo pass. 4. The `getBuffer()` methods take in a new arg for `BufferizationState` llvm/llvm-project#141019, llvm/llvm-project#141466 5. `UnknownTypeConverterFn` in bufferization options now takes in just a type instead of a full value llvm/llvm-project#144658 **Related GitHub Issues:** [sc-95176] [sc-95664] --------- Co-authored-by: Mehrdad Malek <[email protected]>
Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs).
Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the
addToPipeline
-pass API.