Skip to content

Conversation

@Hardcode84
Copy link
Contributor

Initial discussion #170016 (comment)

While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-llvm

Author: Ivan Butygin (Hardcode84)

Changes

Initial discussion #170016 (comment)

While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.


Full diff: https://github.com/llvm/llvm-project/pull/170134.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (+17-16)
  • (modified) mlir/include/mlir/Target/LLVM/ModuleToObject.h (+8-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+12-12)
  • (modified) mlir/lib/Target/LLVM/ModuleToObject.cpp (+23-10)
  • (modified) mlir/lib/Target/LLVM/NVVM/Target.cpp (+6-2)
  • (modified) mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp (+40-5)
  • (modified) mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp (+80-3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 139360f8bd3fc..00f885898ffa1 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -55,10 +55,10 @@ class TargetOptions { StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -97,19 +97,20 @@ class TargetOptions { /// Returns the callback invoked with the initial LLVM IR for the device /// module. - function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module /// after linking the device libraries. - function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> + getOptimizedLlvmIRCallback() const; /// Returns the callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> getISACallback() const; + function_ref<LogicalResult(StringRef)> getISACallback() const; /// Returns the default compilation target: `CompilationTarget::Fatbin`. static CompilationTarget getDefaultCompilationTarget(); @@ -127,10 +128,10 @@ class TargetOptions { StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -153,19 +154,19 @@ class TargetOptions { function_ref<SymbolTable *()> getSymbolTableCallback; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: TypeID typeID; diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index 11fea6f0a4443..0edc20cd32620 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -32,10 +32,10 @@ class ModuleToObject { ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features = {}, int optLevel = 3, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); virtual ~ModuleToObject(); /// Returns the operation being serialized. @@ -120,19 +120,19 @@ class ModuleToObject { int optLevel; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: /// The TargetMachine created for the given Triple, if available. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..a813608fdf209 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink, cmdOptions, elfSection, compilationTarget, getSymbolTableCallback, initialLlvmIRCallback, @@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), compilationTarget(compilationTarget), @@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const { return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getInitialLlvmIRCallback() const { return initialLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getLinkedLlvmIRCallback() const { return linkedLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getOptimizedLlvmIRCallback() const { return optimizedLlvmIRCallback; } -function_ref<void(StringRef)> TargetOptions::getISACallback() const { +function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const { return isaCallback; } diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 4098ccc548dc1..d881dda69453b 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -36,10 +36,11 @@ using namespace mlir::LLVM; ModuleToObject::ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features, - int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + int optLevel, + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : module(module), triple(triple), chip(chip), features(features), optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback), linkedLlvmIRCallback(linkedLlvmIRCallback), @@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { } setDataLayoutAndTriple(*llvmModule); - if (initialLlvmIRCallback) - initialLlvmIRCallback(*llvmModule); + if (initialLlvmIRCallback) { + if (failed(initialLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "InitialLLVMIRCallback failed."; + return std::nullopt; + } + } // Link bitcode files. handleModulePreLink(*llvmModule); @@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { handleModulePostLink(*llvmModule); } - if (linkedLlvmIRCallback) - linkedLlvmIRCallback(*llvmModule); + if (linkedLlvmIRCallback) { + if (failed(linkedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "LinkedLLVMIRCallback failed."; + return std::nullopt; + } + } // Optimize the module. if (failed(optimizeModule(*llvmModule, optLevel))) return std::nullopt; - if (optimizedLlvmIRCallback) - optimizedLlvmIRCallback(*llvmModule); + if (optimizedLlvmIRCallback) { + if (failed(optimizedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "OptimizedLLVMIRCallback failed."; + return std::nullopt; + } + } // Return the serialized object. return moduleToObject(*llvmModule); diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 8760ea8588e2c..cbd6a6d878813 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { return std::nullopt; } - if (isaCallback) - isaCallback(serializedISA.value()); + if (isaCallback) { + if (failed(isaCallback(serializedISA.value()))) { + getOperation().emitError() << "ISACallback failed."; + return std::nullopt; + } + } #define DEBUG_TYPE "serialize-to-isa" LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n" diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index af0af89c7d07e..1692c4490e4d1 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM, ASSERT_TRUE(!!serializer); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string isaResult; - auto isaCallback = [&isaResult](llvm::StringRef isa) { + auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult { isaResult = isa.str(); + return success(); }; gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, @@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM, } } +// Test callback functions failure with ISA. +TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + + auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); + ASSERT_TRUE(!!serializer); + + auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, + {}, {}, {}, {}, isaCallback); + + for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { + std::optional<SmallVector<char, 0>> object = + serializer.serializeToObject(gpuModule, options); + + ASSERT_TRUE(object == std::nullopt); + } +} + // Test linking LLVM IR from a resource attribute. TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { MLIRContext context(registry); @@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { // Hook to intercept the LLVM IR after linking external libs. std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; // Store the bitcode as a DenseI8ArrayAttr. diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 3c880edee4ffc..b392065132787 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM, auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM, ASSERT_TRUE(!serializedBinary->empty()); ASSERT_TRUE(!optimizedLLVMIR.empty()); } + +// Test callback function failure with initial LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, initialCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with linked LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, linkedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with optimized LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, {}, optimizedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} 
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Initial discussion #170016 (comment)

While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.


Full diff: https://github.com/llvm/llvm-project/pull/170134.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (+17-16)
  • (modified) mlir/include/mlir/Target/LLVM/ModuleToObject.h (+8-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+12-12)
  • (modified) mlir/lib/Target/LLVM/ModuleToObject.cpp (+23-10)
  • (modified) mlir/lib/Target/LLVM/NVVM/Target.cpp (+6-2)
  • (modified) mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp (+40-5)
  • (modified) mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp (+80-3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 139360f8bd3fc..00f885898ffa1 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -55,10 +55,10 @@ class TargetOptions { StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -97,19 +97,20 @@ class TargetOptions { /// Returns the callback invoked with the initial LLVM IR for the device /// module. - function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module /// after linking the device libraries. - function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const; + function_ref<LogicalResult(llvm::Module &)> + getOptimizedLlvmIRCallback() const; /// Returns the callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> getISACallback() const; + function_ref<LogicalResult(StringRef)> getISACallback() const; /// Returns the default compilation target: `CompilationTarget::Fatbin`. static CompilationTarget getDefaultCompilationTarget(); @@ -127,10 +128,10 @@ class TargetOptions { StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref<SymbolTable *()> getSymbolTableCallback = {}, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -153,19 +154,19 @@ class TargetOptions { function_ref<SymbolTable *()> getSymbolTableCallback; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: TypeID typeID; diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index 11fea6f0a4443..0edc20cd32620 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -32,10 +32,10 @@ class ModuleToObject { ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features = {}, int optLevel = 3, - function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, - function_ref<void(StringRef)> isaCallback = {}); + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {}, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {}, + function_ref<LogicalResult(StringRef)> isaCallback = {}); virtual ~ModuleToObject(); /// Returns the operation being serialized. @@ -120,19 +120,19 @@ class ModuleToObject { int optLevel; /// Callback invoked with the initial LLVM IR for the device module. - function_ref<void(llvm::Module &)> initialLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref<void(llvm::Module &)> linkedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref<void(StringRef)> isaCallback; + function_ref<LogicalResult(StringRef)> isaCallback; private: /// The TargetMachine created for the given Triple, if available. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..a813608fdf209 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink, cmdOptions, elfSection, compilationTarget, getSymbolTableCallback, initialLlvmIRCallback, @@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref<SymbolTable *()> getSymbolTableCallback, - function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), compilationTarget(compilationTarget), @@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const { return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getInitialLlvmIRCallback() const { return initialLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getLinkedLlvmIRCallback() const { return linkedLlvmIRCallback; } -function_ref<void(llvm::Module &)> +function_ref<LogicalResult(llvm::Module &)> TargetOptions::getOptimizedLlvmIRCallback() const { return optimizedLlvmIRCallback; } -function_ref<void(StringRef)> TargetOptions::getISACallback() const { +function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const { return isaCallback; } diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 4098ccc548dc1..d881dda69453b 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -36,10 +36,11 @@ using namespace mlir::LLVM; ModuleToObject::ModuleToObject( Operation &module, StringRef triple, StringRef chip, StringRef features, - int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback, - function_ref<void(llvm::Module &)> linkedLlvmIRCallback, - function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, - function_ref<void(StringRef)> isaCallback) + int optLevel, + function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback, + function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback, + function_ref<LogicalResult(StringRef)> isaCallback) : module(module), triple(triple), chip(chip), features(features), optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback), linkedLlvmIRCallback(linkedLlvmIRCallback), @@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { } setDataLayoutAndTriple(*llvmModule); - if (initialLlvmIRCallback) - initialLlvmIRCallback(*llvmModule); + if (initialLlvmIRCallback) { + if (failed(initialLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "InitialLLVMIRCallback failed."; + return std::nullopt; + } + } // Link bitcode files. handleModulePreLink(*llvmModule); @@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() { handleModulePostLink(*llvmModule); } - if (linkedLlvmIRCallback) - linkedLlvmIRCallback(*llvmModule); + if (linkedLlvmIRCallback) { + if (failed(linkedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "LinkedLLVMIRCallback failed."; + return std::nullopt; + } + } // Optimize the module. if (failed(optimizeModule(*llvmModule, optLevel))) return std::nullopt; - if (optimizedLlvmIRCallback) - optimizedLlvmIRCallback(*llvmModule); + if (optimizedLlvmIRCallback) { + if (failed(optimizedLlvmIRCallback(*llvmModule))) { + getOperation().emitError() << "OptimizedLLVMIRCallback failed."; + return std::nullopt; + } + } // Return the serialized object. return moduleToObject(*llvmModule); diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 8760ea8588e2c..cbd6a6d878813 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { return std::nullopt; } - if (isaCallback) - isaCallback(serializedISA.value()); + if (isaCallback) { + if (failed(isaCallback(serializedISA.value()))) { + getOperation().emitError() << "ISACallback failed."; + return std::nullopt; + } + } #define DEBUG_TYPE "serialize-to-isa" LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n" diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index af0af89c7d07e..1692c4490e4d1 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM, ASSERT_TRUE(!!serializer); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string isaResult; - auto isaCallback = [&isaResult](llvm::StringRef isa) { + auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult { isaResult = isa.str(); + return success(); }; gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, @@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM, } } +// Test callback functions failure with ISA. +TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + + auto serializer = dyn_cast<gpu::TargetAttrInterface>(target); + ASSERT_TRUE(!!serializer); + + auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, + {}, {}, {}, {}, isaCallback); + + for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) { + std::optional<SmallVector<char, 0>> object = + serializer.serializeToObject(gpuModule, options); + + ASSERT_TRUE(object == std::nullopt); + } +} + // Test linking LLVM IR from a resource attribute. TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { MLIRContext context(registry); @@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { // Hook to intercept the LLVM IR after linking external libs. std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; // Store the bitcode as a DenseI8ArrayAttr. diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 3c880edee4ffc..b392065132787 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) { auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM, auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM, ASSERT_TRUE(!serializedBinary->empty()); ASSERT_TRUE(!optimizedLLVMIR.empty()); } + +// Test callback function failure with initial LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, initialCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with linked LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, linkedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with optimized LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef<ModuleOp> module = + parseSourceString<ModuleOp>(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target); + + auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult { + return failure(); + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, {}, optimizedCallback); + std::optional<SmallVector<char, 0>> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} 
Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general LGTM, the only change I would make is:

  • Passing a Location as well to the callbacks
  • Let the callback handle the error message, so no getOperation().emitError() << "ISACallback failed.";

I'd argue that this is already a wanted change, as I could have a custom out of tree LLVM pass running at any of the LLVM callback levels that could fail.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a diagnostic issue here: the callbacks don't have any information needed to diagnose (no MLIRContext, no operation to attach the diagnostic to...).

The use-case for this looks quite hand-wavy to me to justify adding more complexity to correctly handle all this though.

@Hardcode84
Copy link
Contributor Author

There is a diagnostic issue here: the callbacks don't have any information needed to diagnose (no MLIRContext, no operation to attach the diagnostic to...).

I can add source op arg to all the callbacks.

The use-case for this looks quite hand-wavy to me to justify adding more complexity to correctly handle all this though.

The one potential use case I considered is something similar to Triton TRITON_OVERRIDE_DIR https://github.com/triton-lang/triton/blob/main/README.md#tips-for-hacking, where you can override an IR from file on specific step for debugging. You need to properly propagate fs/parsing errors in this case obviously.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment