Skip to content

Commit 6588586

Browse files
committed
[mlir][vector] Add more tests for ConvertVectorToLLVM (3/n)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.extractelement * vector.extract I have also renamed some function names from `@extract_element{}` to `@extractelement{}` - that's to make a clearer distinction between tests for `vector.extractelement` (tested by `@extractelement{}`) and `vector.extract` (tested by `@extract_element{}`).
1 parent dd094b2 commit 6588586

File tree

1 file changed

+85
-6
lines changed

1 file changed

+85
-6
lines changed

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,8 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3
10491049

10501050
// -----
10511051

1052-
// CHECK-LABEL: @extract_element_0d
1053-
func.func @extract_element_0d(%a: vector<f32>) -> f32 {
1052+
// CHECK-LABEL: @extractelement_0d
1053+
func.func @extractelement_0d(%a: vector<f32>) -> f32 {
10541054
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
10551055
// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
10561056
%1 = vector.extractelement %a[] : vector<f32>
@@ -1059,31 +1059,54 @@ func.func @extract_element_0d(%a: vector<f32>) -> f32 {
10591059

10601060
// -----
10611061

1062-
func.func @extract_element(%arg0: vector<16xf32>) -> f32 {
1062+
func.func @extractelement(%arg0: vector<16xf32>) -> f32 {
10631063
%0 = arith.constant 15 : i32
10641064
%1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
10651065
return %1 : f32
10661066
}
1067-
// CHECK-LABEL: @extract_element(
1067+
// CHECK-LABEL: @extractelement(
10681068
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
10691069
// CHECK: %[[c:.*]] = arith.constant 15 : i32
10701070
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<16xf32>
10711071
// CHECK: return %[[x]] : f32
10721072

1073+
func.func @extractelement_scalable(%arg0: vector<[16]xf32>) -> f32 {
1074+
%0 = arith.constant 15 : i32
1075+
%1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32>
1076+
return %1 : f32
1077+
}
1078+
// CHECK-LABEL: @extractelement_scalable(
1079+
// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
1080+
// CHECK: %[[c:.*]] = arith.constant 15 : i32
1081+
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : i32] : vector<[16]xf32>
1082+
// CHECK: return %[[x]] : f32
1083+
10731084
// -----
10741085

1075-
func.func @extract_element_index(%arg0: vector<16xf32>) -> f32 {
1086+
func.func @extractelement_index(%arg0: vector<16xf32>) -> f32 {
10761087
%0 = arith.constant 15 : index
10771088
%1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
10781089
return %1 : f32
10791090
}
1080-
// CHECK-LABEL: @extract_element_index(
1091+
// CHECK-LABEL: @extractelement_index(
10811092
// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
10821093
// CHECK: %[[c:.*]] = arith.constant 15 : index
10831094
// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
10841095
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<16xf32>
10851096
// CHECK: return %[[x]] : f32
10861097

1098+
func.func @extractelement_index_scalable(%arg0: vector<[16]xf32>) -> f32 {
1099+
%0 = arith.constant 15 : index
1100+
%1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32>
1101+
return %1 : f32
1102+
}
1103+
// CHECK-LABEL: @extractelement_index_scalable(
1104+
// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
1105+
// CHECK: %[[c:.*]] = arith.constant 15 : index
1106+
// CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
1107+
// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<[16]xf32>
1108+
// CHECK: return %[[x]] : f32
1109+
10871110
// -----
10881111

10891112
func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
@@ -1095,6 +1118,15 @@ func.func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
10951118
// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
10961119
// CHECK: return {{.*}} : f32
10971120

1121+
func.func @extract_element_from_vec_1d_scalable(%arg0: vector<[16]xf32>) -> f32 {
1122+
%0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
1123+
return %0 : f32
1124+
}
1125+
// CHECK-LABEL: @extract_element_from_vec_1d_scalable
1126+
// CHECK: llvm.mlir.constant(15 : i64) : i64
1127+
// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
1128+
// CHECK: return {{.*}} : f32
1129+
10981130
// -----
10991131

11001132
func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
@@ -1109,6 +1141,18 @@ func.func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
11091141
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
11101142
// CHECK: return %[[T3]] : index
11111143

1144+
func.func @extract_index_element_from_vec_1d_scalable(%arg0: vector<[16]xindex>) -> index {
1145+
%0 = vector.extract %arg0[15]: index from vector<[16]xindex>
1146+
return %0 : index
1147+
}
1148+
// CHECK-LABEL: @extract_index_element_from_vec_1d_scalable(
1149+
// CHECK-SAME: %[[A:.*]]: vector<[16]xindex>)
1150+
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[16]xindex> to vector<[16]xi64>
1151+
// CHECK: %[[T1:.*]] = llvm.mlir.constant(15 : i64) : i64
1152+
// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<[16]xi64>
1153+
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
1154+
// CHECK: return %[[T3]] : index
1155+
11121156
// -----
11131157

11141158
func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
@@ -1119,6 +1163,14 @@ func.func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16x
11191163
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<16xf32>>>
11201164
// CHECK: return {{.*}} : vector<3x16xf32>
11211165

1166+
func.func @extract_vec_2d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
1167+
%0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
1168+
return %0 : vector<3x[16]xf32>
1169+
}
1170+
// CHECK-LABEL: @extract_vec_2d_from_vec_3d_scalable
1171+
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
1172+
// CHECK: return {{.*}} : vector<3x[16]xf32>
1173+
11221174
// -----
11231175

11241176
func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
@@ -1129,6 +1181,14 @@ func.func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf3
11291181
// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<16xf32>>>
11301182
// CHECK: return {{.*}} : vector<16xf32>
11311183

1184+
func.func @extract_vec_1d_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<[16]xf32> {
1185+
%0 = vector.extract %arg0[0, 0]: vector<[16]xf32> from vector<4x3x[16]xf32>
1186+
return %0 : vector<[16]xf32>
1187+
}
1188+
// CHECK-LABEL: @extract_vec_1d_from_vec_3d_scalable
1189+
// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
1190+
// CHECK: return {{.*}} : vector<[16]xf32>
1191+
11321192
// -----
11331193

11341194
func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
@@ -1141,6 +1201,16 @@ func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
11411201
// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
11421202
// CHECK: return {{.*}} : f32
11431203

1204+
func.func @extract_element_from_vec_3d_scalable(%arg0: vector<4x3x[16]xf32>) -> f32 {
1205+
%0 = vector.extract %arg0[0, 0, 0]: f32 from vector<4x3x[16]xf32>
1206+
return %0 : f32
1207+
}
1208+
// CHECK-LABEL: @extract_element_from_vec_3d_scalable
1209+
// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm.array<4 x array<3 x vector<[16]xf32>>>
1210+
// CHECK: llvm.mlir.constant(0 : i64) : i64
1211+
// CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<[16]xf32>
1212+
// CHECK: return {{.*}} : f32
1213+
11441214
// -----
11451215

11461216
func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 {
@@ -1152,6 +1222,15 @@ func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) ->
11521222
// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
11531223
// CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
11541224

1225+
func.func @extract_element_with_value_1d_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 {
1226+
%0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32>
1227+
return %0 : f32
1228+
}
1229+
// CHECK-LABEL: @extract_element_with_value_1d_scalable
1230+
// CHECK-SAME: %[[VEC:.+]]: vector<[16]xf32>, %[[INDEX:.+]]: index
1231+
// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
1232+
// CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<[16]xf32>
1233+
11551234
// -----
11561235

11571236
func.func @extract_element_with_value_2d(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {

0 commit comments

Comments
 (0)