-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector][test] Update tests for vector.xfter_{read|write} #91943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector][test] Update tests for vector.xfter_{read|write} #91943
Conversation
Updates tests "vector-transfer-permutation-lowering.mlir" to make a clearer split into tests for : * xfer_read vs xfer_write * fixed-width vs scalable tests This is in preparation for llvm#90835 and also for adding more tests for scalable vectors.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesUpdates tests "vector-transfer-permutation-lowering.mlir" to make a
This is in preparation for #90835 and also for adding more tests for Full diff: https://github.com/llvm/llvm-project/pull/91943.diff 1 Files Affected:
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 31bd19c0be8e8..2bd5c33faf506 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -1,13 +1,24 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-// CHECK-LABEL: func @lower_permutation_with_mask_fixed_width(
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+///----------------------------------------------------------------------------------------
+/// Input:
+/// * vector.transfer_write op with a map which _is not_ the permutation of a
+/// minor identity
+/// Output:
+/// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
+/// minor identity
+
+// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
-func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1 : index,
- %base2 : index) {
+func.func @permutation_with_mask_xfer_write_fixed_width(%A : memref<?x?xf32>, %base1 : index,
+ %base2 : index) {
+
%fn1 = arith.constant -2.0 : f32
%vf0 = vector.splat %fn1 : vector<7xf32>
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
@@ -17,7 +28,35 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
return
}
-// CHECK-LABEL: func.func @permutation_with_mask_scalable(
+// CHECK: func.func @permutation_with_mask_xfer_write_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1xi16>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
+// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
+// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
+// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [1, 2, 0] : vector<1x4x[8]xi16> to vector<4x[8]x1xi16>
+// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{.*}}, %[[TRANSPOSE_1]] {in_bounds = [true, true, true]} : vector<4x[8]x1xi16>, memref<1x4x?x1xi16>
+// CHECK: return
+func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1xi16>
+
+ return
+}
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+/// Input:
+/// * vector.transfer_read op with a permutation map
+/// Output:
+/// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
+/// vector.transpose op
+
+// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[IDX_1:.*]]: index,
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
@@ -29,7 +68,7 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
// CHECK: }
-func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
+func.func @permutation_with_mask_xfer_read_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
@@ -41,24 +80,6 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
-// CHECK: func.func @permutation_with_mask_transfer_write_scalable(
-// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
-// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
-// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
-// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
-// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
-// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
-// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
-// CHECK: return
-func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
- %c0 = arith.constant 0 : index
- vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
-} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
-
- return
-}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
CC @nujaa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, simply one debatable comment for consistency.
I'll adapt my PR when this one lands.
@@ -29,7 +68,7 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1 | |||
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32> | |||
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32> | |||
// CHECK: } | |||
func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) { | |||
func.func @permutation_with_mask_xfer_read_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would adding permutation_with_mask_xfer_read_fixed_width
make sense as well ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. In fact, I didn't realise it was missing 😅 I'll add it before landing this, thanks!
Add test for fixed-width cased, more consistent var names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG (bar some nits)
P.s. I'd change the tag from [nfc]
to [test]
-- especially since you're adding a new test.
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Outdated
Show resolved
Hide resolved
…|write} Simplify test
(the summary should also be updated to mention the new test case) |
Good point, thanks, updated! |
…8cedb5f80 Local branch amd-gfx d168ced Merged main:c4e9e41199127bb288e84e9477da99f28941edb3 into amd-gfx:c5f22f9383e1 Remote branch main 9f858c7 [mlir][vector][test] Update tests for vector.xfter_{read|write} (llvm#91943)
Updates tests in "vector-transfer-permutation-lowering.mlir" to make a
clearer split into cases for :
A new test case is added for fixed-width vectors for
vector.transfer_read
.This is to complement an existing test for scalable vectors.
This is in preparation for #90835 and also for adding more tests for
scalable vectors.