Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//===-- AMDGPUMLSchedStrategy.cpp - ML-focused Scheduler Strategy ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// ML-focused scheduling strategy for AMDGPU.
//
//===----------------------------------------------------------------------===//

#include "AMDGPUMLSchedStrategy.h"

using namespace llvm;

AMDGPUMLSchedStrategy::AMDGPUMLSchedStrategy(const MachineSchedContext *C)
: GCNSchedStrategy(C) {
SchedStages.push_back(GCNSchedStageID::ILPInitialSchedule);
SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
// Use more accurate GCN pressure trackers.
UseGCNTrackers = true;
}

bool AMDGPUMLSchedStrategy::tryCandidate(SchedCandidate &Cand,
SchedCandidate &TryCand,
SchedBoundary *Zone) const {
// Initialize the candidate if needed.
if (!Cand.isValid()) {
TryCand.Reason = FirstValid;
return true;
}

// Bias PhysReg Defs and copies to their uses and defined respectively.
if (tryGreater(biasPhysReg(TryCand.SU, TryCand.AtTop),
biasPhysReg(Cand.SU, Cand.AtTop), TryCand, Cand, PhysReg))
return TryCand.Reason != NoCand;

// Avoid exceeding the target's limit.
if (DAG->isTrackingPressure() &&
tryPressure(TryCand.RPDelta.Excess, Cand.RPDelta.Excess, TryCand, Cand,
RegExcess, TRI, DAG->MF))
return TryCand.Reason != NoCand;

// We only compare a subset of features when comparing nodes between
// Top and Bottom boundary. Some properties are simply incomparable, in many
// other instances we should only override the other boundary if something
// is a clear good pick on one boundary. Skip heuristics that are more
// "tie-breaking" in nature.
bool SameBoundary = Zone != nullptr;
if (SameBoundary) {
// For loops that are acyclic path limited, aggressively schedule for
// latency. Within an single cycle, whenever CurrMOps > 0, allow normal
// heuristics to take precedence.
if (Rem.IsAcyclicLatencyLimited && !Zone->getCurrMOps() &&
tryLatency(TryCand, Cand, *Zone))
return TryCand.Reason != NoCand;

// Prioritize instructions that read unbuffered resources by stall cycles.
if (tryLess(Zone->getLatencyStallCycles(TryCand.SU),
Zone->getLatencyStallCycles(Cand.SU), TryCand, Cand, Stall))
return TryCand.Reason != NoCand;
}

// Keep clustered nodes together to encourage downstream peephole
// optimizations which may reduce resource requirements.
//
// This is a best effort to set things up for a post-RA pass. Optimizations
// like generating loads of multiple registers should ideally be done within
// the scheduler pass by combining the loads during DAG postprocessing.
unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID;
unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID;
bool CandIsClusterSucc =
isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx);
bool TryCandIsClusterSucc =
isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx);

if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand,
Cluster))
return TryCand.Reason != NoCand;

if (SameBoundary) {
// Weak edges are for clustering and other constraints.
if (tryLess(getWeakLeft(TryCand.SU, TryCand.AtTop),
getWeakLeft(Cand.SU, Cand.AtTop), TryCand, Cand, Weak))
return TryCand.Reason != NoCand;
}

// Avoid increasing the max pressure of the entire region.
if (DAG->isTrackingPressure() &&
tryPressure(TryCand.RPDelta.CurrentMax, Cand.RPDelta.CurrentMax, TryCand,
Cand, RegMax, TRI, DAG->MF))
return TryCand.Reason != NoCand;

if (SameBoundary) {
// Avoid critical resource consumption and balance the schedule.
TryCand.initResourceDelta(DAG, SchedModel);
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
TryCand, Cand, ResourceReduce))
return TryCand.Reason != NoCand;
if (tryGreater(TryCand.ResDelta.DemandedResources,
Cand.ResDelta.DemandedResources, TryCand, Cand,
ResourceDemand))
return TryCand.Reason != NoCand;

// Avoid serializing long latency dependence chains.
// For acyclic path limited loops, latency was already checked above.
if (!RegionPolicy.DisableLatencyHeuristic && TryCand.Policy.ReduceLatency &&
!Rem.IsAcyclicLatencyLimited && tryLatency(TryCand, Cand, *Zone))
return TryCand.Reason != NoCand;

// Fall through to original instruction order.
if ((Zone->isTop() && TryCand.SU->NodeNum < Cand.SU->NodeNum) ||
(!Zone->isTop() && TryCand.SU->NodeNum > Cand.SU->NodeNum)) {
TryCand.Reason = NodeOrder;
return true;
}
}

return false;
}

AMDGPUMLPostSchedStrategy::AMDGPUMLPostSchedStrategy(
const MachineSchedContext *C)
: PostGenericScheduler(C) {}

