Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 81 additions & 40 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29480,7 +29480,6 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
uint64_t SizeInBits = VT.getScalarSizeInBits();
APInt PreferredZero = APInt::getZero(SizeInBits);
APInt OppositeZero = PreferredZero;
EVT IVT = VT.changeTypeToInteger();
X86ISD::NodeType MinMaxOp;
if (IsMaxOp) {
MinMaxOp = X86ISD::FMAX;
Expand All @@ -29492,8 +29491,8 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
EVT SetCCType =
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);

// The tables below show the expected result of Max in cases of NaN and
// signed zeros.
// The tables below show the expected result of Max in cases of NaN and signed
// zeros.
//
// Y Y
// Num xNaN +0 -0
Expand All @@ -29503,12 +29502,9 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
// xNaN | X | X/Y | -0 | +0 | -0 |
// --------------- ---------------
//
// It is achieved by means of FMAX/FMIN with preliminary checks and operand
// reordering.
//
// We check if any of operands is NaN and return NaN. Then we check if any of
// operands is zero or negative zero (for fmaximum and fminimum respectively)
// to ensure the correct zero is returned.
// It is achieved by means of FMAX/FMIN with preliminary checks, operand
// reordering if one operand is a constant, and bitwise operations and selects
// to handle signed zero and NaN operands otherwise.
auto MatchesZero = [](SDValue Op, APInt Zero) {
Op = peekThroughBitcasts(Op);
if (auto *CstOp = dyn_cast<ConstantFPSDNode>(Op))
Expand Down Expand Up @@ -29539,15 +29535,17 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
Op->getFlags().hasNoSignedZeros() ||
DAG.isKnownNeverZeroFloat(X) ||
DAG.isKnownNeverZeroFloat(Y);
SDValue NewX, NewY;
bool ShouldHandleZeros = true;
SDValue NewX = X;
SDValue NewY = Y;
if (IgnoreSignedZero || MatchesZero(Y, PreferredZero) ||
MatchesZero(X, OppositeZero)) {
// Operands are already in right order or order does not matter.
NewX = X;
NewY = Y;
ShouldHandleZeros = false;
} else if (MatchesZero(X, PreferredZero) || MatchesZero(Y, OppositeZero)) {
NewX = Y;
NewY = X;
ShouldHandleZeros = false;
} else if (!VT.isVector() && (VT == MVT::f16 || Subtarget.hasDQI()) &&
(Op->getFlags().hasNoNaNs() || IsXNeverNaN || IsYNeverNaN)) {
if (IsXNeverNaN)
Expand All @@ -29569,33 +29567,6 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
NewX = DAG.getSelect(DL, VT, NeedSwap, Y, X);
NewY = DAG.getSelect(DL, VT, NeedSwap, X, Y);
return DAG.getNode(MinMaxOp, DL, VT, NewX, NewY, Op->getFlags());
} else {
SDValue IsXSigned;
if (Subtarget.is64Bit() || VT != MVT::f64) {
SDValue XInt = DAG.getNode(ISD::BITCAST, DL, IVT, X);
SDValue ZeroCst = DAG.getConstant(0, DL, IVT);
IsXSigned = DAG.getSetCC(DL, SetCCType, XInt, ZeroCst, ISD::SETLT);
} else {
assert(VT == MVT::f64);
SDValue Ins = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v2f64,
DAG.getConstantFP(0, DL, MVT::v2f64), X,
DAG.getVectorIdxConstant(0, DL));
SDValue VX = DAG.getNode(ISD::BITCAST, DL, MVT::v4f32, Ins);
SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, VX,
DAG.getVectorIdxConstant(1, DL));
Hi = DAG.getBitcast(MVT::i32, Hi);
SDValue ZeroCst = DAG.getConstant(0, DL, MVT::i32);
EVT SetCCType = TLI.getSetCCResultType(DAG.getDataLayout(),
*DAG.getContext(), MVT::i32);
IsXSigned = DAG.getSetCC(DL, SetCCType, Hi, ZeroCst, ISD::SETLT);
}
if (MinMaxOp == X86ISD::FMAX) {
NewX = DAG.getSelect(DL, VT, IsXSigned, X, Y);
NewY = DAG.getSelect(DL, VT, IsXSigned, Y, X);
} else {
NewX = DAG.getSelect(DL, VT, IsXSigned, Y, X);
NewY = DAG.getSelect(DL, VT, IsXSigned, X, Y);
}
}

bool IgnoreNaN = DAG.getTarget().Options.NoNaNsFPMath ||
Expand All @@ -29612,10 +29583,80 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,

SDValue MinMax = DAG.getNode(MinMaxOp, DL, VT, NewX, NewY, Op->getFlags());

