Skip to content

[mlir][ods] Add documentation on how to use sharded op definitions (NFC) #89664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
add_subdirectory(tools/mlir-tblgen)
add_subdirectory(tools/mlir-src-sharder)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")

add_subdirectory(include/mlir)
add_subdirectory(lib)
Expand Down
38 changes: 38 additions & 0 deletions mlir/cmake/modules/AddMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
tablegen(MLIR ${ARGV})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)

# Get the current set of include paths for this td file.
cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
# Filter out any empty include items.
list(REMOVE_ITEM tblgen_includes "")

# Build the absolute path for the current input file.
if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
else()
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
endif()

# Append the includes used for this file to the tablegen_compile_commands
# file.
file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
"--- !FileInfo:\n"
" filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
" includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
)
endfunction()

# Clear out any pre-existing compile_commands file before processing. This
Expand Down Expand Up @@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()

# Declare sharded dialect operation declarations and definitions
function(add_sharded_ops ops_target shard_count)
set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
foreach(index RANGE ${shard_count})
set(SHARDED_SRC ${ops_target}.${index}.cpp)
list(APPEND SHARDED_SRCS ${SHARDED_SRC})
tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
endforeach()
add_public_tablegen_target(MLIR${ops_target}ShardGen)
set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
endfunction()

# Declare a dialect in the include directory
function(add_mlir_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
Expand Down
2 changes: 2 additions & 0 deletions mlir/cmake/modules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# Refer to the best host mlir-tbgen, which might be a host-optimized version
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
Expand Down Expand Up @@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
Expand Down
1 change: 1 addition & 0 deletions mlir/cmake/modules/MLIRConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
Expand Down
94 changes: 94 additions & 0 deletions mlir/docs/DefiningDialects/Operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,100 @@ void process(AddOp op, ArrayRef<Value> newOperands) {
}
```

#### Sharded Operation Definitions

Large dialects with many operations may struggle with C++ compile time of
generated op definitions, due to large compilation units. `mlir-tblgen`
provides the ability to shard op definitions by splitting them up evenly
by passing `-op-shard-count` to `-gen-op-defs` and `-gen-op-decls`. The tool
will generate a single include file for the definitions broken up by
`GET_OP_DEFS_${N}` where `${N}` is the shard number. A shard can be compiled in
a single compilation unit by adding a file like this to your dialect library:

```c++
#include "mlir/IR/Operation.h"
// Add any other required includes.

// Utilities shared by generated op definitions: custom directive parsers,
// printers, etc.
#include "OpUtils.h"

#define GET_OP_DEFS_0
#include "MyDialectOps.cpp.inc"
```

Note: this requires restructing shared utility functions within the dialect
library so they can be shared by multiple compilation units. I.e. instead of
defining `static` methods in the same source file, you should declare them in a
shared header and define them in their own source file.

The op registration hooks are also sharded, because the template instantiation
can take a very long time to compile. Operations should be registered in your
dialect like:

```c++
void MyDialect::initialize() {
registerMyDialectOperations(this);
}
```

CMake and Bazel functions are included to make sharding dialects easier.
Assuming you have organized your operation utility functions into their own
header, define a file that looks like the one above, but without the `#define`:

```c++
// MyDialectOps.cpp
#include "mlir/IR/Operation.h"

#include "OpUtils.h"

#include "MyDialectOps.cpp.inc"
```

In CMake, remove the manual `mlir_tablegen` invocations and replace them with:

```cmake
set(LLVM_TARGET_DEFINITIONS MyDialectOps.td)
add_sharded_ops(MyDialectOps 8) # shard the op definitions by 8

add_mlir_library(MyDialect
MyDialect.cpp
MyDialectOpDefs.cpp
${SHARDED_SRCS}

DEPENDS
MLIRTestOpsShardGen
)
```

This will automatically duplicate the `MyDialectOps.cpp` source file and add the
`#define` up the number of shards indicated.

It is recommended that any out-of-line op member functions (like verifiers) be
defined in a separate source file. In this example, it is called
`MyDialectOpDefs.cpp`.

In Bazel, remove the `-gen-op-defs` and `-gen-op-decls` invocations, and add

```bazel
gentbl_sharded_ops(
name = "MyDialectOpSrcs",
hdr_out = "MyDialectOps.h.inc",
shard_count = 8,
sharder = "//mlir:mlir-src-sharder",
src_file = "MyDialectOps.cpp",
src_out = "MyDialectOps.cpp.inc",
tblgen = "//mlir:mlir-tblgen",
td_file = "MyDialectOps.td",
deps = [":MyDialectOpsTdFiles"],
)

cc_library(
name = "MyDialect",
srcs = glob(["MyDialect/*.cpp"]) + [":MyDialectOpSrcs"]
)
```

## Constraints

Constraint is a core concept in table-driven operation definition: operation
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/TableGen/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,22 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
/// Create a constraint uniquer with a unique prefix derived from the record
/// keeper with an optional tag.
StaticVerifierFunctionEmitter(raw_ostream &os,
const llvm::RecordKeeper &records);
const llvm::RecordKeeper &records,
StringRef tag = "");