bool AMDGPUMLPostSchedStrategy::tryCandidate(SchedCandidate &Cand,
SchedCandidate &TryCand) {
// Initialize the candidate if needed.
if (!Cand.isValid()) {
TryCand.Reason = FirstValid;
return true;
}

// Prioritize instructions that read unbuffered resources by stall cycles.
if (tryLess(Top.getLatencyStallCycles(TryCand.SU),
Top.getLatencyStallCycles(Cand.SU), TryCand, Cand, Stall))
return TryCand.Reason != NoCand;

// Keep clustered nodes together.
unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID;
unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID;
bool CandIsClusterSucc =
isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx);
bool TryCandIsClusterSucc =
isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx);

if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand,
Cluster))
return TryCand.Reason != NoCand;
// Avoid critical resource consumption and balance the schedule.
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
TryCand, Cand, ResourceReduce))
return TryCand.Reason != NoCand;
if (tryGreater(TryCand.ResDelta.DemandedResources,
Cand.ResDelta.DemandedResources, TryCand, Cand,
ResourceDemand))
return TryCand.Reason != NoCand;

// We only compare a subset of features when comparing nodes between
// Top and Bottom boundary.
if (Cand.AtTop == TryCand.AtTop) {
// Avoid serializing long latency dependence chains.
if (Cand.Policy.ReduceLatency &&
tryLatency(TryCand, Cand, Cand.AtTop ? Top : Bot))
return TryCand.Reason != NoCand;
}

// Fall through to original instruction order.
if (TryCand.SU->NodeNum < Cand.SU->NodeNum) {
TryCand.Reason = NodeOrder;
return true;
}

return false;
}
36 changes: 36 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===-- AMDGPUMLSchedStrategy.h - ML-focused Scheduler Strategy -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// ML-focused scheduling strategy for AMDGPU.
//
//===----------------------------------------------------------------------===//

#include "GCNSchedStrategy.h"
#include "llvm/CodeGen/MachineScheduler.h"

namespace llvm {

class AMDGPUMLSchedStrategy final : public GCNSchedStrategy {
protected:
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
SchedBoundary *Zone) const override;

public:
AMDGPUMLSchedStrategy(const MachineSchedContext *C);
};

class AMDGPUMLPostSchedStrategy : public PostGenericScheduler {
protected:
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand) override;

public:
AMDGPUMLPostSchedStrategy(const MachineSchedContext *C);
};

} // End namespace llvm
28 changes: 25 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "GCNPreRAOptimizations.h"
#include "GCNRewritePartialRegUses.h"
#include "GCNSchedStrategy.h"
#include "AMDGPUMLSchedStrategy.h"
#include "GCNVOPDUtils.h"
#include "R600.h"
#include "R600TargetMachine.h"
Expand Down Expand Up @@ -636,6 +637,11 @@ static ScheduleDAGInstrs *createSIMachineScheduler(MachineSchedContext *C) {
return new SIScheduleDAGMI(C);
}

static bool isMLWorkload(const Function &F) {
Attribute WorkloadAttr = F.getFnAttribute("amdgpu-workload-type");
return WorkloadAttr.isValid() && WorkloadAttr.getValueAsString() == "ml";
}

