Skip to content

Conversation

@Mogball
Copy link
Contributor

@Mogball Mogball commented Apr 19, 2024

Stacked PRs:


[mlir][ods] Allow sharding of op definitions

Adds an option to mlir-tblgen -gen-op-defs op-shard-count=N that divides the
op class definitions and op list into N segments, e.g.

// mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations< >(); addOperations< >(); } 

When split across multiple source files, this can help significantly improve
dialect compile time for dialects with a large opset.

@Mogball Mogball requested a review from rupprecht as a code owner April 19, 2024 17:51
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir bazel "Peripheral" support tier build system: utils/bazel labels Apr 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jeff Niu (Mogball)

Changes

Adds an option to mlir-tblgen -gen-op-defs op-shard-count=N that divides the op class definitions and op list into N segments, e.g.

// mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations&lt; &gt;(); addOperations&lt; &gt;(); } 

When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset.


Patch is 29.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89423.diff

14 Files Affected:

  • (modified) mlir/CMakeLists.txt (+3)
  • (modified) mlir/cmake/modules/AddMLIR.cmake (+38)
  • (modified) mlir/cmake/modules/CMakeLists.txt (+2)
  • (modified) mlir/cmake/modules/MLIRConfig.cmake.in (+1)
  • (modified) mlir/include/mlir/TableGen/CodeGenHelpers.h (+8-4)
  • (modified) mlir/lib/TableGen/CodeGenHelpers.cpp (+6-9)
  • (added) mlir/test/mlir-tblgen/shard-op-defs.td (+33)
  • (added) mlir/tools/mlir-src-sharder/CMakeLists.txt (+14)
  • (added) mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp (+114)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+129-35)
  • (modified) mlir/tools/mlir-tblgen/OpGenHelpers.cpp (+24-1)
  • (modified) mlir/tools/mlir-tblgen/OpGenHelpers.h (+5)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+9)
  • (modified) utils/bazel/llvm-project-overlay/mlir/tblgen.bzl (+133)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 5c4301af040b47..4c0ef8387b8dff 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR}) add_subdirectory(tools/mlir-linalg-ods-gen) add_subdirectory(tools/mlir-pdll) add_subdirectory(tools/mlir-tblgen) +add_subdirectory(tools/mlir-src-sharder) set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "") set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "") set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "") set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "") +set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "") +set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "") add_subdirectory(include/mlir) add_subdirectory(lib) diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index 1d2ed748bc2f13..afb74fb2d00025 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -5,6 +5,28 @@ function(mlir_tablegen ofn) tablegen(MLIR ${ARGV}) set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} PARENT_SCOPE) + + # Get the current set of include paths for this td file. + cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN}) + get_directory_property(tblgen_includes INCLUDE_DIRECTORIES) + list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES}) + # Filter out any empty include items. + list(REMOVE_ITEM tblgen_includes "") + + # Build the absolute path for the current input file. + if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS}) + set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS}) + else() + set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS}) + endif() + + # Append the includes used for this file to the tablegen_compile_commands + # file. + file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml + "--- !FileInfo:\n" + " filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n" + " includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n" + ) endfunction() # Clear out any pre-existing compile_commands file before processing. This @@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace) add_dependencies(mlir-headers MLIR${dialect}IncGen) endfunction() +# Declare sharded dialect operation declarations and definitions +function(add_sharded_ops ops_target shard_count) + set(LLVM_TARGET_DEFINITIONS ${ops_target}.td) + mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count}) + mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count}) + set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp) + foreach(index RANGE ${shard_count}) + set(SHARDED_SRC ${ops_target}.${index}.cpp) + list(APPEND SHARDED_SRCS ${SHARDED_SRC}) + tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index}) + set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC}) + endforeach() + add_public_tablegen_target(MLIR${ops_target}ShardGen) + set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE) +endfunction() + # Declare a dialect in the include directory function(add_mlir_interface interface) set(LLVM_TARGET_DEFINITIONS ${interface}.td) diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt index 8d2904ef46dfe8..3ac1c79b090ed6 100644 --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS # Refer to the best host mlir-tbgen, which might be a host-optimized version set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}") set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}") +set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}") configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in @@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS # if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN). set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen) set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll) +set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder) configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in index d4da3cd98cce98..7076d94a32f2bc 100644 --- a/mlir/cmake/modules/MLIRConfig.cmake.in +++ b/mlir/cmake/modules/MLIRConfig.cmake.in @@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@") set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@") set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@") set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@") +set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@") set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@") set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@") set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@") diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h index dd17a44c889bbe..c263c69c53d1e3 100644 --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -99,8 +99,14 @@ class NamespaceEmitter { /// class StaticVerifierFunctionEmitter { public: + /// Create a constraint uniquer with a unique prefix derived from the record + /// keeper with an optional tag. StaticVerifierFunctionEmitter(raw_ostream &os, - const llvm::RecordKeeper &records); + const llvm::RecordKeeper &records, + StringRef tag = ""); + + /// Collect and unique all the constraints used by operations. + void collectOpConstraints(ArrayRef<llvm::Record *> opDefs); /// Collect and unique all compatible type, attribute, successor, and region /// constraints from the operations in the file and emit them at the top of @@ -108,7 +114,7 @@ class StaticVerifierFunctionEmitter { /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl); + void emitOpConstraints(ArrayRef<llvm::Record *> opDefs); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. @@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter { /// Emit pattern constraints. void emitPatternConstraints(); - /// Collect and unique all the constraints used by operations. - void collectOpConstraints(ArrayRef<llvm::Record *> opDefs); /// Collect and unique all pattern constraints. void collectPatternConstraints(ArrayRef<DagLeaf> constraints); diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index d906de6b56afc0..59865146e20bc4 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -24,7 +24,8 @@ using namespace mlir::tblgen; /// Generate a unique label based on the current file name to prevent name /// collisions if multiple generated files are included at once. -static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { +static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records, + StringRef tag) { // Use the input file name when generating a unique name. std::string inputFilename = records.getInputFilename(); @@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { nameRef.consume_back(".td"); // Sanitize any invalid characters. - std::string uniqueName; + std::string uniqueName(tag); for (char c : nameRef) { if (llvm::isAlnum(c) || c == '_') uniqueName.push_back(c); @@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { } StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - raw_ostream &os, const llvm::RecordKeeper &records) - : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} + raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag) + : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef<llvm::Record *> opDefs, bool emitDecl) { - collectOpConstraints(opDefs); - if (emitDecl) - return; - + ArrayRef<llvm::Record *> opDefs) { NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); emitTypeConstraints(); emitAttrConstraints(); diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td new file mode 100644 index 00000000000000..84ac6b0fbe9ebe --- /dev/null +++ b/mlir/test/mlir-tblgen/shard-op-defs.td @@ -0,0 +1,33 @@ +// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS +// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "test"; +} + +class Test_Op<string mnemonic, list<Trait> traits = []>  + : Op<Test_Dialect, mnemonic, traits>; + +def OpA : Test_Op<"a">; +def OpB : Test_Op<"b">; +def OpC : Test_Op<"c">; + +// DECLS: OpA +// DECLS: OpB +// DECLS: OpC +// DECLS: registerTestDialectOperations( +// DECLS: registerTestDialectOperations0( +// DECLS: registerTestDialectOperations1( + +// DEFS-LABEL: GET_OP_DEFS_0 +// DEFS: void test::registerTestDialectOperations( +// DEFS: void test::registerTestDialectOperations0( +// DEFS: OpAAdaptor +// DEFS: OpBAdaptor + +// DEFS-LABEL: GET_OP_DEFS_1 +// DEFS: void test::registerTestDialectOperations1( +// DEFS: OpCAdaptor diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt new file mode 100644 index 00000000000000..4ef870b61124ad --- /dev/null +++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_LINK_COMPONENTS Support) +set(LIBS MLIRSupport) + +add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER + mlir-src-sharder.cpp + + DEPENDS + ${LIBS} + ) + +set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning") +target_link_libraries(mlir-src-sharder PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(mlir-src-sharder) diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp new file mode 100644 index 00000000000000..dc1e2939c7d25b --- /dev/null +++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp @@ -0,0 +1,114 @@ +//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; + +/// Create a dependency file for `-d` option. +/// +/// This functionality is generally only for the benefit of the build system, +/// and is modeled after the same option in TableGen. +static LogicalResult createDependencyFile(StringRef outputFilename, + StringRef dependencyFile) { + if (outputFilename == "-") { + llvm::errs() << "error: the option -d must be used together with -o\n"; + return failure(); + } + + std::string errorMessage; + std::unique_ptr<llvm::ToolOutputFile> outputFile = + openOutputFile(dependencyFile, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + outputFile->os() << outputFilename << ":\n"; + outputFile->keep(); + return success(); +} + +int main(int argc, char **argv) { + // FIXME: This is necessary because we link in TableGen, which defines its + // options as static variables.. some of which overlap with our options. + llvm::cl::ResetCommandLineParser(); + + llvm::cl::opt<unsigned> opShardIndex( + "op-shard-index", llvm::cl::desc("The current shard index")); + llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional, + llvm::cl::desc("<input file>"), + llvm::cl::init("-")); + llvm::cl::opt<std::string> outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + llvm::cl::list<std::string> includeDirs( + "I", llvm::cl::desc("Directory of include files"), + llvm::cl::value_desc("directory"), llvm::cl::Prefix); + llvm::cl::opt<std::string> dependencyFilename( + "d", llvm::cl::desc("Dependency filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("")); + llvm::cl::opt<bool> writeIfChanged( + "write-if-changed", + llvm::cl::desc("Only write to the output file if it changed")); + + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv); + + // Open the input file. + std::string errorMessage; + std::unique_ptr<llvm::MemoryBuffer> inputFile = + openInputFile(inputFilename, &errorMessage); + if (!inputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + // Write the output to a buffer. + std::string outputStr; + llvm::raw_string_ostream os(outputStr); + os << "#define GET_OP_DEFS_" << opShardIndex << "\n" + << inputFile->getBuffer(); + + // Determine whether we need to write the output file. + bool shouldWriteOutput = true; + if (writeIfChanged) { + // Only update the real output file if there are any differences. This + // prevents recompilation of all the files depending on it if there aren't + // any. + if (auto existingOrErr = + llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true)) + if (std::move(existingOrErr.get())->getBuffer() == os.str()) + shouldWriteOutput = false; + } + + // Populate the output file if necessary. + if (shouldWriteOutput) { + std::unique_ptr<llvm::ToolOutputFile> outputFile = + openOutputFile(outputFilename, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + outputFile->os() << os.str(); + outputFile->keep(); + } + + // Always write the depfile, even if the main output hasn't changed. If it's + // missing, Ninja considers the output dirty. + if (!dependencyFilename.empty()) + if (failed(createDependencyFile(outputFilename, dependencyFilename))) + return 1; + + return 0; +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 53ed5cb7c043ec..63fe5a80990746 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef( emitter.adaptor.writeDefTo(os); } -// Emits the opcode enum and op classes. -static void emitOpClasses(const RecordKeeper &recordKeeper, - const std::vector<Record *> &defs, raw_ostream &os, - bool emitDecl) { - // First emit forward declaration for each class, this allows them to refer - // to each others in traits for example. - if (emitDecl) { - os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; - os << "#undef GET_OP_FWD_DEFINES\n"; - for (auto *def : defs) { - Operator op(*def); - NamespaceEmitter emitter(os, op.getCppNamespace()); - os << "class " << op.getCppClassName() << ";\n"; - } - os << "#endif\n\n"; - } - - IfDefScope scope("GET_OP_CLASSES", os); +/// Emit the class declarations or definitions for the given op defs. +static void +emitOpClasses(const RecordKeeper &recordKeeper, + const std::vector<Record *> &defs, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + bool emitDecl) { if (defs.empty()) return; - // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); - os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.emitOpConstraints(defs, emitDecl); - for (auto *def : defs) { Operator op(*def); if (emitDecl) { @@ -4358,34 +4341,145 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, } } -// Emits a comma-separated list of the ops. -static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) { - IfDefScope scope("GET_OP_LIST", os); +/// Emit the declarations for the provided op classes. +static void emitOpClassDecls(const RecordKeeper &recordKeeper, + const std::vector<Record *> &defs, + raw_ostream &os) { + // First emit forward declaration for each class, this allows them to refer + // to each others in traits for example. + for (auto *def : defs) { + Operator op(*def); + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << "class " << op.getCppClassName() << ";\n"; + } + + // Emit the op class declarations. + IfDefScope scope("GET_OP_CLASSES", os); + if (defs.empty()) + return; + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); + staticVerifierEmitter.collectOpConstraints(defs); + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/true); +} + +/// Emit the definitions for the provided op classes. +static void emitOpClassDefs(const RecordKeeper &recordKeeper, + ArrayRef<Record *> defs, raw_ostream &os, + StringRef constraintPrefix = "") { + if (defs.empty()) + return; + + // Generate all of the locally instantiated methods first. + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper, + constraintPrefix); + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + staticVerifierEmitter.collectOpConstraints(defs); + staticVerifierEmitter.emitOpConstraints(defs); - interleave( - // TODO: We are constructing the Operator wrapper instance just for - // getting it's qualified class name here. Reduce the overhead by having a - // lightweight version of Operator class just for that purpose. - defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, - [&os]() { os << ",\n"; }); + // Emit the classes. + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/false); } +/// Emit op declarations for all op records. static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os, recordKeeper); std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); + emitOpClassDecls(recordKeeper, defs, os); + + // If we are generating sharded op definitions, emit the sharded op + // registration hooks. + SmallVector<ArrayRef<Record *>, 4> shardedDefs; + shardOpDefinitions(defs, shardedDefs); + if (defs.empty() || shardedDefs.size() <= 1) + return false; + + Dialect dialect = Operator(defs.front()).getDialect(); + NamespaceEmitter ns(os, dialect); + + const char *const opRegistrationHook = + "void register{0}Operations{1}({2}::{0} *dialect);\n"; + os << formatv(opRegistrationHook, dialect.getCppClassName(), "", + dialect.getCppNamespace()); + for (unsigned i = 0; i < shardedDefs.size(); ++i) { + os << formatv(opRegistrationHook, dialect.getCppClassName(), i, + dialect.getCppNamespace()); + } return false; } +/// Generate the dialect op registration hook and the op class definitions for a +/// shard of ops. +static void emitOpDefShard(const RecordKeeper &recordKeeper, + ... [truncated] 
