From 703c9e9abfa866139752ec47086291fb0ec1de66 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 8 Sep 2023 11:28:55 +0200 Subject: [PATCH 1/3] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bitcast) expansion This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The implementation is 90% a refactoring of the existing `trunci(bitcast)` pattern into a common BitCastRewriter. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` with shifts`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM. --- .../Vector/Transforms/VectorRewritePatterns.h | 7 + .../Transforms/VectorEmulateNarrowType.cpp | 428 ++++++++++++------ .../Vector/vector-rewrite-narrow-types.mlir | 47 ++ .../Vector/CPU/test-rewrite-narrow-types.mlir | 46 ++ 4 files changed, 379 insertions(+), 149 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 8652fc7f5e5c6..eb561ba3b2355 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -23,6 +23,7 @@ namespace mlir { class RewritePatternSet; namespace arith { +class AndIOp; class NarrowTypeEmulationConverter; class TruncIOp; } // namespace arith @@ -309,6 +310,12 @@ FailureOr rewriteBitCastOfTruncI(RewriterBase &rewriter, arith::TruncIOp truncOp, vector::BroadcastOp maybeBroadcastOp); +/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of +/// vector operations comprising `shuffle` and `bitwise` ops. +FailureOr rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, + vector::BitCastOp bitCastOp, + vector::BroadcastOp maybeBroadcastOp); + /// Appends patterns for rewriting vector operations over narrow types with /// ops over wider types. void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 9d659bf694a24..d2524d7e35cf9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -224,6 +224,106 @@ struct BitCastBitsEnumerator { SmallVector sourceElementRanges; }; +/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take +/// advantage of high-level information to avoid leaving LLVM to scramble with +/// peephole optimizations. +/// BitCastBitsEnumerator encodes for each element of the target vector the +/// provenance of the bits in the source vector. We can "transpose" this +/// information to build a sequence of shuffles and bitwise ops that will +/// produce the desired result. +// +/// Consider the following motivating example: +/// ``` +/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8> +/// ``` +// +/// BitCastBitsEnumerator contains the following information: +/// ``` +/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5} +/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7} +/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4} +/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6} +/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3} +/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5} +/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7} +/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4} +/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6} +/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3} +/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5} +/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7} +/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4} +/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6} +/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3} +/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5} +/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7} +/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4} +/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6} +/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3} +/// ``` +/// +/// In the above, each row represents one target vector element and each +/// column represents one bit contribution from a source vector element. +/// The algorithm creates vector.shuffle operations (in this case there are 3 +/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The +/// algorithm populates the bits as follows: +/// ``` +/// src bits 0 ... +/// 1st shuffle |xxxxx |xx |... +/// 2nd shuffle | xxx| xxxxx |... +/// 3rd shuffle | | x|... +/// ``` +// +/// The algorithm proceeds as follows: +/// 1. for each vector.shuffle, collect the source vectors that participate in +/// this shuffle. One source vector per target element of the resulting +/// vector.shuffle. If there is no source element contributing bits for the +/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only +/// 2 columns). +/// 2. represent the bitrange in the source vector as a mask. If there is no +/// source element contributing bits for the current vector.shuffle, take 0. +/// 3. shift right by the proper amount to align the source bitrange at +/// position 0. This is exactly the low end of the bitrange. For instance, +/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to +/// shift right by 3 to get the bits contributed by the source element #1 +/// into position 0. +/// 4. shift left by the proper amount to to align to the desired position in +/// the result element vector. For instance, the contribution of the second +/// source element for the first row needs to be shifted by `5` to form the +/// first i8 result element. +/// +/// Eventually, we end up building the sequence +/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update +/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the +/// bits extracted from the source vector (i.e. the `shuffle -> and` part). +struct BitCastRewriter { + /// Helper metadata struct to hold the static quantities for the rewrite. + struct Metadata { + SmallVector shuffles; + SmallVector masks, shiftRightAmounts, shiftLeftAmounts; + }; + + BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType); + + /// Verify that the preconditions for the rewrite are met. + LogicalResult precondition(PatternRewriter &rewriter, + VectorType preconditionVectorType, Operation *op); + + /// Precompute the metadata for the rewrite. + SmallVector + precomputeMetadata(IntegerType shuffledElementType); + + /// Rewrite one step of the sequence: + /// `(shuffle -> and -> shiftright -> shiftleft -> or)`. + Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue, + Value runningResult, + const BitCastRewriter::Metadata &metadata); + +private: + /// Underlying enumerator that encodes the provenance of the bits in the each + /// element of the result vector. + BitCastBitsEnumerator enumerator; +}; + } // namespace static raw_ostream &operator<<(raw_ostream &os, @@ -256,7 +356,7 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, LDBG("targetVectorType: " << targetVectorType); int64_t bitwidth = targetBitWidth * mostMinorTargetDim; - (void) mostMinorSourceDim; + (void)mostMinorSourceDim; assert(bitwidth == sourceBitWidth * mostMinorSourceDim && "source and target bitwidths must match"); @@ -275,79 +375,107 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, } } +BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, + VectorType targetVectorType) + : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) { + LDBG("\n" << enumerator.sourceElementRanges); +} + +LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter, + VectorType precondition, + Operation *op) { + if (precondition.getRank() != 1 || precondition.isScalable()) + return rewriter.notifyMatchFailure(op, "scalable or >1-D vector"); + + // TODO: consider relaxing this restriction in the future if we find ways + // to really work with subbyte elements across the MLIR/LLVM boundary. + int64_t resultBitwidth = precondition.getElementTypeBitWidth(); + if (resultBitwidth % 8 != 0) + return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8"); + + return success(); +} + +SmallVector +BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { + SmallVector result; + for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries(); + shuffleIdx < e; ++shuffleIdx) { + SmallVector shuffles; + SmallVector masks, shiftRightAmounts, shiftLeftAmounts; + + // Create the attribute quantities for the shuffle / mask / shift ops. + for (auto &srcEltRangeList : enumerator.sourceElementRanges) { + int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size()) + ? srcEltRangeList[shuffleIdx].sourceElementIdx + : 0; + shuffles.push_back(sourceElement); + + int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size()) + ? srcEltRangeList[shuffleIdx].sourceBitBegin + : 0; + int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size()) + ? srcEltRangeList[shuffleIdx].sourceBitEnd + : 0; + IntegerAttr mask = IntegerAttr::get( + shuffledElementType, + llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(), + bitLo, bitHi)); + masks.push_back(mask); + + int64_t shiftRight = bitLo; + shiftRightAmounts.push_back( + IntegerAttr::get(shuffledElementType, shiftRight)); + + int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx); + shiftLeftAmounts.push_back( + IntegerAttr::get(shuffledElementType, shiftLeft)); + } + + result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts}); + } + return result; +} + +Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc, + Value initialValue, Value runningResult, + const BitCastRewriter::Metadata &metadata) { + // Create vector.shuffle from the metadata. + auto shuffleOp = rewriter.create( + loc, initialValue, initialValue, metadata.shuffles); + + // Intersect with the mask. + VectorType shuffledVectorType = shuffleOp.getResultVectorType(); + auto constOp = rewriter.create( + loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks)); + Value andValue = rewriter.create(loc, shuffleOp, constOp); + + // Align right on 0. + auto shiftRightConstantOp = rewriter.create( + loc, + DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts)); + Value shiftedRight = + rewriter.create(loc, andValue, shiftRightConstantOp); + + // Shift bits left into their final position. + auto shiftLeftConstantOp = rewriter.create( + loc, + DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts)); + Value shiftedLeft = + rewriter.create(loc, shiftedRight, shiftLeftConstantOp); + + runningResult = + runningResult + ? rewriter.create(loc, runningResult, shiftedLeft) + : shiftedLeft; + + return runningResult; +} + namespace { /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. - -// BitCastBitsEnumerator encodes for each element of the target vector the -// provenance of the bits in the source vector. We can "transpose" this -// information to build a sequence of shuffles and bitwise ops that will -// produce the desired result. -// -// Let's take the following motivating example to explain the algorithm: -// ``` -// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5> -// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8> -// ``` -// -// BitCastBitsEnumerator contains the following information: -// ``` -// { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 } -// { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 } -// { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 } -// { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 } -// { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 } -// { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 } -// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 } -// { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 } -// { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 } -// { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3} -// { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5} -// { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7} -// { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4} -// { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6} -// { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 } -// { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 } -// { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 } -// { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4} -// { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6} -// { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 } -// ``` -// -// In the above, each row represents one target vector element and each -// column represents one bit contribution from a source vector element. -// The algorithm creates vector.shuffle operations (in this case there are 3 -// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The -// algorithm populates the bits as follows: -// ``` -// src bits 0 ... -// 1st shuffle |xxxxx |xx |... -// 2nd shuffle | xxx| xxxxx |... -// 3rd shuffle | | x|... -// ``` -// -// The algorithm proceeds as follows: -// 1. for each vector.shuffle, collect the source vectors that participate in -// this shuffle. One source vector per target element of the resulting -// vector.shuffle. If there is no source element contributing bits for the -// current vector.shuffle, take 0 (i.e. row 0 in the above example has only -// 2 columns). -// 2. represent the bitrange in the source vector as a mask. If there is no -// source element contributing bits for the current vector.shuffle, take 0. -// 3. shift right by the proper amount to align the source bitrange at -// position 0. This is exactly the low end of the bitrange. For instance, -// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to -// shift right by 3 to get the bits contributed by the source element #1 -// into position 0. -// 4. shift left by the proper amount to to align to the desired position in -// the result element vector. For instance, the contribution of the second -// source element for the first row needs to be shifted by `5` to form the -// first i8 result element. -// Eventually, we end up building the sequence -// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the -// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits -// extracted from the source vector (i.e. the `shuffle -> and` part). struct RewriteBitCastOfTruncI : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -359,93 +487,93 @@ struct RewriteBitCastOfTruncI : OpRewritePattern { if (!truncOp) return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source"); + // Set up the BitCastRewriter and verify the precondition. + VectorType sourceVectorType = bitCastOp.getSourceVectorType(); VectorType targetVectorType = bitCastOp.getResultVectorType(); - if (targetVectorType.getRank() != 1 || targetVectorType.isScalable()) - return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector"); - // TODO: consider relaxing this restriction in the future if we find ways - // to really work with subbyte elements across the MLIR/LLVM boundary. - int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth(); - if (resultBitwidth % 8 != 0) - return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8"); + BitCastRewriter bcr(sourceVectorType, targetVectorType); + if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp))) + return failure(); - VectorType sourceVectorType = bitCastOp.getSourceVectorType(); - BitCastBitsEnumerator be(sourceVectorType, targetVectorType); - LDBG("\n" << be.sourceElementRanges); - - Value initialValue = truncOp.getIn(); - auto initalVectorType = initialValue.getType().cast(); - auto initalElementType = initalVectorType.getElementType(); - auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth(); - - Value res; - for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e; - ++shuffleIdx) { - SmallVector shuffles; - SmallVector masks, shiftRightAmounts, shiftLeftAmounts; - - // Create the attribute quantities for the shuffle / mask / shift ops. - for (auto &srcEltRangeList : be.sourceElementRanges) { - bool idxContributesBits = - (shuffleIdx < (int64_t)srcEltRangeList.size()); - int64_t sourceElementIdx = - idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx - : 0; - shuffles.push_back(sourceElementIdx); - - int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size()) - ? srcEltRangeList[shuffleIdx].sourceBitBegin - : 0; - int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size()) - ? srcEltRangeList[shuffleIdx].sourceBitEnd - : 0; - IntegerAttr mask = IntegerAttr::get( - rewriter.getIntegerType(initalElementBitWidth), - llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi)); - masks.push_back(mask); - - int64_t shiftRight = bitLo; - shiftRightAmounts.push_back(IntegerAttr::get( - rewriter.getIntegerType(initalElementBitWidth), shiftRight)); - - int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx); - shiftLeftAmounts.push_back(IntegerAttr::get( - rewriter.getIntegerType(initalElementBitWidth), shiftLeft)); - } - - // Create vector.shuffle #shuffleIdx. - auto shuffleOp = rewriter.create( - bitCastOp.getLoc(), initialValue, initialValue, shuffles); - // And with the mask. - VectorType vt = VectorType::Builder(initalVectorType) - .setDim(initalVectorType.getRank() - 1, masks.size()); - auto constOp = rewriter.create( - bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks)); - Value andValue = rewriter.create(bitCastOp.getLoc(), - shuffleOp, constOp); - // Align right on 0. - auto shiftRightConstantOp = rewriter.create( - bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts)); - Value shiftedRight = rewriter.create( - bitCastOp.getLoc(), andValue, shiftRightConstantOp); - - auto shiftLeftConstantOp = rewriter.create( - bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts)); - Value shiftedLeft = rewriter.create( - bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp); - - res = res ? rewriter.create(bitCastOp.getLoc(), res, - shiftedLeft) - : shiftedLeft; + // Perform the rewrite. + Value truncValue = truncOp.getIn(); + auto shuffledElementType = + cast(getElementTypeOrSelf(truncValue.getType())); + Value runningResult; + for (const BitCastRewriter ::Metadata &metadata : + bcr.precomputeMetadata(shuffledElementType)) { + runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue, + runningResult, metadata); } - bool narrowing = resultBitwidth <= initalElementBitWidth; + // Finalize the rewrite. + bool narrowing = targetVectorType.getElementTypeBitWidth() <= + shuffledElementType.getIntOrFloatBitWidth(); if (narrowing) { rewriter.replaceOpWithNewOp( - bitCastOp, bitCastOp.getResultVectorType(), res); + bitCastOp, bitCastOp.getResultVectorType(), runningResult); } else { rewriter.replaceOpWithNewOp( - bitCastOp, bitCastOp.getResultVectorType(), res); + bitCastOp, bitCastOp.getResultVectorType(), runningResult); } + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// RewriteExtOfBitCast +//===----------------------------------------------------------------------===// + +namespace { +/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take +/// advantage of high-level information to avoid leaving LLVM to scramble with +/// peephole optimizations. +template +struct RewriteExtOfBitCast : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(ExtOpType extOp, + PatternRewriter &rewriter) const override { + // The source must be a bitcast op. + auto bitCastOp = extOp.getIn().template getDefiningOp(); + if (!bitCastOp) + return rewriter.notifyMatchFailure(extOp, "not a bitcast source"); + + // Set up the BitCastRewriter and verify the precondition. + VectorType sourceVectorType = bitCastOp.getSourceVectorType(); + VectorType targetVectorType = bitCastOp.getResultVectorType(); + BitCastRewriter bcr(sourceVectorType, targetVectorType); + if (failed(bcr.precondition( + rewriter, cast(extOp.getOut().getType()), bitCastOp))) + return failure(); + + // Perform the rewrite. + Value runningResult; + Value sourceValue = bitCastOp.getSource(); + auto shuffledElementType = + cast(getElementTypeOrSelf(sourceValue.getType())); + for (const BitCastRewriter::Metadata &metadata : + bcr.precomputeMetadata(shuffledElementType)) { + runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), + sourceValue, runningResult, metadata); + } + + // Finalize the rewrite. + bool narrowing = + cast(extOp.getOut().getType()).getElementTypeBitWidth() <= + shuffledElementType.getIntOrFloatBitWidth(); + if (narrowing) { + rewriter.replaceOpWithNewOp( + extOp, cast(extOp.getOut().getType()), runningResult); + } else { + rewriter.replaceOpWithNewOp( + extOp, cast(extOp.getOut().getType()), runningResult); + } + return success(); } }; @@ -466,5 +594,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns( void vector::populateVectorNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add, + RewriteExtOfBitCast>(patterns.getContext(), + benefit); } diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index ba6efde40f36c..7754c70458d32 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -146,6 +146,53 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> { return %1 : vector<8xi6> } +// CHECK-LABEL: func.func @f1ext( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> { +func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> { + // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, -32, 124, -128, -16, 62, -64, -8]> : vector<8xi8> + // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[0, 3, 0, 15, 1, 0, 7, 0]> : vector<8xi8> + // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 5, 2, 7, 4, 1, 6, 3]> : vector<8xi8> + // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 3, 5, 1, 4, 5, 2, 5]> : vector<8xi8> + // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1, 2, 3, 3, 4] : vector<5xi8>, vector<5xi8> + // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<8xi8> + // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<8xi8> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 0, 2, 3, 0, 4, 0] : vector<5xi8>, vector<5xi8> + // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<8xi8> + // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<8xi8> + // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<8xi8> + // CHECK: %[[RES:.*]] = arith.extsi %[[O1]] : vector<8xi8> to vector<8xi16> + // return %[[RES]] : vector<8xi16> + + %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5> + %1 = arith.extsi %0 : vector<8xi5> to vector<8xi16> + return %1 : vector<8xi16> +} + +// CHECK-LABEL: func.func @f2ext( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> { +func.func @f2ext(%a: vector<5xi8>) -> vector<8xi16> { + // CHECK-NOT: arith.extsi {{.*}} : vector<8xi8> to vector<8xi16> + // CHECK: %[[RES:.*]] = arith.extui {{.*}} : vector<8xi8> to vector<8xi16> + // return %[[RES]] : vector<8xi16> + + %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5> + %1 = arith.extui %0 : vector<8xi5> to vector<8xi16> + return %1 : vector<8xi16> +} + +// CHECK-LABEL: func.func @f3ext( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi17> { +func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { + // CHECK: bitcast + // CHECK: extsi + // CHECK-NOT: shuffle + // CHECK-NOT: andi + // CHECK-NOT: ori + %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5> + %1 = arith.extsi %0 : vector<8xi5> to vector<8xi17> + return %1 : vector<8xi17> +} + transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): %f = transform.structured.match ops{["func.func"]} in %module_op diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir index 44c608726f135..7d15e2e2e3ef5 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir @@ -124,6 +124,47 @@ func.func @f3(%v: vector<2xi48>) { return } +func.func @print_as_i1_8xi5(%v : vector<8xi5>) { + %bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1> + vector.print %bitsi40 : vector<40xi1> + return +} + +func.func @print_as_i1_8xi16(%v : vector<8xi16>) { + %bitsi128 = vector.bitcast %v : vector<8xi16> to vector<128xi1> + vector.print %bitsi128 : vector<128xi1> + return +} + +func.func @fext(%a: vector<5xi8>) { + %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5> + func.call @print_as_i1_8xi5(%0) : (vector<8xi5>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 1, 0, + // CHECK-SAME: 1, 1, 1, 0, 1, + // CHECK-SAME: 1, 1, 0, 1, 1, + // CHECK-SAME: 1, 1, 0, 1, 1, + // CHECK-SAME: 0, 1, 1, 1, 0, + // CHECK-SAME: 0, 1, 1, 0, 1, + // CHECK-SAME: 1, 1, 1, 1, 0, + // CHECK-SAME: 1, 0, 1, 1, 1 ) + + %1 = arith.extui %0 : vector<8xi5> to vector<8xi16> + func.call @print_as_i1_8xi16(%1) : (vector<8xi16>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + + return +} + + func.func @entry() { %v = arith.constant dense<[ 0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8, @@ -141,6 +182,11 @@ func.func @entry() { ]> : vector<2xi48> func.call @f3(%v3) : (vector<2xi48>) -> () + %v4 = arith.constant dense<[ + 0xef, 0xee, 0xed, 0xec, 0xeb + ]> : vector<5xi8> + func.call @fext(%v4) : (vector<5xi8>) -> () + return } From f20bf26007202e0f42bbec18a28c60b4530e6f62 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 18 Sep 2023 18:55:56 +0200 Subject: [PATCH 2/3] Fix typo --- .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index d2524d7e35cf9..18b46ac921809 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -526,9 +526,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern { //===----------------------------------------------------------------------===// namespace { -/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take -/// advantage of high-level information to avoid leaving LLVM to scramble with -/// peephole optimizations. +/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that +/// take advantage of high-level information to avoid leaving LLVM to scramble +/// with peephole optimizations. template struct RewriteExtOfBitCast : OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 8821ac4bdba653df5e295904dc0119b4bd49fa6e Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 18 Sep 2023 19:01:39 +0200 Subject: [PATCH 3/3] Add endianness warning --- .../mlir/Dialect/Vector/TransformOps/VectorTransformOps.td | 2 ++ .../mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h | 3 +++ 2 files changed, 5 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 133ee4e030f01..3ac6f28dcb938 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -300,6 +300,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op rewriteBitCastOfTruncI(RewriterBase &rewriter, vector::BitCastOp bitCastOp, arith::TruncIOp truncOp, @@ -312,12 +313,14 @@ FailureOr rewriteBitCastOfTruncI(RewriterBase &rewriter, /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. +/// Warning: these patterns currently only work for little endian targets. FailureOr rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, vector::BitCastOp bitCastOp, vector::BroadcastOp maybeBroadcastOp); /// Appends patterns for rewriting vector operations over narrow types with /// ops over wider types. +/// Warning: these patterns currently only work for little endian targets. void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1);