From 384d72fe5e8a7a863d82c3c33f4f609ee56f6a5e Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 14 Jun 2024 12:33:24 +0100 Subject: [PATCH 1/4] [mlir][vector] Add tests for `TransferWritePermutationLowering` Adds more tests to "vector-transfer-permutation-lowering.mlir", specifically for the `TransferWritePermutationLowering` pattern - such tests seem to be missing ATM. The following edge cases are covered: * plain fixed-width (supported) * scalable vectors with mask (supported) * plain fixed-width, masked (not supported) This is a part of a larger effort to make sure that all key cases for patterns under `populateVectorTransferPermutationMapLoweringPatterns` (*) are tested. I also want to make sure that tests use consistent function and variable names. (*) `transform.apply_patterns.vector.transfer_permutation_patterns` in TD parlance) --- .../vector-transfer-permutation-lowering.mlir | 83 +++++++++++++++++-- 1 file changed, 77 insertions(+), 6 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index 0cd134717b1a0..ac5041d13f893 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -1,14 +1,81 @@ // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s ///---------------------------------------------------------------------------------------- -/// vector.transfer_write +/// vector.transfer_write -> vector.transpose + vector.transfer_read ///---------------------------------------------------------------------------------------- -/// Input: -/// * vector.transfer_write op with a map which _is not_ the permutation of a -/// minor identity +/// Input: +/// * vector.transfer_write op with a permutation that under a transpose +/// _would be_ a permutation of a minor identity /// Output: -/// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a +/// * vector.transpose + vector.transfer_write with a map which _is_ a +/// permutation of a minor identity + +// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>, +// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) { +// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16> +// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16> +func.func @xfer_write_perm_minor_id_with_transpose( + %arg0: vector<4x8xi16>, + %mem: memref<2x2x8x4xi16>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] { + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + } : vector<4x8xi16>, memref<2x2x8x4xi16> + + return +} + +// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>, +// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>, +// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) { +// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16> +// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16> +func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable( + %arg0: vector<4x[8]xi16>, + %mem: memref<2x2x?x4xi16>, + %mask: vector<[8]x4xi1>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask { + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + } : vector<4x[8]xi16>, memref<2x2x?x4xi16> + + return +} + +// Masked version is not supported +// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose_masked +// CHECK-NOT: vector.transpose +func.func @xfer_write_perm_minor_id_with_transpose_masked( + %arg0: vector<4x8xi16>, + %mem: memref<2x2x8x4xi16>, + %mask: vector<8x4xi1>) { + + %c0 = arith.constant 0 : index + vector.mask %mask { + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] { + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + } : vector<4x8xi16>, memref<2x2x8x4xi16> + } : vector<8x4xi1> + + return +} + +///---------------------------------------------------------------------------------------- +/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_read +///---------------------------------------------------------------------------------------- +/// Input: +/// * vector.transfer_write op with a map which _is not_ a permutation of a /// minor identity +/// Output: +/// * vector.broadcast + vector.transpose + vector.transfer_write with a map +/// which _is_ a permutation of a minor identity // CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width( // CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32> @@ -94,7 +161,7 @@ func.func @masked_non_permutation_xfer_write_fixed_width( ///---------------------------------------------------------------------------------------- /// vector.transfer_read ///---------------------------------------------------------------------------------------- -/// Input: +/// Input: /// * vector.transfer_read op with a permutation map /// Output: /// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy + @@ -190,6 +257,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// vector.transfer_read +///---------------------------------------------------------------------------------------- +/// TODO: Review and categorize // CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> // CHECK: func.func @transfer_read_reduce_rank_scalable( From 42b374591583254d39343b9604b0ceead9b2ee04 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 14 Jun 2024 16:12:27 +0100 Subject: [PATCH 2/4] !fixup [mlir][vector] Add tests for `TransferWritePermutationLowering` * Update test names * Added CHECK-NOT for permutation map that shouldn't be present * Refine comments --- .../vector-transfer-permutation-lowering.mlir | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index ac5041d13f893..c038baae72e78 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -2,20 +2,23 @@ ///---------------------------------------------------------------------------------------- /// vector.transfer_write -> vector.transpose + vector.transfer_read +/// [Pattern: TransferWritePermutationLowering] ///---------------------------------------------------------------------------------------- /// Input: /// * vector.transfer_write op with a permutation that under a transpose -/// _would be_ a permutation of a minor identity +/// _would be_ a minor identity permutation map /// Output: -/// * vector.transpose + vector.transfer_write with a map which _is_ a -/// permutation of a minor identity +/// * vector.transpose + vector.transfer_write with a permutation map which +/// _is_ a minor identity -// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose( +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map // CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>, // CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) { // CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16> -// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16> -func.func @xfer_write_perm_minor_id_with_transpose( +// CHECK: vector.transfer_write +// CHECK-NOT: permutation_map +// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16> +func.func @xfer_write_transposing_permutation_map %arg0: vector<4x8xi16>, %mem: memref<2x2x8x4xi16>) { @@ -28,13 +31,15 @@ func.func @xfer_write_perm_minor_id_with_transpose( return } -// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable( +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable // CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>, // CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>, // CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) { // CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16> -// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16> -func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable( +// CHECK: vector.transfer_write +// CHECK-NOT: permutation_map +// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16> +func.func @xfer_write_transposing_permutation_map_with_mask_scalable( %arg0: vector<4x[8]xi16>, %mem: memref<2x2x?x4xi16>, %mask: vector<[8]x4xi1>) { @@ -49,9 +54,9 @@ func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable( } // Masked version is not supported -// CHECK-LABEL: func.func @xfer_write_perm_minor_id_with_transpose_masked +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_transpose_masked // CHECK-NOT: vector.transpose -func.func @xfer_write_perm_minor_id_with_transpose_masked( +func.func @xfer_write_transposing_permutation_map_with_transpose_masked( %arg0: vector<4x8xi16>, %mem: memref<2x2x8x4xi16>, %mask: vector<8x4xi1>) { @@ -59,8 +64,8 @@ func.func @xfer_write_perm_minor_id_with_transpose_masked( %c0 = arith.constant 0 : index vector.mask %mask { vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] { - in_bounds = [true, true], - permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> } : vector<4x8xi16>, memref<2x2x8x4xi16> } : vector<8x4xi1> From c0752d00f6fb742ee3e338fd5c7efbf598235a15 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 17 Jun 2024 16:47:16 +0100 Subject: [PATCH 3/4] fixup! !fixup [mlir][vector] Add tests for `TransferWritePermutationLowering` Add missing ( --- .../Dialect/Vector/vector-transfer-permutation-lowering.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index c038baae72e78..2682b08dee117 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -18,7 +18,7 @@ // CHECK: vector.transfer_write // CHECK-NOT: permutation_map // CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16> -func.func @xfer_write_transposing_permutation_map +func.func @xfer_write_transposing_permutation_map( %arg0: vector<4x8xi16>, %mem: memref<2x2x8x4xi16>) { From a05c18787fee8d081e81e46df214e49b5ab93424 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 18 Jun 2024 10:09:31 +0100 Subject: [PATCH 4/4] fixup! fixup! !fixup [mlir][vector] Add tests for `TransferWritePermutationLowering` Refine comments --- .../Vector/vector-transfer-permutation-lowering.mlir | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index 2682b08dee117..35418b38df9b2 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s ///---------------------------------------------------------------------------------------- -/// vector.transfer_write -> vector.transpose + vector.transfer_read +/// vector.transfer_write -> vector.transpose + vector.transfer_write /// [Pattern: TransferWritePermutationLowering] ///---------------------------------------------------------------------------------------- /// Input: @@ -54,9 +54,9 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable( } // Masked version is not supported -// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_transpose_masked +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked // CHECK-NOT: vector.transpose -func.func @xfer_write_transposing_permutation_map_with_transpose_masked( +func.func @xfer_write_transposing_permutation_map_masked( %arg0: vector<4x8xi16>, %mem: memref<2x2x8x4xi16>, %mask: vector<8x4xi1>) { @@ -73,7 +73,8 @@ func.func @xfer_write_transposing_permutation_map_with_transpose_masked( } ///---------------------------------------------------------------------------------------- -/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_read +/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write +/// [Patterns: TransferWriteNonPermutationLowering + TransferWritePermutationLowering] ///---------------------------------------------------------------------------------------- /// Input: /// * vector.transfer_write op with a map which _is not_ a permutation of a