@joker-eph
Copy link
Collaborator

Can we apply this to the TestDialect as a proof of concept?

@joker-eph
Copy link
Collaborator

Can we apply this to the TestDialect as a proof of concept?

Actually I just see that this is what you do in the follow-up PR, so LGTM here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR src sharded makes me think it's related to . mlir files rather than ODS ones. Did you consider making this a mlir-tblgen "function" (such as attribute gen or doc gen) and then call it 2x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir-tblgen only ingests .td files as records. Do you want me to inject a hook into its main function to sniff the command and change its operating mode?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow: can you expand on the command line issue with mlir-tblgen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir-tblgen calls into the TableGen parser and then calls into a function based on the command line with the parsed records. In order to make it ingest a C++ file (or another kind of file), I have to intercept it in the main function:

Turn this

// Generator that prints records. GenRegistration printRecords("print-records", "Print all records to stdout", [](const RecordKeeper &records, raw_ostream &os) { os << records; return false; }); int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }

Into this:

int main(int argc, char **argv) { if (argv[1] == "shard-src-files") return shardSourceFiles(argc, argv) return MlirTblgenMain(argc, argv); } 
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph ping!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping should be for @jpienaar who started in this direction first :)

I agree though that using mlir-tblgen for non-tablegen file does not seem right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't resolved?

@Mogball Mogball force-pushed the users/mogball/pr_1 branch from 4cb871c to caff32b Compare April 22, 2024 16:48
jollaitbot pushed a commit to sailfishos-mirror/llvm-project that referenced this pull request Apr 22, 2024
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the op class definitions and op list into N segments, e.g. ``` // mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations< >(); addOperations< >(); } ``` When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset. stack-info: PR: llvm/llvm-project#89423, branch: users/mogball/pr_1
@joker-eph
Copy link
Collaborator

Before I forget: we should add documentation for this, including the how to structure the dialect to support it.

@Mogball Mogball force-pushed the users/mogball/pr_1 branch from caff32b to 1ab44e5 Compare April 22, 2024 20:41
@Mogball Mogball changed the base branch from main to users/mogball/pr_2 April 22, 2024 20:41
@Mogball
Copy link
Contributor Author

Mogball commented Apr 22, 2024

Before I forget: we should add documentation for this, including the how to structure the dialect to support it.

Added in #89664

Base automatically changed from users/mogball/pr_2 to main April 22, 2024 20:42
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the op class definitions and op list into N segments, e.g. ``` // mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations< >(); addOperations< >(); } ``` When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset. stack-info: PR: #89423, branch: users/mogball/pr_1
@Mogball Mogball force-pushed the users/mogball/pr_1 branch from 1ab44e5 to f341dcf Compare April 22, 2024 20:42
@Mogball Mogball merged commit 1b232fa into main Apr 24, 2024
@Mogball Mogball deleted the users/mogball/pr_1 branch April 24, 2024 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bazel "Peripheral" support tier build system: utils/bazel mlir:core MLIR Core Infrastructure mlir

5 participants