Skip to content

[mlir][vector][test] Update tests for vector.xfter_{read|write} #91943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 13, 2024

Updates tests in "vector-transfer-permutation-lowering.mlir" to make a
clearer split into cases for :

  • xfer_read vs xfer_write
  • fixed-width vs scalable tests

A new test case is added for fixed-width vectors for vector.transfer_read.
This is to complement an existing test for scalable vectors.

This is in preparation for #90835 and also for adding more tests for
scalable vectors.

Updates tests "vector-transfer-permutation-lowering.mlir" to make a
clearer split into tests for :
  * xfer_read vs xfer_write
  * fixed-width vs scalable tests

This is in preparation for llvm#90835 and also for adding more tests for
scalable vectors.
@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Updates tests "vector-transfer-permutation-lowering.mlir" to make a
clearer split into tests for :

  • xfer_read vs xfer_write
  • fixed-width vs scalable tests

This is in preparation for #90835 and also for adding more tests for
scalable vectors.


Full diff: https://github.com/llvm/llvm-project/pull/91943.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+44-23)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 31bd19c0be8e8..2bd5c33faf506 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -1,13 +1,24 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
-// CHECK-LABEL: func @lower_permutation_with_mask_fixed_width(
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+///----------------------------------------------------------------------------------------
+/// Input: 
+///   * vector.transfer_write op with a map which _is not_ the permutation of a
+///     minor identity
+/// Output:
+///   * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
+///     minor identity
+
+// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
 //       CHECK:   %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
 //       CHECK:   %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
 //       CHECK:   %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
 //       CHECK:   %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
 //       CHECK:   vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
-func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1 : index,
-                                       %base2 : index) {
+func.func @permutation_with_mask_xfer_write_fixed_width(%A : memref<?x?xf32>, %base1 : index,
+                                                   %base2 : index) {
+
   %fn1 = arith.constant -2.0 : f32
   %vf0 = vector.splat %fn1 : vector<7xf32>
   %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
@@ -17,7 +28,35 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
   return
 }
 
-// CHECK-LABEL:   func.func @permutation_with_mask_scalable(
+// CHECK:           func.func @permutation_with_mask_xfer_write_scalable(
+// CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:        %[[ARG_1:.*]]: memref<1x4x?x1xi16>,
+// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK:             %[[C0:.*]] = arith.constant 0 : index
+// CHECK:             %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
+// CHECK:             %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
+// CHECK:             %[[TRANSPOSE_1:.*]] =  vector.transpose %[[BCAST_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
+// CHECK:             %[[TRANSPOSE_2:.*]] =  vector.transpose %[[BCAST_1]], [1, 2, 0] : vector<1x4x[8]xi16> to vector<4x[8]x1xi16>
+// CHECK:             vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{.*}}, %[[TRANSPOSE_1]] {in_bounds = [true, true, true]} : vector<4x[8]x1xi16>, memref<1x4x?x1xi16>
+// CHECK:             return
+func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1xi16>, %mask:  vector<4x[8]xi1>){
+     %c0 = arith.constant 0 : index
+      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1xi16>
+
+    return
+}
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+/// Input: 
+///   * vector.transfer_read op with a permutation map
+/// Output:
+///   * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
+///     vector.transpose op
+
+// CHECK-LABEL:   func.func @permutation_with_mask_xfer_read_scalable(
 // CHECK-SAME:      %[[ARG_0:.*]]: memref<?x?xf32>,
 // CHECK-SAME:      %[[IDX_1:.*]]: index,
 // CHECK-SAME:      %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
@@ -29,7 +68,7 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
 // CHECK:           %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
 // CHECK:           return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
 // CHECK:         }
-func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
+func.func @permutation_with_mask_xfer_read_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
 
   %c0 = arith.constant 0 : index
   %cst_0 = arith.constant 0.000000e+00 : f32
@@ -41,24 +80,6 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
   return %1 : vector<8x[4]x2xf32>
 }
 
-// CHECK:           func.func @permutation_with_mask_transfer_write_scalable(
-// CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
-// CHECK-SAME:        %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
-// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>) {
-// CHECK:             %[[C0:.*]] = arith.constant 0 : index
-// CHECK:             %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
-// CHECK:             %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
-// CHECK:             %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
-// CHECK:             %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
-// CHECK:             vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
-// CHECK:             return
-func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask:  vector<4x[8]xi1>){
-     %c0 = arith.constant 0 : index
-      vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
-} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
-
-    return
-}
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@banach-space banach-space requested a review from MacDue May 13, 2024 10:22
@banach-space
Copy link
Contributor Author

CC @nujaa

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, simply one debatable comment for consistency.
I'll adapt my PR when this one lands.

@@ -29,7 +68,7 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
// CHECK: }
func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
func.func @permutation_with_mask_xfer_read_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would adding permutation_with_mask_xfer_read_fixed_width make sense as well ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In fact, I didn't realise it was missing 😅 I'll add it before landing this, thanks!

Add test for fixed-width cased, more consistent var names
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG (bar some nits)

P.s. I'd change the tag from [nfc] to [test] -- especially since you're adding a new test.

@MacDue
Copy link
Member

MacDue commented May 13, 2024

(the summary should also be updated to mention the new test case)

@banach-space banach-space changed the title [mlir][vector][nfc] Update tests for vector.xfter_{read|write} [mlir][vector][test] Update tests for vector.xfter_{read|write} May 13, 2024
@banach-space
Copy link
Contributor Author

(the summary should also be updated to mention the new test case)

Good point, thanks, updated!

@banach-space banach-space merged commit 9f858c7 into llvm:main May 13, 2024
3 of 4 checks passed
@banach-space banach-space deleted the andrzej/update_xfer_permute_tests branch May 13, 2024 18:17
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Jun 5, 2024
…8cedb5f80

Local branch amd-gfx d168ced Merged main:c4e9e41199127bb288e84e9477da99f28941edb3 into amd-gfx:c5f22f9383e1
Remote branch main 9f858c7 [mlir][vector][test] Update tests for vector.xfter_{read|write} (llvm#91943)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants