@@ -642,19 +642,19 @@ struct BitCastRewriter {
642
642
643
643
BitCastRewriter (VectorType sourceVectorType, VectorType targetVectorType);
644
644
645
- // / Verify that the preconditions for the rewrite are met.
646
- LogicalResult precondition (PatternRewriter &rewriter,
647
- VectorType preconditionVectorType , Operation *op);
645
+ // / Verify that general preconditions for the rewrite are met.
646
+ LogicalResult commonPrecondition (PatternRewriter &rewriter,
647
+ VectorType preconditionType , Operation *op);
648
648
649
649
// / Precompute the metadata for the rewrite.
650
650
SmallVector<BitCastRewriter::Metadata>
651
651
precomputeMetadata (IntegerType shuffledElementType);
652
652
653
653
// / Rewrite one step of the sequence:
654
654
// / `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655
- Value rewriteStep (PatternRewriter &rewriter, Location loc, Value initialValue ,
656
- Value runningResult,
657
- const BitCastRewriter::Metadata &metadata);
655
+ Value genericRewriteStep (PatternRewriter &rewriter, Location loc,
656
+ Value initialValue, Value runningResult,
657
+ const BitCastRewriter::Metadata &metadata);
658
658
659
659
private:
660
660
// / Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
719
719
LDBG (" \n " << enumerator.sourceElementRanges );
720
720
}
721
721
722
- LogicalResult BitCastRewriter::precondition (PatternRewriter &rewriter,
723
- VectorType precondition,
724
- Operation *op) {
725
- if (precondition.getRank () != 1 || precondition.isScalable ())
722
+ // / Verify that the precondition type meets the common preconditions for any
723
+ // / conversion.
724
+ static LogicalResult commonConversionPrecondition (PatternRewriter &rewriter,
725
+ VectorType preconditionType,
726
+ Operation *op) {
727
+ if (preconditionType.getRank () != 1 || preconditionType.isScalable ())
726
728
return rewriter.notifyMatchFailure (op, " scalable or >1-D vector" );
727
729
728
730
// TODO: consider relaxing this restriction in the future if we find ways
729
731
// to really work with subbyte elements across the MLIR/LLVM boundary.
730
- int64_t resultBitwidth = precondition .getElementTypeBitWidth ();
732
+ unsigned resultBitwidth = preconditionType .getElementTypeBitWidth ();
731
733
if (resultBitwidth % 8 != 0 )
732
734
return rewriter.notifyMatchFailure (op, " bitwidth is not k * 8" );
733
735
734
736
return success ();
735
737
}
736
738
739
+ LogicalResult BitCastRewriter::commonPrecondition (PatternRewriter &rewriter,
740
+ VectorType preconditionType,
741
+ Operation *op) {
742
+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType )
743
+ return rewriter.notifyMatchFailure (op, " types are not vector" );
744
+
745
+ return commonConversionPrecondition (rewriter, preconditionType, op);
746
+ }
747
+
748
+ // / Verify that source and destination element types meet the precondition for
749
+ // / the supported aligned conversion cases. Alignment means that the either the
750
+ // / source element type is multiple of the destination element type or the other
751
+ // / way around.
752
+ // /
753
+ // / NOTE: This method assumes that common conversion preconditions are met.
754
+ static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
755
+ VectorType srcType,
756
+ VectorType dstType,
757
+ Operation *op) {
758
+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth ();
759
+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
760
+ unsigned byteBitwidth = 8 ;
761
+
762
+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
763
+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
764
+ (dstElemBitwidth % srcElemBitwidth) != 0 )
765
+ return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
766
+
767
+ return success ();
768
+ }
769
+
737
770
SmallVector<BitCastRewriter::Metadata>
738
771
BitCastRewriter::precomputeMetadata (IntegerType shuffledElementType) {
739
772
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
775
808
return result;
776
809
}
777
810
778
- Value BitCastRewriter::rewriteStep (PatternRewriter &rewriter, Location loc,
779
- Value initialValue , Value runningResult ,
780
- const BitCastRewriter::Metadata &metadata) {
811
+ Value BitCastRewriter::genericRewriteStep (
812
+ PatternRewriter &rewriter, Location loc , Value initialValue ,
813
+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
781
814
// Create vector.shuffle from the metadata.
782
815
auto shuffleOp = rewriter.create <vector::ShuffleOp>(
783
816
loc, initialValue, initialValue, metadata.shuffles );
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
810
843
return runningResult;
811
844
}
812
845
846
+ // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
847
+ // / bitwise ops that take advantage of high-level information to avoid leaving
848
+ // / LLVM to scramble with peephole optimizations.
849
+ static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
850
+ Value srcValue) {
851
+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
852
+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
853
+ " Expected i4 type" );
854
+
855
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
856
+ int64_t vecDimSize = srcVecType.getShape ().back ();
857
+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
858
+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
859
+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
860
+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
861
+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
862
+
863
+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
864
+ // byte are place in one vector and the high i4 elements in another vector.
865
+ constexpr int8_t bitsToShift = 4 ;
866
+ auto shiftValues = rewriter.create <arith::ConstantOp>(
867
+ loc, DenseElementsAttr::get (i8VecType, bitsToShift));
868
+ Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
869
+ Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
870
+ Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
871
+
872
+ // 3. Interleave low and high i8 elements using a shuffle.
873
+ SmallVector<int64_t > interleaveMaskValues;
874
+ interleaveMaskValues.reserve (vecDimSize);
875
+ for (int i = 0 , end = vecDimSize / 2 ; i < end; ++i) {
876
+ interleaveMaskValues.push_back (i);
877
+ interleaveMaskValues.push_back (i + (vecDimSize / 2 ));
878
+ }
879
+
880
+ return rewriter.create <vector::ShuffleOp>(
881
+ loc, low, high, rewriter.getI64ArrayAttr (interleaveMaskValues));
882
+ }
883
+
813
884
namespace {
814
885
// / Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
815
886
// / advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
829
900
VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
830
901
VectorType targetVectorType = bitCastOp.getResultVectorType ();
831
902
BitCastRewriter bcr (sourceVectorType, targetVectorType);
832
- if (failed (bcr.precondition (rewriter, targetVectorType, bitCastOp)))
903
+ if (failed (bcr.commonPrecondition (rewriter, targetVectorType, bitCastOp)))
833
904
return failure ();
834
905
835
906
// Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
839
910
Value runningResult;
840
911
for (const BitCastRewriter ::Metadata &metadata :
841
912
bcr.precomputeMetadata (shuffledElementType)) {
842
- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (), truncValue,
843
- runningResult, metadata);
913
+ runningResult = bcr.genericRewriteStep (
914
+ rewriter, bitCastOp-> getLoc (), truncValue, runningResult, metadata);
844
915
}
845
916
846
917
// Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
885
956
VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
886
957
VectorType targetVectorType = bitCastOp.getResultVectorType ();
887
958
BitCastRewriter bcr (sourceVectorType, targetVectorType);
888
- if (failed (bcr.precondition (
959
+ if (failed (bcr.commonPrecondition (
889
960
rewriter, cast<VectorType>(extOp.getOut ().getType ()), bitCastOp)))
890
961
return failure ();
891
962
@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
896
967
cast<IntegerType>(getElementTypeOrSelf (sourceValue.getType ()));
897
968
for (const BitCastRewriter::Metadata &metadata :
898
969
bcr.precomputeMetadata (shuffledElementType)) {
899
- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (),
900
- sourceValue, runningResult, metadata);
970
+ runningResult = bcr.genericRewriteStep (
971
+ rewriter, bitCastOp-> getLoc (), sourceValue, runningResult, metadata);
901
972
}
902
973
903
974
// Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
915
986
return success ();
916
987
}
917
988
};
989
+
990
+ // / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
991
+ // / bitwise ops that take advantage of high-level information to avoid leaving
992
+ // / LLVM to scramble with peephole optimizations.
993
+ // /
994
+ // / For example:
995
+ // / extsi vector<8xi4> -> vector<8xi32>
996
+ // / is rewriten as
997
+ // / sequence of shuffles and bitwise of for i4 -> i8
998
+ // / extsi vector<8xi8> -> vector<8xi32>
999
+ // /
1000
+ // / sitofp vector<8xi4> -> vector<8xf32>
1001
+ // / is rewriten as
1002
+ // / sequence of shuffles and bitwise of for i4 -> i8
1003
+ // / sitofp vector<8xi8> -> vector<8xf32>
1004
+ // /
1005
+ template <typename ConversionOpType>
1006
+ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1007
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1008
+
1009
+ LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1010
+ PatternRewriter &rewriter) const override {
1011
+ // Set up the BitCastRewriter and verify the preconditions.
1012
+ Value srcValue = conversionOp.getIn ();
1013
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1014
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1015
+ if (failed (
1016
+ commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1017
+ return failure ();
1018
+
1019
+ // Check general alignment preconditions.
1020
+ if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1021
+ conversionOp)))
1022
+ return failure ();
1023
+
1024
+ // Perform the rewrite.
1025
+ Value subByteExt =
1026
+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1027
+
1028
+ // Finalize the rewrite.
1029
+ rewriter.replaceOpWithNewOp <ConversionOpType>(
1030
+ conversionOp, conversionOp.getType (), subByteExt);
1031
+ return success ();
1032
+ }
1033
+ };
1034
+
918
1035
} // namespace
919
1036
920
1037
// ===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
936
1053
patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
937
1054
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
938
1055
benefit);
1056
+
1057
+ // Patterns for aligned cases. We set higher priority as they are expected to
1058
+ // generate better performance for aligned cases.
1059
+ patterns.add <RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1060
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1061
+ patterns.getContext (), benefit.getBenefit () + 1 );
939
1062
}
0 commit comments