Skip to content

Commit 21c251a

Browse files
authored
[LowerMatrixIntrinsics] Drop support for typed pointers (#65605)
1 parent 390b486 commit 21c251a

File tree

1 file changed

+7
-31
lines changed

1 file changed

+7
-31
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
177177
assert((!isa<ConstantInt>(Stride) ||
178178
cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
179179
"Stride must be >= the number of elements in the result vector.");
180-
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
181180

182181
// Compute the start of the vector with index VecIdx as VecIdx * Stride.
183182
Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
@@ -189,11 +188,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
189188
else
190189
VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
191190

192-
// Cast elementwise vector start pointer to a pointer to a vector
193-
// (EltType x NumElements)*.
194-
auto *VecType = FixedVectorType::get(EltType, NumElements);
195-
Type *VecPtrType = PointerType::get(VecType, AS);
196-
return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
191+
return VecStart;
197192
}
198193

199194
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -1060,13 +1055,6 @@ class LowerMatrixIntrinsics {
10601055
return Changed;
10611056
}
10621057

1063-
/// Turns \p BasePtr into an elementwise pointer to \p EltType.
1064-
Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
1065-
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
1066-
Type *EltPtrType = PointerType::get(EltType, AS);
1067-
return Builder.CreatePointerCast(BasePtr, EltPtrType);
1068-
}
1069-
10701058
/// Replace intrinsic calls
10711059
bool VisitCallInst(CallInst *Inst) {
10721060
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
@@ -1118,7 +1106,7 @@ class LowerMatrixIntrinsics {
11181106
auto *VType = cast<VectorType>(Ty);
11191107
Type *EltTy = VType->getElementType();
11201108
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
1121-
Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
1109+
Value *EltPtr = Ptr;
11221110
MatrixTy Result;
11231111
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
11241112
Value *GEP = computeVectorAddr(
@@ -1144,17 +1132,11 @@ class LowerMatrixIntrinsics {
11441132
Value *Offset = Builder.CreateAdd(
11451133
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
11461134

1147-
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1148-
Value *EltPtr =
1149-
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1150-
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1135+
Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
11511136
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
11521137
ResultShape.NumColumns);
1153-
Type *TilePtrTy = PointerType::get(TileTy, AS);
1154-
Value *TilePtr =
1155-
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
11561138

1157-
return loadMatrix(TileTy, TilePtr, Align,
1139+
return loadMatrix(TileTy, TileStart, Align,
11581140
Builder.getInt64(MatrixShape.getStride()), IsVolatile,
11591141
ResultShape, Builder);
11601142
}
@@ -1190,17 +1172,11 @@ class LowerMatrixIntrinsics {
11901172
Value *Offset = Builder.CreateAdd(
11911173
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
11921174

1193-
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
1194-
Value *EltPtr =
1195-
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
1196-
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
1175+
Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
11971176
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
11981177
StoreVal.getNumColumns());
1199-
Type *TilePtrTy = PointerType::get(TileTy, AS);
1200-
Value *TilePtr =
1201-
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
12021178

1203-
storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
1179+
storeMatrix(TileTy, StoreVal, TileStart, MAlign,
12041180
Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
12051181
}
12061182

@@ -1210,7 +1186,7 @@ class LowerMatrixIntrinsics {
12101186
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
12111187
IRBuilder<> &Builder) {
12121188
auto VType = cast<VectorType>(Ty);
1213-
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
1189+
Value *EltPtr = Ptr;
12141190
for (auto Vec : enumerate(StoreVal.vectors())) {
12151191
Value *GEP = computeVectorAddr(
12161192
EltPtr,

0 commit comments

Comments
 (0)