/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);

/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
/// the generated file.
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);

/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
Expand Down Expand Up @@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
/// Emit pattern constraints.
void emitPatternConstraints();

/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all pattern constraints.
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);

Expand Down
15 changes: 6 additions & 9 deletions mlir/lib/TableGen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ using namespace mlir::tblgen;

/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
StringRef tag) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();

Expand All @@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
nameRef.consume_back(".td");

// Sanitize any invalid characters.
std::string uniqueName;
std::string uniqueName(tag);
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
Expand All @@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
}

StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const llvm::RecordKeeper &records)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}

void StaticVerifierFunctionEmitter::emitOpConstraints(
ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
collectOpConstraints(opDefs);
if (emitDecl)
return;

ArrayRef<llvm::Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/lib/Dialect/Test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTestEnumDefIncGen)

set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
Expand All @@ -43,6 +41,8 @@ mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls)
mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)

add_sharded_ops(TestOps 20)

# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
Expand All @@ -56,6 +56,7 @@ add_mlir_library(MLIRTestDialect
TestTypes.cpp
TestOpsSyntax.cpp
TestDialectInterfaces.cpp
${SHARDED_SRCS}

EXCLUDE_FROM_LIBMLIR

Expand All @@ -66,6 +67,7 @@ add_mlir_library(MLIRTestDialect
MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
MLIRTestOpsSyntaxIncGen
MLIRTestOpsShardGen

LINK_LIBS PUBLIC
MLIRControlFlowInterfaces
Expand Down
5 changes: 1 addition & 4 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,9 @@ struct TestOpEffectInterfaceFallback
void TestDialect::initialize() {
registerAttributes();
registerTypes();
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
registerOpsSyntax();
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
Expand Down
1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/Test/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@
using namespace mlir;
using namespace test;

#define GET_OP_CLASSES
#include "TestOps.cpp.inc"
33 changes: 33 additions & 0 deletions mlir/test/mlir-tblgen/shard-op-defs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS

include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "test";
}

class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

def OpA : Test_Op<"a">;
def OpB : Test_Op<"b">;
def OpC : Test_Op<"c">;

// DECLS: OpA
// DECLS: OpB
// DECLS: OpC
// DECLS: registerTestDialectOperations(
// DECLS: registerTestDialectOperations0(
// DECLS: registerTestDialectOperations1(

// DEFS-LABEL: GET_OP_DEFS_0
// DEFS: void test::registerTestDialectOperations(
// DEFS: void test::registerTestDialectOperations0(
// DEFS: OpAAdaptor
// DEFS: OpBAdaptor

// DEFS-LABEL: GET_OP_DEFS_1
// DEFS: void test::registerTestDialectOperations1(
// DEFS: OpCAdaptor
14 changes: 14 additions & 0 deletions mlir/tools/mlir-src-sharder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
set(LLVM_LINK_COMPONENTS Support)
set(LIBS MLIRSupport)

add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
mlir-src-sharder.cpp

DEPENDS
${LIBS}
)

set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})

mlir_check_all_link_libraries(mlir-src-sharder)
Loading