Skip to content

[mlir][vector] Add more tests for ConvertVectorToLLVM (5/n) #106510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Aug 29, 2024

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.extract_strided_slice

In practice, this has meant adding 1 new test:

  • @extract_strided_slice_f32_1d_from_2d_scalable,

and replacing @extract_strided_slice_scalable with
@extract_strided_slice_f32_2d_from_2d_scalable.

For consistency with other tests, I have also removed "vector" from test
functions names for:

  • vector.print
  • vector.type_cast
  • vector.extract_strided_slice

The first two Ops precede vector.extract_strided_slice in the test
file, i.e. those should be next to be updated in this series of patches.
However,

  • For vector.print, we don't use vectors in this test file (that
    would require running VectorToSCF).
  • For vector.type_cast, the existing tests assume fixed-width sizes.
    We need to write new tests and I am leaving that as a TODO.

Note, I've also updated test function names to be more descriptive and
consistent with other tests, e.g.

  • @extract_strided_slice3 -> @extract_strided_slice_f32_2d_from_2d

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.extract_strided_slice

For consistency with other tests, I have also removed "vector" from test
functions names for:
  * vector.print
  * vector.type_cast
  * vector.extract_strided_slice

The first two Ops precede `vector.extract_strided_slice` in the test
file, i.e. those should be next to be updated in this series of patches.
However,
  * For `vector.print`, we don't use vectors in this test file (that
    would require running VectorToSCF).
  * For `vector.type_cast`, the existing tests assume fixed-width sizes.
    We need to write new tests and I am leaving that as a TODO.

Note, I've also updated test function names to be more descriptive and
consistent with other tests, e.g.
  * `@extract_strided_slice3` -> `@extract_strided_slice_f32_2d_from_2d`
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.extract_strided_slice

For consistency with other tests, I have also removed "vector" from test
functions names for:

  • vector.print
  • vector.type_cast
  • vector.extract_strided_slice

The first two Ops precede vector.extract_strided_slice in the test
file, i.e. those should be next to be updated in this series of patches.
However,

  • For vector.print, we don't use vectors in this test file (that
    would require running VectorToSCF).
  • For vector.type_cast, the existing tests assume fixed-width sizes.
    We need to write new tests and I am leaving that as a TODO.

Note, I've also updated test function names to be more descriptive and
consistent with other tests, e.g.

  • @<!-- -->extract_strided_slice3 -> @<!-- -->extract_strided_slice_f32_2d_from_2d

Full diff: https://github.com/llvm/llvm-project/pull/106510.diff

1 Files Affected:

  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+88-65)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 63bcecd863e95d..2c8b1a2a6ff1f6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s
 
