Skip to content

Commit 6dfaecf

Browse files
authored
[mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation (#89131)
This PR builds on #79494 with an additional path for efficient unsigned `i4 ->i8` type extension for 1D/2D operations. This will impact any i4 -> i8/i16/i32/i64 unsigned extensions as well as sitofp i4 -> f8/f16/f32/f64.
1 parent a00bbcb commit 6dfaecf

File tree

2 files changed

+101
-11
lines changed

2 files changed

+101
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880880
return rewriter.create<vector::InterleaveOp>(loc, low, high);
881881
}
882882

883+
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884+
/// bitwise ops that take advantage of high-level information to avoid leaving
885+
/// LLVM to scramble with peephole optimizations.
886+
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
887+
Value srcValue) {
888+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
889+
assert(srcVecType.getElementType().isSignlessInteger(4) &&
890+
"Expected i4 type");
891+
892+
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893+
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
894+
constexpr int64_t i4Toi8BitwidthFactor = 2;
895+
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
896+
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
897+
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
898+
899+
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900+
// byte are placed in one vector and the high i4 elements in another vector.
901+
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
902+
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
903+
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
904+
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
905+
lowBitsMaskValues);
906+
constexpr int8_t highBitsToShift = 4;
907+
auto highShiftValues = rewriter.create<arith::ConstantOp>(
908+
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
909+
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910+
911+
// 3. Interleave low and high i8 elements.
912+
return rewriter.create<vector::InterleaveOp>(loc, low, high);
913+
}
914+
883915
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884916
/// that take advantage of high-level information to avoid leaving LLVM to
885917
/// scramble with peephole optimizations.
@@ -1048,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10481080

10491081
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
10501082
/// bitwise ops that take advantage of high-level information to avoid leaving
1051-
/// LLVM to scramble with peephole optimizations.
1083+
/// LLVM to scramble with peephole optimizations. Templated to choose between
1084+
/// signed and unsigned conversions.
10521085
///
1053-
/// For example:
1086+
/// For example (signed):
10541087
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
10551088
/// is rewriten as
10561089
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1069,16 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10691102
/// %4 = vector.interleave %2, %3 : vector<4xi8>
10701103
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
10711104
///
1072-
template <typename ConversionOpType>
1073-
struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1105+
/// Example (unsigned):
1106+
/// arith.extui %in : vector<8xi4> to vector<8xi32>
1107+
/// is rewritten as
1108+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1109+
/// %1 = arith.andi %0, 15 : vector<4xi8>
1110+
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1111+
/// %3 = vector.interleave %1, %2 : vector<4xi8>
1112+
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1113+
///
1114+
template <typename ConversionOpType, bool isSigned>
1115+
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
10741116
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
10751117

10761118
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
10771119
PatternRewriter &rewriter) const override {
10781120
// Verify the preconditions.
10791121
Value srcValue = conversionOp.getIn();
1080-
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1081-
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1122+
auto srcVecType = cast<VectorType>(srcValue.getType());
1123+
auto dstVecType = cast<VectorType>(conversionOp.getType());
10821124
if (failed(
10831125
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
10841126
return failure();
@@ -1089,8 +1131,14 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10891131
return failure();
10901132

10911133
// Perform the rewrite.
1092-
Value subByteExt =
1093-
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1134+
Value subByteExt;
1135+
if (isSigned) {
1136+
subByteExt =
1137+
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1138+
} else {
1139+
subByteExt =
1140+
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1141+
}
10941142

10951143
// Finalize the rewrite.
10961144
rewriter.replaceOpWithNewOp<ConversionOpType>(
@@ -1229,10 +1277,12 @@ void vector::populateVectorNarrowTypeRewritePatterns(
12291277

12301278
// Patterns for aligned cases. We set higher priority as they are expected to
12311279
// generate better performance for aligned cases.
1232-
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1233-
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1280+
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1281+
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
12341282
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
12351283
benefit.getBenefit() + 1);
1284+
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
1285+
patterns.getContext(), benefit.getBenefit() + 1);
12361286
}
12371287

12381288
void vector::populateVectorTransposeNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,47 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
324324
return %0 : vector<16x8xi7>
325325
}
326326

327+
// CHECK-LABEL: func.func @aligned_extui(
328+
func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
329+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
330+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
331+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
332+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
333+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
334+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
335+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
336+
// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
337+
%0 = arith.extui %a : vector<8xi4> to vector<8xi32>
338+
return %0 : vector<8xi32>
339+
}
340+
341+
// CHECK-LABEL: func.func @aligned_extui_2d(
342+
func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
343+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
344+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
345+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
346+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
347+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
348+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
349+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
350+
// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
351+
%0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
352+
return %0 : vector<8x32xi32>
353+
}
354+
355+
// CHECK-LABEL: func.func @aligned_extui_base_case(
356+
func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
357+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
358+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
359+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
360+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
361+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
362+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
363+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
364+
%0 = arith.extui %a : vector<8xi4> to vector<8xi8>
365+
return %0 : vector<8xi8>
366+
}
367+
327368
module attributes {transform.with_named_sequence} {
328369
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
329370
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -335,4 +376,3 @@ module attributes {transform.with_named_sequence} {
335376
transform.yield
336377
}
337378
}
338-

0 commit comments

Comments
 (0)