Skip to content

Commit df1cd5d

Browse files
committed
[mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation
This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64. The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.
1 parent 0cb024b commit df1cd5d

File tree

2 files changed

+176
-20
lines changed

2 files changed

+176
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 143 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -642,19 +642,19 @@ struct BitCastRewriter {
642642

643643
BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
644644

645-
/// Verify that the preconditions for the rewrite are met.
646-
LogicalResult precondition(PatternRewriter &rewriter,
647-
VectorType preconditionVectorType, Operation *op);
645+
/// Verify that general preconditions for the rewrite are met.
646+
LogicalResult commonPrecondition(PatternRewriter &rewriter,
647+
VectorType preconditionType, Operation *op);
648648

649649
/// Precompute the metadata for the rewrite.
650650
SmallVector<BitCastRewriter::Metadata>
651651
precomputeMetadata(IntegerType shuffledElementType);
652652

653653
/// Rewrite one step of the sequence:
654654
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655-
Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
656-
Value runningResult,
657-
const BitCastRewriter::Metadata &metadata);
655+
Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
656+
Value initialValue, Value runningResult,
657+
const BitCastRewriter::Metadata &metadata);
658658

659659
private:
660660
/// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
719719
LDBG("\n" << enumerator.sourceElementRanges);
720720
}
721721

722-
LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
723-
VectorType precondition,
724-
Operation *op) {
725-
if (precondition.getRank() != 1 || precondition.isScalable())
722+
/// Verify that the precondition type meets the common preconditions for any
723+
/// conversion.
724+
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
725+
VectorType preconditionType,
726+
Operation *op) {
727+
if (preconditionType.getRank() != 1 || preconditionType.isScalable())
726728
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
727729

728730
// TODO: consider relaxing this restriction in the future if we find ways
729731
// to really work with subbyte elements across the MLIR/LLVM boundary.
730-
int64_t resultBitwidth = precondition.getElementTypeBitWidth();
732+
unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
731733
if (resultBitwidth % 8 != 0)
732734
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
733735

734736
return success();
735737
}
736738

739+
LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
740+
VectorType preconditionType,
741+
Operation *op) {
742+
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
743+
return rewriter.notifyMatchFailure(op, "types are not vector");
744+
745+
return commonConversionPrecondition(rewriter, preconditionType, op);
746+
}
747+
748+
/// Verify that source and destination element types meet the precondition for
749+
/// the supported aligned conversion cases. Alignment means that the either the
750+
/// source element type is multiple of the destination element type or the other
751+
/// way around.
752+
///
753+
/// NOTE: This method assumes that common conversion preconditions are met.
754+
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
755+
VectorType srcType,
756+
VectorType dstType,
757+
Operation *op) {
758+
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
759+
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
760+
unsigned byteBitwidth = 8;
761+
762+
// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
763+
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
764+
(dstElemBitwidth % srcElemBitwidth) != 0)
765+
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
766+
767+
return success();
768+
}
769+
737770
SmallVector<BitCastRewriter::Metadata>
738771
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
739772
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
775808
return result;
776809
}
777810

