@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880
880
return rewriter.create <vector::InterleaveOp>(loc, low, high);
881
881
}
882
882
883
+ // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884
+ // / bitwise ops that take advantage of high-level information to avoid leaving
885
+ // / LLVM to scramble with peephole optimizations.
886
+ static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
887
+ Value srcValue) {
888
+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889
+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
890
+ " Expected i4 type" );
891
+
892
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893
+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
894
+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
895
+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
896
+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897
+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898
+
899
+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900
+ // byte are placed in one vector and the high i4 elements in another vector.
901
+ constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111] bit mask
902
+ auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903
+ loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904
+ Value low = rewriter.create <arith::AndIOp>(loc, i8VecType, i8Vector,
905
+ lowBitsMaskValues);
906
+ constexpr int8_t highBitsToShift = 4 ;
907
+ auto highShiftValues = rewriter.create <arith::ConstantOp>(
908
+ loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909
+ Value high = rewriter.create <arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910
+
911
+ // 3. Interleave low and high i8 elements.
912
+ return rewriter.create <vector::InterleaveOp>(loc, low, high);
913
+ }
914
+
883
915
// / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884
916
// / that take advantage of high-level information to avoid leaving LLVM to
885
917
// / scramble with peephole optimizations.
@@ -1048,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1048
1080
1049
1081
// / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1050
1082
// / bitwise ops that take advantage of high-level information to avoid leaving
1051
- // / LLVM to scramble with peephole optimizations.
1083
+ // / LLVM to scramble with peephole optimizations. Templated to choose between
1084
+ // / signed and unsigned conversions.
1052
1085
// /
1053
- // / For example:
1086
+ // / For example (signed) :
1054
1087
// / arith.extsi %in : vector<8xi4> to vector<8xi32>
1055
1088
// / is rewriten as
1056
1089
// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1069,16 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1069
1102
// / %4 = vector.interleave %2, %3 : vector<4xi8>
1070
1103
// / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1071
1104
// /
1072
- template <typename ConversionOpType>
1073
- struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1105
+ // / Example (unsigned):
1106
+ // / arith.extui %in : vector<8xi4> to vector<8xi32>
1107
+ // / is rewritten as
1108
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1109
+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1110
+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
1111
+ // / %3 = vector.interleave %1, %2 : vector<4xi8>
1112
+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1113
+ // /
1114
+ template <typename ConversionOpType, bool isSigned>
1115
+ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1074
1116
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1075
1117
1076
1118
LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1077
1119
PatternRewriter &rewriter) const override {
1078
1120
// Verify the preconditions.
1079
1121
Value srcValue = conversionOp.getIn ();
1080
- auto srcVecType = dyn_cast <VectorType>(srcValue.getType ());
1081
- auto dstVecType = dyn_cast <VectorType>(conversionOp.getType ());
1122
+ auto srcVecType = cast <VectorType>(srcValue.getType ());
1123
+ auto dstVecType = cast <VectorType>(conversionOp.getType ());
1082
1124
if (failed (
1083
1125
commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1084
1126
return failure ();
@@ -1089,8 +1131,14 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1089
1131
return failure ();
1090
1132
1091
1133
// Perform the rewrite.
1092
- Value subByteExt =
1093
- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1134
+ Value subByteExt;
1135
+ if (isSigned) {
1136
+ subByteExt =
1137
+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1138
+ } else {
1139
+ subByteExt =
1140
+ rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1141
+ }
1094
1142
1095
1143
// Finalize the rewrite.
1096
1144
rewriter.replaceOpWithNewOp <ConversionOpType>(
@@ -1229,10 +1277,12 @@ void vector::populateVectorNarrowTypeRewritePatterns(
1229
1277
1230
1278
// Patterns for aligned cases. We set higher priority as they are expected to
1231
1279
// generate better performance for aligned cases.
1232
- patterns.add <RewriteAlignedSubByteIntSignedExt <arith::ExtSIOp>,
1233
- RewriteAlignedSubByteIntSignedExt <arith::SIToFPOp>,
1280
+ patterns.add <RewriteAlignedSubByteIntExt <arith::ExtSIOp, /* isSigned= */ true >,
1281
+ RewriteAlignedSubByteIntExt <arith::SIToFPOp, /* isSigned= */ true >,
1234
1282
RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1235
1283
benefit.getBenefit () + 1 );
1284
+ patterns.add <RewriteAlignedSubByteIntExt<arith::ExtUIOp, /* isSigned=*/ false >>(
1285
+ patterns.getContext (), benefit.getBenefit () + 1 );
1236
1286
}
1237
1287
1238
1288
void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments