Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds the ability to declare extended instruction sets that have no semantic impact and can be safely removed from a module.
* - ``SPV_INTEL_fp_max_error``
- Adds the ability to specify the maximum error for floating-point operations.
* - ``SPV_INTEL_masked_gather_scatter``
- Allows OpTypeVector to have a phyiscal pointer type component type and introduces gather scatter instructions

To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:

Expand Down
12 changes: 12 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,16 @@ let TargetPrefix = "spv" in {

// FPMaxErrorDecorationINTEL
def int_spv_assign_fpmaxerror_decoration: Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;

// Masked Gather Scatter Intrinsics
def int_spv_masked_gather
:DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>],
[IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
def int_spv_masked_scatter
:DefaultAttrsIntrinsic<[],
[llvm_anyvector_ty, LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
[IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>;
}
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVTargetMachine.cpp
SPIRVUtils.cpp
SPIRVEmitNonSemanticDI.cpp
SPIRVCodeGenPrepare.cpp

LINK_COMPONENTS
Analysis
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SPIRVSubtarget;
class InstructionSelector;
class RegisterBankInfo;

ModulePass *createSPIRVCodeGenPreparePass( const SPIRVTargetMachine &TM);
ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
FunctionPass *createSPIRVStructurizerPass();
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
Expand Down
125 changes: 125 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//===-- SPIRVCodeGenPreparePass.cpp - preserve masked scatter gather --*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass preserves the intrinsic @llvm.masked.* intrinsics by replacing
// it with a spv intrinsic
//===----------------------------------------------------------------------===//
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Pass.h"
#include "llvm/PassRegistry.h"
#include "llvm/Support/raw_ostream.h"

#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"

using namespace llvm;

namespace llvm {
void initializeSPIRVCodeGenPreparePass(PassRegistry &);
} // namespace llvm

namespace {
class SPIRVCodeGenPrepare : public ModulePass {

const SPIRVTargetMachine &TM;

public:
static char ID;
SPIRVCodeGenPrepare(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
initializeSPIRVCodeGenPreparePass(*PassRegistry::getPassRegistry());
}

bool runOnModule(Module &M) override;

StringRef getPassName() const override {
return "SPIRV CodeGen prepare pass";
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
ModulePass::getAnalysisUsage(AU);
}
};

} // namespace

char SPIRVCodeGenPrepare::ID = 0;
INITIALIZE_PASS(SPIRVCodeGenPrepare, "codegen-prepare", "SPIRV codegen prepare",
false, false)

static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
ArrayRef<unsigned> OpNos) {
Function *F = nullptr;
if (OpNos.empty()) {
F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID);
} else {
SmallVector<Type *> Tys;
for (unsigned OpNo : OpNos) {
Tys.push_back(II->getOperand(OpNo)->getType());
}

F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys);
}
II->setCalledFunction(F);
return true;
}

static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic,
const SPIRVSubtarget &ST,
SPIRVGlobalRegistry &GR) {
auto IntrinsicID = Intrinsic->getIntrinsicID();
if (ST.canUseExtension(
SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
switch (IntrinsicID) {
case Intrinsic::masked_scatter: {
return toSpvOverloadedIntrinsic(
Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_scatter,
{0, 1});
} break;

case Intrinsic::masked_gather: {
VectorType* Vty = dyn_cast<VectorType>(Intrinsic -> getOperand(0) -> getType());
PointerType* PTy = dyn_cast<PointerType>(Vty -> getElementType());

VectorType* ResVecType = dyn_cast<VectorType>(Intrinsic -> getType());
Type *CompType = ResVecType -> getElementType();
GR.addPointerToBaseTypeMap(PTy, CompType);
return toSpvOverloadedIntrinsic(
Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_gather, {3, 0});
} break;
default:
break;
}
}
return false;
}

bool SPIRVCodeGenPrepare::runOnModule(Module &M) {
bool Changed = false;
for (Function &F : M) {
const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F);
SPIRVGlobalRegistry &GR = *(STI.getSPIRVGlobalRegistry());
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (auto *II = dyn_cast<IntrinsicInst>(&I))
Changed |= lowerIntrinsicToFunction(II, STI, GR);
}
}
}
return Changed;
}

