@@ -1096,43 +1096,55 @@ class VectorExtractOpConversion
10961096 SmallVector<OpFoldResult> positionVec = getMixedValues (
10971097 adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
10981098
1099- // Extract entire vector. Should be handled by folder, but just to be safe.
1100- ArrayRef<OpFoldResult> position (positionVec);
1101- if (position.empty ()) {
1102- rewriter.replaceOp (extractOp, adaptor.getVector ());
1103- return success ();
1104- }
1105-
1106- // One-shot extraction of vector from array (only requires extractvalue).
1107- // Except for extracting 1-element vectors.
1108- if (isa<VectorType>(resultType) &&
1109- position.size () !=
1110- static_cast <size_t >(extractOp.getSourceVectorType ().getRank ())) {
1111- if (extractOp.hasDynamicPosition ())
1112- return failure ();
1113-
1114- Value extracted = rewriter.create <LLVM::ExtractValueOp>(
1115- loc, adaptor.getVector (), getAsIntegers (position));
1116- rewriter.replaceOp (extractOp, extracted);
1117- return success ();
1099+ // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100+ // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101+ // from a N-d vector extract to a nested aggregate vector extract in two
1102+ // steps:
1103+ // - Extract a member from the nested aggregate. The result can be
1104+ // a lower rank nested aggregate or a vector (1-D). This is done using
1105+ // `llvm.extractvalue`.
1106+ // - Extract a scalar out of the vector if needed. This is done using
1107+ // `llvm.extractelement`.
1108+
1109+ // Determine if we need to extract a member out of the aggregate. We
1110+ // always need to extract a member if the input rank >= 2.
1111+ bool extractsAggregate = extractOp.getSourceVectorType ().getRank () >= 2 ;
1112+ // Determine if we need to extract a scalar as the result. We extract
1113+ // a scalar if the extract is full rank, i.e., the number of indices is
1114+ // equal to source vector rank.
1115+ bool extractsScalar = static_cast <int64_t >(positionVec.size ()) ==
1116+ extractOp.getSourceVectorType ().getRank ();
1117+
1118+ // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119+ // need to add a position for this change.
1120+ if (extractOp.getSourceVectorType ().getRank () == 0 ) {
1121+ Type idxType = typeConverter->convertType (rewriter.getIndexType ());
1122+ positionVec.push_back (rewriter.getZeroAttr (idxType));
11181123 }
11191124
1120- // Potential extraction of 1-D vector from array.
11211125 Value extracted = adaptor.getVector ();
1122- if (position.size () > 1 ) {
1123- if (extractOp.hasDynamicPosition ())
1126+ if (extractsAggregate) {
1127+ ArrayRef<OpFoldResult> position (positionVec);
1128+ if (extractsScalar) {
1129+ // If we are extracting a scalar from the extracted member, we drop
1130+ // the last index, which will be used to extract the scalar out of the
1131+ // vector.
1132+ position = position.drop_back ();
1133+ }
1134+ // llvm.extractvalue does not support dynamic dimensions.
1135+ if (!llvm::all_of (position, llvm::IsaPred<Attribute>)) {
11241136 return failure ();
1137+ }
1138+ extracted = rewriter.create <LLVM::ExtractValueOp>(
1139+ loc, extracted, getAsIntegers (position));
1140+ }
11251141
1126- SmallVector<int64_t > nMinusOnePosition =
1127- getAsIntegers (position.drop_back ());
1128- extracted = rewriter.create <LLVM::ExtractValueOp>(loc, extracted,
1129- nMinusOnePosition);
1142+ if (extractsScalar) {
1143+ extracted = rewriter.create <LLVM::ExtractElementOp>(
1144+ loc, extracted, getAsLLVMValue (rewriter, loc, positionVec.back ()));
11301145 }
11311146
1132- Value lastPosition = getAsLLVMValue (rewriter, loc, position.back ());
1133- // Remaining extraction of element from 1-D LLVM vector.
1134- rewriter.replaceOpWithNewOp <LLVM::ExtractElementOp>(extractOp, extracted,
1135- lastPosition);
1147+ rewriter.replaceOp (extractOp, extracted);
11361148 return success ();
11371149 }
11381150};
0 commit comments