Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ UnrollVerifyLoopInfo("unroll-verify-loopinfo", cl::Hidden,
#endif
);

static cl::opt<bool>
UnrollSimplifyReductions("unroll-simplify-reductions", cl::init(true),
cl::Hidden, cl::desc("Try to simplify reductions "
"after unrolling a loop."));


/// Check if unrolling created a situation where we need to insert phi nodes to
/// preserve LCSSA form.
Expand Down Expand Up @@ -209,6 +214,255 @@ static bool isEpilogProfitable(Loop *L) {
return false;
}

/// This function tries to break apart simple reduction loops like the one
/// below:
///
/// loop:
/// PN = PHI [SUM2, loop], ...
/// X = ...
/// SUM1 = ADD (X, PN)
/// Y = ...
/// SUM2 = ADD (Y, SUM1)
/// br loop
///
/// into independent sums of the form:
///
/// loop:
/// PN1 = PHI [SUM1, loop], ...
/// PN2 = PHI [SUM2, loop], ...
/// X = ...
/// SUM1 = ADD (X, PN1)
/// Y = ...
/// SUM2 = ADD (Y, PN2)
/// <Reductions>
/// br loop
///
/// where <Reductions> are new instructions inserted to compute the final
/// values of the reduction from the partial sums we introduced, in this case:
///
/// <Reductions> =
/// PN.red = ADD (PN1, PN2)
/// SUM1.red = ADD (SUM1, PN2)
/// SUM2.red = ADD (SUM1, SUM2)
///
/// In practice in most cases only one or two of the reduced values are
/// required outside the loop so most of the reduction instructions do not
/// need to be added into the loop. Moreover, these instructions can be sunk
/// from the loop which happens in later passes.
///
/// This is a very common pattern in unrolled loops that compute dot products
/// (for example) and breaking apart the reduction chains can help greatly with
/// vectorisation.
static bool trySimplifyReductions(Instruction &I) {
// Check if I is a PHINode (potentially the start of a reduction chain).
// Note: For simplicity we only consider loops that consists of a single
// basic block that branches to itself.
BasicBlock *BB = I.getParent();
PHINode *PN = dyn_cast<PHINode>(&I);
if (!PN || PN->getBasicBlockIndex(BB) == -1)
return false;

// Attempt to construct a list of instructions that are chained together
// (i.e. that perform a reduction).
SmallVector<BinaryOperator *, 16> Ops;
for (Instruction *Cur = PN, *Next = nullptr; /* true */;
Cur = Next, Next = nullptr) {
// Try to find the next element in the reduction chain.
for (auto *U : Cur->users()) {
auto *Candidate = dyn_cast<Instruction>(U);
if (Candidate && Candidate->getParent() == BB) {
// If we've already found a candidate element for the chain and we find
// *another* candidate we bail out as this means the intermediate
// values of the reduction are needed within the loop, and so there is
// no point in breaking the reduction apart.
if (Next)
return false;
Next = Candidate;
}
}
// If we've reached the start, i.e. the next element in the chain would be
// the PN we started with, we are done.
if (Next == PN)
break;
// Else, check if we found a candidate at all and if so if it is a binary
// operator.
if (!Next || !isa<BinaryOperator>(Next))
return false;
// If everything checks out, add the new element to the chain.
Ops.push_back(cast<BinaryOperator>(Next));
}

// Ensure the reduction comprises at least two instructions, otherwise this
// is a trivial reduction of a single element that doesn't need to be
// simplified.
if (Ops.size() < 2)
return false;

LLVM_DEBUG(dbgs() << "Candidate reduction of length " << Ops.size()
<< " found at " << I << ".\n");

// Ensure all instructions perform the same operation and that the operation
// is associative and commutative so that we can break the chain apart and
// reassociate the Ops.
Instruction::BinaryOps const Opcode = Ops[0]->getOpcode();
for (auto const *Op : Ops)
if (Op->getOpcode() != Opcode || !Op->isAssociative() ||
!Op->isCommutative())
return false;

// Define the neutral element of the reduction or bail out if we don't have
// one defined.
// TODO: This could be generalised to other operations (e.g. MUL's).
Value *NeutralElem = nullptr;
switch (Opcode) {
case Instruction::BinaryOps::Add:
case Instruction::BinaryOps::Or:
case Instruction::BinaryOps::Xor:
case Instruction::BinaryOps::FAdd:
NeutralElem = Constant::getNullValue(PN->getType());
break;
case Instruction::BinaryOps::And:
NeutralElem = Constant::getAllOnesValue(PN->getType());
break;
case Instruction::BinaryOps::Mul:
case Instruction::BinaryOps::FMul:
default:
return false;
}
assert(NeutralElem && "Neutral element of reduction undefined.");

// --------------------------------------------------------------------- //
// At this point Ops is a list of chained binary operations performing a //
// reduction that we know we can break apart. //
// --------------------------------------------------------------------- //

// For shorthand, let N be the length of the chain.
unsigned const N = Ops.size();
LLVM_DEBUG(dbgs() << "Simplifying reduction of length " << N << ".\n");

// Create new phi nodes for all but the first element in the chain.
SmallVector<PHINode *, 16> Phis{PN};
for (unsigned i = 1; i < N; i++) {
PHINode *NewPN = PHINode::Create(PN->getType(), PN->getNumIncomingValues(),
PN->getName());
// Copy incoming blocks from the first/original PN to the new Phi and set
// their incoming values to the neutral element of the reduction.
for (auto *IncomingBB : PN->blocks())
NewPN->addIncoming(NeutralElem, IncomingBB);
NewPN->insertAfter(Phis.back());
Phis.push_back(NewPN);
}

// Set the chained operands of the Ops to the Phis and the incoming values of
// the Phis (for this BB) to the Ops.
for (unsigned i = 0; i < N; i++) {
PHINode *Phi = Phis[i];
Instruction *Op = Ops[i];

// Find the index of the operand of Op to replace. The first Op reads its
// value from the first Phi node. The other Ops read their value from the
// previous Op.
Value *OperandToReplace = i == 0 ? cast<Value>(PN) : Ops[i-1];
unsigned OperandIdx = Op->getOperand(0) == OperandToReplace ? 0 : 1;
assert(Op->getOperand(OperandIdx) == OperandToReplace &&
"Operand mismatch. Perhaps a malformed chain?");

// Set the operand of Op to Phi and the incoming value of Phi for BB to Op.
Op->setOperand(OperandIdx, Phi);
Phi->setIncomingValueForBlock(BB, Op);
}

// Replace old uses of PN and Ops outside this BB with the updated totals.
// The "old" total corresponding to PN now corresponds to the sum of all
// Phis. Similarly, the old totals in Ops correspond to the sum of the
// partial results in the new Ops up to the index of the Op we want to
// compute, plus the sum of the Phis from that index onwards.
//
// More rigorously, the totals can be computed as follows.
// 1. Let k be an index in the list of length N+1 below of the variables we
// want to compute the new totals for:
// { PN, Ops[0], Ops[1], ... }
// 2. Let Sum(k) denote the new total to compute for the k-th variable in the
// list above. Then,
// Sum(0) = Sum(PN) = \sum_{0 <= i < N} Phis[i],
// Sum(1) = Sum(Ops[0]) = \sum_{0 <= i < 1} Ops[i] +
// \sum_{1 <= i < N} Phis[i],
// ...
// Sum(N) = Sum(Ops[N-1]) = \sum_{0 <= i < N} Ops[i].
// 3. More generally,
// Sum(k) = Sum(PN) if k == 0 else Sum(Ops[k-1])
// = \sum_{0 <= i < k} Ops[i] +
// \sum_{k <= i < N} Phis[i],
// for 0 <= k <= N.
// 4. Finally, if we name the sums in Ops and Phis separately, i.e.
// SOps(k) = \sum_{0 <= i < k} Ops[i],
// SPhis(k) = \sum_{k <= i < N} Phis[i],
// then
// Sum(k) = SOps(k) + SPhis(k), 0 <= k <= N.
// .

// Helper function to create a new binary op.
// Note: We copy the flags from Ops[0]. Could this be too permissive?
auto CreateBinOp = [&](Value *V1, Value *V2) {
auto Name = PN->getName() + ".red";
return BinaryOperator::CreateWithCopiedFlags(Opcode, V1, V2, Ops[0], Name,
&BB->back());
};

// Compute the partial sums of the Ops:
// SOps[k] = \sum_{0 <= i < k} Ops[i], 0 <= k <= N.
// For 1 <= k <= N we have:
// SOps[k] = Ops[k-1] + \sum_{0 <= i < k-1} Ops[i]
// = Ops[k-1] + SOps[k-1],
// so if we compute SOps in order (i.e. from 0 to N) we can reuse partial
// results.
SmallVector<Value *, 16> SOps(N+1);
SOps[0] = nullptr; // alternatively we could use NeutralElem
SOps[1] = Ops.front();
for (unsigned k = 2; k <= N; k++)
SOps[k] = CreateBinOp(SOps[k-1], Ops[k-1]);

// Compute the partial sums of the Phis:
// SPhis[k] = \sum_{k <= i < N} Phis[i], 0 <= k <= N.
// Similarly, for 0 <= k <= N-1 we have:
// SPhis[k] = Phis[k] + \sum_{k+1 <= i < N} Phis[i]
// = Phis[k] + SPhis[k+1],
// so if we compute SPhis in reverse (i.e. from N down to 0) we can reuse the
// partial sums computed thus far.
SmallVector<Value *, 16> SPhis(N+1);
SPhis[N] = nullptr; // alternatively we could use NeutralElem
SPhis[N-1] = Phis.back();
for (signed k = N-2; k >= 0; k--)
SPhis[k] = CreateBinOp(SPhis[k+1], Phis[k]);

// Finally, compute the total sums for PN and Ops from:
// Sums[k] = SOps[k] + SPhis[k], 0 <= k <= N.
// These sums might be dead so we had them to a weak tracking vector for
// cleanup after.
SmallVector<WeakTrackingVH, 16> Sums(N+1);
for (unsigned k = 0; k <= N; k++) {
// Pick the Op we want to compute the new total for.
Value *Op = k == 0 ? cast<Value>(PN) : Ops[k-1];

Value *SOp = SOps[k], *SPhi = SPhis[k];
if (SOp && SPhi)
Sums[k] = CreateBinOp(SOp, SPhi);
else if (SOp)
Sums[k] = SOp;
else
Sums[k] = SPhi;

// Replace uses of the old total with the new total.
Op->replaceUsesOutsideBlock(Sums[k], BB);
}

// Drop dead totals. In case the totals *are* used they could and should be
// sunk, but this happens in later passes so we don't bother doing it here.
RecursivelyDeleteTriviallyDeadInstructionsPermissive(Sums);

return true;
}

