- Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][gpu] Propagate errors from ModuleToObject callbacks #170134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| @llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-llvm Author: Ivan Butygin (Hardcode84) ChangesInitial discussion #170016 (comment) While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail. Full diff: https://github.com/llvm/llvm-project/pull/170134.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 139360f8bd3fc..00f885898ffa1 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -55,10 +55,10 @@ class TargetOptions { StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -97,19 +97,20 @@ class TargetOptions { /// Returns the callback invoked with the initial LLVM IR for the device /// module. - function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module /// after linking the device libraries. - function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> + getOptimizedLlvmIRCallback() const; /// Returns the callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> getISACallback() const; + function_ref<LogicalResult(StringRef)> getISACallback() const; /// Returns the default compilation target: `CompilationTarget::Fatbin`. static CompilationTarget getDefaultCompilationTarget(); @@ -127,10 +128,10 @@ class TargetOptions { StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -153,19 +154,19 @@ class TargetOptions { function_ref<SymbolTable *()> getSymbolTableCallback; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: TypeID typeID; diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index 11fea6f0a4443..0edc20cd32620 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -32,10 +32,10 @@ class ModuleToObject { ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features = {}, int optLevel = 3, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); virtual ~ModuleToObject(); /// Returns the operation being serialized. @@ -120,19 +120,19 @@ class ModuleToObject { int optLevel; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: /// The TargetMachine created for the given Triple, if available. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..a813608fdf209 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink, cmdOptions, elfSection, compilationTarget, getSymbolTableCallback, initialLlvmIRCallback, @@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), compilationTarget(compilationTarget), @@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const { return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getInitialLlvmIRCallback() const { return initialLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getLinkedLlvmIRCallback() const { return linkedLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getOptimizedLlvmIRCallback() const { return optimizedLlvmIRCallback; } -function_ref<void(StringRef)> TargetOptions::getISACallback() const { +function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const { return isaCallback; } diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 4098ccc548dc1..d881dda69453b 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -36,10 +36,11 @@ using namespace mlir::LLVM; ModuleToObject::ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features, - int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + int optLevel, + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : module(module), triple(triple), chip(chip), features(features), optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback), linkedLlvmIRCallback(linkedLlvmIRCallback), @@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { } setDataLayoutAndTriple(*llvmModule); - if (initialLlvmIRCallback) - initialLlvmIRCallback(*llvmModule); + if (initialLlvmIRCallback) { + if (failed(initialLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "InitialLLVMIRCallback failed."; + return std::nullopt; + } + } // Link bitcode files. handleModulePreLink(*llvmModule); @@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { handleModulePostLink(*llvmModule); } - if (linkedLlvmIRCallback) - linkedLlvmIRCallback(*llvmModule); + if (linkedLlvmIRCallback) { + if (failed(linkedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "LinkedLLVMIRCallback failed."; + return std::nullopt; + } + } // Optimize the module. if (failed(optimizeModule(*llvmModule, optLevel))) return std::nullopt; - if (optimizedLlvmIRCallback) - optimizedLlvmIRCallback(*llvmModule); + if (optimizedLlvmIRCallback) { + if (failed(optimizedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "OptimizedLLVMIRCallback failed."; + return std::nullopt; + } + } // Return the serialized object. return moduleToObject(*llvmModule); diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 8760ea8588e2c..cbd6a6d878813 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { return std::nullopt; } - if (isaCallback) - isaCallback(serializedISA.value()); + if (isaCallback) { + if (failed(isaCallback(serializedISA.value()))) { + getOperation().emitError() << "ISACallback failed."; + return std::nullopt; + } + } #define DEBUG_TYPE "serialize-to-isa" LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n" diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index af0af89c7d07e..1692c4490e4d1 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM, ASSERT_TRUE(!!serializer); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string isaResult; - auto isaCallback = [&isaResult](llvm::StringRef isa) { + auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult { isaResult = isa.str(); + return success(); }; gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, @@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM, } } +// Test callback functions failure with ISA. +TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + + auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); + ASSERT_TRUE(!!serializer); + + auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, + {}, {}, {}, {}, isaCallback); + + for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { + std::optional<SmallVector<char, 0>> object = + serializer.serializeToObject(gpuModule, options); + + ASSERT_TRUE(object == std::nullopt); + } +} + // Test linking LLVM IR from a resource attribute. TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { MLIRContext context(registry); @@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { // Hook to intercept the LLVM IR after linking external libs. std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; // Store the bitcode as a DenseI8ArrayAttr. diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 3c880edee4ffc..b392065132787 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM, auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM, ASSERT_TRUE(!serializedBinary->empty()); ASSERT_TRUE(!optimizedLLVMIR.empty()); } + +// Test callback function failure with initial LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, initialCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with linked LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, linkedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with optimized LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, {}, optimizedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} |
| @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesInitial discussion #170016 (comment) While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail. Full diff: https://github.com/llvm/llvm-project/pull/170134.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 139360f8bd3fc..00f885898ffa1 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -55,10 +55,10 @@ class TargetOptions { StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -97,19 +97,20 @@ class TargetOptions { /// Returns the callback invoked with the initial LLVM IR for the device /// module. - function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module /// after linking the device libraries. - function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> + getOptimizedLlvmIRCallback() const; /// Returns the callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> getISACallback() const; + function_ref<LogicalResult(StringRef)> getISACallback() const; /// Returns the default compilation target: `CompilationTarget::Fatbin`. static CompilationTarget getDefaultCompilationTarget(); @@ -127,10 +128,10 @@ class TargetOptions { StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -153,19 +154,19 @@ class TargetOptions { function_ref<SymbolTable *()> getSymbolTableCallback; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: TypeID typeID; diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index 11fea6f0a4443..0edc20cd32620 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -32,10 +32,10 @@ class ModuleToObject { ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features = {}, int optLevel = 3, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); virtual ~ModuleToObject(); /// Returns the operation being serialized. @@ -120,19 +120,19 @@ class ModuleToObject { int optLevel; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: /// The TargetMachine created for the given Triple, if available. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..a813608fdf209 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink, cmdOptions, elfSection, compilationTarget, getSymbolTableCallback, initialLlvmIRCallback, @@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), compilationTarget(compilationTarget), @@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const { return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getInitialLlvmIRCallback() const { return initialLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getLinkedLlvmIRCallback() const { return linkedLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getOptimizedLlvmIRCallback() const { return optimizedLlvmIRCallback; } -function_ref<void(StringRef)> TargetOptions::getISACallback() const { +function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const { return isaCallback; } diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 4098ccc548dc1..d881dda69453b 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -36,10 +36,11 @@ using namespace mlir::LLVM; ModuleToObject::ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features, - int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + int optLevel, + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : module(module), triple(triple), chip(chip), features(features), optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback), linkedLlvmIRCallback(linkedLlvmIRCallback), @@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { } setDataLayoutAndTriple(*llvmModule); - if (initialLlvmIRCallback) - initialLlvmIRCallback(*llvmModule); + if (initialLlvmIRCallback) { + if (failed(initialLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "InitialLLVMIRCallback failed."; + return std::nullopt; + } + } // Link bitcode files. handleModulePreLink(*llvmModule); @@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { handleModulePostLink(*llvmModule); } - if (linkedLlvmIRCallback) - linkedLlvmIRCallback(*llvmModule); + if (linkedLlvmIRCallback) { + if (failed(linkedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "LinkedLLVMIRCallback failed."; + return std::nullopt; + } + } // Optimize the module. if (failed(optimizeModule(*llvmModule, optLevel))) return std::nullopt; - if (optimizedLlvmIRCallback) - optimizedLlvmIRCallback(*llvmModule); + if (optimizedLlvmIRCallback) { + if (failed(optimizedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "OptimizedLLVMIRCallback failed."; + return std::nullopt; + } + } // Return the serialized object. return moduleToObject(*llvmModule); diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 8760ea8588e2c..cbd6a6d878813 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { return std::nullopt; } - if (isaCallback) - isaCallback(serializedISA.value()); + if (isaCallback) { + if (failed(isaCallback(serializedISA.value()))) { + getOperation().emitError() << "ISACallback failed."; + return std::nullopt; + } + } #define DEBUG_TYPE "serialize-to-isa" LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n" diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index af0af89c7d07e..1692c4490e4d1 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM, ASSERT_TRUE(!!serializer); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string isaResult; - auto isaCallback = [&isaResult](llvm::StringRef isa) { + auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult { isaResult = isa.str(); + return success(); }; gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, @@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM, } } +// Test callback functions failure with ISA. +TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + + auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); + ASSERT_TRUE(!!serializer); + + auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, + {}, {}, {}, {}, isaCallback); + + for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { + std::optional<SmallVector<char, 0>> object = + serializer.serializeToObject(gpuModule, options); + + ASSERT_TRUE(object == std::nullopt); + } +} + // Test linking LLVM IR from a resource attribute. TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { MLIRContext context(registry); @@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { // Hook to intercept the LLVM IR after linking external libs. std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; // Store the bitcode as a DenseI8ArrayAttr. diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 3c880edee4ffc..b392065132787 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM, auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM, ASSERT_TRUE(!serializedBinary->empty()); ASSERT_TRUE(!optimizedLLVMIR.empty()); } + +// Test callback function failure with initial LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, initialCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with linked LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, linkedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with optimized LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, {}, optimizedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} |
fabianmcg left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general LGTM, the only change I would make is:
- Passing a Location as well to the callbacks
- Let the callback handle the error message, so no
getOperation().emitError() << "ISACallback failed.";
I'd argue that this is already a wanted change, as I could have a custom out of tree LLVM pass running at any of the LLVM callback levels that could fail.
joker-eph left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a diagnostic issue here: the callbacks don't have any information needed to diagnose (no MLIRContext, no operation to attach the diagnostic to...).
The use-case for this looks quite hand-wavy to me to justify adding more complexity to correctly handle all this though.
I can add source op arg to all the callbacks.
The one potential use case I considered is something similar to Triton |
Initial discussion #170016 (comment)
While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.