diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h new file mode 100644 index 000000000..ee8630b53 --- /dev/null +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -0,0 +1,69 @@ +//===-- Driver.h - The top-level MLIR compiler driver -----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H +#define GC_EXECUTIONENGINE_DRIVER_DRIVER_H + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include +#include + +namespace mlir { +class DialectRegistry; +namespace gc { + +const DialectRegistry &initCompilerAndGetDialects(); + +// the pointers to XXXMemRefType +using GeneralMemrefPtr = void *; +using JitModuleFuncT = void (*)(void **); + +struct DriverOptions { + /// the optimization level for the LLVM-JIT + llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; + /// whether to run the MLIR transformation passes + bool runTransforms = true; + /// todo: target machine, etc. +}; + +class JitModule { +public: + static llvm::Expected> + create(Operation *op, const DriverOptions &options = {}); + + /// args should be an array of XXXMemrefType* + void call(GeneralMemrefPtr *args, std::size_t numArgs) { + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numArgs); + for (size_t i = 0; i < numArgs; i++) { + realargs.push_back(&args[i]); + } + compute(realargs.data()); + } + + /// directly call compute(). args should be an array of void*. args[i] should + /// be a pointer to the real data. For passing memref, users need to 1) create + /// a pointer to XXXMemrefType 2) store the pointer to pointer to + /// XXXMemrefType in args[i] + void callRaw(void **args) { compute(args); } + + JitModule(std::unique_ptr engine, JitModuleFuncT compute); + ~JitModule(); + +private: + std::unique_ptr engine; + JitModuleFuncT compute; +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt index 3a1d63d3d..8d92804d6 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect) + add_mlir_dialect_library(MLIRCPURuntimeDialect CPURuntimeDialect.cpp CPURuntimeOps.cpp @@ -10,5 +12,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIRFuncDialect + ${MLIR_LINK_COMPONENTS} ) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt index 3bc84f6c8..52c0d7441 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -1,3 +1,5 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect) + add_mlir_dialect_library(MLIRCPURuntimeTransforms CPURuntimeToLLVM.cpp @@ -8,7 +10,7 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIRFuncDialect + ${MLIR_LINK_COMPONENTS} MLIRCPURuntimeDialect ) diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt index 8aa223412..ae0c1c8df 100644 --- a/lib/gc/ExecutionEngine/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(CPURuntime) +add_subdirectory(Driver) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt new file mode 100644 index 000000000..8bda0a16c --- /dev/null +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -0,0 +1,41 @@ +if(GC_DEV_LINK_LLVM_DYLIB) + set(LLVM_LINK_COMPONENTS + LLVM + ) + get_property(dialect_libs GLOBAL PROPERTY GC_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY GC_PASS_LIBS) + set(MLIR_LINK_COMPONENTS + MLIR + MLIRExecutionEngineShared + ) +else() + set(LLVM_LINK_COMPONENTS + Core + Support + nativecodegen + native + ) + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + set(MLIR_LINK_COMPONENTS + MLIRBuiltinToLLVMIRTranslation + MLIRExecutionEngine + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRToLLVMIRTranslationRegistration + ) +endif() + +add_mlir_library(GCJitWrapper + Driver.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} + ${dialect_libs} + ${conversion_libs} + GCPasses + ) + diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp new file mode 100644 index 000000000..0f2361924 --- /dev/null +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -0,0 +1,82 @@ +//===-- Driver.cpp - Top-level MLIR compiler driver -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/Driver/Driver.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Passes.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" +#include "string.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" + +namespace mlir { +namespace gc { + +static DialectRegistry initDialects() { + mlir::registerAllPasses(); + mlir::gc::registerGraphCompilerPasses(); + mlir::cpuruntime::registerCPURuntimePasses(); + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerAllDialects(registry); + mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + registry.insert(); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + mlir::registerAllToLLVMIRTranslations(registry); + return registry; +} + +const DialectRegistry &initCompilerAndGetDialects() { + static DialectRegistry reg = initDialects(); + return reg; +} + +static const char defaultComputeName[] = "_mlir_ciface_compute"; + +llvm::Expected> +JitModule::create(Operation *op, const DriverOptions &options) { + if (options.runTransforms) { + mlir::PassManager pm{op->getContext()}; + populateCPUPipeline(pm); + if (auto result = pm.run(op); failed(result)) { + return llvm::make_error( + "MLIR pass error", llvm::inconvertibleErrorCode()); + } + } + ExecutionEngineOptions exeOptions; + exeOptions.jitCodeGenOptLevel = options.jitCodeGenOptLevel; + std::unique_ptr tm = nullptr; + auto exec = ExecutionEngine::create(op, exeOptions, std::move(tm)); + if (!exec) { + return exec.takeError(); + } + auto &engine = *exec; + JitModuleFuncT compute; + { + auto expectCompute = engine->lookupPacked(defaultComputeName); + if (!expectCompute) { + return expectCompute.takeError(); + } + compute = *expectCompute; + } + return std::make_shared(std::move(engine), compute); +} + +JitModule::JitModule(std::unique_ptr engine, + JitModuleFuncT compute) + : engine{std::move(engine)}, compute{compute} {} +JitModule::~JitModule() = default; + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 63d170dfb..72003224a 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/test/mlir/unittests/CMakeLists.txt b/test/mlir/unittests/CMakeLists.txt index 62ce46eeb..8bdb63b68 100644 --- a/test/mlir/unittests/CMakeLists.txt +++ b/test/mlir/unittests/CMakeLists.txt @@ -13,4 +13,5 @@ function(add_mlir_unittest test_dirname) endfunction() add_subdirectory(Example) +add_subdirectory(ExecutionEngine) diff --git a/test/mlir/unittests/ExecutionEngine/CMakeLists.txt b/test/mlir/unittests/ExecutionEngine/CMakeLists.txt new file mode 100644 index 000000000..0e7315a0f --- /dev/null +++ b/test/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(GCExecutionEngineTests + JitWrapper.cpp +) +target_link_libraries(GCExecutionEngineTests + PRIVATE + GCJitWrapper + GCCpuRuntime) diff --git a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp new file mode 100644 index 000000000..f7b93eaa6 --- /dev/null +++ b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp @@ -0,0 +1,70 @@ +//===-- JitWrapper.cpp - Wrapper for JIT ------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/Driver/Driver.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/ExecutionEngine/MemRefUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" +#include + +using namespace mlir; + +static const char code1[] = R"mlir( +module { +llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32 +func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + return %2 : tensor<128xf32> +} +} +)mlir"; + +extern "C" { +extern int gc_runtime_keep_alive; +} + +TEST(ExecutionEngine, JitWrapper) { + gc_runtime_keep_alive = 0; + MLIRContext ctx{gc::initCompilerAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code1); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get()); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{{128}, {128}}; + void *args[] = {&*bufA, &*bufB, &*bufC}; + jited.get()->call(args, 3); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufC[{i}], 1.0f + i); + } +}