Skip to content

Commit d91a035

Browse files
committed
[X86] getConstantFromPool - add basic handling for non-zero address offsets
As detailed on #127047 - getConstantFromPool can't handle cases where the constant pool load address offset is non-zero This patch add an optional pointer argument to store the offset allowing users that can handle it to correctly extract the offset sub-constant data This is initially just handled by X86FixupVectorConstantsPass which uses it to extract the offset constant bits - we don't have thorough test coverage for this yet, so I've only added it for the simpler sext/zext/zmovl cases
1 parent 1435c8e commit d91a035

File tree

5 files changed

+65
-32
lines changed

5 files changed

+65
-32
lines changed

llvm/lib/Target/X86/X86FixupVectorConstants.cpp

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,33 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
139139
}
140140

141141
static std::optional<APInt> extractConstantBits(const Constant *C,
142-
unsigned NumBits) {
142+
int64_t ByteOffset) {
143+
int64_t BitOffset = ByteOffset * 8;
143144
if (std::optional<APInt> Bits = extractConstantBits(C))
145+
return Bits->extractBits(Bits->getBitWidth() - BitOffset, BitOffset);
146+
return std::nullopt;
147+
}
148+
149+
static std::optional<APInt>
150+
extractConstantBits(const Constant *C, int64_t ByteOffset, unsigned NumBits) {
151+
if (std::optional<APInt> Bits = extractConstantBits(C, ByteOffset))
144152
return Bits->zextOrTrunc(NumBits);
145153
return std::nullopt;
146154
}
147155

148156
// Attempt to compute the splat width of bits data by normalizing the splat to
149157
// remove undefs.
150158
static std::optional<APInt> getSplatableConstant(const Constant *C,
159+
int64_t ByteOffset,
151160
unsigned SplatBitWidth) {
152161
const Type *Ty = C->getType();
153162
assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
154163
"Illegal splat width");
155164

165+
// TODO: Add ByteOffset support once we have test coverage.
166+
if (ByteOffset != 0)
167+
return std::nullopt;
168+
156169
if (std::optional<APInt> Bits = extractConstantBits(C))
157170
if (Bits->isSplat(SplatBitWidth))
158171
return Bits->trunc(SplatBitWidth);
@@ -241,10 +254,12 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
241254