ModulePass *llvm::createSPIRVCodeGenPreparePass(const SPIRVTargetMachine &TM) {
return new SPIRVCodeGenPrepare(TM);
}
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_INTEL_long_composites",
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
{"SPV_INTEL_fp_max_error",
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}};
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
{"SPV_INTEL_masked_gather_scatter",
SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter}};

bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
llvm::StringRef ArgValue,
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,17 @@ SPIRVType *SPIRVGlobalRegistry::createOpType(
SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {

const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
auto EleOpc = ElemType->getOpcode();
(void)EleOpc;
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
EleOpc == SPIRV::OpTypeBool) &&
"Invalid vector element type");
if (!ST.canUseExtension(
SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
EleOpc == SPIRV::OpTypeBool) &&
"Invalid vector element type");
}

return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
Expand Down Expand Up @@ -1060,6 +1066,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}

if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isVoidTy())
Expand Down Expand Up @@ -1088,11 +1095,12 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}

unsigned AddrSpace = typeToAddressSpace(Ty);
SPIRVType *SpvElementType = nullptr;
if (Type *ElemTy = ::getPointeeType(Ty))
SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
else if (Type *ElemTy = this->findPointerToBaseTypeMap(Ty))
SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
else
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);

Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class SPIRVGlobalRegistry {

// Holds the maximum ID we have in the module.
unsigned Bound;

/// maps the pointer type to the base type
DenseMap<Type *, Type *> PointerToBaseTypeMap;
// Maps values associated with untyped pointers into deduced element types of
// untyped pointers.
DenseMap<Value *, Type *> DeducedElTys;
Expand Down Expand Up @@ -635,6 +636,19 @@ class SPIRVGlobalRegistry {
void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg);
void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType);

void addPointerToBaseTypeMap(Type *PTy, Type *BaseTy) {
if(PTy == nullptr)
return;
assert(PTy->isPointerTy() && "PTy must be a pointer type");
PointerToBaseTypeMap[PTy] = BaseTy;
}

Type *findPointerToBaseTypeMap(const Type *PTy) {
auto BaseTyIter = PointerToBaseTypeMap.find(PTy);
return BaseTyIter == PointerToBaseTypeMap.end() ? nullptr : BaseTyIter -> second;
}

};
} // end namespace llvm
#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
6 changes: 6 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,9 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab
"$res = OpAliasScopeDeclINTEL $AliasDomain">;
def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops),
"$res = OpAliasScopeListDeclINTEL">;

//SPV_INTEL_masked_gather_scatter
def OpMaskedGatherINTEL: Op<6428, (outs ID:$res) , (ins TYPE:$type, ID:$PtrVector, i32imm:$alignment, ID:$mask, ID:$fillempty),
"$res = OpMaskedGatherINTEL $type $PtrVector $alignment $mask $fillempty">;
def OpMaskedScatterINTEL: Op<6429, (outs) , (ins ID:$inVector, ID:$PtrVector, i32imm:$alignment, ID:$mask),
"OpMaskedScatterINTEL $inVector $PtrVector $alignment $mask">;
28 changes: 28 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3192,6 +3192,34 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_discard: {
return selectDiscard(ResVReg, ResType, I);
}
case Intrinsic::spv_masked_gather: {
Register MemLoc = I.getOperand(2).getReg();
int32_t Alignment = I.getOperand(3).getImm();
Register Mask = I.getOperand(4).getReg();
Register PassThrough = I.getOperand(5).getReg();
return BuildMI(*(I.getParent()), I, I.getDebugLoc(),
TII.get(SPIRV::OpMaskedGatherINTEL))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(MemLoc)
.addImm(Alignment)
.addUse(Mask)
.addUse(PassThrough)
.constrainAllUses(TII, TRI, RBI);
}
case Intrinsic::spv_masked_scatter: {
Register Value = I.getOperand(1).getReg();
Register MemLocs = I.getOperand(2).getReg();
int32_t Alignment = I.getOperand(3).getImm();
Register Mask = I.getOperand(4).getReg();
auto MIB = BuildMI(*(I.getParent()), I, I.getDebugLoc(),
TII.get(SPIRV::OpMaskedScatterINTEL))
.addUse(Value)
.addUse(MemLocs)
.addImm(Alignment)
.addUse(Mask);
return MIB.constrainAllUses(TII, TRI, RBI);
}
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,14 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
break;
}
case SPIRV::OpMaskedGatherINTEL:
case SPIRV::OpMaskedScatterINTEL:
if (ST.canUseExtension(
SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
}
break;

default:
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>;
defm SPV_INTEL_long_composites : ExtensionOperand<117>;
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<120>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -513,6 +514,7 @@ defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composi
defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>;
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,12 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) {
}

