@@ -177,7 +177,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
177
177
assert ((!isa<ConstantInt>(Stride) ||
178
178
cast<ConstantInt>(Stride)->getZExtValue () >= NumElements) &&
179
179
" Stride must be >= the number of elements in the result vector." );
180
- unsigned AS = cast<PointerType>(BasePtr->getType ())->getAddressSpace ();
181
180
182
181
// Compute the start of the vector with index VecIdx as VecIdx * Stride.
183
182
Value *VecStart = Builder.CreateMul (VecIdx, Stride, " vec.start" );
@@ -189,11 +188,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
189
188
else
190
189
VecStart = Builder.CreateGEP (EltType, BasePtr, VecStart, " vec.gep" );
191
190
192
- // Cast elementwise vector start pointer to a pointer to a vector
193
- // (EltType x NumElements)*.
194
- auto *VecType = FixedVectorType::get (EltType, NumElements);
195
- Type *VecPtrType = PointerType::get (VecType, AS);
196
- return Builder.CreatePointerCast (VecStart, VecPtrType, " vec.cast" );
191
+ return VecStart;
197
192
}
198
193
199
194
// / LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -1060,13 +1055,6 @@ class LowerMatrixIntrinsics {
1060
1055
return Changed;
1061
1056
}
1062
1057
1063
- // / Turns \p BasePtr into an elementwise pointer to \p EltType.
1064
- Value *createElementPtr (Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
1065
- unsigned AS = cast<PointerType>(BasePtr->getType ())->getAddressSpace ();
1066
- Type *EltPtrType = PointerType::get (EltType, AS);
1067
- return Builder.CreatePointerCast (BasePtr, EltPtrType);
1068
- }
1069
-
1070
1058
// / Replace intrinsic calls
1071
1059
bool VisitCallInst (CallInst *Inst) {
1072
1060
if (!Inst->getCalledFunction () || !Inst->getCalledFunction ()->isIntrinsic ())
@@ -1118,7 +1106,7 @@ class LowerMatrixIntrinsics {
1118
1106
auto *VType = cast<VectorType>(Ty);
1119
1107
Type *EltTy = VType->getElementType ();
1120
1108
Type *VecTy = FixedVectorType::get (EltTy, Shape.getStride ());
1121
- Value *EltPtr = createElementPtr ( Ptr , EltTy, Builder) ;
1109
+ Value *EltPtr = Ptr ;
1122
1110
MatrixTy Result;
1123
1111
for (unsigned I = 0 , E = Shape.getNumVectors (); I < E; ++I) {
1124
1112
Value *GEP = computeVectorAddr (
@@ -1144,17 +1132,11 @@ class LowerMatrixIntrinsics {
1144
1132
Value *Offset = Builder.CreateAdd (
1145
1133
Builder.CreateMul (J, Builder.getInt64 (MatrixShape.getStride ())), I);
1146
1134
1147
- unsigned AS = cast<PointerType>(MatrixPtr->getType ())->getAddressSpace ();
1148
- Value *EltPtr =
1149
- Builder.CreatePointerCast (MatrixPtr, PointerType::get (EltTy, AS));
1150
- Value *TileStart = Builder.CreateGEP (EltTy, EltPtr, Offset);
1135
+ Value *TileStart = Builder.CreateGEP (EltTy, MatrixPtr, Offset);
1151
1136
auto *TileTy = FixedVectorType::get (EltTy, ResultShape.NumRows *
1152
1137
ResultShape.NumColumns );
1153
- Type *TilePtrTy = PointerType::get (TileTy, AS);
1154
- Value *TilePtr =
1155
- Builder.CreatePointerCast (TileStart, TilePtrTy, " col.cast" );
1156
1138
1157
- return loadMatrix (TileTy, TilePtr , Align,
1139
+ return loadMatrix (TileTy, TileStart , Align,
1158
1140
Builder.getInt64 (MatrixShape.getStride ()), IsVolatile,
1159
1141
ResultShape, Builder);
1160
1142
}
@@ -1190,17 +1172,11 @@ class LowerMatrixIntrinsics {
1190
1172
Value *Offset = Builder.CreateAdd (
1191
1173
Builder.CreateMul (J, Builder.getInt64 (MatrixShape.getStride ())), I);
1192
1174
1193
- unsigned AS = cast<PointerType>(MatrixPtr->getType ())->getAddressSpace ();
1194
- Value *EltPtr =
1195
- Builder.CreatePointerCast (MatrixPtr, PointerType::get (EltTy, AS));
1196
- Value *TileStart = Builder.CreateGEP (EltTy, EltPtr, Offset);
1175
+ Value *TileStart = Builder.CreateGEP (EltTy, MatrixPtr, Offset);
1197
1176
auto *TileTy = FixedVectorType::get (EltTy, StoreVal.getNumRows () *
1198
1177
StoreVal.getNumColumns ());
1199
- Type *TilePtrTy = PointerType::get (TileTy, AS);
1200
- Value *TilePtr =
1201
- Builder.CreatePointerCast (TileStart, TilePtrTy, " col.cast" );
1202
1178
1203
- storeMatrix (TileTy, StoreVal, TilePtr , MAlign,
1179
+ storeMatrix (TileTy, StoreVal, TileStart , MAlign,
1204
1180
Builder.getInt64 (MatrixShape.getStride ()), IsVolatile, Builder);
1205
1181
}
1206
1182
@@ -1210,7 +1186,7 @@ class LowerMatrixIntrinsics {
1210
1186
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1211
1187
IRBuilder<> &Builder) {
1212
1188
auto VType = cast<VectorType>(Ty);
1213
- Value *EltPtr = createElementPtr ( Ptr , VType-> getElementType (), Builder) ;
1189
+ Value *EltPtr = Ptr ;
1214
1190
for (auto Vec : enumerate(StoreVal.vectors ())) {
1215
1191
Value *GEP = computeVectorAddr (
1216
1192
EltPtr,
0 commit comments