- Notifications
You must be signed in to change notification settings - Fork 15.3k
[SPIR-V] Fix tracking ptr null constants of builtin types in calls with mangling #94263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SPIR-V] Fix tracking ptr null constants of builtin types in calls with mangling #94263
Conversation
Before this change, the method for deducing function argument types assumed that any argument of untyped pointer type must be either: 1) A pointer of an LLVM IR element type, passed byval/byref. 2) An OpenCL/SPIR-V builtin type if there is spv_assign_type intrinsic assigning a TargetExtType. 3) Just a pointer (with default size) This does not take into consideration builtin functions which might also have arguments of OpenCL/SPIR-V builtin type. Since builtins have just their prototypes inside a module (no body), no spv_assign_type intrinsics are generared for their arguments. Hence, a fourth option: 4) An OpenCL/SPIR-V builtin type if the mangled function name contains type information. A test mimicking SPIR-V Translator behavior was added.
…th mangling This change makes sure that ptr null (constant) arguments of builtin function calls are assigned proper builtin types if such are deduced from mangled names. Two tests demonstrating the expected bahavior (as in the SPIR-V Translator) are added. WIP: - The test builtin-call-multiple-ptr-null-args-one-of-builtin- type.ll is failing and requires additional work. - processInstrAfterVisit method must be simplified. TrackConstants can be removed.
| @llvm/pr-subscribers-backend-spir-v Author: Michal Paszkowski (michalpaszkowski) ChangesThis change makes sure that ptr null (constant) arguments of builtin Two tests demonstrating the expected bahavior (as in the SPIR-V WIP:
Full diff: https://github.com/llvm/llvm-project/pull/94263.diff 8 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index f4daab7d06eb5..7d4d9801c7ce2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -211,12 +211,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); } - // In case OriginalArgType is of untyped pointer type, there are three + // In case OriginalArgType is of untyped pointer type, there are four // possibilities: // 1) This is a pointer of an LLVM IR element type, passed byval/byref. // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type // intrinsic assigning a TargetExtType. - // 3) This is a pointer, try to retrieve pointer element type from a + // 3) This is an OpenCL/SPIR-V builtin type if the mangled function name + // contains type information (the Arg's function is a builtin, has no + // body). + // 4) This is a pointer, try to retrieve pointer element type from a // spv_assign_ptr_type intrinsic or otherwise use default pointer element // type. if (hasPointeeTypeAttr(Arg)) { @@ -255,6 +258,14 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST)); } + std::string DemangledFuncName = demangleBuiltinCall(F.getName()); + if (!DemangledFuncName.empty()) { + Type *BuiltinType = SPIRV::parseBuiltinCallArgumentBaseType( + DemangledFuncName, ArgIdx, F.getContext()); + if (BuiltinType && BuiltinType->isTargetExtTy()) + return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual); + } + // Replace PointerType with TypedPointerType to be able to map SPIR-V types to // LLVM types in a consistent manner if (isUntypedPointerTy(OriginalArgType)) { @@ -509,7 +520,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // globally later. if (Info.Callee.isGlobal()) { std::string FuncName = Info.Callee.getGlobal()->getName().str(); - DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); + DemangledName = demangleBuiltinCall(FuncName); CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); // TODO: support constexpr casts and indirect calls. if (CF == nullptr) diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index ffbd1e17bad5e..7e9347fa2fbc9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -851,6 +851,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( I->setOperand(OperandToReplace, PtrCastI); } +static bool inline isDirectNonIntrinsicCall(CallInst *CI) { + return CI && !CI->isIndirectCall() && !CI->isInlineAsm() && + CI->getCalledFunction() && !CI->getCalledFunction()->isIntrinsic(); +} + void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B) { // Handle basic instructions: @@ -874,14 +879,12 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, // Handle calls to builtins (non-intrinsics): CallInst *CI = dyn_cast<CallInst>(I); - if (!CI || CI->isIndirectCall() || CI->isInlineAsm() || - !CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic()) + if (!isDirectNonIntrinsicCall(CI)) return; - // collect information about formal parameter types - std::string DemangledName = - getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName()); + // Collect information about formal parameter types Function *CalledF = CI->getCalledFunction(); + std::string DemangledName = demangleBuiltinCall(CalledF->getName()); SmallVector<Type *, 4> CalledArgTys; bool HaveTypes = false; for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) { @@ -1195,12 +1198,21 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, I->replaceAllUsesWith(NewOp); NewOp->setArgOperand(0, I); } + + auto *CI = dyn_cast<CallInst>(I); + bool IsCall = CI && isDirectNonIntrinsicCall(CI); + std::string DemangledCall = + IsCall ? demangleBuiltinCall(CI->getCalledFunction()->getName()) : ""; + bool IsPhi = isa<PHINode>(I), BPrepared = false; + for (const auto &Op : I->operands()) { if ((isa<ConstantAggregateZero>(Op) && Op->getType()->isVectorTy()) || isa<PHINode>(I) || isa<SwitchInst>(I)) TrackConstants = false; if ((isa<ConstantData>(Op) || isa<ConstantExpr>(Op)) && TrackConstants) { + Constant *OpConst = cast<Constant>(Op); + unsigned OpNo = Op.getOperandNo(); if (II && ((II->getIntrinsicID() == Intrinsic::spv_gep && OpNo == 0) || (II->paramHasAttr(OpNo, Attribute::ImmArg)))) @@ -1210,12 +1222,27 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, : B.SetInsertPoint(I); BPrepared = true; } + Value *OpTyVal = Op; - if (Op->getType()->isTargetExtTy()) + Value *ConstVal = Op; + + if (Op->getType()->isTargetExtTy() || + (Op->getType()->isPointerTy() && !DemangledCall.empty() && + OpConst->isNullValue())) { OpTyVal = PoisonValue::get(Op->getType()); + } + + if (Op->getType()->isPointerTy() && !DemangledCall.empty() && + OpConst->isNullValue()) { + Type *DemangledTy = SPIRV::parseBuiltinCallArgumentBaseType( + DemangledCall, Op.getOperandNo(), I->getContext()); + if (DemangledTy) + ConstVal = Constant::getNullValue(DemangledTy); + } + auto *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant, - {Op->getType(), OpTyVal->getType()}, Op, - OpTyVal, {}, B); + {Op->getType(), OpTyVal->getType()}, + ConstVal, OpTyVal, {}, B); I->setOperand(OpNo, NewOp); } } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 624899600693a..3ce3b9b3b1e64 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -81,13 +81,14 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, } GR->add(Const, &MF, SrcReg); if (Const->getType()->isTargetExtTy()) { - // remember association so that we can restore it when assign types + // Remember association so that we can restore it when assign types. MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT || SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)) TargetExtConstTypes[SrcMI] = Const->getType(); if (Const->isNullValue()) { MachineIRBuilder MIB(MF); + MIB.setInsertPt(*MI.getParent(), MI); SPIRVType *ExtType = GR->getOrCreateSPIRVType(Const->getType(), MIB); SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull)); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index c20f3546a3e55..b8cd83d1ca975 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -332,7 +332,7 @@ static bool isNonMangledOCLBuiltin(StringRef Name) { Name == "__translate_sampler_initializer"; } -std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { +std::string demangleBuiltinCall(StringRef Name) { bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); bool IsNonMangledSPIRV = Name.starts_with("__spirv_"); bool IsNonMangledHLSL = Name.starts_with("__hlsl_"); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 33cb509dc4a59..a539e960f22a2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -90,9 +90,9 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID); // Get type of i-th operand of the metadata node. Type *getMDOperandAsType(const MDNode *N, unsigned I); -// If OpenCL or SPIR-V builtin function name is recognized, return a demangled -// name, otherwise return an empty string. -std::string getOclOrSpirvBuiltinDemangledName(StringRef Name); +// If SPIR-V builtin function name is recognized, return a demangled name, +// otherwise return an empty string. +std::string demangleBuiltinCall(StringRef Name); // Check if a string contains a builtin prefix. bool hasBuiltinTypePrefix(StringRef Name); diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll new file mode 100644 index 0000000000000..103fab2838607 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll @@ -0,0 +1,13 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent +; CHECK-DAG: %[[#EVENT_NULL:]] = OpConstantNull %[[#EVENT]] +; CHECK-DAG: %[[#]] = OpFunctionCall %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#EVENT_NULL]] + +define spir_kernel void @foo() { + %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr null, ptr null, i64 1, i64 1, ptr null) + ret void +} + +declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr) diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll new file mode 100644 index 0000000000000..51d094e398216 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll @@ -0,0 +1,13 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent +; CHECK-DAG: %[[#EVENT_NULL:]] = OpConstantNull %[[#EVENT]] +; CHECK-DAG: %[[#]] = OpFunctionCall %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#EVENT_NULL]] + +define spir_kernel void @foo(ptr %a, ptr %b) { + %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr %a, ptr %b, i64 1, i64 1, ptr null) + ret void +} + +declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr) diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll new file mode 100644 index 0000000000000..ff386affdc4c5 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll @@ -0,0 +1,15 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0 +; CHECK-DAG: %[[#PTR_INT8:]] = OpTypePointer Function %[[#INT8]] +; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent +; CHECK-DAG: %[[#FUNC_TY:]] = OpTypeFunction %[[#]] %[[#PTR_INT8]] %[[#PTR_INT8]] %[[#]] %[[#]] %[[#EVENT]] +; CHECK-DAG: %[[#]] = OpFunction %[[#]] None %[[#FUNC_TY]] + +define spir_kernel void @foo(ptr %a, ptr %b) { + %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr %a, ptr %b, i64 1, i64 1, ptr null) + ret void +} + +declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr) |
This change makes sure that ptr null (constant) arguments of builtin
function calls are assigned proper builtin types if such are deduced
from mangled names.
Two tests demonstrating the expected bahavior (as in the SPIR-V
Translator) are added.
WIP:
The test builtin-call-multiple-ptr-null-args-one-of-builtin-
type.ll is failing and requires additional work.
processInstrAfterVisit method must be simplified. TrackConstants can
be removed.