Skip to content

[mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation #79494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 156 additions & 20 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,19 +642,19 @@ struct BitCastRewriter {

BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);

/// Verify that the preconditions for the rewrite are met.
LogicalResult precondition(PatternRewriter &rewriter,
VectorType preconditionVectorType, Operation *op);
/// Verify that general preconditions for the rewrite are met.
LogicalResult commonPrecondition(PatternRewriter &rewriter,
VectorType preconditionType, Operation *op);

/// Precompute the metadata for the rewrite.
SmallVector<BitCastRewriter::Metadata>
precomputeMetadata(IntegerType shuffledElementType);

/// Rewrite one step of the sequence:
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
Value runningResult,
const BitCastRewriter::Metadata &metadata);
Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
Value initialValue, Value runningResult,
const BitCastRewriter::Metadata &metadata);

private:
/// Underlying enumerator that encodes the provenance of the bits in the each
Expand Down Expand Up @@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
LDBG("\n" << enumerator.sourceElementRanges);
}

LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
VectorType precondition,
Operation *op) {
if (precondition.getRank() != 1 || precondition.isScalable())
/// Verify that the precondition type meets the common preconditions for any
/// conversion.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
if (!preconditionType || preconditionType.getRank() != 1 ||
preconditionType.isScalable())
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");

// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
int64_t resultBitwidth = precondition.getElementTypeBitWidth();
unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
if (resultBitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");

return success();
}

LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
return rewriter.notifyMatchFailure(op, "types are not vector");

return commonConversionPrecondition(rewriter, preconditionType, op);
}

/// Verify that source and destination element types meet the precondition for
/// the supported aligned conversion cases. Alignment means that the either the
/// source element type is multiple of the destination element type or the other
/// way around.
///
/// NOTE: This method assumes that common conversion preconditions are met.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
VectorType srcType,
VectorType dstType,
Operation *op) {
if (!srcType || !dstType)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
unsigned byteBitwidth = 8;

// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");

return success();
}

SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
SmallVector<BitCastRewriter::Metadata> result;
Expand Down Expand Up @@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
return result;
}

Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
Value initialValue, Value runningResult,
const BitCastRewriter::Metadata &metadata) {
Value BitCastRewriter::genericRewriteStep(
PatternRewriter &rewriter, Location loc, Value initialValue,
Value runningResult, const BitCastRewriter::Metadata &metadata) {
// Create vector.shuffle from the metadata.
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
loc, initialValue, initialValue, metadata.shuffles);
Expand Down Expand Up @@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
return runningResult;
}

/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");

// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
int64_t vecDimSize = srcVecType.getShape().back();
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);

// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
// byte are place in one vector and the high i4 elements in another vector.
constexpr int8_t bitsToShift = 4;
auto shiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);

// 3. Interleave low and high i8 elements using a shuffle.
SmallVector<int64_t> interleaveMaskValues;
interleaveMaskValues.reserve(vecDimSize);
for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
interleaveMaskValues.push_back(i);
interleaveMaskValues.push_back(i + (vecDimSize / 2));
}

return rewriter.create<vector::ShuffleOp>(
loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
}

namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
Expand All @@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
return failure();

// Perform the rewrite.
Expand All @@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
Value runningResult;
for (const BitCastRewriter ::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
runningResult, metadata);
runningResult = bcr.genericRewriteStep(
rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
}

// Finalize the rewrite.
Expand Down Expand Up @@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
if (failed(bcr.precondition(
if (failed(bcr.commonPrecondition(
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
return failure();

Expand All @@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
for (const BitCastRewriter::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
sourceValue, runningResult, metadata);
runningResult = bcr.genericRewriteStep(
rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
}

// Finalize the rewrite.
Expand All @@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
return success();
}
};

/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
///
/// For example:
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
/// is rewriten as
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
/// : vector<4xi8>, vector<4xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
/// is rewriten as
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
/// : vector<4xi8>, vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
template <typename ConversionOpType>
struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
using OpRewritePattern<ConversionOpType>::OpRewritePattern;

LogicalResult matchAndRewrite(ConversionOpType conversionOp,
PatternRewriter &rewriter) const override {
// Set up the BitCastRewriter and verify the preconditions.
Value srcValue = conversionOp.getIn();
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
if (failed(
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
return failure();

// Check general alignment preconditions.
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
conversionOp)))
return failure();

// Perform the rewrite.
Value subByteExt =
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);

// Finalize the rewrite.
rewriter.replaceOpWithNewOp<ConversionOpType>(
conversionOp, conversionOp.getType(), subByteExt);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
benefit);

// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
patterns.getContext(), benefit.getBenefit() + 1);
}
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}

// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
// CHECK: arith.shli
// CHECK: arith.shrsi
// CHECK: arith.shrsi
// CHECK: vector.shuffle
// CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
return %0 : vector<8xi32>
}

// CHECK-LABEL: func.func @aligned_extsi_base_case(
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
// CHECK: arith.shli
// CHECK: arith.shrsi
// CHECK: arith.shrsi
// CHECK: vector.shuffle
// CHECK-NOT: arith.extsi
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
return %0 : vector<8xi8>
}

// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
// CHECK: arith.shli
// CHECK: arith.shrsi
// CHECK: arith.shrsi
// CHECK: shuffle
// CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
return %0 : vector<8xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
Expand Down