-
Notifications
You must be signed in to change notification settings - Fork 15k
Closed
Labels
Description
For the following IR
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
func.func @gather_failure(%arg0: tensor<8x128x768xf32>, %arg2: tensor<8x1xf32>, %arg3 : index) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%1 = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]
} outs(%arg2 : tensor<8x1xf32>) {
^bb0(%arg5: f32):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = affine.apply #map1(%arg3, %3, %arg3)
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
linalg.yield %extracted : f32
} -> tensor<8x1xf32>
return %1 : tensor<8x1xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
run
mlir-opt -transform-interpreter -split-input-file test.mlir
It gives the error
within split at test_gather_core.mlir:1 offset :13:18: error: 'vector.shape_cast' op source/result number of elements must match
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
^
within split at test_gather_core.mlir:1 offset :13:18: note: see current operation: %7 = "vector.shape_cast"(%5) : (vector<8x1xindex>) -> vector<1xindex>
with verify-each=0
you can see the following output IR
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
"builtin.module"() ({
"func.func"() <{function_type = (tensor<8x128x768xf32>, tensor<8x1xf32>, index) -> tensor<8x1xf32>, sym_name = "gather_failure"}> ({
^bb0(%arg1: tensor<8x128x768xf32>, %arg2: tensor<8x1xf32>, %arg3: index):
%3 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
%4 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%5 = "arith.constant"() <{value = 0 : index}> : () -> index
%6 = "arith.constant"() <{value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>}> : () -> vector<8xindex>
%7 = "vector.broadcast"(%6) : (vector<8xindex>) -> vector<1x8xindex>
%8 = "vector.transpose"(%7) <{permutation = array<i64: 1, 0>}> : (vector<1x8xindex>) -> vector<8x1xindex>
%9 = "arith.addi"(%arg3, %arg3) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
%10 = "vector.shape_cast"(%8) : (vector<8x1xindex>) -> vector<1xindex>
%11 = "vector.extractelement"(%10, %4) : (vector<1xindex>, i32) -> index
%12 = "vector.transfer_read"(%arg1, %11, %5, %9, %3) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 3, 1, 0>, permutation_map = #map}> : (tensor<8x128x768xf32>, index, index, index, f32) -> vector<8x1xf32>
%13 = "vector.transfer_write"(%12, %arg2, %5, %5) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = #map1}> : (vector<8x1xf32>, tensor<8x1xf32>, index, index) -> tensor<8x1xf32>
"func.return"(%13) : (tensor<8x1xf32>) -> ()
}) : () -> ()