-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add tests for xfer-permute-lowering (1/n)(nfc) #95529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add tests for xfer-permute-lowering (1/n)(nfc) #95529
Conversation
Adds more tests to "vector-transfer-permutation-lowering.mlir", specifically for the `TransferWritePermutationLowering` pattern - such tests seem to be missing ATM. The following edge cases are covered: * plain fixed-width (supported) * scalable vectors with mask (supported) * plain fixed-width, masked (not supported) This is a part of a larger effort to make sure that all key cases for patterns under `populateVectorTransferPermutationMapLoweringPatterns` (*) are tested. I also want to make sure that tests use consistent function and variable names. (*) `transform.apply_patterns.vector.transfer_permutation_patterns` in TD parlance)
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesAdds more tests to "vector-transfer-permutation-lowering.mlir", The following edge cases are covered:
This is a part of a larger effort to make sure that all key cases for (*) Full diff: https://github.com/llvm/llvm-project/pull/95529.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 0cd134717b1a0..ac5041d13f893 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -1,14 +1,81 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
///----------------------------------------------------------------------------------------
-/// vector.transfer_write
+/// vector.transfer_write -> vector.transpose + vector.transfer_read
///----------------------------------------------------------------------------------------
-/// Input:
-/// * vector.transfer_write op with a map which _is not_ the permutation of a
-/// minor identity
+/// Input:
+/// * vector.transfer_write op with a permutation that under a transpose
+/// _would be_ a permutation of a minor identity
/// Output:
-/// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
+/// * vector.transpose + vector.transfer_write with a map which _is_ a
+/// permutation of a minor identity
+
+// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
+// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
+// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
+// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
+func.func @xfer_write_perm_minor_id_with_transpose(
+ %arg0: vector<4x8xi16>,
+ %mem: memref<2x2x8x4xi16>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+ in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ } : vector<4x8xi16>, memref<2x2x8x4xi16>
+
+ return
+}
+
+// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) {
+// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
+// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
+func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+ %arg0: vector<4x[8]xi16>,
+ %mem: memref<2x2x?x4xi16>,
+ %mask: vector<[8]x4xi1>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask {
+ in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ } : vector<4x[8]xi16>, memref<2x2x?x4xi16>
+
+ return
+}
+
+// Masked version is not supported
+// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose_masked
+// CHECK-NOT: vector.transpose
+func.func @xfer_write_perm_minor_id_with_transpose_masked(
+ %arg0: vector<4x8xi16>,
+ %mem: memref<2x2x8x4xi16>,
+ %mask: vector<8x4xi1>) {
+
+ %c0 = arith.constant 0 : index
+ vector.mask %mask {
+ vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+ in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ } : vector<4x8xi16>, memref<2x2x8x4xi16>
+ } : vector<8x4xi1>
+
+ return
+}
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_read
+///----------------------------------------------------------------------------------------
+/// Input:
+/// * vector.transfer_write op with a map which _is not_ a permutation of a
/// minor identity
+/// Output:
+/// * vector.broadcast + vector.transpose + vector.transfer_write with a map
+/// which _is_ a permutation of a minor identity
// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
@@ -94,7 +161,7 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
///----------------------------------------------------------------------------------------
/// vector.transfer_read
///----------------------------------------------------------------------------------------
-/// Input:
+/// Input:
/// * vector.transfer_read op with a permutation map
/// Output:
/// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
@@ -190,6 +257,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+/// TODO: Review and categorize
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
// CHECK: func.func @transfer_read_reduce_rank_scalable(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks, it's a good start, I think there is a misunderstanding on the scope of those patterns. Feel free to tell me I'm wrong.
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
* Update test names * Added CHECK-NOT for permutation map that shouldn't be present * Refine comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points, I think the pipeline fails because of that missed parenthesis. Else, LGTM.
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
…owering` Add missing (
…tationLowering` Refine comments
TransferWritePermutationLowering
…5529) Adds more tests to "vector-transfer-permutation-lowering.mlir", specifically for the `TransferWritePermutationLowering` pattern - such tests seem to be missing ATM. The following edge cases are covered: * plain fixed-width (supported) * scalable vectors with mask (supported) * plain fixed-width, masked (not supported) This is a part of a larger effort to make sure that all key cases for patterns under `populateVectorTransferPermutationMapLoweringPatterns` (*) are tested. I also want to make sure that tests use consistent function and variable names. (*) `transform.apply_patterns.vector.transfer_permutation_patterns` in TD parlance)
Adds more tests to "vector-transfer-permutation-lowering.mlir",
specifically for the
TransferWritePermutationLowering
pattern - suchtests seem to be missing ATM.
The following edge cases are covered:
This is a part of a larger effort to make sure that all key cases for
patterns under
populateVectorTransferPermutationMapLoweringPatterns
(*) are tested. I also want to make sure that tests use consistent
function and variable names.
(*)
transform.apply_patterns.vector.transfer_permutation_patterns
inTD parlance)