Skip to content

Commit 5518580

Browse files
committed
[LV] Add initial support for vectorizing literal struct return values
This patch adds initial support for vectorizing literal struct return values. Currently, this is limited to the case where the struct is homogeneous (all elements have the same type) and not packed. The intended use case for this is vectorizing intrinsics such as: ``` declare { float, float } @llvm.sincos.f32(float %x) ``` Mapping them to structure-returning library calls such as: ``` declare { <4 x float>, <4 x i32> } @Sleef_sincosf4_u10advsimd(<4 x float>) ``` It could also be possible to vectorize the intrinsic (without a libcall) and then later lower the intrinsic to a library call. This may be desired if the only library calls available take output pointers rather than return multiple values. Implementing this required two main changes: 1. Supporting widening `extractvalue` 2. Adding support for "wide" types (in LV and parts of the cost model) The first change is relatively straightforward, the second is larger as it requires changing assumptions that types are always scalars or vectors. In this patch, a "wide" type is defined as a vector, or a struct literal where all elements are vectors (of the same element count). To help with the second change some helpers for wide types have been added (that work similarly to existing vector helpers). These have been used along the paths needed to support vectorizing calls, however, I expect there are many places that still only expect vector types.
1 parent 7773243 commit 5518580

File tree

16 files changed

+506
-79
lines changed

16 files changed

+506
-79
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

+3-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/Analysis/LoopAccessAnalysis.h"
1919
#include "llvm/IR/Module.h"
2020
#include "llvm/IR/VFABIDemangler.h"
21+
#include "llvm/IR/VectorUtils.h"
2122
#include "llvm/Support/CheckedArithmetic.h"
2223

2324
namespace llvm {
@@ -127,18 +128,8 @@ namespace Intrinsic {
127128
typedef unsigned ID;
128129
}
129130

130-
/// A helper function for converting Scalar types to vector types. If
131-
/// the incoming type is void, we return void. If the EC represents a
132-
/// scalar, we return the scalar type.
133-
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
134-
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
135-
return Scalar;
136-
return VectorType::get(Scalar, EC);
137-
}
138-
139-
inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
140-
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
141-
}
131+
/// Returns true if `Ty` can be widened by the loop vectorizer.
132+
bool canWidenType(Type *Ty);
142133

143134
/// Identify if the intrinsic is trivially vectorizable.
144135
/// This method returns true if the intrinsic's argument types are all scalars

llvm/include/llvm/CodeGen/BasicTTIImpl.h

+25-17
Original file line numberDiff line numberDiff line change
@@ -1561,8 +1561,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
15611561
Type *RetTy = ICA.getReturnType();
15621562

15631563
ElementCount RetVF =
1564-
(RetTy->isVectorTy() ? cast<VectorType>(RetTy)->getElementCount()
1565-
: ElementCount::getFixed(1));
1564+
isWideTy(RetTy) ? getWideTypeVF(RetTy) : ElementCount::getFixed(1);
1565+
15661566
const IntrinsicInst *I = ICA.getInst();
15671567
const SmallVectorImpl<const Value *> &Args = ICA.getArgs();
15681568
FastMathFlags FMF = ICA.getFlags();
@@ -1883,10 +1883,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
18831883
InstructionCost ScalarizationCost = InstructionCost::getInvalid();
18841884
if (RetVF.isVector() && !RetVF.isScalable()) {
18851885
ScalarizationCost = 0;
1886-
if (!RetTy->isVoidTy())
1887-
ScalarizationCost += getScalarizationOverhead(
1888-
cast<VectorType>(RetTy),
1889-
/*Insert*/ true, /*Extract*/ false, CostKind);
1886+
if (!RetTy->isVoidTy()) {
1887+
for (Type *VectorTy : getContainedTypes(RetTy)) {
1888+
ScalarizationCost += getScalarizationOverhead(
1889+
cast<VectorType>(VectorTy),
1890+
/*Insert*/ true, /*Extract*/ false, CostKind);
1891+
}
1892+
}
18901893
ScalarizationCost +=
18911894
getOperandsScalarizationOverhead(Args, ICA.getArgTypes(), CostKind);
18921895
}
@@ -2477,27 +2480,32 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
24772480
// Else, assume that we need to scalarize this intrinsic. For math builtins
24782481
// this will emit a costly libcall, adding call overhead and spills. Make it
24792482
// very expensive.
2480-
if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
2483+
if (isWideTy(RetTy)) {
2484+
const SmallVector<Type *, 2> RetVTys = getContainedTypes(RetTy);
2485+
24812486
// Scalable vectors cannot be scalarized, so return Invalid.
2482-
if (isa<ScalableVectorType>(RetTy) || any_of(Tys, [](const Type *Ty) {
2483-
return isa<ScalableVectorType>(Ty);
2484-
}))
2487+
if (any_of(concat<Type *const>(RetVTys, Tys),
2488+
[](Type *Ty) { return isa<ScalableVectorType>(Ty); }))
24852489
return InstructionCost::getInvalid();
24862490

2487-
InstructionCost ScalarizationCost =
2488-
SkipScalarizationCost
2489-
? ScalarizationCostPassed
2490-
: getScalarizationOverhead(RetVTy, /*Insert*/ true,
2491-
/*Extract*/ false, CostKind);
2491+
InstructionCost ScalarizationCost = ScalarizationCostPassed;
2492+
if (!SkipScalarizationCost) {
2493+
ScalarizationCost = 0;
2494+
for (Type *RetVTy : RetVTys) {
2495+
ScalarizationCost += getScalarizationOverhead(
2496+
cast<VectorType>(RetVTy), /*Insert*/ true,
2497+
/*Extract*/ false, CostKind);
2498+
}
2499+
}
24922500

2493-
unsigned ScalarCalls = cast<FixedVectorType>(RetVTy)->getNumElements();
2501+
unsigned ScalarCalls = getWideTypeVF(RetTy).getFixedValue();
24942502
SmallVector<Type *, 4> ScalarTys;
24952503
for (Type *Ty : Tys) {
24962504
if (Ty->isVectorTy())
24972505
Ty = Ty->getScalarType();
24982506
ScalarTys.push_back(Ty);
24992507
}
2500-
IntrinsicCostAttributes Attrs(IID, RetTy->getScalarType(), ScalarTys, FMF);
2508+
IntrinsicCostAttributes Attrs(IID, ToNarrowTy(RetTy), ScalarTys, FMF);
25012509
InstructionCost ScalarCost =
25022510
thisT()->getIntrinsicInstrCost(Attrs, CostKind);
25032511
for (Type *Ty : Tys) {

llvm/include/llvm/IR/DerivedTypes.h

+4
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ class StructType : public Type {
301301
/// {<vscale x 2 x i32>, <vscale x 4 x i64>}}
302302
bool containsHomogeneousScalableVectorTypes() const;
303303

304+
/// Return true if this struct is non-empty and all element types are the
305+
/// same.
306+
bool containsHomogeneousTypes() const;
307+
304308
/// Return true if this is a named struct that has a non-empty name.
305309
bool hasName() const { return SymbolTableEntry != nullptr; }
306310

llvm/include/llvm/IR/VectorUtils.h

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===----------- VectorUtils.h - Vector type utility functions -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/ADT/SmallVector.h"
10+
#include "llvm/IR/DerivedTypes.h"
11+
12+
namespace llvm {
13+
14+
/// A helper function for converting Scalar types to vector types. If
15+
/// the incoming type is void, we return void. If the EC represents a
16+
/// scalar, we return the scalar type.
17+
inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
18+
if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
19+
return Scalar;
20+
return VectorType::get(Scalar, EC);
21+
}
22+
23+
inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
24+
return ToVectorTy(Scalar, ElementCount::getFixed(VF));
25+
}
26+
27+
/// A helper for converting to wider (vector) types. For scalar types, this is
28+
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
29+
/// struct where each element type has been widened to a vector type. Note: Only
30+
/// unpacked literal struct types are supported.
31+
Type *ToWideTy(Type *Ty, ElementCount EC);
32+
33+
/// A helper for converting wide types to narrow (non-vector) types. For vector
34+
/// types, this is equivalent to calling .getScalarType(). For struct types,
35+
/// this returns a new struct where each element type has been converted to a
36+
/// scalar type. Note: Only unpacked literal struct types are supported.
37+
Type *ToNarrowTy(Type *Ty);
38+
39+
/// Returns the types contained in `Ty`. For struct types, it returns the
40+
/// elements, all other types are returned directly.
41+
SmallVector<Type *, 2> getContainedTypes(Type *Ty);
42+
43+
/// Returns true if `Ty` is a vector type or a struct of vector types where all
44+
/// vector types share the same VF.
45+
bool isWideTy(Type *Ty);
46+
47+
/// Returns the vectorization factor for a widened type.
48+
inline ElementCount getWideTypeVF(Type *Ty) {
49+
assert(isWideTy(Ty) && "expected widened type!");
50+
return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
51+
}
52+
53+
} // namespace llvm

llvm/lib/Analysis/VectorUtils.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ static cl::opt<unsigned> MaxInterleaveGroupFactor(
3939
cl::desc("Maximum factor for an interleaved access group (default = 8)"),
4040
cl::init(8));
4141

42+
/// Returns true if `Ty` can be widened by the loop vectorizer.
43+
bool llvm::canWidenType(Type *Ty) {
44+
Type *ElTy = Ty;
45+
// For now, only allow widening non-packed literal structs where all
46+
// element types are the same. This simplifies the cost model and
47+
// conversion between scalar and wide types.
48+
if (auto *StructTy = dyn_cast<StructType>(Ty);
49+
StructTy && !StructTy->isPacked() && StructTy->isLiteral() &&
50+
StructTy->containsHomogeneousTypes()) {
51+
ElTy = StructTy->elements().front();
52+
}
53+
return VectorType::isValidElementType(ElTy);
54+
}
55+
4256
/// Return true if all of the intrinsic's arguments and return type are scalars
4357
/// for the scalar form of the intrinsic, and vectors for the vector form of the
4458
/// intrinsic (except operands that are marked as always being scalar by

llvm/lib/IR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ add_llvm_component_library(LLVMCore
7171
Value.cpp
7272
ValueSymbolTable.cpp
7373
VectorBuilder.cpp
74+
VectorUtils.cpp
7475
Verifier.cpp
7576
VFABIDemangler.cpp
7677
RuntimeLibcalls.cpp

llvm/lib/IR/Type.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,12 @@ bool StructType::containsHomogeneousScalableVectorTypes() const {
433433
Type *FirstTy = getNumElements() > 0 ? elements()[0] : nullptr;
434434
if (!FirstTy || !isa<ScalableVectorType>(FirstTy))
435435
return false;
436-
for (Type *Ty : elements())
437-
if (Ty != FirstTy)
438-
return false;
439-
return true;
436+
return containsHomogeneousTypes();
437+
}
438+
439+
bool StructType::containsHomogeneousTypes() const {
440+
ArrayRef<Type *> ElementTys = elements();
441+
return !ElementTys.empty() && all_equal(ElementTys);
440442
}
441443

442444
void StructType::setBody(ArrayRef<Type*> Elements, bool isPacked) {

llvm/lib/IR/VFABIDemangler.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/ADT/SmallString.h"
1212
#include "llvm/ADT/StringSwitch.h"
1313
#include "llvm/IR/Module.h"
14+
#include "llvm/IR/VectorUtils.h"
1415
#include "llvm/Support/Debug.h"
1516
#include "llvm/Support/raw_ostream.h"
1617
#include <limits>
@@ -346,12 +347,15 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
346347
// Also check the return type if not void.
347348
Type *RetTy = Signature->getReturnType();
348349
if (!RetTy->isVoidTy()) {
349-
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
350-
// If we have an unknown scalar element type we can't find a reasonable VF.
351-
if (!ReturnEC)
352-
return std::nullopt;
353-
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
354-
MinEC = *ReturnEC;
350+
for (Type *RetTy : getContainedTypes(RetTy)) {
351+
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
352+
// If we have an unknown scalar element type we can't find a reasonable
353+
// VF.
354+
if (!ReturnEC)
355+
return std::nullopt;
356+
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
357+
MinEC = *ReturnEC;
358+
}
355359
}
356360

357361
// The SVE Vector function call ABI bases the VF on the widest element types
@@ -566,7 +570,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,
566570

567571
auto *RetTy = ScalarFTy->getReturnType();
568572
if (!RetTy->isVoidTy())
569-
RetTy = VectorType::get(RetTy, VF);
573+
RetTy = ToWideTy(RetTy, VF);
570574
return FunctionType::get(RetTy, VecTypes, false);
571575
}
572576

llvm/lib/IR/VectorUtils.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===----------- VectorUtils.cpp - Vector type utility functions ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/IR/VectorUtils.h"
10+
#include "llvm/ADT/SmallVectorExtras.h"
11+
12+
using namespace llvm;
13+
14+
/// A helper for converting to wider (vector) types. For scalar types, this is
15+
/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
16+
/// struct where each element type has been widened to a vector type. Note: Only
17+
/// unpacked literal struct types are supported.
18+
Type *llvm::ToWideTy(Type *Ty, ElementCount EC) {
19+
if (EC.isScalar())
20+
return Ty;
21+
auto *StructTy = dyn_cast<StructType>(Ty);
22+
if (!StructTy)
23+
return ToVectorTy(Ty, EC);
24+
assert(StructTy->isLiteral() && !StructTy->isPacked() &&
25+
"expected unpacked struct literal");
26+
return StructType::get(
27+
Ty->getContext(),
28+
map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
29+
return VectorType::get(ElTy, EC);
30+
}));
31+
}
32+
33+
/// A helper for converting wide types to narrow (non-vector) types. For vector
34+
/// types, this is equivalent to calling .getScalarType(). For struct types,
35+
/// this returns a new struct where each element type has been converted to a
36+
/// scalar type. Note: Only unpacked literal struct types are supported.
37+
Type *llvm::ToNarrowTy(Type *Ty) {
38+
auto *StructTy = dyn_cast<StructType>(Ty);
39+
if (!StructTy)
40+
return Ty->getScalarType();
41+
assert(StructTy->isLiteral() && !StructTy->isPacked() &&
42+
"expected unpacked struct literal");
43+
return StructType::get(
44+
Ty->getContext(),
45+
map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
46+
return ElTy->getScalarType();
47+
}));
48+
}
49+
50+
/// Returns the types contained in `Ty`. For struct types, it returns the
51+
/// elements, all other types are returned directly.
52+
SmallVector<Type *, 2> llvm::getContainedTypes(Type *Ty) {
53+
auto *StructTy = dyn_cast<StructType>(Ty);
54+
if (StructTy)
55+
return to_vector<2>(StructTy->elements());
56+
return {Ty};
57+
}
58+
59+
/// Returns true if `Ty` is a vector type or a struct of vector types where all
60+
/// vector types share the same VF.
61+
bool llvm::isWideTy(Type *Ty) {
62+
auto ContainedTys = getContainedTypes(Ty);
63+
if (ContainedTys.empty() || !ContainedTys.front()->isVectorTy())
64+
return false;
65+
ElementCount VF = cast<VectorType>(ContainedTys.front())->getElementCount();
66+
return all_of(ContainedTys, [&](Type *Ty) {
67+
return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
68+
});
69+
}

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
945945
// Check that the instruction return type is vectorizable.
946946
// We can't vectorize casts from vector type to scalar type.
947947
// Also, we can't vectorize extractelement instructions.
948-
if ((!VectorType::isValidElementType(I.getType()) &&
949-
!I.getType()->isVoidTy()) ||
948+
Type *InstTy = I.getType();
949+
if (!(InstTy->isVoidTy() || canWidenType(InstTy)) ||
950950
(isa<CastInst>(I) &&
951951
!VectorType::isValidElementType(I.getOperand(0)->getType())) ||
952952
isa<ExtractElementInst>(I)) {

0 commit comments

Comments
 (0)