@@ -10,37 +10,40 @@ The Transform dialect uses the dialect extension mechanism to allow additional o
10
10
// In MyExtension.cpp.
11
11
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
12
12
13
- // Define a new Transform dialect extension. This uses the CRTP idiom to identify
14
- // extensions.
13
+ // Define a new Transform dialect extension. This uses the CRTP idiom to
14
+ // identify extensions.
15
15
class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension > {
16
16
public:
17
17
// The extension must derive the base constructor.
18
18
using Base::Base;
19
19
20
- // This function initializes the extension, similarly to ` initialize ` in dialect
21
- // definitions. List individual operations and dependent dialects here.
20
+ // This function initializes the extension, similarly to ` initialize ` in
21
+ // dialect definitions. List individual operations and dependent dialects
22
+ // here.
22
23
void init();
23
24
};
24
25
25
26
void MyExtension::init() {
26
- // Similarly to dialects, an extension can declare a dependent dialect. This dialect
27
- // will be loaded along with the extension and, therefore, along with the Transform
28
- // dialect. Only declare as dependent the dialects that contain the attributes or
29
- // types used by transform operations. Do NOT declare as dependent the dialects
30
- // produced during the transformation.
27
+ // Similarly to dialects, an extension can declare a dependent dialect. This
28
+ // dialect will be loaded along with the extension and, therefore, along with
29
+ // the Transform dialect. Only declare as dependent the dialects that contain
30
+ // the attributes or types used by transform operations. Do NOT declare as
31
+ // dependent the dialects produced during the transformation.
32
+ //
31
33
// declareDependentDialect<MyDialect >();
32
34
33
- // When transformations are applied, they may produce new operations from previously
34
- // unloaded dialects. Typically, a pass would need to declare itself dependent on
35
- // the dialects containing such new operations. To avoid confusion with the dialects
36
- // the extension itself depends on, the Transform dialects differentiates between:
35
+ // When transformations are applied, they may produce new operations from
36
+ // previously unloaded dialects. Typically, a pass would need to declare
37
+ // itself dependent on the dialects containing such new operations. To avoid
38
+ // confusion with the dialects the extension itself depends on, the Transform
39
+ // dialects differentiates between:
37
40
// - dependent dialects, which are used by the transform operations, and
38
- // - generated dialects, which contain the entities (attributes, operations,
39
- // types) that may be produced by applying the transformation even when not
40
- // present in the original payload IR.
41
- // In the following chapter, we will be add operations that generate function calls
42
- // and structured control flow operations, so let's declare the corresponding
43
- // dialects as generated.
41
+ // - generated dialects, which contain the entities (attributes, operations,
42
+ // types) that may be produced by applying the transformation even when
43
+ // not present in the original payload IR.
44
+ // In the following chapter, we will be add operations that generate function
45
+ // calls and structured control flow operations, so let's declare the
46
+ // corresponding dialects as generated.
44
47
declareGeneratedDialect<::mlir::scf::SCFDialect>();
45
48
declareGeneratedDialect<::mlir::func::FuncDialect>();
46
49
@@ -89,7 +92,7 @@ mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
89
92
# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation.
90
93
add_public_tablegen_target(MyExtensionIncGen)
91
94
92
- # Don't forget to generate the documentation, this will produce a MyExtension.md under
95
+ # Don't forget to generate the documentation, this will produce a MyExtension.md under
93
96
# Dialects.
94
97
add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc)
95
98
```
@@ -103,7 +106,8 @@ add_mlir_library(
103
106
# Built from the following source files.
104
107
MyExtension.cpp
105
108
106
- # Make sure ODS declaration and definitions are generated before compiling this.
109
+ # Make sure ODS declaration and definitions are generated before compiling
110
+ # this.
107
111
DEPENDS
108
112
MyExtensionIncGen
109
113
@@ -136,10 +140,10 @@ This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, tha
136
140
void MyExtension::init () {
137
141
// …
138
142
139
- // Finally, we register the additional transform operations with the dialect. List all
140
- // operations generated from ODS. This call will perform additional checks that the
141
- // operations implement the transform and memory effect interfaces required by the
142
- // dialect interpreter and assert if they do not.
143
+ // Finally, we register the additional transform operations with the dialect.
144
+ // List all operations generated from ODS. This call will perform additional
145
+ // checks that the operations implement the transform and memory effect
146
+ // interfaces required by the dialect interpreter and assert if they do not.
143
147
registerTransformOps<
144
148
#define GET_OP_LIST
145
149
#include "MyExtension.cpp.inc"
@@ -154,34 +158,36 @@ With this setup, we are now ready to define the new transform operation to rewri
154
158
``` tablegen
155
159
// In MyExtension.td.
156
160
157
- // Define the new operation. By convention, prefix its name with the name of the dialect
158
- // extension, "my.". The full operation name will be further prefixed with "transform.".
161
+ // Define the new operation. By convention, prefix its name with the name of the
162
+ // dialect extension, "my.". The full operation name will be further prefixed
163
+ // with "transform.".
159
164
def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target",
160
- // Indicate that the operation implements the required TransformOpInterface and
161
- // MemoryEffectsOpInterface.
165
+ // Indicate that the operation implements the required TransformOpInterface
166
+ // and MemoryEffectsOpInterface.
162
167
[DeclareOpInterfaceMethods<TransformOpInterface>,
163
168
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
164
- // Provide a brief and a full description. It is recommended that the latter describes
165
- // the effects on the operands and how the operation processes various failure modes.
169
+ // Provide a brief and a full description. It is recommended that the latter
170
+ // describes the effects on the operands and how the operation processes
171
+ // various failure modes.
166
172
let summary = "Changes the callee of a call operation to the specified one";
167
173
let description = [{
168
- For each `func.call` payload operation associated with the handle, changes its
169
- callee to be the symbol whose name is provided as an attribute to this operation.
174
+ For each `func.call` payload operation associated with the handle, changes
175
+ its callee to be the symbol whose name is provided as an attribute to this operation.
170
176
171
- Generates a silenceable failure if the operand is associated with payload operations
172
- that are not `func.call`.
173
- Only reads the operand.
177
+ Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand.
174
178
}];
175
179
176
- // The arguments include the handle to the payload operations and the attribute that
177
- // specifies the new callee. The handle must implement TransformHandleTypeInterface.
178
- // We use a string attribute as the symbol may not exist in the transform IR so the
179
- // verification may fail.
180
+ // The arguments include the handle to the payload operations and the
181
+ // attribute that specifies the new callee. The handle must implement
182
+ // TransformHandleTypeInterface.
183
+ // We use a string attribute as the symbol may not exist in the transform IR
184
+ // so the verification may fail.
180
185
let arguments = (ins
181
186
TransformHandleTypeInterface:$call,
182
187
StrAttr:$new_target);
183
188
184
- // The results are empty as the transformation does not produce any new payload.
189
+ // The results are empty as the transformation does not produce any new
190
+ // payload.
185
191
let results = (outs);
186
192
187
193
// Provide nice syntax.
@@ -224,8 +230,8 @@ must be modified with the provided rewriter.
224
230
// It can also carry additional user-defined state.
225
231
::mlir::transform::TransformState &state) {
226
232
227
- // First, we need to obtain the list of payload operations that are associated with
228
- // the operand handle.
233
+ // First, we need to obtain the list of payload operations that are associated
234
+ // with the operand handle.
229
235
auto payload = state.getPayloadOps(getCall());
230
236
231
237
// Then, we iterate over the list of operands and call the actual IR-mutating
@@ -280,56 +286,66 @@ void registerMyExtension(::mlir::DialectRegistry ®istry) {
280
286
After registering the extension, it becomes possible to use our new operation in the Transform dialect interpreter. The upstream testing pass can be used as is.
281
287
282
288
```mlir
283
- transform.sequence failures(propagate) {
284
- ^bb0(%arg0: !transform.any_op,
285
- %arg1: !transform.op<"linalg.matmul">,
286
- %arg2: !transform.op<"linalg.elemwise_binary">):
287
- // Since the %arg2 handle is associated with both elementwise operations,
288
- // we need to split it into two handles so we can target only the second
289
- // elementwise operation.
290
- %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">)
291
- -> (!transform.any_op, !transform.any_op)
292
-
293
- // The actual tiling transformation takes tile sizes as attributes. It produces a
294
- // handle to the loop generated during tiling.
295
- %loop, %tiled = transform.structured.tile_using_forall %max tile_sizes [8, 32]
296
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
297
-
298
- // We can now fuse the other operations into the loop. Here, we fuse
299
- // operations one-by-one. This requires the operation that is being fused
300
- // to define the value used within the loop, so the order of such fusions
301
- // is important. We could also use "transform.merge_handles" to obtain
302
- // a single handle to all operations and give it to `fuse_into_containing_op`
303
- // that would take care of the ordering in this case.
304
- %add_fused = transform.structured.fuse_into_containing_op %add into %loop
305
- : (!transform.any_op, !transform.any_op) -> !transform.any_op
306
- %matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop
307
- : (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op
308
-
309
- // Tile again to get the desired size. Note that this time this tiles the
310
- // "add" operation and fuses matmul into the loop, but doesn't affect the
311
- // "max" operation. This illustrates the precise targeting with the transform
312
- // dialect. Otherwise, it is difficult to differentiate "add" and "max", both
313
- // of which having the same kind.
314
- %loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
315
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
316
- %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2
317
- : (!transform.any_op, !transform.any_op) -> !transform.any_op
318
-
319
- // Since outlining is currently only implemented for region-holding operations
320
- // such as loops, use tiling to size 1 to materialize the outer loop that is
321
- // going to be outlined.
322
- %outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
323
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
324
- transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
325
- : (!transform.any_op, !transform.any_op) -> !transform.any_op
326
- %func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
327
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
328
-
329
- // Rewrite the call target.
330
- transform.my.change_call_target %call, "microkernel" : !transform.any_op
331
-
332
- transform.yield
289
+ module attributes {transform.with_named_sequence} {
290
+ transform.named_sequence @__transform_main(
291
+ %arg0: !transform.any_op,
292
+ %arg1: !transform.op<"linalg.matmul">,
293
+ %arg2: !transform.op<"linalg.elemwise_binary">) {
294
+ // Since the %arg2 handle is associated with both elementwise operations,
295
+ // we need to split it into two handles so we can target only the second
296
+ // elementwise operation.
297
+ %add, %max = transform.split_handle %arg2
298
+ : (!transform.op<"linalg.elemwise_binary">)
299
+ -> (!transform.any_op, !transform.any_op)
300
+
301
+ // The actual tiling transformation takes tile sizes as attributes. It
302
+ // produces a handle to the loop generated during tiling.
303
+ %loop, %tiled = transform.structured.tile_using_forall %max
304
+ tile_sizes [8, 32]
305
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
306
+
307
+ // We can now fuse the other operations into the loop. Here, we fuse
308
+ // operations one-by-one. This requires the operation that is being fused
309
+ // to define the value used within the loop, so the order of such fusions
310
+ // is important. We could also use "transform.merge_handles" to obtain
311
+ // a single handle to all operations and give it to
312
+ // `fuse_into_containing_op` that would take care of the ordering in this
313
+ // case.
314
+ %add_fused = transform.structured.fuse_into_containing_op %add into %loop
315
+ : (!transform.any_op, !transform.any_op) -> !transform.any_op
316
+ %matmul_fused = transform.structured.fuse_into_containing_op %arg1
317
+ into %loop
318
+ : (!transform.op<"linalg.matmul">, !transform.any_op)
319
+ -> !transform.any_op
320
+
321
+ // Tile again to get the desired size. Note that this time this tiles the
322
+ // "add" operation and fuses matmul into the loop, but doesn't affect the
323
+ // "max" operation. This illustrates the precise targeting with the
324
+ // transform dialect. Otherwise, it is difficult to differentiate "add" and
325
+ // "max", both of which having the same kind.
326
+ %loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused
327
+ tile_sizes [4, 4]
328
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
329
+ %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused
330
+ into %loop_2
331
+ : (!transform.any_op, !transform.any_op) -> !transform.any_op
332
+
333
+ // Since outlining is currently only implemented for region-holding
334
+ // operations such as loops, use tiling to size 1 to materialize the outer
335
+ // loop that is going to be outlined.
336
+ %outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
337
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
338
+ transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
339
+ : (!transform.any_op, !transform.any_op) -> !transform.any_op
340
+ %func, %call = transform.loop.outline %outline_target
341
+ {func_name = "outlined"}
342
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
343
+
344
+ // Rewrite the call target.
345
+ transform.my.change_call_target %call, "microkernel" : !transform.any_op
346
+
347
+ transform.yield
348
+ }
333
349
}
334
350
```
335
351
0 commit comments