// We handle signed-zero ordering by taking the larger (or smaller) sign bit.
if (ShouldHandleZeros) {
const fltSemantics &Sem = VT.getFltSemantics();
unsigned EltBits = VT.getScalarSizeInBits();
bool IsFakeVector = !VT.isVector();
MVT LogicVT = VT.getSimpleVT();
if (IsFakeVector)
LogicVT = (VT == MVT::f64) ? MVT::v2f64
: (VT == MVT::f32) ? MVT::v4f32
: MVT::v8f16;

// We take the sign bit from the first operand and combine it with the
// output sign bit (see below). Right now, if ShouldHandleZeros is true, the
// operands will never have been swapped. If you add another optimization
// that swaps the input operands if one is a known value, make sure this
// logic stays correct!
SDValue LogicX = NewX;
SDValue LogicMinMax = MinMax;
if (IsFakeVector) {
// Promote scalars to vectors for bitwise operations.
LogicX = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, LogicVT, NewX);
LogicMinMax = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, LogicVT, MinMax);
}

// x86's min/max operations return the second operand if both inputs are
// signed zero. For the maximum operation, we want to "and" the sign bit of
// the output with the sign bit of the first operand--that means that if the
// first operand is +0.0, the output will be too. For the minimum, it's the
// opposite: we "or" the output sign bit with the sign bit of the first
// operand, ensuring that if the first operand is -0.0, the output will be
// too.
SDValue Result;
if (IsMaxOp) {
// getSignedMaxValue returns a bit pattern of all ones but the highest
// bit. We "or" that with the first operand, then "and" that with the max
// operation's result. That clears only the sign bit, and only if the
// first operand is positive.
SDValue OrMask = DAG.getConstantFP(
APFloat(Sem, APInt::getSignedMaxValue(EltBits)), DL, LogicVT);
SDValue MaskedSignBit =
DAG.getNode(X86ISD::FOR, DL, LogicVT, LogicX, OrMask);
Result =
DAG.getNode(X86ISD::FAND, DL, LogicVT, MaskedSignBit, LogicMinMax);
} else {
// Likewise, getSignMask returns a bit pattern with only the highest bit
// set. This one *sets* only the sign bit, and only if the first operand
// is *negative*.
SDValue AndMask = DAG.getConstantFP(
APFloat(Sem, APInt::getSignMask(EltBits)), DL, LogicVT);
SDValue MaskedSignBit =
DAG.getNode(X86ISD::FAND, DL, LogicVT, LogicX, AndMask);
Result =
DAG.getNode(X86ISD::FOR, DL, LogicVT, MaskedSignBit, LogicMinMax);
}

// Extract scalar back from vector.
if (IsFakeVector)
MinMax = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Result,
DAG.getVectorIdxConstant(0, DL));
else
MinMax = Result;
}

if (IgnoreNaN || DAG.isKnownNeverNaN(IsNum ? NewY : NewX))
return MinMax;

SDValue NaNSrc = IsNum ? MinMax : NewX;
// The x86 min/max return the second operand if either is NaN, which doesn't
// match the numeric or non-numeric semantics. For the non-numeric versions,
// we want to return NaN if either operand is NaN. To do that, we check if
// NewX (the first operand) is NaN, and select it if so. For the numeric
// versions, we want to return the non-NaN operand if there is one. So we
// check if NewY (the second operand) is NaN, and again select the first
// operand if so.
SDValue NaNSrc = IsNum ? NewY : NewX;
SDValue IsNaN = DAG.getSetCC(DL, SetCCType, NaNSrc, NaNSrc, ISD::SETUO);

