Skip to content

[mlir][test] Shard the Test Dialect (NFC) #89628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
add_subdirectory(tools/mlir-tblgen)
add_subdirectory(tools/mlir-src-sharder)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")

add_subdirectory(include/mlir)
add_subdirectory(lib)
Expand Down
38 changes: 38 additions & 0 deletions mlir/cmake/modules/AddMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
tablegen(MLIR ${ARGV})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)

# Get the current set of include paths for this td file.
cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
# Filter out any empty include items.
list(REMOVE_ITEM tblgen_includes "")

# Build the absolute path for the current input file.
if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
else()
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
endif()

# Append the includes used for this file to the tablegen_compile_commands
# file.
file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
"--- !FileInfo:\n"
" filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
" includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
)
endfunction()

# Clear out any pre-existing compile_commands file before processing. This
Expand Down Expand Up @@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()

# Declare sharded dialect operation declarations and definitions
function(add_sharded_ops ops_target shard_count)
set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
foreach(index RANGE ${shard_count})
set(SHARDED_SRC ${ops_target}.${index}.cpp)
list(APPEND SHARDED_SRCS ${SHARDED_SRC})
tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
endforeach()
add_public_tablegen_target(MLIR${ops_target}ShardGen)
set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
endfunction()

# Declare a dialect in the include directory
function(add_mlir_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
Expand Down
2 changes: 2 additions & 0 deletions mlir/cmake/modules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# Refer to the best host mlir-tbgen, which might be a host-optimized version
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
Expand Down Expand Up @@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
Expand Down
1 change: 1 addition & 0 deletions mlir/cmake/modules/MLIRConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/TableGen/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,22 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
/// Create a constraint uniquer with a unique prefix derived from the record
/// keeper with an optional tag.
StaticVerifierFunctionEmitter(raw_ostream &os,
const llvm::RecordKeeper &records);
const llvm::RecordKeeper &records,
StringRef tag = "");

/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);

/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
/// the generated file.
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);

/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
Expand Down Expand Up @@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
/// Emit pattern constraints.
void emitPatternConstraints();

/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all pattern constraints.
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);

Expand Down
15 changes: 6 additions & 9 deletions mlir/lib/TableGen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ using namespace mlir::tblgen;

/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
StringRef tag) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();

Expand All @@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
nameRef.consume_back(".td");

// Sanitize any invalid characters.
std::string uniqueName;
std::string uniqueName(tag);
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
Expand All @@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
}

StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const llvm::RecordKeeper &records)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}

void StaticVerifierFunctionEmitter::emitOpConstraints(
ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
collectOpConstraints(opDefs);
if (emitDecl)
return;

ArrayRef<llvm::Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/lib/Dialect/Test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTestEnumDefIncGen)

set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
Expand All @@ -43,6 +41,8 @@ mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls)
mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)

add_sharded_ops(TestOps 20)

# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
Expand All @@ -56,6 +56,7 @@ add_mlir_library(MLIRTestDialect
TestTypes.cpp
TestOpsSyntax.cpp
TestDialectInterfaces.cpp
${SHARDED_SRCS}

EXCLUDE_FROM_LIBMLIR

Expand All @@ -66,6 +67,7 @@ add_mlir_library(MLIRTestDialect
MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
MLIRTestOpsSyntaxIncGen
MLIRTestOpsShardGen

LINK_LIBS PUBLIC
MLIRControlFlowInterfaces
Expand Down
5 changes: 1 addition & 4 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,9 @@ struct TestOpEffectInterfaceFallback
void TestDialect::initialize() {
registerAttributes();
registerTypes();
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
registerOpsSyntax();
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
Expand Down
1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/Test/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@
using namespace mlir;
using namespace test;

#define GET_OP_CLASSES
#include "TestOps.cpp.inc"
33 changes: 33 additions & 0 deletions mlir/test/mlir-tblgen/shard-op-defs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS

include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "test";
}

class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

def OpA : Test_Op<"a">;
def OpB : Test_Op<"b">;
def OpC : Test_Op<"c">;

// DECLS: OpA
// DECLS: OpB
// DECLS: OpC
// DECLS: registerTestDialectOperations(
// DECLS: registerTestDialectOperations0(
// DECLS: registerTestDialectOperations1(

// DEFS-LABEL: GET_OP_DEFS_0
// DEFS: void test::registerTestDialectOperations(
// DEFS: void test::registerTestDialectOperations0(
// DEFS: OpAAdaptor
// DEFS: OpBAdaptor

// DEFS-LABEL: GET_OP_DEFS_1
// DEFS: void test::registerTestDialectOperations1(
// DEFS: OpCAdaptor
14 changes: 14 additions & 0 deletions mlir/tools/mlir-src-sharder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
set(LLVM_LINK_COMPONENTS Support)
set(LIBS MLIRSupport)

add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
mlir-src-sharder.cpp

DEPENDS
${LIBS}
)

set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})