778-
Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
779-
Value initialValue, Value runningResult,
780-
const BitCastRewriter::Metadata &metadata) {
811+
Value BitCastRewriter::genericRewriteStep(
812+
PatternRewriter &rewriter, Location loc, Value initialValue,
813+
Value runningResult, const BitCastRewriter::Metadata &metadata) {
781814
// Create vector.shuffle from the metadata.
782815
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
783816
loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
810843
return runningResult;
811844
}
812845

846+
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
847+
/// bitwise ops that take advantage of high-level information to avoid leaving
848+
/// LLVM to scramble with peephole optimizations.
849+
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
850+
Value srcValue) {
851+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
852+
assert(srcVecType.getElementType().isSignlessInteger(4) &&
853+
"Expected i4 type");
854+
855+
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
856+
int64_t vecDimSize = srcVecType.getShape().back();
857+
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
858+
constexpr int64_t i4Toi8BitwidthFactor = 2;
859+
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
860+
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
861+
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
862+
863+
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
864+
// byte are place in one vector and the high i4 elements in another vector.
865+
constexpr int8_t bitsToShift = 4;
866+
auto shiftValues = rewriter.create<arith::ConstantOp>(
867+
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
868+
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
869+
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
870+
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
871+
872+
// 3. Interleave low and high i8 elements using a shuffle.
873+
SmallVector<int64_t> interleaveMaskValues;
874+
interleaveMaskValues.reserve(vecDimSize);
875+
for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
876+
interleaveMaskValues.push_back(i);
877+
interleaveMaskValues.push_back(i + (vecDimSize / 2));
878+
}
879+
880+
return rewriter.create<vector::ShuffleOp>(
881+
loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
882+
}
883+
813884
namespace {
814885
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
815886
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
829900
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
830901
VectorType targetVectorType = bitCastOp.getResultVectorType();
831902
BitCastRewriter bcr(sourceVectorType, targetVectorType);
832-
if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
903+
if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
833904
return failure();
834905

835906
// Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
839910
Value runningResult;
840911
for (const BitCastRewriter ::Metadata &metadata :
841912
bcr.precomputeMetadata(shuffledElementType)) {
842-
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
843-
runningResult, metadata);
913+
runningResult = bcr.genericRewriteStep(
914+
rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
844915
}
845916

846917
// Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
885956
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
886957
VectorType targetVectorType = bitCastOp.getResultVectorType();
887958
BitCastRewriter bcr(sourceVectorType, targetVectorType);
888-
if (failed(bcr.precondition(
959+
if (failed(bcr.commonPrecondition(
889960
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
890961
return failure();
891962

@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
896967
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
897968
for (const BitCastRewriter::Metadata &metadata :
898969
bcr.precomputeMetadata(shuffledElementType)) {
899-
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
900-
sourceValue, runningResult, metadata);
970+
runningResult = bcr.genericRewriteStep(
971+
rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
901972
}
902973

903974
// Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
915986
return success();
916987
}
917988
};
989+
990+
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
991+
/// bitwise ops that take advantage of high-level information to avoid leaving
992+
/// LLVM to scramble with peephole optimizations.
993+
///
994+
/// For example:
995+
/// extsi vector<8xi4> -> vector<8xi32>
996+
/// is rewriten as
997+
/// sequence of shuffles and bitwise of for i4 -> i8
998+
/// extsi vector<8xi8> -> vector<8xi32>
999+
///
1000+
/// sitofp vector<8xi4> -> vector<8xf32>
1001+
/// is rewriten as
1002+
/// sequence of shuffles and bitwise of for i4 -> i8
1003+
/// sitofp vector<8xi8> -> vector<8xf32>
1004+
///
1005+
template <typename ConversionOpType>
1006+
struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1007+
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1008+
1009+
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1010+
PatternRewriter &rewriter) const override {
1011+
// Set up the BitCastRewriter and verify the preconditions.
1012+
Value srcValue = conversionOp.getIn();
1013+
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1014+
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1015+
if (failed(
1016+
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1017+
return failure();
1018+
1019+
// Check general alignment preconditions.
1020+
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1021+
conversionOp)))
1022+
return failure();
1023+
1024+
// Perform the rewrite.
1025+
Value subByteExt =
1026+
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1027+
1028+
// Finalize the rewrite.
1029+
rewriter.replaceOpWithNewOp<ConversionOpType>(
1030+
conversionOp, conversionOp.getType(), subByteExt);
1031+
return success();
1032+
}
1033+
};
1034+
9181035
} // namespace
9191036

9201037
//===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
9361053
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
9371054
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
9381055
benefit);
1056+
1057+
// Patterns for aligned cases. We set higher priority as they are expected to
1058+
// generate better performance for aligned cases.
1059+
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1060+
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1061+
patterns.getContext(), benefit.getBenefit() + 1);
9391062
}

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
193193
return %1 : vector<8xi17>
194194
}
195195

196+
// CHECK-LABEL: func.func @aligned_extsi(
197+
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
198+
// CHECK: arith.shli
199+
// CHECK: arith.shrsi
200+
// CHECK: arith.shrsi
201+
// CHECK: vector.shuffle
202+
// CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
203+
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
204+
return %0 : vector<8xi32>
205+
}
206+
207+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
208+
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
209+
// CHECK: arith.shli
210+
// CHECK: arith.shrsi
211+
// CHECK: arith.shrsi
212+
// CHECK: vector.shuffle
213+
// CHECK-NOT: arith.extsi
214+
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
215+
return %0 : vector<8xi8>
216+
}
217+
218+
// CHECK-LABEL: func.func @aligned_sitofp(
219+
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
220+
// CHECK: arith.shli
221+
// CHECK: arith.shrsi
222+
// CHECK: arith.shrsi
223+
// CHECK: shuffle
224+
// CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
225+
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
226+
return %0 : vector<8xf32>
227+
}
228+
196229
module attributes {transform.with_named_sequence} {
197230
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
198231
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)