@@ -139,20 +139,33 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
139139}
140140
141141static std::optional<APInt> extractConstantBits (const Constant *C,
142- unsigned NumBits) {
142+ int64_t ByteOffset) {
143+ int64_t BitOffset = ByteOffset * 8 ;
143144 if (std::optional<APInt> Bits = extractConstantBits (C))
145+ return Bits->extractBits (Bits->getBitWidth () - BitOffset, BitOffset);
146+ return std::nullopt ;
147+ }
148+
149+ static std::optional<APInt>
150+ extractConstantBits (const Constant *C, int64_t ByteOffset, unsigned NumBits) {
151+ if (std::optional<APInt> Bits = extractConstantBits (C, ByteOffset))
144152 return Bits->zextOrTrunc (NumBits);
145153 return std::nullopt ;
146154}
147155
148156// Attempt to compute the splat width of bits data by normalizing the splat to
149157// remove undefs.
150158static std::optional<APInt> getSplatableConstant (const Constant *C,
159+ int64_t ByteOffset,
151160 unsigned SplatBitWidth) {
152161 const Type *Ty = C->getType ();
153162 assert ((Ty->getPrimitiveSizeInBits () % SplatBitWidth) == 0 &&
154163 " Illegal splat width" );
155164
165+ // TODO: Add ByteOffset support once we have test coverage.
166+ if (ByteOffset != 0 )
167+ return std::nullopt ;
168+
156169 if (std::optional<APInt> Bits = extractConstantBits (C))
157170 if (Bits->isSplat (SplatBitWidth))
158171 return Bits->trunc (SplatBitWidth);
@@ -241,10 +254,12 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
241254
242255// Attempt to rebuild a normalized splat vector constant of the requested splat
243256// width, built up of potentially smaller scalar values.
244- static Constant *rebuildSplatCst (const Constant *C, unsigned /* NumBits*/ ,
245- unsigned /* NumElts*/ , unsigned SplatBitWidth) {
257+ static Constant *rebuildSplatCst (const Constant *C, int64_t ByteOffset,
258+ unsigned /* NumBits*/ , unsigned /* NumElts*/ ,
259+ unsigned SplatBitWidth) {
246260 // TODO: Truncate to NumBits once ConvertToBroadcastAVX512 support this.
247- std::optional<APInt> Splat = getSplatableConstant (C, SplatBitWidth);
261+ std::optional<APInt> Splat =
262+ getSplatableConstant (C, ByteOffset, SplatBitWidth);
248263 if (!Splat)
249264 return nullptr ;
250265
@@ -263,16 +278,17 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
263278 return rebuildConstant (C->getContext (), SclTy, *Splat, NumSclBits);
264279}
265280
266- static Constant *rebuildZeroUpperCst (const Constant *C, unsigned NumBits ,
267- unsigned /* NumElts*/ ,
281+ static Constant *rebuildZeroUpperCst (const Constant *C, int64_t ByteOffset ,
282+ unsigned NumBits, unsigned /* NumElts*/ ,
268283 unsigned ScalarBitWidth) {
269284 Type *SclTy = C->getType ()->getScalarType ();
270285 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits ();
271286 LLVMContext &Ctx = C->getContext ();
272287
273288 if (NumBits > ScalarBitWidth) {
274289 // Determine if the upper bits are all zero.
275- if (std::optional<APInt> Bits = extractConstantBits (C, NumBits)) {
290+ if (std::optional<APInt> Bits =
291+ extractConstantBits (C, ByteOffset, NumBits)) {
276292 if (Bits->countLeadingZeros () >= (NumBits - ScalarBitWidth)) {
277293 // If the original constant was made of smaller elements, try to retain
278294 // those types.
@@ -290,14 +306,14 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
290306}
291307
292308static Constant *rebuildExtCst (const Constant *C, bool IsSExt,
293- unsigned NumBits , unsigned NumElts ,
294- unsigned SrcEltBitWidth) {
309+ int64_t ByteOffset , unsigned NumBits ,
310+ unsigned NumElts, unsigned SrcEltBitWidth) {
295311 unsigned DstEltBitWidth = NumBits / NumElts;
296312 assert ((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
297313 (DstEltBitWidth % SrcEltBitWidth) == 0 &&
298314 (DstEltBitWidth > SrcEltBitWidth) && " Illegal extension width" );
299315
300- if (std::optional<APInt> Bits = extractConstantBits (C, NumBits)) {
316+ if (std::optional<APInt> Bits = extractConstantBits (C, ByteOffset, NumBits)) {
301317 assert ((Bits->getBitWidth () / DstEltBitWidth) == NumElts &&
302318 (Bits->getBitWidth () % DstEltBitWidth) == 0 &&
303319 " Unexpected constant extension" );
@@ -319,13 +335,15 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
319335
320336 return nullptr ;
321337}
322- static Constant *rebuildSExtCst (const Constant *C, unsigned NumBits,
323- unsigned NumElts, unsigned SrcEltBitWidth) {
324- return rebuildExtCst (C, true , NumBits, NumElts, SrcEltBitWidth);
338+ static Constant *rebuildSExtCst (const Constant *C, int64_t ByteOffset,
339+ unsigned NumBits, unsigned NumElts,
340+ unsigned SrcEltBitWidth) {
341+ return rebuildExtCst (C, true , ByteOffset, NumBits, NumElts, SrcEltBitWidth);
325342}
326- static Constant *rebuildZExtCst (const Constant *C, unsigned NumBits,
327- unsigned NumElts, unsigned SrcEltBitWidth) {
328- return rebuildExtCst (C, false , NumBits, NumElts, SrcEltBitWidth);
343+ static Constant *rebuildZExtCst (const Constant *C, int64_t ByteOffset,
344+ unsigned NumBits, unsigned NumElts,
345+ unsigned SrcEltBitWidth) {
346+ return rebuildExtCst (C, false , ByteOffset, NumBits, NumElts, SrcEltBitWidth);
329347}
330348
331349bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
@@ -344,7 +362,8 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
344362 int Op;
345363 int NumCstElts;
346364 int MemBitWidth;
347- std::function<Constant *(const Constant *, unsigned , unsigned , unsigned )>
365+ std::function<Constant *(const Constant *, int64_t , unsigned , unsigned ,
366+ unsigned )>
348367 RebuildConstant;
349368 };
350369 auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned RegBitWidth,
@@ -359,19 +378,23 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
359378#endif
360379 assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
361380 " Unexpected number of operands!" );
362- if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
381+ int64_t ByteOffset = 0 ;
382+ if (auto *C = X86::getConstantFromPool (MI, OperandNo, &ByteOffset)) {
363383 unsigned CstBitWidth = C->getType ()->getPrimitiveSizeInBits ();
364384 RegBitWidth = RegBitWidth ? RegBitWidth : CstBitWidth;
365385 for (const FixupEntry &Fixup : Fixups) {
366- if (Fixup.Op ) {
386+ if (Fixup.Op && 0 <= ByteOffset &&
387+ (RegBitWidth + (8 * ByteOffset)) <= CstBitWidth) {
367388 // Construct a suitable constant and adjust the MI to use the new
368389 // constant pool entry.
369- if (Constant *NewCst = Fixup.RebuildConstant (
370- C, RegBitWidth, Fixup.NumCstElts , Fixup.MemBitWidth )) {
390+ if (Constant *NewCst =
391+ Fixup.RebuildConstant (C, ByteOffset, RegBitWidth,
392+ Fixup.NumCstElts , Fixup.MemBitWidth )) {
371393 unsigned NewCPI =
372394 CP->getConstantPoolIndex (NewCst, Align (Fixup.MemBitWidth / 8 ));
373395 MI.setDesc (TII->get (Fixup.Op ));
374396 MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
397+ MI.getOperand (OperandNo + X86::AddrDisp).setOffset (0 );
375398 return true ;
376399 }
377400 }
0 commit comments