void SPIRVPassConfig::addIRPasses() {
TargetPassConfig::addIRPasses();

if (TM.getSubtargetImpl()->canUseExtension(
SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
addPass(createSPIRVCodeGenPreparePass(TM));
}
TargetPassConfig::addIRPasses();
if (TM.getSubtargetImpl()->isVulkanEnv()) {
// 1. Simplify loop for subsequent transformations. After this steps, loops
// have the following properties:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}

; CHECK-NOT: Name [[#]] "llvm.masked.gather.v4i32.v4p4"
; CHECK-NOT: Name [[#]] "llvm.masked.scatter.v4i32.v4p4"

; CHECK-DAG: OpCapability MaskedGatherScatterINTEL
; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter"

; CHECK-DAG: %[[#TYPEINT:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#TYPEPTRINT:]] = OpTypePointer Generic %[[#TYPEINT]]
; CHECK-DAG: %[[#TYPEVECPTR:]] = OpTypeVector %[[#TYPEPTRINT]] 4
; CHECK-DAG: %[[#TYPEVECINT:]] = OpTypeVector %[[#TYPEINT]] 4

; CHECK-DAG: %[[#CONST4:]] = OpConstant %[[#TYPEINT]] 4
; CHECK-DAG: %[[#CONST0:]] = OpConstant %[[#TYPEINT]] 0
; CHECK-DAG: %[[#CONST1:]] = OpConstant %[[#TYPEINT]] 1
; CHECK-DAG: %[[#TRUE:]] = OpConstantTrue %[[#]]
; CHECK-DAG: %[[#FALSE:]] = OpConstantFalse %[[#]]
; CHECK-DAG: %[[#MASK1:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#FALSE]] %[[#TRUE]] %[[#TRUE]]
; CHECK-DAG: %[[#FILL:]] = OpConstantComposite %[[#]] %[[#CONST4]] %[[#CONST0]] %[[#CONST1]] %[[#CONST0]]
; CHECK-DAG: %[[#MASK2:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]]

; CHECK: %[[#VECGATHER:]] = OpLoad %[[#TYPEVECPTR]]
; CHECK: %[[#VECSCATTER:]] = OpLoad %[[#TYPEVECPTR]]
; CHECK: %[[#GATHER:]] = OpMaskedGatherINTEL %[[#TYPEVECINT]] %[[#VECGATHER]] 4 %[[#MASK1]] %[[#FILL]]
; CHECK: OpMaskedScatterINTEL %[[#GATHER]] %[[#VECSCATTER]] 4 %[[#MASK2]]

; Function Attrs: nounwind readnone
define spir_kernel void @foo() {
entry:
%arg0 = alloca <4 x ptr addrspace(4)>
%arg1 = alloca <4 x ptr addrspace(4)>
%0 = load <4 x ptr addrspace(4)>, ptr %arg0
%1 = load <4 x ptr addrspace(4)>, ptr %arg1
%res = call <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)> %0, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>)
call void @llvm.masked.scatter.v4i32.v4p4(<4 x i32> %res, <4 x ptr addrspace(4)> %1, i32 4, <4 x i1> splat (i1 true))
ret void
}

declare <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)>, i32, <4 x i1>, <4 x i32>)

declare void @llvm.masked.scatter.v4i32.v4p4(<4 x i32>, <4 x ptr addrspace(4)>, i32, <4 x i1>)

!llvm.module.flags = !{!0}
!opencl.spir.version = !{!1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 2}