return DAG.getSelect(DL, VT, IsNaN, NewX, MinMax);
Expand Down
42 changes: 16 additions & 26 deletions llvm/test/CodeGen/X86/avx512fp16-fminimum-fmaximum.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@ declare <32 x half> @llvm.maximum.v32f16(<32 x half>, <32 x half>)
define half @test_fminimum(half %x, half %y) {
; CHECK-LABEL: test_fminimum:
; CHECK: # %bb.0:
; CHECK-NEXT: vmovw %xmm0, %eax
; CHECK-NEXT: testw %ax, %ax
; CHECK-NEXT: sets %al
; CHECK-NEXT: kmovd %eax, %k1
; CHECK-NEXT: vmovaps %xmm1, %xmm2
; CHECK-NEXT: vmovsh %xmm0, %xmm0, %xmm2 {%k1}
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm0 {%k1}
; CHECK-NEXT: vminsh %xmm2, %xmm0, %xmm1
; CHECK-NEXT: vminsh %xmm1, %xmm0, %xmm2
; CHECK-NEXT: vpbroadcastw {{.*#+}} xmm1 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
; CHECK-NEXT: vpternlogq {{.*#+}} xmm1 = (xmm1 & xmm0) | xmm2
; CHECK-NEXT: vcmpunordsh %xmm0, %xmm0, %k1
; CHECK-NEXT: vmovsh %xmm0, %xmm0, %xmm1 {%k1}
; CHECK-NEXT: vmovaps %xmm1, %xmm0
Expand Down Expand Up @@ -92,16 +87,12 @@ define half @test_fminimum_combine_cmps(half %x, half %y) {
define half @test_fmaximum(half %x, half %y) {
; CHECK-LABEL: test_fmaximum:
; CHECK: # %bb.0:
; CHECK-NEXT: vmovw %xmm0, %eax
; CHECK-NEXT: testw %ax, %ax
; CHECK-NEXT: sets %al
; CHECK-NEXT: kmovd %eax, %k1
; CHECK-NEXT: vmovaps %xmm0, %xmm2
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm2 {%k1}
; CHECK-NEXT: vmaxsh %xmm1, %xmm0, %xmm2
; CHECK-NEXT: vpbroadcastw {{.*#+}} xmm1 = [NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN]
; CHECK-NEXT: vpternlogq {{.*#+}} xmm1 = xmm2 & (xmm1 | xmm0)
; CHECK-NEXT: vcmpunordsh %xmm0, %xmm0, %k1
; CHECK-NEXT: vmovsh %xmm0, %xmm0, %xmm1 {%k1}
; CHECK-NEXT: vmaxsh %xmm2, %xmm1, %xmm0
; CHECK-NEXT: vcmpunordsh %xmm1, %xmm1, %k1
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm0 {%k1}
; CHECK-NEXT: vmovaps %xmm1, %xmm0
; CHECK-NEXT: retq
%r = call half @llvm.maximum.f16(half %x, half %y)
ret half %r
Expand Down Expand Up @@ -196,10 +187,9 @@ define <16 x half> @test_fmaximum_v16f16_nans(<16 x half> %x, <16 x half> %y) "n
define <32 x half> @test_fminimum_v32f16_szero(<32 x half> %x, <32 x half> %y) "no-nans-fp-math"="true" {
; CHECK-LABEL: test_fminimum_v32f16_szero:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovw2m %zmm0, %k1
; CHECK-NEXT: vpblendmw %zmm0, %zmm1, %zmm2 {%k1}
; CHECK-NEXT: vmovdqu16 %zmm1, %zmm0 {%k1}
; CHECK-NEXT: vminph %zmm2, %zmm0, %zmm0
; CHECK-NEXT: vminph %zmm1, %zmm0, %zmm1
; CHECK-NEXT: vpbroadcastw {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
; CHECK-NEXT: vpternlogq {{.*#+}} zmm0 = (zmm0 & zmm2) | zmm1
; CHECK-NEXT: retq
%r = call <32 x half> @llvm.minimum.v32f16(<32 x half> %x, <32 x half> %y)
ret <32 x half> %r
Expand All @@ -208,12 +198,12 @@ define <32 x half> @test_fminimum_v32f16_szero(<32 x half> %x, <32 x half> %y) "
define <32 x half> @test_fmaximum_v32f16_nans_szero(<32 x half> %x, <32 x half> %y) {
; CHECK-LABEL: test_fmaximum_v32f16_nans_szero:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovw2m %zmm0, %k1
; CHECK-NEXT: vpblendmw %zmm1, %zmm0, %zmm2 {%k1}
; CHECK-NEXT: vmaxph %zmm1, %zmm0, %zmm2
; CHECK-NEXT: vpbroadcastw {{.*#+}} zmm1 = [NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN]
; CHECK-NEXT: vpternlogq {{.*#+}} zmm1 = zmm2 & (zmm1 | zmm0)
Comment on lines +201 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code seems better since no memory load there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm-mca seems to think the new code is faster despite the memory load, at least in a tight loop. You can try it yourself with:

# LLVM-MCA-BEGIN old vpmovw2m %zmm0, %k1 vpblendmw %zmm0, %zmm1, %zmm2 {%k1} vmovdqu16 %zmm1, %zmm0 {%k1} vminph %zmm2, %zmm0, %zmm0 # LLVM-MCA-END # LLVM-MCA-BEGIN new vminph %zmm1, %zmm0, %zmm1 vpbroadcastw (%rdi), %zmm2 vpternlogq $248, %zmm2, %zmm0, %zmm1 # LLVM-MCA-END

Here's what I get with -mcpu=sapphirerapids:

[0] Code Region - old Iterations: 100 Instructions: 400 Total Cycles: 1103 Total uOps: 400 Dispatch Width: 6 uOps Per Cycle: 0.36 IPC: 0.36 Block RThroughput: 2.0 [snip] [1] Code Region - new Iterations: 100 Instructions: 300 Total Cycles: 607 Total uOps: 400 Dispatch Width: 6 uOps Per Cycle: 0.66 IPC: 0.49 Block RThroughput: 1.0 
; CHECK-NEXT: vcmpunordph %zmm0, %zmm0, %k1
; CHECK-NEXT: vmovdqu16 %zmm0, %zmm1 {%k1}
; CHECK-NEXT: vmaxph %zmm2, %zmm1, %zmm0
; CHECK-NEXT: vcmpunordph %zmm1, %zmm1, %k1
; CHECK-NEXT: vmovdqu16 %zmm1, %zmm0 {%k1}
; CHECK-NEXT: vmovdqa64 %zmm1, %zmm0
; CHECK-NEXT: retq
%r = call <32 x half> @llvm.maximum.v32f16(<32 x half> %x, <32 x half> %y)
ret <32 x half> %r
Expand Down
Loading