/// Perform some cleanup and simplifications on loops after unrolling. It is
/// useful to simplify the IV's in the new loop, as well as do a quick
/// simplify/dce pass of the instructions.
Expand Down Expand Up @@ -272,6 +526,16 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI,
// have a phi which (potentially indirectly) uses instructions later in
// the block we're iterating through.
RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
// Try to simplify reductions (e.g. chains of floating-point adds) into
// independent operations (see more at trySimplifyReductions). This is a
// very common pattern in unrolled loops that compute dot products (for
// example).
//
// We do this outside the loop over the instructions above to let
// instsimplify kick in before trying to apply this transform.
if (UnrollSimplifyReductions)
for (PHINode &PN : BB->phis())
trySimplifyReductions(PN);
}
}

Expand Down
55 changes: 55 additions & 0 deletions llvm/test/CodeGen/AArch64/polybench-3mm.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: opt -passes=loop-unroll,instcombine -unroll-count=2 %s | llc --mattr=,+neon | FileCheck %s

target triple = "aarch64"

; This is a reduced example adapted from the Polybench 3MM kernel.
; We are doing something similar to:
; double dot = 0.0;
; for (long k = 0; k < 1000; k++)
; dot += A[k] * B[k*nb];
; return dot;

define double @test(ptr %A, ptr %B, i64 %nb) {
; CHECK-LABEL: test:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi d0, #0000000000000000
; CHECK-NEXT: movi d1, #0000000000000000
; CHECK-NEXT: lsl x8, x2, #4
; CHECK-NEXT: mov x9, xzr
; CHECK-NEXT: .LBB0_1: // %loop
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: add x10, x0, x9, lsl #3
; CHECK-NEXT: ldr d2, [x1]
; CHECK-NEXT: ldr d5, [x1, x2, lsl #3]
; CHECK-NEXT: add x9, x9, #2
; CHECK-NEXT: add x1, x1, x8
; CHECK-NEXT: ldp d3, d4, [x10]
; CHECK-NEXT: cmp x9, #1000
; CHECK-NEXT: fmadd d0, d2, d3, d0
; CHECK-NEXT: fmadd d1, d5, d4, d1
; CHECK-NEXT: b.ne .LBB0_1
; CHECK-NEXT: // %bb.2: // %exit
; CHECK-NEXT: fadd d0, d0, d1
; CHECK-NEXT: ret
entry:
br label %loop

loop:
%k = phi i64 [ %k.next, %loop ], [ 0, %entry ]
%dot = phi double [ %dot.next, %loop ], [ 0.000000e+00, %entry ]
%A.gep = getelementptr inbounds double, ptr %A, i64 %k
%A.val = load double, ptr %A.gep, align 8
%B.idx = mul nsw i64 %k, %nb
%B.gep = getelementptr inbounds double, ptr %B, i64 %B.idx
%B.val = load double, ptr %B.gep, align 8
%fmul = fmul fast double %B.val, %A.val
%dot.next = fadd fast double %fmul, %dot
%k.next = add nuw nsw i64 %k, 1
%cmp = icmp eq i64 %k.next, 1000
br i1 %cmp, label %exit, label %loop

exit:
%dot.next.lcssa = phi double [ %dot.next, %loop ]
ret double %dot.next.lcssa
}
20 changes: 20 additions & 0 deletions llvm/test/Transforms/LoopUnroll/AArch64/falkor-prefetch.ll
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ exit:
; NOHWPF-LABEL: loop2:
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: phi
; NOHWPF-NEXT: getelementptr
; NOHWPF-NEXT: load
; NOHWPF-NEXT: add
Expand Down Expand Up @@ -106,13 +113,23 @@ exit:
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: icmp
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: add
; NOHWPF-NEXT: br
; NOHWPF-NEXT-LABEL: exit2:
;
; CHECK-LABEL: @unroll2(
; CHECK-LABEL: loop2:
; CHECK-NEXT: phi
; CHECK-NEXT: phi
; CHECK-NEXT: phi
; CHECK-NEXT: phi
; CHECK-NEXT: phi
; CHECK-NEXT: getelementptr
; CHECK-NEXT: load
; CHECK-NEXT: add
Expand All @@ -130,6 +147,9 @@ exit:
; CHECK-NEXT: add
; CHECK-NEXT: add
; CHECK-NEXT: icmp
; CHECK-NEXT: add
; CHECK-NEXT: add
; CHECK-NEXT: add
; CHECK-NEXT: br
; CHECK-NEXT-LABEL: exit2:

Expand Down
Loading