Skip to content
63 changes: 63 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3963,6 +3963,66 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
return nullptr;
}

// fcmp + sel patterns into max/min intrinsic.
static Value *foldSelectICmpIntoMaxMin(SelectInst &SI,
InstCombiner::BuilderTy &Builder) {
// Do this transformation only when select instruction
// gives NSZ guarantee.
auto *SIFOp = dyn_cast<FPMathOperator>(&SI);
if (!SIFOp || !SIFOp->hasNoSignedZeros())
return nullptr;

auto TryFoldIntoMaxMinIntrinsic =
[&Builder, &SI](CmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS,
Value *TVal, Value *FVal) -> Value * {
// Early exit if the operands are not in the expected form.
if ((CmpRHS != TVal || CmpLHS != FVal) &&
(CmpLHS != TVal || CmpRHS != FVal))
return nullptr;

bool isSwapped = (CmpLHS == FVal && CmpRHS == TVal);
// Only these relational predicates can be transformed into maxnum/minnum
// intrinsic.
// X > C ? X : C --> maxnum(X, C)
// X > C ? C : X --> minnum(X, C)
if (Pred == CmpInst::FCMP_OGT) {
Intrinsic::ID MaxMinIID =
isSwapped ? Intrinsic::minnum : Intrinsic::maxnum;
return Builder.CreateIntrinsic(SI.getType(), MaxMinIID, {TVal, FVal},
&SI);
}

// X < C ? X : C --> minnum(X, C)
// X < C ? C : X --> maxnum(X, C)
if (Pred == CmpInst::FCMP_OLT) {
Intrinsic::ID MaxMinIID =
isSwapped ? Intrinsic::maxnum : Intrinsic::minnum;
return Builder.CreateIntrinsic(SI.getType(), MaxMinIID, {TVal, FVal},
&SI);
}

return nullptr;
};

// select((fcmp Pred, X, Y), X, Y)
// => minnum/maxnum(X, Y)
//
// Pred := OGT and OLT
Value *X, *Y;
Value *TVal, *FVal;
CmpPredicate Pred;

// Note: OneUse check for `Cmp` is necessary because it makes sure that other
// InstCombine folds don't undo this transformation and cause an infinite
// loop. Furthermore, it could also increase the operation count.
if (match(&SI,
m_OneUse(m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Y))),
m_Value(TVal), m_Value(FVal)))))
return TryFoldIntoMaxMinIntrinsic(Pred, X, Y, TVal, FVal);

return nullptr;
}

static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
Value *FalseVal,
InstCombiner::BuilderTy &Builder,
Expand Down Expand Up @@ -4455,6 +4515,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Value *V = foldSelectIntoAddConstant(SI, Builder))
return replaceInstUsesWith(SI, V);

if (Value *V = foldSelectICmpIntoMaxMin(SI, Builder))
return replaceInstUsesWith(SI, V);

// select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
// Load inst is intentionally not checked for hasOneUse()
if (match(FalseVal, m_Zero()) &&
Expand Down
30 changes: 30 additions & 0 deletions llvm/test/Transforms/InstCombine/float-clamp-to-minmax-nsz.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt < %s -passes=instcombine -S | FileCheck %s

define float @src(float %arg0) {
; CHECK-LABEL: define float @src(
; CHECK-SAME: float [[ARG0:%.*]]) {
; CHECK-NEXT: [[V1:%.*]] = call nsz float @llvm.maxnum.f32(float [[ARG0]], float 0.000000e+00)
; CHECK-NEXT: [[V3:%.*]] = call nsz float @llvm.minnum.f32(float [[V1]], float 0x3FE96C8000000000)
; CHECK-NEXT: ret float [[V3]]
;
%v0 = fcmp nsz ogt float %arg0, 0.000000e+00
%v1 = select nsz i1 %v0, float %arg0, float 0.000000e+00
%v2 = fcmp nsz ogt float %v1, 0x3FE96C8000000000
%v3 = select nsz i1 %v2, float 0x3FE96C8000000000, float %v1
ret float %v3
}

define float @src2(float %arg0) {
; CHECK-LABEL: define float @src2(
; CHECK-SAME: float [[ARG0:%.*]]) {
; CHECK-NEXT: [[V1:%.*]] = call nsz float @llvm.minnum.f32(float [[ARG0]], float 0.000000e+00)
; CHECK-NEXT: [[V3:%.*]] = call nsz float @llvm.maxnum.f32(float [[V1]], float -1.000000e+02)
; CHECK-NEXT: ret float [[V3]]
;
%v0 = fcmp nsz olt float %arg0, 0.000000e+00
%v1 = select nsz i1 %v0, float %arg0, float 0.000000e+00
%v2 = fcmp nsz olt float %v1, -100.00e+00
%v3 = select nsz i1 %v2, float -100.00e+00, float %v1
ret float %v3
}