- Notifications
You must be signed in to change notification settings - Fork 15.3k
[SPIRV] Support for extension SPV_INTEL_masked_gather_scatter #131566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
| @llvm/pr-subscribers-backend-spir-v Author: VISHAKH PRAKASH (VishMCW) ChangesAdd intrinsic SPV_INTEL_masked_gather_scatter
Full diff: https://github.com/llvm/llvm-project/pull/131566.diff 14 Files Affected:
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst index 3e19ff881dffc..781a16dff0d0f 100644 --- a/llvm/docs/SPIRVUsage.rst +++ b/llvm/docs/SPIRVUsage.rst @@ -209,6 +209,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na - Adds the ability to declare extended instruction sets that have no semantic impact and can be safely removed from a module. * - ``SPV_INTEL_fp_max_error`` - Adds the ability to specify the maximum error for floating-point operations. + * - ``SPV_INTEL_masked_gather_scatter`` + - Allows OpTypeVector to have a phyiscal pointer type component type and introduces gather scatter instructions To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use: diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index df3e137c80980..5a70350af1804 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -143,4 +143,16 @@ let TargetPrefix = "spv" in { // FPMaxErrorDecorationINTEL def int_spv_assign_fpmaxerror_decoration: Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>; + + // Masked Gather Scatter Intrinsics + def int_spv_masked_gather + :DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>], + [IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>; + def int_spv_masked_scatter + :DefaultAttrsIntrinsic<[], + [llvm_anyvector_ty, LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], + [IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>; } diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 4a2b534b948d6..48ef19b334695 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -46,6 +46,7 @@ add_llvm_target(SPIRVCodeGen SPIRVTargetMachine.cpp SPIRVUtils.cpp SPIRVEmitNonSemanticDI.cpp + SPIRVCodeGenPrepare.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index d765dfe370be2..a843a7144f3c2 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -19,6 +19,7 @@ class SPIRVSubtarget; class InstructionSelector; class RegisterBankInfo; +ModulePass *createSPIRVCodeGenPreparePass( const SPIRVTargetMachine &TM); ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM); FunctionPass *createSPIRVStructurizerPass(); FunctionPass *createSPIRVMergeRegionExitTargetsPass(); diff --git a/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp new file mode 100644 index 0000000000000..ea497880360ff --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp @@ -0,0 +1,125 @@ +//===-- SPIRVCodeGenPreparePass.cpp - preserve masked scatter gather --*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass preserves the intrinsic @llvm.masked.* intrinsics by replacing +// it with a spv intrinsic +//===----------------------------------------------------------------------===// +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Pass.h" +#include "llvm/PassRegistry.h" +#include "llvm/Support/raw_ostream.h" + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVCodeGenPreparePass(PassRegistry &); +} // namespace llvm + +namespace { +class SPIRVCodeGenPrepare : public ModulePass { + + const SPIRVTargetMachine &TM; + +public: + static char ID; + SPIRVCodeGenPrepare(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) { + initializeSPIRVCodeGenPreparePass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + StringRef getPassName() const override { + return "SPIRV CodeGen prepare pass"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + ModulePass::getAnalysisUsage(AU); + } +}; + +} // namespace + +char SPIRVCodeGenPrepare::ID = 0; +INITIALIZE_PASS(SPIRVCodeGenPrepare, "codegen-prepare", "SPIRV codegen prepare", + false, false) + +static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID, + ArrayRef<unsigned> OpNos) { + Function *F = nullptr; + if (OpNos.empty()) { + F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID); + } else { + SmallVector<Type *> Tys; + for (unsigned OpNo : OpNos) { + Tys.push_back(II->getOperand(OpNo)->getType()); + } + + F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys); + } + II->setCalledFunction(F); + return true; +} + +static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic, + const SPIRVSubtarget &ST, + SPIRVGlobalRegistry &GR) { + auto IntrinsicID = Intrinsic->getIntrinsicID(); + if (ST.canUseExtension( + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) { + switch (IntrinsicID) { + case Intrinsic::masked_scatter: { + return toSpvOverloadedIntrinsic( + Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_scatter, + {0, 1}); + } break; + + case Intrinsic::masked_gather: { + VectorType* Vty = dyn_cast<VectorType>(Intrinsic -> getOperand(0) -> getType()); + PointerType* PTy = dyn_cast<PointerType>(Vty -> getElementType()); + + VectorType* ResVecType = dyn_cast<VectorType>(Intrinsic -> getType()); + Type *CompType = ResVecType -> getElementType(); + GR.addPointerToBaseTypeMap(PTy, CompType); + return toSpvOverloadedIntrinsic( + Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_gather, {3, 0}); + } break; + default: + break; + } + } + return false; +} + +bool SPIRVCodeGenPrepare::runOnModule(Module &M) { + bool Changed = false; + for (Function &F : M) { + const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F); + SPIRVGlobalRegistry &GR = *(STI.getSPIRVGlobalRegistry()); + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + Changed |= lowerIntrinsicToFunction(II, STI, GR); + } + } + } + return Changed; +} + +ModulePass *llvm::createSPIRVCodeGenPreparePass(const SPIRVTargetMachine &TM) { + return new SPIRVCodeGenPrepare(TM); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 37119bf01545c..357bccbbc2c1a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -92,7 +92,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> {"SPV_INTEL_long_composites", SPIRV::Extension::Extension::SPV_INTEL_long_composites}, {"SPV_INTEL_fp_max_error", - SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}}; + SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}, + {"SPV_INTEL_masked_gather_scatter", + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter}}; bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName, llvm::StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index cbec1c95eadc3..829de62ed8213 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -231,11 +231,17 @@ SPIRVType *SPIRVGlobalRegistry::createOpType( SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, MachineIRBuilder &MIRBuilder) { + + const SPIRVSubtarget &ST = + cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); auto EleOpc = ElemType->getOpcode(); (void)EleOpc; - assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || - EleOpc == SPIRV::OpTypeBool) && - "Invalid vector element type"); + if (!ST.canUseExtension( + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) { + assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || + EleOpc == SPIRV::OpTypeBool) && + "Invalid vector element type"); + } return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { return MIRBuilder.buildInstr(SPIRV::OpTypeVector) @@ -1060,6 +1066,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( return Width == 1 ? getOpTypeBool(MIRBuilder) : getOpTypeInt(Width, MIRBuilder, false); } + if (Ty->isFloatingPointTy()) return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); if (Ty->isVoidTy()) @@ -1088,11 +1095,12 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR)); return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); } - unsigned AddrSpace = typeToAddressSpace(Ty); SPIRVType *SpvElementType = nullptr; if (Type *ElemTy = ::getPointeeType(Ty)) SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR); + else if (Type *ElemTy = this->findPointerToBaseTypeMap(Ty)) + SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR); else SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 89599f17ef737..24cc62b3b69a0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -78,7 +78,8 @@ class SPIRVGlobalRegistry { // Holds the maximum ID we have in the module. unsigned Bound; - + /// maps the pointer type to the base type + DenseMap<Type *, Type *> PointerToBaseTypeMap; // Maps values associated with untyped pointers into deduced element types of // untyped pointers. DenseMap<Value *, Type *> DeducedElTys; @@ -635,6 +636,19 @@ class SPIRVGlobalRegistry { void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg); void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg); void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType); + + void addPointerToBaseTypeMap(Type *PTy, Type *BaseTy) { + if(PTy == nullptr) + return; + assert(PTy->isPointerTy() && "PTy must be a pointer type"); + PointerToBaseTypeMap[PTy] = BaseTy; + } + + Type *findPointerToBaseTypeMap(const Type *PTy) { + auto BaseTyIter = PointerToBaseTypeMap.find(PTy); + return BaseTyIter == PointerToBaseTypeMap.end() ? nullptr : BaseTyIter -> second; + } + }; } // end namespace llvm #endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index a8f862271dbab..d16621eff4f3b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -956,3 +956,9 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab "$res = OpAliasScopeDeclINTEL $AliasDomain">; def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops), "$res = OpAliasScopeListDeclINTEL">; + +//SPV_INTEL_masked_gather_scatter +def OpMaskedGatherINTEL: Op<6428, (outs ID:$res) , (ins TYPE:$type, ID:$PtrVector, i32imm:$alignment, ID:$mask, ID:$fillempty), + "$res = OpMaskedGatherINTEL $type $PtrVector $alignment $mask $fillempty">; +def OpMaskedScatterINTEL: Op<6429, (outs) , (ins ID:$inVector, ID:$PtrVector, i32imm:$alignment, ID:$mask), + "OpMaskedScatterINTEL $inVector $PtrVector $alignment $mask">; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index b188f36ca9a9e..1b25470791286 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3192,6 +3192,34 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_discard: { return selectDiscard(ResVReg, ResType, I); } + case Intrinsic::spv_masked_gather: { + Register MemLoc = I.getOperand(2).getReg(); + int32_t Alignment = I.getOperand(3).getImm(); + Register Mask = I.getOperand(4).getReg(); + Register PassThrough = I.getOperand(5).getReg(); + return BuildMI(*(I.getParent()), I, I.getDebugLoc(), + TII.get(SPIRV::OpMaskedGatherINTEL)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(MemLoc) + .addImm(Alignment) + .addUse(Mask) + .addUse(PassThrough) + .constrainAllUses(TII, TRI, RBI); + } + case Intrinsic::spv_masked_scatter: { + Register Value = I.getOperand(1).getReg(); + Register MemLocs = I.getOperand(2).getReg(); + int32_t Alignment = I.getOperand(3).getImm(); + Register Mask = I.getOperand(4).getReg(); + auto MIB = BuildMI(*(I.getParent()), I, I.getDebugLoc(), + TII.get(SPIRV::OpMaskedScatterINTEL)) + .addUse(Value) + .addUse(MemLocs) + .addImm(Alignment) + .addUse(Mask); + return MIB.constrainAllUses(TII, TRI, RBI); + } default: { std::string DiagMsg; raw_string_ostream OS(DiagMsg); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 63894acacbc73..e487cca7decd5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1770,6 +1770,14 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL); break; } + case SPIRV::OpMaskedGatherINTEL: + case SPIRV::OpMaskedScatterINTEL: + if (ST.canUseExtension( + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter); + Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL); + } + break; default: break; diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index caee778eddbc4..d3ee45b6591a7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -313,6 +313,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>; defm SPV_INTEL_long_composites : ExtensionOperand<117>; defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>; defm SPV_INTEL_fp_max_error : ExtensionOperand<119>; +defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<120>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -513,6 +514,7 @@ defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composi defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>; defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>; defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>; +defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index 0aa214dd354ee..ad26ed3b597db 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -175,8 +175,12 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) { } void SPIRVPassConfig::addIRPasses() { - TargetPassConfig::addIRPasses(); + if (TM.getSubtargetImpl()->canUseExtension( + SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) { + addPass(createSPIRVCodeGenPreparePass(TM)); + } + TargetPassConfig::addIRPasses(); if (TM.getSubtargetImpl()->isVulkanEnv()) { // 1. Simplify loop for subsequent transformations. After this steps, loops // have the following properties: diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll new file mode 100644 index 0000000000000..b10c63b700a5a --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll @@ -0,0 +1,50 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %} + +; CHECK-NOT: Name [[#]] "llvm.masked.gather.v4i32.v4p4" +; CHECK-NOT: Name [[#]] "llvm.masked.scatter.v4i32.v4p4" + +; CHECK-DAG: OpCapability MaskedGatherScatterINTEL +; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter" + +; CHECK-DAG: %[[#TYPEINT:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#TYPEPTRINT:]] = OpTypePointer Generic %[[#TYPEINT]] +; CHECK-DAG: %[[#TYPEVECPTR:]] = OpTypeVector %[[#TYPEPTRINT]] 4 +; CHECK-DAG: %[[#TYPEVECINT:]] = OpTypeVector %[[#TYPEINT]] 4 + +; CHECK-DAG: %[[#CONST4:]] = OpConstant %[[#TYPEINT]] 4 +; CHECK-DAG: %[[#CONST0:]] = OpConstant %[[#TYPEINT]] 0 +; CHECK-DAG: %[[#CONST1:]] = OpConstant %[[#TYPEINT]] 1 +; CHECK-DAG: %[[#TRUE:]] = OpConstantTrue %[[#]] +; CHECK-DAG: %[[#FALSE:]] = OpConstantFalse %[[#]] +; CHECK-DAG: %[[#MASK1:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#FALSE]] %[[#TRUE]] %[[#TRUE]] +; CHECK-DAG: %[[#FILL:]] = OpConstantComposite %[[#]] %[[#CONST4]] %[[#CONST0]] %[[#CONST1]] %[[#CONST0]] +; CHECK-DAG: %[[#MASK2:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] + +; CHECK: %[[#VECGATHER:]] = OpLoad %[[#TYPEVECPTR]] +; CHECK: %[[#VECSCATTER:]] = OpLoad %[[#TYPEVECPTR]] +; CHECK: %[[#GATHER:]] = OpMaskedGatherINTEL %[[#TYPEVECINT]] %[[#VECGATHER]] 4 %[[#MASK1]] %[[#FILL]] +; CHECK: OpMaskedScatterINTEL %[[#GATHER]] %[[#VECSCATTER]] 4 %[[#MASK2]] + +; Function Attrs: nounwind readnone +define spir_kernel void @foo() { +entry: + %arg0 = alloca <4 x ptr addrspace(4)> + %arg1 = alloca <4 x ptr addrspace(4)> + %0 = load <4 x ptr addrspace(4)>, ptr %arg0 + %1 = load <4 x ptr addrspace(4)>, ptr %arg1 + %res = call <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)> %0, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>) + call void @llvm.masked.scatter.v4i32.v4p4(<4 x i32> %res, <4 x ptr addrspace(4)> %1, i32 4, <4 x i1> splat (i1 true)) + ret void +} + +declare <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)>, i32, <4 x i1>, <4 x i32>) + +declare void @llvm.masked.scatter.v4i32.v4p4(<4 x i32>, <4 x ptr addrspace(4)>, i32, <4 x i1>) + +!llvm.module.flags = !{!0} +!opencl.spir.version = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 1, i32 2} + |
| @llvm/pr-subscribers-llvm-ir Author: VISHAKH PRAKASH (VishMCW) ChangesAdd intrinsic SPV_INTEL_masked_gather_scatter
Full diff: https://github.com/llvm/llvm-project/pull/131566.diff 14 Files Affected:
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst index 3e19ff881dffc..781a16dff0d0f 100644 --- a/llvm/docs/SPIRVUsage.rst +++ b/llvm/docs/SPIRVUsage.rst @@ -209,6 +209,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na - Adds the ability to declare extended instruction sets that have no semantic impact and can be safely removed from a module. * - ``SPV_INTEL_fp_max_error`` - Adds the ability to specify the maximum error for floating-point operations. + * - ``SPV_INTEL_masked_gather_scatter`` + - Allows OpTypeVector to have a phyiscal pointer type component type and introduces gather scatter instructions To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use: diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index df3e137c80980..5a70350af1804 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -143,4 +143,16 @@ let TargetPrefix = "spv" in { // FPMaxErrorDecorationINTEL def int_spv_assign_fpmaxerror_decoration: Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>; + + // Masked Gather Scatter Intrinsics + def int_spv_masked_gather + :DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>], + [IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>; + def int_spv_masked_scatter + :DefaultAttrsIntrinsic<[], + [llvm_anyvector_ty, LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], + [IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>; } diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 4a2b534b948d6..48ef19b334695 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -46,6 +46,7 @@ add_llvm_target(SPIRVCodeGen SPIRVTargetMachine.cpp SPIRVUtils.cpp SPIRVEmitNonSemanticDI.cpp + SPIRVCodeGenPrepare.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index d765dfe370be2..a843a7144f3c2 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -19,6 +19,7 @@ class SPIRVSubtarget; class InstructionSelector; class RegisterBankInfo; +ModulePass *createSPIRVCodeGenPreparePass( const SPIRVTargetMachine &TM); ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM); FunctionPass *createSPIRVStructurizerPass(); FunctionPass *createSPIRVMergeRegionExitTargetsPass(); diff --git a/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp new file mode 100644 index 0000000000000..ea497880360ff --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp @@ -0,0 +1,125 @@ +//===-- SPIRVCodeGenPreparePass.cpp - preserve masked scatter gather --*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass preserves the intrinsic @llvm.masked.* intrinsics by replacing +// it with a spv intrinsic +//===----------------------------------------------------------------------===// +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Pass.h" +#include "llvm/PassRegistry.h" +#include "llvm/Support/raw_ostream.h" + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVCodeGenPreparePass(PassRegistry &); +} // namespace llvm + +namespace { +class SPIRVCodeGenPrepare : public ModulePass { + + const SPIRVTargetMachine &TM; + +public: + static char ID; + SPIRVCodeGenPrepare(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) { + initializeSPIRVCodeGenPreparePass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + StringRef getPassName() const override { + return "SPIRV CodeGen prepare pass"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + ModulePass::getAnalysisUsage(AU); + } +}; + +} // namespace + +char SPIRVCodeGenPrepare::ID = 0; +INITIALIZE_PASS(SPIRVCodeGenPrepare, "codegen-prepare", "SPIRV codegen prepare", + false, false) + +static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID, + ArrayRef<unsigned> OpNos) { + Function *F = nullptr; + if (OpNos.empty()) { + F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID); + } else { + SmallVector<Type *> Tys; + for (unsigned OpNo : OpNos) { + Tys.push_back(II->getOperand(OpNo)->getType()); + } + + F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys); + } + II->setCalledFunction(F); + return true; +} + +static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic, + const SPIRVSubtarget &ST, + SPIRVGlobalRegistry &GR) { + auto IntrinsicID = Intrinsic->getIntrinsicID(); + if (ST.canUseExtension( + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) { + switch (IntrinsicID) { + case Intrinsic::masked_scatter: { + return toSpvOverloadedIntrinsic( + Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_scatter, + {0, 1}); + } break; + + case Intrinsic::masked_gather: { + VectorType* Vty = dyn_cast<VectorType>(Intrinsic -> getOperand(0) -> getType()); + PointerType* PTy = dyn_cast<PointerType>(Vty -> getElementType()); + + VectorType* ResVecType = dyn_cast<VectorType>(Intrinsic -> getType()); + Type *CompType = ResVecType -> getElementType(); + GR.addPointerToBaseTypeMap(PTy, CompType); + return toSpvOverloadedIntrinsic( + Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_gather, {3, 0}); + } break; + default: + break; + } + } + return false; +} + +bool SPIRVCodeGenPrepare::runOnModule(Module &M) { + bool Changed = false; + for (Function &F : M) { + const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F); + SPIRVGlobalRegistry &GR = *(STI.getSPIRVGlobalRegistry()); + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + Changed |= lowerIntrinsicToFunction(II, STI, GR); + } + } + } + return Changed; +} + +ModulePass *llvm::createSPIRVCodeGenPreparePass(const SPIRVTargetMachine &TM) { + return new SPIRVCodeGenPrepare(TM); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 37119bf01545c..357bccbbc2c1a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -92,7 +92,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> {"SPV_INTEL_long_composites", SPIRV::Extension::Extension::SPV_INTEL_long_composites}, {"SPV_INTEL_fp_max_error", - SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}}; + SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}, + {"SPV_INTEL_masked_gather_scatter", + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter}}; bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName, llvm::StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index cbec1c95eadc3..829de62ed8213 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -231,11 +231,17 @@ SPIRVType *SPIRVGlobalRegistry::createOpType( SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, MachineIRBuilder &MIRBuilder) { + + const SPIRVSubtarget &ST = + cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); auto EleOpc = ElemType->getOpcode(); (void)EleOpc; - assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || - EleOpc == SPIRV::OpTypeBool) && - "Invalid vector element type"); + if (!ST.canUseExtension( + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) { + assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || + EleOpc == SPIRV::OpTypeBool) && + "Invalid vector element type"); + } return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { return MIRBuilder.buildInstr(SPIRV::OpTypeVector) @@ -1060,6 +1066,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( return Width == 1 ? getOpTypeBool(MIRBuilder) : getOpTypeInt(Width, MIRBuilder, false); } + if (Ty->isFloatingPointTy()) return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); if (Ty->isVoidTy()) @@ -1088,11 +1095,12 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR)); return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); } - unsigned AddrSpace = typeToAddressSpace(Ty); SPIRVType *SpvElementType = nullptr; if (Type *ElemTy = ::getPointeeType(Ty)) SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR); + else if (Type *ElemTy = this->findPointerToBaseTypeMap(Ty)) + SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR); else SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 89599f17ef737..24cc62b3b69a0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -78,7 +78,8 @@ class SPIRVGlobalRegistry { // Holds the maximum ID we have in the module. unsigned Bound; - + /// maps the pointer type to the base type + DenseMap<Type *, Type *> PointerToBaseTypeMap; // Maps values associated with untyped pointers into deduced element types of // untyped pointers. DenseMap<Value *, Type *> DeducedElTys; @@ -635,6 +636,19 @@ class SPIRVGlobalRegistry { void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg); void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg); void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType); + + void addPointerToBaseTypeMap(Type *PTy, Type *BaseTy) { + if(PTy == nullptr) + return; + assert(PTy->isPointerTy() && "PTy must be a pointer type"); + PointerToBaseTypeMap[PTy] = BaseTy; + } + + Type *findPointerToBaseTypeMap(const Type *PTy) { + auto BaseTyIter = PointerToBaseTypeMap.find(PTy); + return BaseTyIter == PointerToBaseTypeMap.end() ? nullptr : BaseTyIter -> second; + } + }; } // end namespace llvm #endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index a8f862271dbab..d16621eff4f3b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -956,3 +956,9 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab "$res = OpAliasScopeDeclINTEL $AliasDomain">; def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops), "$res = OpAliasScopeListDeclINTEL">; + +//SPV_INTEL_masked_gather_scatter +def OpMaskedGatherINTEL: Op<6428, (outs ID:$res) , (ins TYPE:$type, ID:$PtrVector, i32imm:$alignment, ID:$mask, ID:$fillempty), + "$res = OpMaskedGatherINTEL $type $PtrVector $alignment $mask $fillempty">; +def OpMaskedScatterINTEL: Op<6429, (outs) , (ins ID:$inVector, ID:$PtrVector, i32imm:$alignment, ID:$mask), + "OpMaskedScatterINTEL $inVector $PtrVector $alignment $mask">; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index b188f36ca9a9e..1b25470791286 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3192,6 +3192,34 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_discard: { return selectDiscard(ResVReg, ResType, I); } + case Intrinsic::spv_masked_gather: { + Register MemLoc = I.getOperand(2).getReg(); + int32_t Alignment = I.getOperand(3).getImm(); + Register Mask = I.getOperand(4).getReg(); + Register PassThrough = I.getOperand(5).getReg(); + return BuildMI(*(I.getParent()), I, I.getDebugLoc(), + TII.get(SPIRV::OpMaskedGatherINTEL)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(MemLoc) + .addImm(Alignment) + .addUse(Mask) + .addUse(PassThrough) + .constrainAllUses(TII, TRI, RBI); + } + case Intrinsic::spv_masked_scatter: { + Register Value = I.getOperand(1).getReg(); + Register MemLocs = I.getOperand(2).getReg(); + int32_t Alignment = I.getOperand(3).getImm(); + Register Mask = I.getOperand(4).getReg(); + auto MIB = BuildMI(*(I.getParent()), I, I.getDebugLoc(), + TII.get(SPIRV::OpMaskedScatterINTEL)) + .addUse(Value) + .addUse(MemLocs) + .addImm(Alignment) + .addUse(Mask); + return MIB.constrainAllUses(TII, TRI, RBI); + } default: { std::string DiagMsg; raw_string_ostream OS(DiagMsg); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 63894acacbc73..e487cca7decd5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1770,6 +1770,14 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL); break; } + case SPIRV::OpMaskedGatherINTEL: + case SPIRV::OpMaskedScatterINTEL: + if (ST.canUseExtension( + SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter); + Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL); + } + break; default: break; diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index caee778eddbc4..d3ee45b6591a7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -313,6 +313,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>; defm SPV_INTEL_long_composites : ExtensionOperand<117>; defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>; defm SPV_INTEL_fp_max_error : ExtensionOperand<119>; +defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<120>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -513,6 +514,7 @@ defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composi defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>; defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>; defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>; +defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index 0aa214dd354ee..ad26ed3b597db 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -175,8 +175,12 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) { } void SPIRVPassConfig::addIRPasses() { - TargetPassConfig::addIRPasses(); + if (TM.getSubtargetImpl()->canUseExtension( + SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) { + addPass(createSPIRVCodeGenPreparePass(TM)); + } + TargetPassConfig::addIRPasses(); if (TM.getSubtargetImpl()->isVulkanEnv()) { // 1. Simplify loop for subsequent transformations. After this steps, loops // have the following properties: diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll new file mode 100644 index 0000000000000..b10c63b700a5a --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll @@ -0,0 +1,50 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %} + +; CHECK-NOT: Name [[#]] "llvm.masked.gather.v4i32.v4p4" +; CHECK-NOT: Name [[#]] "llvm.masked.scatter.v4i32.v4p4" + +; CHECK-DAG: OpCapability MaskedGatherScatterINTEL +; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter" + +; CHECK-DAG: %[[#TYPEINT:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#TYPEPTRINT:]] = OpTypePointer Generic %[[#TYPEINT]] +; CHECK-DAG: %[[#TYPEVECPTR:]] = OpTypeVector %[[#TYPEPTRINT]] 4 +; CHECK-DAG: %[[#TYPEVECINT:]] = OpTypeVector %[[#TYPEINT]] 4 + +; CHECK-DAG: %[[#CONST4:]] = OpConstant %[[#TYPEINT]] 4 +; CHECK-DAG: %[[#CONST0:]] = OpConstant %[[#TYPEINT]] 0 +; CHECK-DAG: %[[#CONST1:]] = OpConstant %[[#TYPEINT]] 1 +; CHECK-DAG: %[[#TRUE:]] = OpConstantTrue %[[#]] +; CHECK-DAG: %[[#FALSE:]] = OpConstantFalse %[[#]] +; CHECK-DAG: %[[#MASK1:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#FALSE]] %[[#TRUE]] %[[#TRUE]] +; CHECK-DAG: %[[#FILL:]] = OpConstantComposite %[[#]] %[[#CONST4]] %[[#CONST0]] %[[#CONST1]] %[[#CONST0]] +; CHECK-DAG: %[[#MASK2:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] + +; CHECK: %[[#VECGATHER:]] = OpLoad %[[#TYPEVECPTR]] +; CHECK: %[[#VECSCATTER:]] = OpLoad %[[#TYPEVECPTR]] +; CHECK: %[[#GATHER:]] = OpMaskedGatherINTEL %[[#TYPEVECINT]] %[[#VECGATHER]] 4 %[[#MASK1]] %[[#FILL]] +; CHECK: OpMaskedScatterINTEL %[[#GATHER]] %[[#VECSCATTER]] 4 %[[#MASK2]] + +; Function Attrs: nounwind readnone +define spir_kernel void @foo() { +entry: + %arg0 = alloca <4 x ptr addrspace(4)> + %arg1 = alloca <4 x ptr addrspace(4)> + %0 = load <4 x ptr addrspace(4)>, ptr %arg0 + %1 = load <4 x ptr addrspace(4)>, ptr %arg1 + %res = call <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)> %0, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>) + call void @llvm.masked.scatter.v4i32.v4p4(<4 x i32> %res, <4 x ptr addrspace(4)> %1, i32 4, <4 x i1> splat (i1 true)) + ret void +} + +declare <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)>, i32, <4 x i1>, <4 x i32>) + +declare void @llvm.masked.scatter.v4i32.v4p4(<4 x i32>, <4 x ptr addrspace(4)>, i32, <4 x i1>) + +!llvm.module.flags = !{!0} +!opencl.spir.version = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 1, i32 2} + |
Add intrinsic SPV_INTEL_masked_gather_scatter