242255
// Attempt to rebuild a normalized splat vector constant of the requested splat
243256
// width, built up of potentially smaller scalar values.
244-
static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
245-
unsigned /*NumElts*/, unsigned SplatBitWidth) {
257+
static Constant *rebuildSplatCst(const Constant *C, int64_t ByteOffset,
258+
unsigned /*NumBits*/, unsigned /*NumElts*/,
259+
unsigned SplatBitWidth) {
246260
// TODO: Truncate to NumBits once ConvertToBroadcastAVX512 support this.
247-
std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
261+
std::optional<APInt> Splat =
262+
getSplatableConstant(C, ByteOffset, SplatBitWidth);
248263
if (!Splat)
249264
return nullptr;
250265

@@ -263,16 +278,17 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
263278
return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
264279
}
265280

266-
static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
267-
unsigned /*NumElts*/,
281+
static Constant *rebuildZeroUpperCst(const Constant *C, int64_t ByteOffset,
282+
unsigned NumBits, unsigned /*NumElts*/,
268283
unsigned ScalarBitWidth) {
269284
Type *SclTy = C->getType()->getScalarType();
270285
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
271286
LLVMContext &Ctx = C->getContext();
272287

273288
if (NumBits > ScalarBitWidth) {
274289
// Determine if the upper bits are all zero.
275-
if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
290+
if (std::optional<APInt> Bits =
291+
extractConstantBits(C, ByteOffset, NumBits)) {
276292
if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
277293
// If the original constant was made of smaller elements, try to retain
278294
// those types.
@@ -290,14 +306,14 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
290306
}
291307

292308
static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
293-
unsigned NumBits, unsigned NumElts,
294-
unsigned SrcEltBitWidth) {
309+
int64_t ByteOffset, unsigned NumBits,
310+
unsigned NumElts, unsigned SrcEltBitWidth) {
295311
unsigned DstEltBitWidth = NumBits / NumElts;
296312
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
297313
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
298314
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
299315

300-
if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
316+
if (std::optional<APInt> Bits = extractConstantBits(C, ByteOffset, NumBits)) {
301317
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
302318
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
303319
"Unexpected constant extension");
@@ -319,13 +335,15 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
319335

320336
return nullptr;
321337
}
322-
static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
323-
unsigned NumElts, unsigned SrcEltBitWidth) {
324-
return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
338+
static Constant *rebuildSExtCst(const Constant *C, int64_t ByteOffset,
339+
unsigned NumBits, unsigned NumElts,
340+
unsigned SrcEltBitWidth) {
341+
return rebuildExtCst(C, true, ByteOffset, NumBits, NumElts, SrcEltBitWidth);
325342
}
326-
static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
327-
unsigned NumElts, unsigned SrcEltBitWidth) {
328-
return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
343+
static Constant *rebuildZExtCst(const Constant *C, int64_t ByteOffset,
344+
unsigned NumBits, unsigned NumElts,
345+
unsigned SrcEltBitWidth) {
346+
return rebuildExtCst(C, false, ByteOffset, NumBits, NumElts, SrcEltBitWidth);
329347
}
330348

331349
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -344,7 +362,8 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
344362
int Op;
345363
int NumCstElts;
346364
int MemBitWidth;
347-
std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
365+
std::function<Constant *(const Constant *, int64_t, unsigned, unsigned,
366+
unsigned)>
348367
RebuildConstant;
349368
};
350369
auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned RegBitWidth,
@@ -359,19 +378,23 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
359378
#endif
360379
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
361380
"Unexpected number of operands!");
362-
if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
381+
int64_t ByteOffset = 0;
382+
if (auto *C = X86::getConstantFromPool(MI, OperandNo, &ByteOffset)) {
363383
unsigned CstBitWidth = C->getType()->getPrimitiveSizeInBits();
364384
RegBitWidth = RegBitWidth ? RegBitWidth : CstBitWidth;
365385
for (const FixupEntry &Fixup : Fixups) {
366-
if (Fixup.Op) {
386+
if (Fixup.Op && 0 <= ByteOffset &&
387+
(RegBitWidth + (8 * ByteOffset)) <= CstBitWidth) {
367388
// Construct a suitable constant and adjust the MI to use the new
368389
// constant pool entry.
369-
if (Constant *NewCst = Fixup.RebuildConstant(
370-
C, RegBitWidth, Fixup.NumCstElts, Fixup.MemBitWidth)) {
390+
if (Constant *NewCst =
391+
Fixup.RebuildConstant(C, ByteOffset, RegBitWidth,
392+
Fixup.NumCstElts, Fixup.MemBitWidth)) {
371393
unsigned NewCPI =
372394
CP->getConstantPoolIndex(NewCst, Align(Fixup.MemBitWidth / 8));
373395
MI.setDesc(TII->get(Fixup.Op));
374396
MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
397+
MI.getOperand(OperandNo + X86::AddrDisp).setOffset(0);
375398
return true;
376399
}
377400
}

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3656,7 +3656,7 @@ int X86::getFirstAddrOperandIdx(const MachineInstr &MI) {
36563656
}
36573657

36583658
const Constant *X86::getConstantFromPool(const MachineInstr &MI,
3659-
unsigned OpNo) {
3659+
unsigned OpNo, int64_t *ByteOffset) {
36603660
assert(MI.getNumOperands() >= (OpNo + X86::AddrNumOperands) &&
36613661
"Unexpected number of operands!");
36623662

@@ -3665,7 +3665,11 @@ const Constant *X86::getConstantFromPool(const MachineInstr &MI,
36653665
return nullptr;
36663666

36673667
const MachineOperand &Disp = MI.getOperand(OpNo + X86::AddrDisp);
3668-
if (!Disp.isCPI() || Disp.getOffset() != 0)
3668+
if (!Disp.isCPI())
3669+
return nullptr;
3670+
3671+
int64_t Offset = Disp.getOffset();
3672+
if (Offset != 0 && !ByteOffset)
36693673
return nullptr;
36703674

36713675
ArrayRef<MachineConstantPoolEntry> Constants =
@@ -3677,6 +3681,9 @@ const Constant *X86::getConstantFromPool(const MachineInstr &MI,
36773681
if (ConstantEntry.isMachineConstantPoolEntry())
36783682
return nullptr;
36793683

3684+
if (ByteOffset)
3685+
*ByteOffset = Offset;
3686+
36803687
return ConstantEntry.Val.ConstVal;
36813688
}
36823689

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ bool isX87Instruction(MachineInstr &MI);
112112
int getFirstAddrOperandIdx(const MachineInstr &MI);
113113

114114
/// Find any constant pool entry associated with a specific instruction operand.
115-
const Constant *getConstantFromPool(const MachineInstr &MI, unsigned OpNo);
115+
/// By default returns null if the address offset is non-zero, but will return
116+
/// the entry if \p ByteOffset is non-null to store the value.
117+
const Constant *getConstantFromPool(const MachineInstr &MI, unsigned OpNo,
118+
int64_t *ByteOffset = nullptr);
116119

117120
} // namespace X86
118121

llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
242242
; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
243243
; AVX512-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
244244
; AVX512-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
245-
; AVX512-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
245+
; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
246246
; AVX512-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
247247
; AVX512-FCP-NEXT: vmovq %xmm3, (%rsi)
248248
; AVX512-FCP-NEXT: vmovq %xmm4, (%rdx)
@@ -307,7 +307,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
307307
; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
308308
; AVX512DQ-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
309309
; AVX512DQ-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
310-
; AVX512DQ-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
310+
; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
311311
; AVX512DQ-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
312312
; AVX512DQ-FCP-NEXT: vmovq %xmm3, (%rsi)
313313
; AVX512DQ-FCP-NEXT: vmovq %xmm4, (%rdx)
@@ -372,7 +372,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
372372
; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
373373
; AVX512BW-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
374374
; AVX512BW-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
375-
; AVX512BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
375+
; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
376376
; AVX512BW-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
377377
; AVX512BW-FCP-NEXT: vmovq %xmm3, (%rsi)
378378
; AVX512BW-FCP-NEXT: vmovq %xmm4, (%rdx)
@@ -437,7 +437,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
437437
; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
438438
; AVX512DQ-BW-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
439439
; AVX512DQ-BW-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
440-
; AVX512DQ-BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
440+
; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
441441
; AVX512DQ-BW-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
442442
; AVX512DQ-BW-FCP-NEXT: vmovq %xmm3, (%rsi)
443443
; AVX512DQ-BW-FCP-NEXT: vmovq %xmm4, (%rdx)

llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
226226
; AVX512-FCP-NEXT: vmovaps (%rdi), %ymm4
227227
; AVX512-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
228228
; AVX512-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
229-
; AVX512-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
229+
; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
230230
; AVX512-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
231231
; AVX512-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
232232
; AVX512-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
@@ -291,7 +291,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
291291
; AVX512DQ-FCP-NEXT: vmovaps (%rdi), %ymm4
292292
; AVX512DQ-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
293293
; AVX512DQ-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
294-
; AVX512DQ-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
294+
; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
295295
; AVX512DQ-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
296296
; AVX512DQ-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
297297
; AVX512DQ-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
@@ -356,7 +356,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
356356
; AVX512BW-FCP-NEXT: vmovaps (%rdi), %ymm4
357357
; AVX512BW-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
358358
; AVX512BW-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
359-
; AVX512BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
359+
; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
360360
; AVX512BW-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
361361
; AVX512BW-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
362362
; AVX512BW-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
@@ -421,7 +421,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
421421
; AVX512DQ-BW-FCP-NEXT: vmovaps (%rdi), %ymm4
422422
; AVX512DQ-BW-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
423423
; AVX512DQ-BW-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
424-
; AVX512DQ-BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
424+
; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
425425
; AVX512DQ-BW-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
426426
; AVX512DQ-BW-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
427427
; AVX512DQ-BW-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4

0 commit comments

Comments
 (0)