+// TODO: Add tests for for vector.type_cast that would cover scalable vectors
+
 func.func @bitcast_f32_to_i32_vector_0d(%input: vector<f32>) -> vector<i32> {
   %0 = vector.bitcast %input : vector<f32> to vector<i32>
   return %0 : vector<i32>
@@ -1467,8 +1469,6 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
 //       CHECK:   vector.insert
 
-// -----
-
 func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: f32, %idx: index)
                                         -> vector<1x[16]xf32> {
   %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x[16]xf32>
@@ -1482,11 +1482,11 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
 
 // -----
 
-func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
+func.func @type_cast_f32(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
   %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
   return %0 : memref<vector<8x8x8xf32>>
 }
-// CHECK-LABEL: @vector_type_cast
+// CHECK-LABEL: @type_cast_f32
 //       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)>
 //       CHECK:   %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:   llvm.insertvalue %[[allocated]], {{.*}}[0] : !llvm.struct<(ptr, ptr, i64)>
@@ -1495,18 +1495,22 @@ func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32
 //       CHECK:   llvm.mlir.constant(0 : index
 //       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64)>
 
+// NOTE: No test for scalable vectors - the input memref is fixed size.
+
 // -----
 
-func.func @vector_index_type_cast(%arg0: memref<8x8x8xindex>) -> memref<vector<8x8x8xindex>> {
+func.func @type_cast_index(%arg0: memref<8x8x8xindex>) -> memref<vector<8x8x8xindex>> {
   %0 = vector.type_cast %arg0: memref<8x8x8xindex> to memref<vector<8x8x8xindex>>
   return %0 : memref<vector<8x8x8xindex>>
 }
-// CHECK-LABEL: @vector_index_type_cast(
+// CHECK-LABEL: @type_cast_index(
 // CHECK-SAME: %[[A:.*]]: memref<8x8x8xindex>)
 //       CHECK:   %{{.*}} = builtin.unrealized_conversion_cast %[[A]] : memref<8x8x8xindex> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 
 //       CHECK:   %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64)> to memref<vector<8x8x8xindex>>
 
+// NOTE: No test for scalable vectors - the input memref is fixed size.
+
 // -----
 
 func.func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> memref<vector<8x8x8xf32>, 3> {
@@ -1522,16 +1526,18 @@ func.func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> m
 //       CHECK:   llvm.mlir.constant(0 : index
 //       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
 
+// NOTE: No test for scalable vectors - the input memref is fixed size.
+
 // -----
 
-func.func @vector_print_scalar_i1(%arg0: i1) {
+func.func @print_scalar_i1(%arg0: i1) {
   vector.print %arg0 : i1
   return
 }
 //
 // Type "boolean" always uses zero extension.
 //
-// CHECK-LABEL: @vector_print_scalar_i1(
+// CHECK-LABEL: @print_scalar_i1(
 // CHECK-SAME: %[[A:.*]]: i1)
 //       CHECK: %[[S:.*]] = arith.extui %[[A]] : i1 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1539,11 +1545,11 @@ func.func @vector_print_scalar_i1(%arg0: i1) {
 
 // -----
 
-func.func @vector_print_scalar_i4(%arg0: i4) {
+func.func @print_scalar_i4(%arg0: i4) {
   vector.print %arg0 : i4
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i4(
+// CHECK-LABEL: @print_scalar_i4(
 // CHECK-SAME: %[[A:.*]]: i4)
 //       CHECK: %[[S:.*]] = arith.extsi %[[A]] : i4 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1551,11 +1557,11 @@ func.func @vector_print_scalar_i4(%arg0: i4) {
 
 // -----
 
-func.func @vector_print_scalar_si4(%arg0: si4) {
+func.func @print_scalar_si4(%arg0: si4) {
   vector.print %arg0 : si4
   return
 }
-// CHECK-LABEL: @vector_print_scalar_si4(
+// CHECK-LABEL: @print_scalar_si4(
 // CHECK-SAME: %[[A:.*]]: si4)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : si4 to i4
 //       CHECK: %[[S:.*]] = arith.extsi %[[C]] : i4 to i64
@@ -1564,11 +1570,11 @@ func.func @vector_print_scalar_si4(%arg0: si4) {
 
 // -----
 
-func.func @vector_print_scalar_ui4(%arg0: ui4) {
+func.func @print_scalar_ui4(%arg0: ui4) {
   vector.print %arg0 : ui4
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui4(
+// CHECK-LABEL: @print_scalar_ui4(
 // CHECK-SAME: %[[A:.*]]: ui4)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui4 to i4
 //       CHECK: %[[S:.*]] = arith.extui %[[C]] : i4 to i64
@@ -1577,11 +1583,11 @@ func.func @vector_print_scalar_ui4(%arg0: ui4) {
 
 // -----
 
-func.func @vector_print_scalar_i32(%arg0: i32) {
+func.func @print_scalar_i32(%arg0: i32) {
   vector.print %arg0 : i32
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i32(
+// CHECK-LABEL: @print_scalar_i32(
 // CHECK-SAME: %[[A:.*]]: i32)
 //       CHECK: %[[S:.*]] = arith.extsi %[[A]] : i32 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1589,11 +1595,11 @@ func.func @vector_print_scalar_i32(%arg0: i32) {
 
 // -----
 
-func.func @vector_print_scalar_ui32(%arg0: ui32) {
+func.func @print_scalar_ui32(%arg0: ui32) {
   vector.print %arg0 : ui32
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui32(
+// CHECK-LABEL: @print_scalar_ui32(
 // CHECK-SAME: %[[A:.*]]: ui32)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui32 to i32
 //       CHECK: %[[S:.*]] = arith.extui %[[C]] : i32 to i64
@@ -1601,11 +1607,11 @@ func.func @vector_print_scalar_ui32(%arg0: ui32) {
 
 // -----
 
-func.func @vector_print_scalar_i40(%arg0: i40) {
+func.func @print_scalar_i40(%arg0: i40) {
   vector.print %arg0 : i40
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i40(
+// CHECK-LABEL: @print_scalar_i40(
 // CHECK-SAME: %[[A:.*]]: i40)
 //       CHECK: %[[S:.*]] = arith.extsi %[[A]] : i40 to i64
 //       CHECK: llvm.call @printI64(%[[S]]) : (i64) -> ()
@@ -1613,11 +1619,11 @@ func.func @vector_print_scalar_i40(%arg0: i40) {
 
 // -----
 
-func.func @vector_print_scalar_si40(%arg0: si40) {
+func.func @print_scalar_si40(%arg0: si40) {
   vector.print %arg0 : si40
   return
 }
-// CHECK-LABEL: @vector_print_scalar_si40(
+// CHECK-LABEL: @print_scalar_si40(
 // CHECK-SAME: %[[A:.*]]: si40)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : si40 to i40
 //       CHECK: %[[S:.*]] = arith.extsi %[[C]] : i40 to i64
@@ -1626,11 +1632,11 @@ func.func @vector_print_scalar_si40(%arg0: si40) {
 
 // -----
 
-func.func @vector_print_scalar_ui40(%arg0: ui40) {
+func.func @print_scalar_ui40(%arg0: ui40) {
   vector.print %arg0 : ui40
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui40(
+// CHECK-LABEL: @print_scalar_ui40(
 // CHECK-SAME: %[[A:.*]]: ui40)
 //       CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui40 to i40
 //       CHECK: %[[S:.*]] = arith.extui %[[C]] : i40 to i64
@@ -1639,22 +1645,22 @@ func.func @vector_print_scalar_ui40(%arg0: ui40) {
 
 // -----
 
-func.func @vector_print_scalar_i64(%arg0: i64) {
+func.func @print_scalar_i64(%arg0: i64) {
   vector.print %arg0 : i64
   return
 }
-// CHECK-LABEL: @vector_print_scalar_i64(
+// CHECK-LABEL: @print_scalar_i64(
 // CHECK-SAME: %[[A:.*]]: i64)
 //       CHECK:    llvm.call @printI64(%[[A]]) : (i64) -> ()
 //       CHECK:    llvm.call @printNewline() : () -> ()
 
 // -----
 
-func.func @vector_print_scalar_ui64(%arg0: ui64) {
+func.func @print_scalar_ui64(%arg0: ui64) {
   vector.print %arg0 : ui64
   return
 }
-// CHECK-LABEL: @vector_print_scalar_ui64(
+// CHECK-LABEL: @print_scalar_ui64(
 // CHECK-SAME: %[[A:.*]]: ui64)
 //       CHECK:    %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : ui64 to i64
 //       CHECK:    llvm.call @printU64(%[[C]]) : (i64) -> ()
@@ -1662,11 +1668,11 @@ func.func @vector_print_scalar_ui64(%arg0: ui64) {
 
 // -----
 
-func.func @vector_print_scalar_index(%arg0: index) {
+func.func @print_scalar_index(%arg0: index) {
   vector.print %arg0 : index
   return
 }
-// CHECK-LABEL: @vector_print_scalar_index(
+// CHECK-LABEL: @print_scalar_index(
 // CHECK-SAME: %[[A:.*]]: index)
 //       CHECK:    %[[C:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
 //       CHECK:    llvm.call @printU64(%[[C]]) : (i64) -> ()
@@ -1674,22 +1680,22 @@ func.func @vector_print_scalar_index(%arg0: index) {
 
 // -----
 
-func.func @vector_print_scalar_f32(%arg0: f32) {
+func.func @print_scalar_f32(%arg0: f32) {
   vector.print %arg0 : f32
   return
 }
-// CHECK-LABEL: @vector_print_scalar_f32(
+// CHECK-LABEL: @print_scalar_f32(
 // CHECK-SAME: %[[A:.*]]: f32)
 //       CHECK:    llvm.call @printF32(%[[A]]) : (f32) -> ()
 //       CHECK:    llvm.call @printNewline() : () -> ()
 
 // -----
 
-func.func @vector_print_scalar_f64(%arg0: f64) {
+func.func @print_scalar_f64(%arg0: f64) {
   vector.print %arg0 : f64
   return
 }
-// CHECK-LABEL: @vector_print_scalar_f64(
+// CHECK-LABEL: @print_scalar_f64(
 // CHECK-SAME: %[[A:.*]]: f64)
 //       CHECK:    llvm.call @printF64(%[[A]]) : (f64) -> ()
 //       CHECK:    llvm.call @printNewline() : () -> ()
@@ -1699,46 +1705,50 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
 // CHECK-LABEL: module {
 // CHECK: llvm.func @printString(!llvm.ptr)
 // CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
-// CHECK: @vector_print_string
+// CHECK: @print_string
 //       CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
 //       CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
 //       CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
-func.func @vector_print_string() {
+func.func @print_string() {
   vector.print str "Hello, World!"
   return
 }
 
 // -----
 
-func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
+func.func @extract_strided_slice_f32(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
 }
-// CHECK-LABEL: @extract_strided_slice1(
+// CHECK-LABEL: @extract_strided_slice_f32(
 //  CHECK-SAME:    %[[A:.*]]: vector<4xf32>)
 //       CHECK:    %[[T0:.*]] = llvm.shufflevector %[[A]], %[[A]] [2, 3] : vector<4xf32>
 //       CHECK:    return %[[T0]] : vector<2xf32>
 
+// NOTE: For scalable vectors we could only extract vector<[4]xf32> from vector<[4]xf32>, but that would be a NOP.
+
 // -----
 
-func.func @extract_strided_index_slice1(%arg0: vector<4xindex>) -> vector<2xindex> {
+func.func @extract_strided_index_slice_index(%arg0: vector<4xindex>) -> vector<2xindex> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xindex> to vector<2xindex>
   return %0 : vector<2xindex>
 }
-// CHECK-LABEL: @extract_strided_index_slice1(
+// CHECK-LABEL: @extract_strided_index_slice_index(
 //  CHECK-SAME:    %[[A:.*]]: vector<4xindex>)
 //       CHECK:    %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4xindex> to vector<4xi64>
 //       CHECK:    %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T0]] [2, 3] : vector<4xi64>
 //       CHECK:    %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex>
 //       CHECK:    return %[[T3]] : vector<2xindex>
 
+// NOTE: For scalable vectors we could only extract vector<[4]xindex> from vector<[4]xindex>, but that would be a NOP.
+
 // -----
 
-func.func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
+func.func @extract_strided_slice_f32_1d_from_2d(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
   return %0 : vector<2x8xf32>
 }
-// CHECK-LABEL: @extract_strided_slice2(
+// CHECK-LABEL: @extract_strided_slice_f32_1d_from_2d(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x8xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
 //       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<8xf32>>
@@ -1749,13 +1759,28 @@ func.func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
 //       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<8xf32>> to vector<2x8xf32>
 //       CHECK:    return %[[T5]]
 
+func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+  return %0 : vector<2x[8]xf32>
+}
+// CHECK-LABEL:   func.func @extract_strided_slice_f32_1d_from_2d_scalable(
+//  CHECK-SAME:    %[[ARG:.*]]: vector<4x[8]xf32>)
+//       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+//       CHECK:    return %[[T5]]
+
 // -----
 
-func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @extract_strided_slice_f32_2d_from_2d(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
   return %0 : vector<2x2xf32>
 }
-// CHECK-LABEL: @extract_strided_slice3(
+// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x8xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
 //       CHECK:    %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
@@ -1769,27 +1794,25 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
 //       CHECK:    %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[T7]] : !llvm.array<2 x vector<2xf32>> to vector<2x2xf32>
 //       CHECK:    return %[[VAL_12]] : vector<2x2xf32>
 
-// -----
-
-func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
-  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
-  return %0 : vector<1x1x[4]xi32>
-}
-
-// CHECK-LABEL:   func.func @extract_strided_slice_scalable(
-// CHECK-SAME:      %[[ARG_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
-
-//      CHECK:      %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
-//      CHECK:      %[[CST:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
-//      CHECK:      %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
-//      CHECK:      %[[CST_1:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
-//      CHECK:      %[[CAST_3:.*]] = builtin.unrealized_conversion_cast %[[CST_1]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
-
-//      CHECK:      %[[EXT:.*]] = llvm.extractvalue %[[CAST_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
-//      CHECK:      %[[INS_1:.*]] = llvm.insertvalue %[[EXT]], %[[CAST_3]][0] : !llvm.array<1 x vector<[4]xi32>>
-//      CHECK:      %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
-
-//      CHECK:      builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+// NOTE: For scalable vectors, we can only extract "full" scalable dimensions
+// (e.g. [8] from [8], but not [4] from [8]).
+
+func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+  return %0 : vector<2x[8]xf32>
+}
+// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
+//  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
+// CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
+// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK:           %[[T6:.*]] = llvm.insertvalue %[[T5]], %[[T4]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK:           %[[T7:.*]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK:           %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T6]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK:           %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T8]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK:           return %[[T9]] : vector<2x[8]xf32>
 
 // -----
 

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welcome back, just a little NIT.

// -----

func.func @extract_strided_index_slice1(%arg0: vector<4xindex>) -> vector<2xindex> {
func.func @extract_strided_index_slice_index(%arg0: vector<4xindex>) -> vector<2xindex> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func.func @extract_strided_index_slice_index(%arg0: vector<4xindex>) -> vector<2xindex> {
func.func @extract_strided_slice_index(%arg0: vector<4xindex>) -> vector<2xindex> {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your eagle-eye saves me once again - thank you 🙏🏻 :)

@banach-space
Copy link
Contributor Author

Welcome back, just a little NIT.

Great to be back, though only for this week. I'm OOO next two weeks, but I will monitor my GitHub inbox (should you send sth ;-) )

@banach-space
Copy link
Contributor Author

Ping @nujaa :) Please feel free to land if this looks good to you, thanks!

@nujaa
Copy link
Contributor

nujaa commented Sep 2, 2024

LGTM!

@banach-space banach-space merged commit a9c71d3 into llvm:main Sep 2, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/extend_vector_to_llvm_test_5 branch November 22, 2024 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants