diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 465e01d8b658f..5d5ee945b6686 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -1,5 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s +import functools +from typing import Callable + from mlir.ir import * from mlir.dialects import transform from mlir.dialects import pdl @@ -18,33 +21,40 @@ def run(f): return f +def create_sequence(func: Callable) -> Callable: + @functools.wraps(func) + def decorated() -> None: + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + func(sequence.bodyTarget) + transform.YieldOp() + + return decorated + + @run -def testBufferizeToAllocationOpCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.BufferizeToAllocationOp(sequence.bodyTarget) - transform.YieldOp() +@create_sequence +def testBufferizeToAllocationOpCompact(target): + structured.BufferizeToAllocationOp(target) # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact # CHECK: transform.sequence # CHECK: transform.structured.bufferize_to_allocation @run -def testBufferizeToAllocationOpArgs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() +@create_sequence +def testBufferizeToAllocationOpArgs(target): + structured.BufferizeToAllocationOp( + target, + memory_space=3, + memcpy_op="memref.copy", + alloc_op="memref.alloca", + bufferize_destination_only=True, ) - with InsertionPoint(sequence.body): - structured.BufferizeToAllocationOp( - sequence.bodyTarget, - memory_space=3, - memcpy_op="memref.copy", - alloc_op="memref.alloca", - bufferize_destination_only=True, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs # CHECK: transform.sequence # CHECK: transform.structured.bufferize_to_allocation @@ -55,78 +65,54 @@ def testBufferizeToAllocationOpArgs(): @run -def testDecompose(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.DecomposeOp(sequence.bodyTarget) - transform.YieldOp() +@create_sequence +def testDecompose(target): + structured.DecomposeOp(target) # CHECK-LABEL: TEST: testDecompose # CHECK: transform.sequence # CHECK: transform.structured.decompose @run -def testFuseIntoContainingOpTypes(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() +@create_sequence +def testFuseIntoContainingOpTypes(target): + fused = structured.MatchOp.match_op_names(target, ["test.dummy"]) + containing = structured.MatchOp.match_op_names(target, ["test.dummy"]) + structured.FuseIntoContainingOp( + transform.OperationType.get("test.dummy"), + transform.OperationType.get("test.dummy"), + fused, + containing, ) - with InsertionPoint(sequence.body): - fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) - containing = structured.MatchOp.match_op_names( - sequence.bodyTarget, ["test.dummy"] - ) - structured.FuseIntoContainingOp( - transform.OperationType.get("test.dummy"), - transform.OperationType.get("test.dummy"), - fused, - containing, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes # CHECK: = transform.structured.fuse_into_containing_op # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">) @run -def testFuseIntoContainingOpCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) - containing = structured.MatchOp.match_op_names( - sequence.bodyTarget, ["test.dummy"] - ) - structured.FuseIntoContainingOp(fused, containing) - transform.YieldOp() +@create_sequence +def testFuseIntoContainingOpCompact(target): + fused = structured.MatchOp.match_op_names(target, ["test.dummy"]) + containing = structured.MatchOp.match_op_names(target, ["test.dummy"]) + structured.FuseIntoContainingOp(fused, containing) # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact # CHECK: = transform.structured.fuse_into_containing_op # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) @run -def testGeneralize(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.GeneralizeOp(sequence.bodyTarget) - transform.YieldOp() +@create_sequence +def testGeneralize(target): + structured.GeneralizeOp(target) # CHECK-LABEL: TEST: testGeneralize # CHECK: transform.sequence # CHECK: transform.structured.generalize @run -def testInterchange(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0]) - transform.YieldOp() +@create_sequence +def testInterchange(target): + structured.InterchangeOp(target, iterator_interchange=[1, 0]) # CHECK-LABEL: TEST: testInterchange # CHECK: transform.sequence # CHECK: transform.structured.interchange @@ -134,15 +120,11 @@ def testInterchange(): @run -def testMapCopyToThreadsOpCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() +@create_sequence +def testMapCopyToThreadsOpCompact(target): + structured.MapCopyToThreadsOp( + target, total_num_threads=32, desired_bit_alignment=128 ) - with InsertionPoint(sequence.body): - structured.MapCopyToThreadsOp( - sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128 - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact # CHECK: = transform.structured.gpu.map_copy_to_threads # CHECK-SAME: total_num_threads = 32 @@ -151,19 +133,15 @@ def testMapCopyToThreadsOpCompact(): @run -def testMapCopyToThreadsOpTypes(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() +@create_sequence +def testMapCopyToThreadsOpTypes(target): + structured.MapCopyToThreadsOp( + transform.OperationType.get("test.opA"), + transform.OperationType.get("test.opB"), + target, + total_num_threads=32, + desired_bit_alignment=128, ) - with InsertionPoint(sequence.body): - structured.MapCopyToThreadsOp( - transform.OperationType.get("test.opA"), - transform.OperationType.get("test.opB"), - sequence.bodyTarget, - total_num_threads=32, - desired_bit_alignment=128, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes # CHECK: = transform.structured.gpu.map_copy_to_threads # CHECK-SAME: total_num_threads = 32 @@ -172,13 +150,9 @@ def testMapCopyToThreadsOpTypes(): @run -def testMatchOpNamesString(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy") - transform.YieldOp() +@create_sequence +def testMatchOpNamesString(target): + structured.MatchOp.match_op_names(target, "test.dummy") # CHECK-LABEL: TEST: testMatchOpNamesString # CHECK: transform.structured.match ops # CHECK-SAME: ["test.dummy"] @@ -186,13 +160,9 @@ def testMatchOpNamesString(): @run -def testMatchOpNamesList(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) - transform.YieldOp() +@create_sequence +def testMatchOpNamesList(target): + structured.MatchOp.match_op_names(target, ["test.dummy"]) # CHECK-LABEL: TEST: testMatchOpNamesList # CHECK: transform.structured.match ops # CHECK-SAME: ["test.dummy"] @@ -200,13 +170,9 @@ def testMatchOpNamesList(): @run -def testMaskedVectorizeStatic(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4]) - transform.YieldOp() +@create_sequence +def testMaskedVectorizeStatic(target): + structured.MaskedVectorizeOp(target, [16, 4]) # CHECK-LABEL: TEST: testMaskedVectorizeStatic # CHECK: transform.sequence # CHECK: transform.structured.masked_vectorize @@ -214,14 +180,10 @@ def testMaskedVectorizeStatic(): @run -def testMaskedVectorizeArray(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - sizes = Attribute.parse("[16, 4]") - structured.MaskedVectorizeOp(sequence.bodyTarget, sizes) - transform.YieldOp() +@create_sequence +def testMaskedVectorizeArray(target): + sizes = Attribute.parse("[16, 4]") + structured.MaskedVectorizeOp(target, sizes) # CHECK-LABEL: TEST: testMaskedVectorizeArray # CHECK: transform.sequence # CHECK: transform.structured.masked_vectorize @@ -229,15 +191,11 @@ def testMaskedVectorizeArray(): @run -def testMaskedVectorizeMixed(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"]) - sz2 = Attribute.parse("4") - structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2]) - transform.YieldOp() +@create_sequence +def testMaskedVectorizeMixed(target): + sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + sz2 = Attribute.parse("4") + structured.MaskedVectorizeOp(target, [sz1, sz2]) # CHECK-LABEL: TEST: testMaskedVectorizeMixed # CHECK: transform.sequence # CHECK: %[[V0:.*]] = transform.structured.match @@ -246,15 +204,11 @@ def testMaskedVectorizeMixed(): @run -def testMaskedVectorizeScalable(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"]) - sz2 = Attribute.parse("4") - structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]]) - transform.YieldOp() +@create_sequence +def testMaskedVectorizeScalable(target): + sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + sz2 = Attribute.parse("4") + structured.MaskedVectorizeOp(target, [16, [sz1], [sz2], [8]]) # CHECK-LABEL: TEST: testMaskedVectorizeScalable # CHECK: transform.sequence # CHECK-DAG: %[[V0:.*]] = transform.structured.match @@ -263,15 +217,9 @@ def testMaskedVectorizeScalable(): @run -def testMaskedVectorizeArgs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.MaskedVectorizeOp( - sequence.bodyTarget, [16, 4], vectorize_nd_extract=True - ) - transform.YieldOp() +@create_sequence +def testMaskedVectorizeArgs(target): + structured.MaskedVectorizeOp(target, [16, 4], vectorize_nd_extract=True) # CHECK-LABEL: TEST: testMaskedVectorizeArgs # CHECK: transform.sequence # CHECK: transform.structured.masked_vectorize @@ -279,17 +227,13 @@ def testMaskedVectorizeArgs(): @run -def testMatchOpNamesTyped(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() +@create_sequence +def testMatchOpNamesTyped(target): + structured.MatchOp.match_op_names( + transform.OperationType.get("test.dummy"), + target, + ["test.dummy"], ) - with InsertionPoint(sequence.body): - structured.MatchOp.match_op_names( - transform.OperationType.get("test.dummy"), - sequence.bodyTarget, - ["test.dummy"], - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMatchOpNamesTyped # CHECK: transform.structured.match ops # CHECK-SAME: ["test.dummy"] @@ -297,15 +241,11 @@ def testMatchOpNamesTyped(): @run -def testMultitileSizesCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() +@create_sequence +def testMultitileSizesCompact(target): + structured.MultiTileSizesOp( + transform.AnyOpType.get(), target, dimension=1, target_size=42 ) - with InsertionPoint(sequence.body): - structured.MultiTileSizesOp( - pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42 - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMultitileSizes # CHECK: transform.sequence # CHECK-NOT: divisor @@ -318,19 +258,15 @@ def testMultitileSizesCompact(): @run -def testMultitileSizesAllArgs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() +@create_sequence +def testMultitileSizesAllArgs(target): + structured.MultiTileSizesOp( + transform.AnyOpType.get(), + target, + dimension=1, + target_size=42, + divisor=2, ) - with InsertionPoint(sequence.body): - structured.MultiTileSizesOp( - pdl.OperationType.get(), - sequence.bodyTarget, - dimension=1, - target_size=42, - divisor=2, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMultitileSizes # CHECK: transform.sequence # CHECK: transform.structured.multitile_sizes @@ -340,13 +276,9 @@ def testMultitileSizesAllArgs(): @run -def testPadOpNoArgs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.PadOp(sequence.bodyTarget) - transform.YieldOp() +@create_sequence +def testPadOpNoArgs(target): + structured.PadOp(target) # CHECK-LABEL: TEST: testPadOpNoArgs # CHECK: transform.sequence # CHECK: transform.structured.pad @@ -359,21 +291,17 @@ def testPadOpNoArgs(): @run -def testPadOpArgs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() +@create_sequence +def testPadOpArgs(target): + structured.PadOp( + target, + padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")], + padding_dimensions=Attribute.parse("[1]"), + pad_to_multiple_of=[128], + pack_paddings=[0], + transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")], + copy_back_op="linalg.copy", ) - with InsertionPoint(sequence.body): - structured.PadOp( - sequence.bodyTarget, - padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")], - padding_dimensions=Attribute.parse("[1]"), - pad_to_multiple_of=[128], - pack_paddings=[0], - transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")], - copy_back_op="linalg.copy", - ) - transform.YieldOp() # CHECK-LABEL: TEST: testPadOpArgs # CHECK: transform.sequence # CHECK: transform.structured.pad @@ -386,39 +314,27 @@ def testPadOpArgs(): @run -def testScalarize(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.ScalarizeOp(sequence.bodyTarget) - transform.YieldOp() +@create_sequence +def testScalarize(target): + structured.ScalarizeOp(target) # CHECK-LABEL: TEST: testScalarize # CHECK: transform.structured.scalarize @run -def testSplit(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) - structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1]) - transform.YieldOp() +@create_sequence +def testSplit(target): + split = structured.SplitOp(target, dimension=1, split_point=42) + structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1]) # CHECK-LABEL: TEST: testSplit # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 @run -def testTileCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) - transform.YieldOp() +@create_sequence +def testTileCompact(target): + structured.TileOp(target, sizes=[4, 8], interchange=[0, 1]) # CHECK-LABEL: TEST: testTileCompact # CHECK: transform.sequence # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] @@ -426,15 +342,11 @@ def testTileCompact(): @run -def testTileAttributes(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) +@create_sequence +def testTileAttributes(target): attr = DenseI64ArrayAttr.get([4, 8]) ichange = DenseI64ArrayAttr.get([0, 1]) - with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) - transform.YieldOp() + structured.TileOp(target, sizes=attr, interchange=ichange) # CHECK-LABEL: TEST: testTileAttributes # CHECK: transform.sequence # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] @@ -442,15 +354,9 @@ def testTileAttributes(): @run -def testTileZero(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - structured.TileOp( - sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3] - ) - transform.YieldOp() +@create_sequence +def testTileZero(target): + structured.TileOp(target, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) # CHECK-LABEL: TEST: testTileZero # CHECK: transform.sequence # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] @@ -480,32 +386,22 @@ def testTileDynamic(): @run -def testTileExplicitLoopTypeSingle(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - structured.TileOp( - transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4] - ) - transform.YieldOp() +@create_sequence +def testTileExplicitLoopTypeSingle(target): + structured.TileOp(transform.OperationType.get("scf.for"), target, sizes=[2, 3, 4]) # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) -> # CHECK-COUNT-3: !transform.op<"scf.for"> @run -def testTileExplicitLoopTypeAll(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) +@create_sequence +def testTileExplicitLoopTypeAll(target): types = [ transform.OperationType.get(x) for x in ["scf.for", "scf.parallel", "scf.forall"] ] - with InsertionPoint(sequence.body): - structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4]) - transform.YieldOp() + structured.TileOp(types, target, sizes=[2, 3, 4]) # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll # CHECK: = transform.structured.tile # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, @@ -513,31 +409,22 @@ def testTileExplicitLoopTypeAll(): @run -def testTileScalable(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() +@create_sequence +def testTileScalable(target): + structured.TileOp( + target, + sizes=[4, [2]], ) - with InsertionPoint(sequence.body): - structured.TileOp( - sequence.bodyTarget, - sizes=[4, [2]], - ) - transform.YieldOp() # CHECK-LABEL: TEST: testTileScalable # CHECK: transform.sequence # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]] @run -def testTileToForallCompact(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, - [], - transform.OperationType.get("linalg.matmul"), - ) - with InsertionPoint(sequence.body): - structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4]) - transform.YieldOp() +@create_sequence +def testTileToForallCompact(target): + matmul = transform.CastOp(transform.OperationType.get("linalg.matmul"), target) + structured.TileToForallOp(matmul, num_threads=[2, 3, 4]) # CHECK-LABEL: TEST: testTileToForallCompact # CHECK: = transform.structured.tile_to_forall_op # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] @@ -545,18 +432,14 @@ def testTileToForallCompact(): @run -def testTileToForallLoopsAndTileOpTypes(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() +@create_sequence +def testTileToForallLoopsAndTileOpTypes(target): + structured.TileToForallOp( + transform.OperationType.get("scf.forall"), # loops_type + transform.OperationType.get("linalg.matmul"), # tiled_op_type + target, + num_threads=[2, 3, 4], ) - with InsertionPoint(sequence.body): - structured.TileToForallOp( - transform.OperationType.get("scf.forall"), # loops_type - transform.OperationType.get("linalg.matmul"), # tiled_op_type - sequence.bodyTarget, - num_threads=[2, 3, 4], - ) - transform.YieldOp() # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes # CHECK: = transform.structured.tile_to_forall_op # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] @@ -564,76 +447,54 @@ def testTileToForallLoopsAndTileOpTypes(): @run -def testTileToForallTileSizes(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4]) - transform.YieldOp() +@create_sequence +def testTileToForallTileSizes(target): + structured.TileToForallOp(target, tile_sizes=[2, 3, 4]) # CHECK-LABEL: TEST: testTileToForallTileSizes # CHECK: = transform.structured.tile_to_forall_op # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4] @run -def testTileToForallMixedDynamic(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) - structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4]) - transform.YieldOp() +@create_sequence +def testTileToForallMixedDynamic(target): + n = structured.MatchOp.match_op_names(target, ["test.dummy"]) + structured.TileToForallOp(target, num_threads=[n, 3, 4]) # CHECK-LABEL: TEST: testTileToForallMixedDynamic # CHECK: = transform.structured.tile_to_forall_op # CHECK-SAME: num_threads [%{{.*}} : !transform.any_op, 3, 4] @run -def testTileToForallPackedDynamic(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) - structured.TileToForallOp(sequence.bodyTarget, num_threads=n) - transform.YieldOp() +@create_sequence +def testTileToForallPackedDynamic(target): + n = structured.MatchOp.match_op_names(target, ["test.dummy"]) + structured.TileToForallOp(target, num_threads=n) # CHECK-LABEL: TEST: testTileToForallPackedDynamic # CHECK: = transform.structured.tile_to_forall_op # CHECK-SAME: num_threads *(%0 : !transform.any_op) @run -def testTileToForallMapping(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - mapping = Attribute.parse("[ #gpu.thread, #gpu.thread ]") - structured.TileToForallOp( - sequence.bodyTarget, num_threads=[2, 3], mapping=mapping - ) - transform.YieldOp() +@create_sequence +def testTileToForallMapping(target): + mapping = Attribute.parse("[ #gpu.thread, #gpu.thread ]") + structured.TileToForallOp(target, num_threads=[2, 3], mapping=mapping) # CHECK-LABEL: TEST: testTileToForallMapping # CHECK: = transform.structured.tile_to_forall_op # CHECK-SAME: mapping = [#gpu.thread, #gpu.thread] @run -def testVectorizeAllAttrs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() +@create_sequence +def testVectorizeAllAttrs(target): + structured.VectorizeOp( + target, + disable_multi_reduction_to_contract_patterns=True, + disable_transfer_permutation_map_lowering_patterns=True, + vectorize_nd_extract=True, + vectorize_padding=True, ) - with InsertionPoint(sequence.body): - structured.VectorizeOp( - sequence.bodyTarget, - disable_multi_reduction_to_contract_patterns=True, - disable_transfer_permutation_map_lowering_patterns=True, - vectorize_nd_extract=True, - vectorize_padding=True, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testVectorizeAllAttrs # CHECK: transform.sequence # CHECK: = transform.structured.vectorize @@ -644,19 +505,15 @@ def testVectorizeAllAttrs(): @run -def testVectorizeNoAttrs(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() +@create_sequence +def testVectorizeNoAttrs(target): + structured.VectorizeOp( + target, + disable_multi_reduction_to_contract_patterns=False, + disable_transfer_permutation_map_lowering_patterns=False, + vectorize_nd_extract=False, + vectorize_padding=False, ) - with InsertionPoint(sequence.body): - structured.VectorizeOp( - sequence.bodyTarget, - disable_multi_reduction_to_contract_patterns=False, - disable_transfer_permutation_map_lowering_patterns=False, - vectorize_nd_extract=False, - vectorize_padding=False, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testVectorizeNoAttrs # CHECK: transform.sequence # CHECK: = transform.structured.vectorize @@ -667,20 +524,16 @@ def testVectorizeNoAttrs(): @run -def testMatchInterfaceEnum(): +@create_sequence +def testMatchInterfaceEnum(target): names = ArrayAttr.get([StringAttr.get("test.dummy")]) result_type = transform.AnyOpType.get() - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + fused = structured.MatchOp.__base__( + result_type, + target, + ops=names, + interface=structured.MatchInterfaceEnum.LinalgOp, ) - with InsertionPoint(sequence.body): - fused = structured.MatchOp.__base__( - result_type, - sequence.bodyTarget, - ops=names, - interface=structured.MatchInterfaceEnum.LinalgOp, - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMatchInterfaceEnum # CHECK: transform.sequence # CHECK: = transform.structured.match @@ -688,7 +541,8 @@ def testMatchInterfaceEnum(): @run -def testMatchInterfaceEnumReplaceAttributeBuilder(): +@create_sequence +def testMatchInterfaceEnumReplaceAttributeBuilder(target): @register_attribute_builder("MatchInterfaceEnum", replace=True) def match_interface_enum(x, context): if x == "LinalgOp": @@ -699,17 +553,12 @@ def match_interface_enum(x, context): names = ArrayAttr.get([StringAttr.get("test.dummy")]) result_type = transform.AnyOpType.get() - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + fused = structured.MatchOp.__base__( + result_type, + target, + ops=names, + interface="TilingInterface", ) - with InsertionPoint(sequence.body): - fused = structured.MatchOp.__base__( - result_type, - sequence.bodyTarget, - ops=names, - interface="TilingInterface", - ) - transform.YieldOp() # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder # CHECK: transform.sequence # CHECK: = transform.structured.match