static ScheduleDAGInstrs *
createGCNMaxOccupancyMachineScheduler(MachineSchedContext *C) {
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
Expand All @@ -659,6 +665,11 @@ createGCNMaxILPMachineScheduler(MachineSchedContext *C) {
return DAG;
}

static ScheduleDAGInstrs *createGCNMLMachineScheduler(MachineSchedContext *C) {
return new GCNScheduleDAGMILive(C,
std::make_unique<AMDGPUMLSchedStrategy>(C));
}

static ScheduleDAGInstrs *
createGCNMaxMemoryClauseMachineScheduler(MachineSchedContext *C) {
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
Expand Down Expand Up @@ -1170,6 +1181,9 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
if (ST.enableSIScheduler())
return createSIMachineScheduler(C);

if (isMLWorkload(C->MF->getFunction()))
return createGCNMLMachineScheduler(C);

Attribute SchedStrategyAttr =
C->MF->getFunction().getFnAttribute("amdgpu-sched-strategy");
StringRef SchedStrategy = SchedStrategyAttr.isValid()
Expand All @@ -1191,14 +1205,22 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
if (SchedStrategy == "iterative-maxocc")
return createIterativeGCNMaxOccupancyMachineScheduler(C);

if (SchedStrategy == "ml")
return createGCNMLMachineScheduler(C);

return createGCNMaxOccupancyMachineScheduler(C);
}

ScheduleDAGInstrs *
GCNTargetMachine::createPostMachineScheduler(MachineSchedContext *C) const {
ScheduleDAGMI *DAG =
new GCNPostScheduleDAGMILive(C, std::make_unique<PostGenericScheduler>(C),
/*RemoveKillFlags=*/true);
if (isMLWorkload(C->MF->getFunction()))
return new GCNPostScheduleDAGMILive(
C, std::make_unique<AMDGPUMLPostSchedStrategy>(C),
/*RemoveKillFlags=*/true);

ScheduleDAGMI *DAG = new GCNPostScheduleDAGMILive(
C, std::make_unique<PostGenericScheduler>(C),
/*RemoveKillFlags=*/true);
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
DAG->addMutation(createLoadClusterDAGMutation(DAG->TII, DAG->TRI));
if (ST.shouldClusterStores())
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ add_llvm_target(AMDGPUCodeGen
AMDGPUMacroFusion.cpp
AMDGPUMCInstLower.cpp
AMDGPUMemoryUtils.cpp
AMDGPUMLSchedStrategy.cpp
AMDGPUIGroupLP.cpp
AMDGPULowerVGPREncoding.cpp
AMDGPUMCResourceInfo.cpp
Expand Down
20 changes: 10 additions & 10 deletions llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ static bool canUsePressureDiffs(const SUnit &SU) {
return true;
}

static void getRegisterPressures(
void GCNSchedStrategy::getRegisterPressures(
bool AtTop, const RegPressureTracker &RPTracker, SUnit *SU,
std::vector<unsigned> &Pressure, std::vector<unsigned> &MaxPressure,
GCNDownwardRPTracker &DownwardTracker, GCNUpwardRPTracker &UpwardTracker,
ScheduleDAGMI *DAG, const SIRegisterInfo *SRI) {
// getDownwardPressure() and getUpwardPressure() make temporary changes to
// the tracker, so we need to pass those function a non-const copy.
RegPressureTracker &TempTracker = const_cast<RegPressureTracker &>(RPTracker);
if (!GCNTrackers) {
if (!useGCNTrackers()) {
AtTop
? TempTracker.getDownwardPressure(SU->getInstr(), Pressure, MaxPressure)
: TempTracker.getUpwardPressure(SU->getInstr(), Pressure, MaxPressure);
Expand Down Expand Up @@ -244,7 +244,7 @@ void GCNSchedStrategy::initCandidate(SchedCandidate &Cand, SUnit *SU,
//
// In EXPENSIVE_CHECKS, we always query RPTracker to verify the results of
// PressureDiffs.
if (AtTop || !canUsePressureDiffs(*SU) || GCNTrackers) {
if (AtTop || !canUsePressureDiffs(*SU) || useGCNTrackers()) {
getRegisterPressures(AtTop, RPTracker, SU, Pressure, MaxPressure,
DownwardTracker, UpwardTracker, DAG, SRI);
} else {
Expand Down Expand Up @@ -388,7 +388,7 @@ void GCNSchedStrategy::pickNodeFromQueue(SchedBoundary &Zone,
unsigned VGPRPressure = 0;
IsPending = false;
if (DAG->isTrackingPressure()) {
if (!GCNTrackers) {
if (!useGCNTrackers()) {
SGPRPressure = Pressure[AMDGPU::RegisterPressureSets::SReg_32];
VGPRPressure = Pressure[AMDGPU::RegisterPressureSets::VGPR_32];
} else {
Expand Down Expand Up @@ -611,7 +611,7 @@ SUnit *GCNSchedStrategy::pickNode(bool &IsTopNode) {
}

void GCNSchedStrategy::schedNode(SUnit *SU, bool IsTopNode) {
if (GCNTrackers) {
if (useGCNTrackers()) {
MachineInstr *MI = SU->getInstr();
IsTopNode ? (void)DownwardTracker.advance(MI, false)
: UpwardTracker.recede(*MI);
Expand Down Expand Up @@ -693,7 +693,7 @@ GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
SchedStages.push_back(GCNSchedStageID::UnclusteredHighRPReschedule);
SchedStages.push_back(GCNSchedStageID::ClusteredLowOccupancyReschedule);
SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
GCNTrackers = GCNTrackers & !IsLegacyScheduler;
UseGCNTrackers = GCNTrackers & !IsLegacyScheduler;
}

GCNMaxILPSchedStrategy::GCNMaxILPSchedStrategy(const MachineSchedContext *C)
Expand Down Expand Up @@ -1115,9 +1115,10 @@ void GCNScheduleDAGMILive::finalizeSchedule() {
void GCNScheduleDAGMILive::runSchedStages() {
LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");

GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
if (!Regions.empty()) {
BBLiveInMap = getRegionLiveInMap();
if (GCNTrackers)
if (S.useGCNTrackers())
RegionLiveOuts.buildLiveRegMap();
}

Expand All @@ -1129,7 +1130,6 @@ void GCNScheduleDAGMILive::runSchedStages() {
}
#endif

GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
while (S.advanceStage()) {
auto Stage = createSchedStage(S.getCurrentStage());
if (!Stage->initGCNSchedStage())
Expand All @@ -1145,7 +1145,7 @@ void GCNScheduleDAGMILive::runSchedStages() {
continue;
}

if (GCNTrackers) {
if (S.useGCNTrackers()) {
GCNDownwardRPTracker *DownwardTracker = S.getDownwardTracker();
GCNUpwardRPTracker *UpwardTracker = S.getUpwardTracker();
GCNRPTracker::LiveRegSet *RegionLiveIns =
Expand Down Expand Up @@ -1294,7 +1294,7 @@ bool PreRARematStage::initGCNSchedStage() {

// Rematerialize identified instructions and update scheduler's state.
rematerialize();
if (GCNTrackers)
if (S.useGCNTrackers())
DAG.RegionLiveOuts.buildLiveRegMap();
REMAT_DEBUG({
dbgs() << "Retrying function scheduling with new min. occupancy of "
Expand Down
Loading
Loading