mlir_check_all_link_libraries(mlir-src-sharder)
114 changes: 114 additions & 0 deletions mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/ToolOutputFile.h"

using namespace mlir;

/// Create a dependency file for `-d` option.
///
/// This functionality is generally only for the benefit of the build system,
/// and is modeled after the same option in TableGen.
static LogicalResult createDependencyFile(StringRef outputFilename,
StringRef dependencyFile) {
if (outputFilename == "-") {
llvm::errs() << "error: the option -d must be used together with -o\n";
return failure();
}

std::string errorMessage;
std::unique_ptr<llvm::ToolOutputFile> outputFile =
openOutputFile(dependencyFile, &errorMessage);
if (!outputFile) {
llvm::errs() << errorMessage << "\n";
return failure();
}

outputFile->os() << outputFilename << ":\n";
outputFile->keep();
return success();
}

int main(int argc, char **argv) {
// FIXME: This is necessary because we link in TableGen, which defines its
// options as static variables.. some of which overlap with our options.
llvm::cl::ResetCommandLineParser();

llvm::cl::opt<unsigned> opShardIndex(
"op-shard-index", llvm::cl::desc("The current shard index"));
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::list<std::string> includeDirs(
"I", llvm::cl::desc("Directory of include files"),
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
llvm::cl::opt<std::string> dependencyFilename(
"d", llvm::cl::desc("Dependency filename"),
llvm::cl::value_desc("filename"), llvm::cl::init(""));
llvm::cl::opt<bool> writeIfChanged(
"write-if-changed",
llvm::cl::desc("Only write to the output file if it changed"));

llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);

// Open the input file.
std::string errorMessage;
std::unique_ptr<llvm::MemoryBuffer> inputFile =
openInputFile(inputFilename, &errorMessage);
if (!inputFile) {
llvm::errs() << errorMessage << "\n";
return 1;
}

// Write the output to a buffer.
std::string outputStr;
llvm::raw_string_ostream os(outputStr);
os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
<< inputFile->getBuffer();

// Determine whether we need to write the output file.
bool shouldWriteOutput = true;
if (writeIfChanged) {
// Only update the real output file if there are any differences. This
// prevents recompilation of all the files depending on it if there aren't
// any.
if (auto existingOrErr =
llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
if (std::move(existingOrErr.get())->getBuffer() == os.str())
shouldWriteOutput = false;
}

// Populate the output file if necessary.
if (shouldWriteOutput) {
std::unique_ptr<llvm::ToolOutputFile> outputFile =
openOutputFile(outputFilename, &errorMessage);
if (!outputFile) {
llvm::errs() << errorMessage << "\n";
return 1;
}
outputFile->os() << os.str();
outputFile->keep();
}

// Always write the depfile, even if the main output hasn't changed. If it's
// missing, Ninja considers the output dirty.
if (!dependencyFilename.empty())
if (failed(createDependencyFile(outputFilename, dependencyFilename)))
return 1;

return 0;
}
Loading