@@ -158,57 +158,71 @@ static std::optional<APInt> getSplatableConstant(const Constant *C,
158
158
return std::nullopt;
159
159
}
160
160
161
- // Attempt to rebuild a normalized splat vector constant of the requested splat
162
- // width, built up of potentially smaller scalar values.
161
+ // Split raw bits into a constant vector of elements of a specific bit width.
163
162
// NOTE: We don't always bother converting to scalars if the vector length is 1.
164
- static Constant *rebuildSplatableConstant (const Constant *C,
165
- unsigned SplatBitWidth) {
166
- std::optional<APInt> Splat = getSplatableConstant (C, SplatBitWidth);
167
- if (!Splat)
168
- return nullptr ;
169
-
170
- // Determine scalar size to use for the constant splat vector, clamping as we
171
- // might have found a splat smaller than the original constant data.
172
- const Type *OriginalType = C->getType ();
173
- Type *SclTy = OriginalType->getScalarType ();
174
- unsigned NumSclBits = SclTy->getPrimitiveSizeInBits ();
175
- NumSclBits = std::min<unsigned >(NumSclBits, SplatBitWidth);
176
- LLVMContext &Ctx = OriginalType->getContext ();
163
+ static Constant *rebuildConstant (LLVMContext &Ctx, Type *SclTy,
164
+ const APInt &Bits, unsigned NumSclBits) {
165
+ unsigned BitWidth = Bits.getBitWidth ();
177
166
178
167
if (NumSclBits == 8 ) {
179
168
SmallVector<uint8_t > RawBits;
180
- for (unsigned I = 0 ; I != SplatBitWidth ; I += 8 )
181
- RawBits.push_back (Splat-> extractBits (8 , I).getZExtValue ());
169
+ for (unsigned I = 0 ; I != BitWidth ; I += 8 )
170
+ RawBits.push_back (Bits. extractBits (8 , I).getZExtValue ());
182
171
return ConstantDataVector::get (Ctx, RawBits);
183
172
}
184
173
185
174
if (NumSclBits == 16 ) {
186
175
SmallVector<uint16_t > RawBits;
187
- for (unsigned I = 0 ; I != SplatBitWidth ; I += 16 )
188
- RawBits.push_back (Splat-> extractBits (16 , I).getZExtValue ());
176
+ for (unsigned I = 0 ; I != BitWidth ; I += 16 )
177
+ RawBits.push_back (Bits. extractBits (16 , I).getZExtValue ());
189
178
if (SclTy->is16bitFPTy ())
190
179
return ConstantDataVector::getFP (SclTy, RawBits);
191
180
return ConstantDataVector::get (Ctx, RawBits);
192
181
}
193
182
194
183
if (NumSclBits == 32 ) {
195
184
SmallVector<uint32_t > RawBits;
196
- for (unsigned I = 0 ; I != SplatBitWidth ; I += 32 )
197
- RawBits.push_back (Splat-> extractBits (32 , I).getZExtValue ());
185
+ for (unsigned I = 0 ; I != BitWidth ; I += 32 )
186
+ RawBits.push_back (Bits. extractBits (32 , I).getZExtValue ());
198
187
if (SclTy->isFloatTy ())
199
188
return ConstantDataVector::getFP (SclTy, RawBits);
200
189
return ConstantDataVector::get (Ctx, RawBits);
201
190
}
202
191
203
- // Fallback to i64 / double.
192
+ assert (NumSclBits == 64 && " Unhandled vector element width" );
193
+
204
194
SmallVector<uint64_t > RawBits;
205
- for (unsigned I = 0 ; I != SplatBitWidth ; I += 64 )
206
- RawBits.push_back (Splat-> extractBits (64 , I).getZExtValue ());
195
+ for (unsigned I = 0 ; I != BitWidth ; I += 64 )
196
+ RawBits.push_back (Bits. extractBits (64 , I).getZExtValue ());
207
197
if (SclTy->isDoubleTy ())
208
198
return ConstantDataVector::getFP (SclTy, RawBits);
209
199
return ConstantDataVector::get (Ctx, RawBits);
210
200
}
211
201
202
+ // Attempt to rebuild a normalized splat vector constant of the requested splat
203
+ // width, built up of potentially smaller scalar values.
204
+ static Constant *rebuildSplatableConstant (const Constant *C,
205
+ unsigned SplatBitWidth) {
206
+ std::optional<APInt> Splat = getSplatableConstant (C, SplatBitWidth);
207
+ if (!Splat)
208
+ return nullptr ;
209
+
210
+ // Determine scalar size to use for the constant splat vector, clamping as we
211
+ // might have found a splat smaller than the original constant data.
212
+ const Type *OriginalType = C->getType ();
213
+ Type *SclTy = OriginalType->getScalarType ();
214
+ unsigned NumSclBits = SclTy->getPrimitiveSizeInBits ();
215
+ NumSclBits = std::min<unsigned >(NumSclBits, SplatBitWidth);
216
+
217
+ // Fallback to i64 / double.
218
+ NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32 )
219
+ ? NumSclBits
220
+ : 64 ;
221
+
222
+ // Extract per-element bits.
223
+ return rebuildConstant (OriginalType->getContext (), SclTy, *Splat, NumSclBits);
224
+ }
225
+
212
226
bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
213
227
MachineBasicBlock &MBB,
214
228
MachineInstr &MI) {
0 commit comments