Skip to content

Commit 6626ed6

Browse files
authored
[MLIR] Fix BubbleDownVectorBitCastForExtract crash on non-static index (#116518)
Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash. This patch is to make sure it return gracefully instead of crashing.
1 parent c51786b commit 6626ed6

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -596,12 +596,11 @@ struct BubbleDownVectorBitCastForExtract
596596
unsigned expandRatio =
597597
castDstType.getNumElements() / castSrcType.getNumElements();
598598

599-
auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
600-
assert(values[0].is<Attribute>() && "Unexpected non-constant index");
601-
return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
602-
};
603-
604-
uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
599+
// Get the first element of the mixed position as integer.
600+
auto mixedPos = extractOp.getMixedPosition();
601+
if (mixedPos.size() > 0 && !mixedPos[0].is<Attribute>())
602+
return failure();
603+
uint64_t index = cast<IntegerAttr>(mixedPos[0].get<Attribute>()).getInt();
605604

606605
// Get the single scalar (as a vector) in the source value that packs the
607606
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>

mlir/test/Dialect/Vector/vector-transforms.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
433433
%0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
434434
return %0 : vector<i32>
435435
}
436+
437+
// Make sure not crash on dynamic index `vector.extract`:
438+
func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
439+
%0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
440+
%1 = vector.extract %0[%index] : i16 from vector<8xi16>
441+
return %1 : i16
442+
}
443+
444+
// CHECK-LABEL: func.func @vector_extract_dynamic_index
445+
// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
446+
// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
447+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
448+
// CHECK: return %[[EXTRACT]]

0 commit comments

Comments
 (0)