From 13faa3333b395e8bea8ae45cee32005ea90b2392 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 15:07:57 +0800 Subject: [PATCH 01/32] add cpuruntime dialect --- include/gc/Dialect/CMakeLists.txt | 1 + include/gc/Dialect/CPURuntime/CMakeLists.txt | 2 + .../gc/Dialect/CPURuntime/IR/CMakeLists.txt | 1 + .../Dialect/CPURuntime/IR/CPURuntimeDialect.h | 18 ++ .../CPURuntime/IR/CPURuntimeDialect.td | 34 ++++ .../gc/Dialect/CPURuntime/IR/CPURuntimeOps.h | 26 +++ .../gc/Dialect/CPURuntime/IR/CPURuntimeOps.td | 72 ++++++++ .../CPURuntime/Transforms/CMakeLists.txt | 5 + .../CPURuntime/Transforms/CPURuntimePasses.h | 29 ++++ .../CPURuntime/Transforms/CPURuntimePasses.td | 57 +++++++ lib/gc/Dialect/CMakeLists.txt | 1 + lib/gc/Dialect/CPURuntime/CMakeLists.txt | 2 + lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt | 16 ++ .../CPURuntime/IR/CPURuntimeDialect.cpp | 26 +++ .../Dialect/CPURuntime/IR/CPURuntimeOps.cpp | 56 ++++++ .../CPURuntime/Transforms/CMakeLists.txt | 16 ++ .../Transforms/CPURuntimePasses.cpp | 77 +++++++++ .../Transforms/CPURuntimeToLLVM.cpp | 159 ++++++++++++++++++ lib/gc/Transforms/CMakeLists.txt | 2 + src/CMakeLists.txt | 1 + src/gc-opt/CMakeLists.txt | 2 +- src/gc-opt/gc-opt.cpp | 5 +- .../Dialect/CPURuntime/cpu-runner/printf.mlir | 17 ++ .../CPURuntime/cpuruntime-atexit-to-omp.mlir | 41 +++++ .../CPURuntime/cpuruntime-to-llvm.mlir | 19 +++ 25 files changed, 683 insertions(+), 2 deletions(-) create mode 100644 include/gc/Dialect/CPURuntime/CMakeLists.txt create mode 100644 include/gc/Dialect/CPURuntime/IR/CMakeLists.txt create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td create mode 100644 include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt create mode 100644 include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h create mode 100644 include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td create mode 100644 lib/gc/Dialect/CPURuntime/CMakeLists.txt create mode 100644 lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt create mode 100644 lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp create mode 100644 lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp create mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt create mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp create mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp create mode 100644 test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir create mode 100644 test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir create mode 100644 test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir diff --git a/include/gc/Dialect/CMakeLists.txt b/include/gc/Dialect/CMakeLists.txt index ffeda0aa7..a23f3f9f1 100644 --- a/include/gc/Dialect/CMakeLists.txt +++ b/include/gc/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(CPURuntime) add_subdirectory(OnednnGraph) add_subdirectory(Microkernel) add_subdirectory(Linalgx) \ No newline at end of file diff --git a/include/gc/Dialect/CPURuntime/CMakeLists.txt b/include/gc/Dialect/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/include/gc/Dialect/CPURuntime/IR/CMakeLists.txt new file mode 100644 index 000000000..fb73ae02b --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(CPURuntimeOps cpuruntime) diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h new file mode 100644 index 000000000..757182964 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h @@ -0,0 +1,18 @@ +//===- CPURuntimeDialect.h - CPU Runtime dialect ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_CPURUNTIMEDIALECT_H +#define CPURUNTIME_CPURUNTIMEDIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOpsDialect.h.inc" + +#endif // CPURUNTIME_CPURUNTIMEDIALECT_H diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td new file mode 100644 index 000000000..06f3af526 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td @@ -0,0 +1,34 @@ +//===- CPURuntimeDialect.td - CPU Runtime Dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPUPARALLEL_DIALECT +#define CPUPARALLEL_DIALECT + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// CPURuntime dialect definition. +//===----------------------------------------------------------------------===// + +def CPURuntime_Dialect : Dialect { + let name = "cpuruntime"; + let summary = "A dialect for CPU parallel primitives."; + let description = [{ + This dialect contains primitives for CPU runtime. + }]; + let cppNamespace = "::mlir::cpuruntime"; +} + +//===----------------------------------------------------------------------===// +// Base cpuruntime operation definition. +//===----------------------------------------------------------------------===// + +class CPURuntime_Op traits = []> : + Op; + +#endif // CPUPARALLEL_DIALECT diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h new file mode 100644 index 000000000..5ce667a91 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h @@ -0,0 +1,26 @@ +//===- CPURuntimeOps.h - CPU Runtime Ops ====--------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_CPURUNTIMEOPS_H +#define CPURUNTIME_CPURUNTIMEOPS_H + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h.inc" + +#endif // CPURUNTIME_CPURUNTIMEOPS_H diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td new file mode 100644 index 000000000..cc1b7c555 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td @@ -0,0 +1,72 @@ +//===- CPURuntimeOps.td - CPU Runtime Ops -----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_OPS +#define CPURUNTIME_OPS + +include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" + + +def CPURuntime_AtParallelExitOp : CPURuntime_Op<"at_parallel_exit", [ + ParentOneOf<["scf::ForallOp", "scf::ParallelOp", "omp::WsloopOp", "memref::AllocaScopeOp"]>, + SingleBlockImplicitTerminator<"ParallelExitReturnOp"> + ]> { + let summary = "Runs the block once in all threads at the exit of the parallel section"; + let description = [{ + It executes the block for each thread working in the parallel section for + once, at the exit of parallel section. + }]; + + let regions = (region SizedRegion<1>:$region); + + let hasCustomAssemblyFormat = 1; + + // The default builder does not add a region with an empty body, add our own. + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins)>, + ]; +} + +def CPURuntime_ParallelExitReturnOp : CPURuntime_Op<"parallel_exit.return", [ + Pure, + HasParent<"AtParallelExitOp">, + Terminator, ReturnLike + ]> { + let summary = "Terminates at_parallel_exit block"; + let description = [{ + at_parallel_exit should ends with parallel_exit.return + }]; + let assemblyFormat = + [{ attr-dict }]; +} + + +def CPURuntime_PrintfOp : CPURuntime_Op<"printf", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$format, + Variadic>:$args)> { + let summary = "C-style printf"; + let description = [{ + `cpuruntime.printf` takes a literal format string `format` and an arbitrary number of + scalar arguments that should be printed. + + The format string is a C-style printf string, subject to any restrictions + imposed by one's target platform. + }]; + let assemblyFormat = [{ + $format attr-dict ($args^ `:` type($args))? + }]; +} + + +#endif // CPURUNTIME_OPS diff --git a/include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt new file mode 100644 index 000000000..763ffab86 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS CPURuntimePasses.td) +mlir_tablegen(CPURuntimePasses.h.inc --gen-pass-decls -name CPURuntime) +mlir_tablegen(CPURuntimePasses.capi.h.inc -gen-pass-capi-header --prefix CPURuntime) +mlir_tablegen(CPURuntimePasses.capi.cpp.inc -gen-pass-capi-impl --prefix CPURuntime) +add_public_tablegen_target(MLIRCPURuntimePassesIncGen) diff --git a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h new file mode 100644 index 000000000..8fde8f4fd --- /dev/null +++ b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h @@ -0,0 +1,29 @@ +//===- CPURuntimePasses.h - CPU Runtime Passes ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_CPURUNTIMEPASSES_H +#define CPURUNTIME_CPURUNTIMEPASSES_H + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h" +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace cpuruntime { +void registerConvertCPURuntimeToLLVMInterface(DialectRegistry ®istry); + +#define GEN_PASS_DECL +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" +} // namespace cpuruntime +} // namespace mlir + +#endif diff --git a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td new file mode 100644 index 000000000..0685ce498 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td @@ -0,0 +1,57 @@ +//===- CPURuntimePasses.td - CPU Runtime Passes -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_PASS +#define CPURUNTIME_PASS + +include "mlir/Pass/PassBase.td" + + +def CPURuntimeAtExitToOmp: Pass<"cpuruntime-atexit-to-omp", "::mlir::func::FuncOp"> { + let summary = "Lower at_parallel_exit to code in omp.parallel section"; + let description = [{ + Switches the name of a FuncOp named `bar` to `foo` and folds. + ``` + omp.parallel { + omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + cpuruntime.at_parallel_exit { + "your.op"() + cpuruntime.parallel_exit.return + } + } + omp.yield + } + omp.terminator + } + ``` + Will be changed into + ``` + omp.parallel { + omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + } + omp.yield + } + "your.op"() + omp.terminator + } + ``` + }]; +} + + +def CPURuntimeToLLVM: Pass<"convert-cpuruntime-to-llvm"> { + let summary = "Convert cpuruntime to LLVM dialect"; + let description = [{ + This pass converts supported cpuruntime ops to LLVM dialect instructions. + }]; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + +#endif // CPURUNTIME_PASS diff --git a/lib/gc/Dialect/CMakeLists.txt b/lib/gc/Dialect/CMakeLists.txt index a880ff2ed..8720bd8e6 100644 --- a/lib/gc/Dialect/CMakeLists.txt +++ b/lib/gc/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(CPURuntime) add_subdirectory(Linalgx) add_subdirectory(Microkernel) add_subdirectory(OnednnGraph) diff --git a/lib/gc/Dialect/CPURuntime/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt new file mode 100644 index 000000000..e349da72c --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRCPURuntimeDialect + CPURuntimeDialect.cpp + CPURuntimeOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + MLIRCPURuntimeOpsIncGen + MLIRCPURuntimePassesIncGen + + LINK_LIBS PUBLIC + MLIR + # MLIRInferTypeOpInterface + # MLIRFuncDialect + ) diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp new file mode 100644 index 000000000..9f3e97b57 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp @@ -0,0 +1,26 @@ +//===- CPURuntimeDialect.cpp - CPU Runtime Dialect --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h" +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h" + +using namespace mlir; +using namespace mlir::cpuruntime; + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// CPURuntime dialect. +//===----------------------------------------------------------------------===// + +void CPURuntimeDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp.inc" + >(); +} diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp new file mode 100644 index 000000000..ca632e9db --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp @@ -0,0 +1,56 @@ +//===- CPURuntimeOps.cpp - CPU Runtime Ops ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h" +#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h" + +#define GET_OP_CLASSES +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp.inc" + +#include + +namespace mlir { +using namespace bufferization; + +namespace cpuruntime { + +void AtParallelExitOp::build(OpBuilder &b, OperationState &result) { + OpBuilder::InsertionGuard g(b); + Region *bodyRegion = result.addRegion(); + b.createBlock(bodyRegion); +} + +void AtParallelExitOp::print(OpAsmPrinter &p) { + p << " "; + p.printRegion(getRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + p.printOptionalAttrDict(getOperation()->getAttrs()); +} + +ParseResult AtParallelExitOp::parse(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); + if (parser.parseRegion(*region, regionOperands)) + return failure(); + + if (region->empty()) + OpBuilder(builder.getContext()).createBlock(region.get()); + result.addRegion(std::move(region)); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +} // namespace cpuruntime +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt new file mode 100644 index 000000000..ee6148aa4 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRCPURuntimeTransforms + CPURuntimePasses.cpp + CPURuntimeToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + MLIRCPURuntimePassesIncGen + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRCPURuntimeDialect + ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS MLIRCPURuntimeTransforms) \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp new file mode 100644 index 000000000..f2a098fcf --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp @@ -0,0 +1,77 @@ +//===- CPURuntimePasses.cpp - CPU Runtime Passes ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" + +namespace mlir::cpuruntime { +#define GEN_PASS_DEF_CPURUNTIMEATEXITTOOMP +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" + +namespace { + +class CPURuntimeAtExitToOmpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtParallelExitOp op, + PatternRewriter &rewriter) const final { + auto parent = op->getParentOp(); + Operation *secondLast = nullptr; + while (parent && (llvm::isa(parent) || + llvm::isa(parent))) { + secondLast = parent; + parent = parent->getParentOp(); + } + auto parallel = llvm::dyn_cast(parent); + if (!parallel) { + return failure(); + } + assert(secondLast->getBlock()); + auto itr = secondLast->getBlock()->end(); + --itr; + rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), + secondLast->getBlock(), itr); + rewriter.eraseOp(op); + return success(); + } +}; + +class CPURuntimeExitReturnRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ParallelExitReturnOp op, + PatternRewriter &rewriter) const final { + rewriter.eraseOp(op); + return success(); + } +}; + +class CPURuntimeAtExitToOmp + : public impl::CPURuntimeAtExitToOmpBase { +public: + using impl::CPURuntimeAtExitToOmpBase< + CPURuntimeAtExitToOmp>::CPURuntimeAtExitToOmpBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::cpuruntime diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp new file mode 100644 index 000000000..73cf14a84 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp @@ -0,0 +1,159 @@ +//===- CPURuntimeToLLVM.cpp - CPU Runtime To LLVM ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" + +namespace mlir::cpuruntime { + +void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +#define GEN_PASS_DEF_CPURUNTIMETOLLVM +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" + +namespace { +static const char formatStringPrefix[] = "cpuprintfFormat_"; + +static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, + const Location loc, + ConversionPatternRewriter &rewriter, + StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.template lookupSymbol(name))) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + ret = rewriter.create(loc, name, type, + LLVM::Linkage::External); + } + return ret; +} + + +class PrintfRewriter : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(PrintfOp op, PrintfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto moduleOp = op->getParentOfType(); + auto loc = op->getLoc(); + mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); + mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); + mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); + mlir::Type i8Ptr = LLVM::LLVMPointerType::get(op.getContext()); + auto printfFunc = getOrDefineFunction( + moduleOp, loc, rewriter, "printf", + LLVM::LLVMFunctionType::get(llvmI32, {i8Ptr}, /*isVarArg*/ true)); + + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<20> formatString(adaptor.getFormat()); + formatString.push_back('\0'); // Null terminate for C + size_t formatStringSize = formatString.size_in_bytes(); + + auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize); + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString)); + } + Value globalPtr = rewriter.create( + loc, + LLVM::LLVMPointerType::get(rewriter.getContext(), + global.getAddrSpace()), + global.getSymNameAttr()); + Value stringStart = rewriter.create( + loc, i8Ptr, globalType, globalPtr, ArrayRef{0, 0}); + SmallVector appendFormatArgs = {stringStart}; + for (auto arg : adaptor.getArgs()) { + if (auto floatType = dyn_cast(arg.getType())) { + if (!floatType.isF64()) + arg = rewriter.create( + loc, typeConverter->convertType(rewriter.getF64Type()), arg); + } + if (arg.getType().getIntOrFloatBitWidth() != 64) + arg = rewriter.create(loc, llvmI64, arg); + appendFormatArgs.push_back(arg); + } + rewriter.create(loc, printfFunc, appendFormatArgs); + rewriter.eraseOp(op); + return success(); + } +}; + +class CPURuntimeToLLVM + : public impl::CPURuntimeToLLVMBase { +public: + using Base::Base; + void runOnOperation() final { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions options(&getContext()); + LLVMTypeConverter converter(&getContext(), options); + populateCPURuntimeToLLVMConversionPatterns(converter, patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +/// Implement the interface to convert MemRef to LLVM. +struct CPURuntimeToDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateCPURuntimeToLLVMConversionPatterns(typeConverter, patterns); + } +}; + +} // namespace + +void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter); +} + +void registerConvertCPURuntimeToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + dialect->addInterfaces(); + }); +} + +} // namespace mlir::cpuruntime diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index df8a14d01..e87106946 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -14,3 +14,5 @@ add_mlir_library(GCPasses MLIRBufferizationToMemRef MLIRBufferizationPipelines ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCPasses) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f6298c270..e71fe30a0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(gc_pass_libs GLOBAL PROPERTY GC_PASS_LIBS) add_subdirectory(dnnl) add_subdirectory(gc-cpu-runner) diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index ff33375de..0deb242a0 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -2,7 +2,7 @@ set(gc_opt_libs ${dialect_libs} ${conversion_libs} MLIROptLib - GCPasses) + ${gc_pass_libs}) if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") endif() diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 72a25abf5..e1996c050 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -18,6 +18,7 @@ */ #include "gc/Transforms/Passes.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -25,9 +26,11 @@ int main(int argc, char *argv[]) { mlir::registerAllPasses(); mlir::gc::registerGraphCompilerPasses(); - + mlir::cpuruntime::registerCPURuntimePasses(); mlir::DialectRegistry registry; + registry.insert(); mlir::registerAllDialects(registry); + mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir b/test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir new file mode 100644 index 000000000..e95471d50 --- /dev/null +++ b/test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir @@ -0,0 +1,17 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm | gc-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | FileCheck %s + +module { + func.func @doprint(%t: f32, %t2: i32, %t3: i64) { + cpuruntime.printf "Hello world %f %d %lld\n" %t, %t2, %t3 : f32, i32, i64 + return + } + + func.func @main() { + %c2 = arith.constant 2.0 : f32 + %c32i = arith.constant 2000000 : i32 + %c64i = arith.constant 2000000 : i64 + call @doprint(%c2, %c32i, %c64i) : (f32, i32, i64) -> () + return + } + // CHECK: Hello world 2.000000 2000000 2000000 +} \ No newline at end of file diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir new file mode 100644 index 000000000..401de95cc --- /dev/null +++ b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir @@ -0,0 +1,41 @@ +// RUN: gc-opt %s --cpuruntime-atexit-to-omp | FileCheck %s + +module { + func.func @parallel_insert_slice(%arg0: memref<512x512xf32>) -> memref<512x512xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<512x512xf32> + %c512 = arith.constant 512 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + memref.copy %arg0, %alloc : memref<512x512xf32> to memref<512x512xf32> + %0 = llvm.mlir.constant(1 : i64) : i64 + omp.parallel { + omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> + %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.dealloc %alloc_0 : memref<512xf32> + cpuruntime.at_parallel_exit { + memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> + cpuruntime.parallel_exit.return + } + } + omp.yield + } + memref.prefetch %alloc[%c0,%c0], read, locality<3>, data : memref<512x512xf32> + omp.terminator + } + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK: omp.parallel + // CHECK-NEXT: omp.wsloop + // CHECK-NEXT: memref.alloca_scope + // CHECK-NOT: cpuruntime.at_parallel_exit + // CHECK: omp.yield + // CHECK: memref.prefetch {{%alloc}}[%[[C0]], %[[C0]]] + // CHECK-NEXT: memref.prefetch {{%alloc}}[%[[C1]], %[[C0]]] + // CHECK-NEXT: omp.terminator + return %alloc : memref<512x512xf32> + } +} diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir new file mode 100644 index 000000000..fb8d748f1 --- /dev/null +++ b/test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir @@ -0,0 +1,19 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm | FileCheck %s + +module { + // CHECK: llvm.mlir.global internal constant @cpuprintfFormat_0("Hello world %f %d %lld\0A\00") {addr_space = 0 : i32} + // CHECK: llvm.func @printf(!llvm.ptr, + // CHECK-NEXT: func.func @doprint(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i64) + func.func @doprint(%t: f32, %t2: i32, %t3: i64) { + // CHECK-NEXT: llvm.mlir.addressof + // CHECK-DAG: %[[C1:.*]] = llvm.getelementptr + // CHECK-SAME: !llvm.ptr, !llvm.array<24 x i8> + // CHECK: %[[C2:.*]] = llvm.fpext %[[ARG0]] + // CHECK: %[[C3:.*]] = llvm.zext %[[ARG1]] + // CHECK-NOT: cpuruntime.printf + // CHECK-NEXT: llvm.call @printf(%[[C1]], %[[C2]], %[[C3]], %[[ARG2]]) + cpuruntime.printf "Hello world %f %d %lld\n" %t, %t2, %t3 : f32, i32, i64 + return + } + +} \ No newline at end of file From 161848e3e442ac1f6007de5c664e3b70f5a20b02 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 15:11:57 +0800 Subject: [PATCH 02/32] format --- lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp | 8 +++----- src/gc-opt/gc-opt.cpp | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp index 73cf14a84..c56621d45 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp @@ -25,7 +25,7 @@ namespace mlir::cpuruntime { void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns); #define GEN_PASS_DEF_CPURUNTIMETOLLVM #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" @@ -48,7 +48,6 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, return ret; } - class PrintfRewriter : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -110,8 +109,7 @@ class PrintfRewriter : public ConvertOpToLLVMPattern { } }; -class CPURuntimeToLLVM - : public impl::CPURuntimeToLLVMBase { +class CPURuntimeToLLVM : public impl::CPURuntimeToLLVMBase { public: using Base::Base; void runOnOperation() final { @@ -146,7 +144,7 @@ struct CPURuntimeToDialectInterface : public ConvertToLLVMPatternInterface { } // namespace void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns) { patterns.add(converter); } diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index e1996c050..9b06ecf1f 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -17,8 +17,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "gc/Transforms/Passes.h" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" +#include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" From 447ef129fdbd73e15533c1bbc9a8f1e6f1273413 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 15:18:43 +0800 Subject: [PATCH 03/32] add dependency --- lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt index e349da72c..3a1d63d3d 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -10,7 +10,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIR - # MLIRInferTypeOpInterface - # MLIRFuncDialect + MLIRFuncDialect ) From a73dcc12e1023de4b59d4303835c2489b9f955b7 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 16:03:52 +0800 Subject: [PATCH 04/32] fix new MLIR --- .../Transforms/CPURuntimePasses.cpp | 19 +++++++------- .../CPURuntime/cpuruntime-atexit-to-omp.mlir | 25 +++++++++++-------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp index f2a098fcf..a8f74c079 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp @@ -27,21 +27,22 @@ class CPURuntimeAtExitToOmpRewriter LogicalResult matchAndRewrite(AtParallelExitOp op, PatternRewriter &rewriter) const final { auto parent = op->getParentOp(); - Operation *secondLast = nullptr; - while (parent && (llvm::isa(parent) || - llvm::isa(parent))) { - secondLast = parent; + omp::ParallelOp parallel; + while (parent) { + parallel = llvm::dyn_cast(parent); + if (parallel) { + break; + } parent = parent->getParentOp(); } - auto parallel = llvm::dyn_cast(parent); if (!parallel) { return failure(); } - assert(secondLast->getBlock()); - auto itr = secondLast->getBlock()->end(); + auto &block = parallel.getRegion().front(); + auto itr = block.end(); --itr; - rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), - secondLast->getBlock(), itr); + rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), &block, + itr); rewriter.eraseOp(op); return success(); } diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir index 401de95cc..172777690 100644 --- a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir +++ b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir @@ -10,18 +10,21 @@ module { memref.copy %arg0, %alloc : memref<512x512xf32> to memref<512x512xf32> %0 = llvm.mlir.constant(1 : i64) : i64 omp.parallel { - omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> - %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.dealloc %alloc_0 : memref<512xf32> - cpuruntime.at_parallel_exit { - memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> - cpuruntime.parallel_exit.return + omp.wsloop { + omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> + %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.dealloc %alloc_0 : memref<512xf32> + cpuruntime.at_parallel_exit { + memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> + cpuruntime.parallel_exit.return + } } + omp.yield } - omp.yield + omp.terminator } memref.prefetch %alloc[%c0,%c0], read, locality<3>, data : memref<512x512xf32> omp.terminator @@ -30,7 +33,7 @@ module { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 // CHECK: omp.parallel // CHECK-NEXT: omp.wsloop - // CHECK-NEXT: memref.alloca_scope + // CHECK: memref.alloca_scope // CHECK-NOT: cpuruntime.at_parallel_exit // CHECK: omp.yield // CHECK: memref.prefetch {{%alloc}}[%[[C0]], %[[C0]]] From 1cfede8e24231618e69ef692c3b1a0120d5a8857 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 15 May 2024 15:24:41 +0800 Subject: [PATCH 05/32] add --- include/gc/Transforms/Passes.h | 26 +++ include/gc/Transforms/Passes.td | 13 ++ lib/gc/Transforms/CMakeLists.txt | 1 + lib/gc/Transforms/Pipeline.cpp | 164 +++++++++++++++++++ test/gc/Transforms/Pipeline/run.mlir | 23 +++ test/gc/Transforms/Pipeline/tensor_args.mlir | 13 ++ 6 files changed, 240 insertions(+) create mode 100644 lib/gc/Transforms/Pipeline.cpp create mode 100644 test/gc/Transforms/Pipeline/run.mlir create mode 100644 test/gc/Transforms/Pipeline/tensor_args.mlir diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 243a6f4f6..e5e4aee33 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -12,8 +12,34 @@ #include "mlir/Pass/Pass.h" namespace mlir { + +namespace LLVM { +class LLVMDialect; +} + +namespace scf { +class SCFDialect; +} + +namespace openmp { +class OpenMPDialect; +} + +namespace linalg { +class LinalgDialect; +} + +namespace MemRef { +class MemRefDialect; +} + +class PassManager; + namespace gc { +void populateFrontendPasses(mlir::PassManager &); +void populateCPUPipeline(mlir::PassManager &); + #define GEN_PASS_DECL #include "gc/Transforms/Passes.h.inc" diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index d31baa5a7..ff0cd8a90 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -31,4 +31,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { ]; } +def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { + let summary = "All-in-one pipeline for GC for CPU"; + let dependentDialects = ["onednn_graph::OneDNNGraphDialect", + "tensor::TensorDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "LLVM::LLVMDialect", + "scf::SCFDialect", + "bufferization::BufferizationDialect", + "omp::OpenMPDialect", + "vector::VectorDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index e7e97ea26..f3fa43e04 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp + Pipeline.cpp TileNamed.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp new file mode 100644 index 000000000..ed50925f6 --- /dev/null +++ b/lib/gc/Transforms/Pipeline.cpp @@ -0,0 +1,164 @@ +//===- Pipeline.cpp - Graph Compiler all-in-one pipeline --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Passes.h" + +namespace mlir::gc { + +void populateFrontendPasses(mlir::PassManager &pm) { + // pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg()); +} +// linalg + linalgX + tensor ==> GC V1 GIR + +void populateTensorPasses(mlir::PassManager &pm) { + // + padding propagation pass, upstream-able 127x127 -> tilling size:32 + // ->padding to 128x128 + // + layout propagation pass, upstream-able 4x32x4x32 -> + // tensor.pack/tensor.unpack + // + tensor constant propagation pass, down-stream pass, designed to support + // oneDNN graph spec + // + linalg.matmul lowering to (scf.loop + linalg.brgemm) pass, upstream-able + // + fine-grain fusion pass, upstream-able -> scf.for + linalgx.mask + // + lower linalg to arith/math on virtual vector pass, up-streamable + + // REMOVE this pass after the above passes are added. Currently we add this + // pass to make the pipeline work properly + pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); +} +// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack ==> +// GC V1 TIR + +void populateVectorPasses(mlir::PassManager &pm) { + // + bf16 promotion pass, down-stream pass, device dependent pass, maybe can + // upstream + // + bf16 cast elimilation pass, down-stream pass, fast-math kind pass, + // designed to support oneDNN graph spec + pm.addNestedPass(arith::createArithExpandOpsPass()); + // + lower to physical vector pass, down-stream pass, device dependent pass, + // maybe can upstream +} +// scf + arith + math + vector + tensor + linalg.brgemm + +void populateBufferizationPasses(mlir::PassManager &pm) { + bufferization::OneShotBufferizationOptions options; + pm.addPass(bufferization::createOneShotBufferizePass(options)); + pm.addPass(createCSEPass()); + pm.addPass(mlir::func::createFuncBufferizePass()); + pm.addPass(bufferization::createBufferResultsToOutParamsPass()); + pm.addNestedPass( + bufferization::createBufferizationBufferizePass()); + pm.addNestedPass( + bufferization::createFinalizingBufferizePass()); + // + buffer schedule pass, down-stream pass, to migrate buffer reschedule pass + // from GC V1. + pm.addNestedPass( + bufferization::createBufferHoistingPass()); // Need to improve this pass + // to support thread-local + // allocator. + pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); + pm.addNestedPass(bufferization::createBufferDeallocationPass()); + pm.addPass(createBufferizationToMemRefPass()); +} +// scf + arith + math + vector + memref + linalg.brgemm + +void populateMicroKernelPasses(mlir::PassManager &pm) { + // + ConvertLinalgToMicrokernel pass, upstream-able, + // + CleanupInvalidMicrokernel pass, upstream-able + // + InvariantMicrokernelMotion pass, upstream-able + // + ConvertMicrokernelToDnnlFunc, down-stream pass, to lower brgemm to dnnl + // call + // + ConvertMicrokernelToXsmm, down-stream pass, to lower brgemm to libxsmm + // call + // + LowerMicrokernel pass, upstream-able + // + DispatchMicrokernel, down-stream pass +} +// scf + arith + math + vector + memref + func/microkernel + +void populateCPURuntimePasses(mlir::PassManager &pm) { + // + flatten nested parallel pass, down-stream pass, to support coarse-grain + // fusion + // pm.addNestedPass(parallelcpu::createParallelCPUAtExitToOmp()); + // remove this pass after we add FlattenNestedParallel + pm.addPass(createConvertSCFToOpenMPPass()); +} + +void populateLoweringToLLVMPasses(mlir::PassManager &pm) { + pm.addPass(createConvertSCFToCFPass()); + // pm.addPass(parallelcpu::createParallelCPUToLLVM()); + pm.addPass(createConvertOpenMPToLLVMPass()); + pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addPass(createConvertMathToLibmPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addNestedPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addPass(createSymbolDCEPass()); +} + +void populateLLVMPasses(mlir::PassManager &pm) { + pm.addPass(memref::createExpandOpsPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + populateLoweringToLLVMPasses(pm); +} + +void populateCPUPipeline(mlir::PassManager &pm) { + // front-end, oneDNN graph dialect + populateFrontendPasses(pm); + // middle-end, LinalgX/Linalg/tensor dialects + populateTensorPasses(pm); + // middle-end, arith/math/vector dialects + populateVectorPasses(pm); + // back-end, arith/math/vector/memref dialects + populateBufferizationPasses(pm); + // REMOVE this pass after the TensorPasses are added. Currently we add this + // pass to make the pipeline work properly + pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); + populateMicroKernelPasses(pm); + populateCPURuntimePasses(pm); + // // back-end, llvm dialect + populateLLVMPasses(pm); +} + +#define GEN_PASS_DEF_GCCPUPIPELINE +#include "gc/Transforms/Passes.h.inc" +namespace { + +class GCCPUPipeline : public impl::GCCPUPipelineBase { +public: + friend struct PassHelper; + using impl::GCCPUPipelineBase::GCCPUPipelineBase; + void runOnOperation() final { + auto op = getOperation(); + PassManager pm{op->getContext()}; + populateCPUPipeline(pm); + if (failed(pm.run(op))) + signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::gc diff --git a/test/gc/Transforms/Pipeline/run.mlir b/test/gc/Transforms/Pipeline/run.mlir new file mode 100644 index 000000000..799935006 --- /dev/null +++ b/test/gc/Transforms/Pipeline/run.mlir @@ -0,0 +1,23 @@ +// RUN: gc-opt %s --gc-cpu-pipeline | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s + +module { +func.func @aaa() -> tensor<128xf32> { + %c2 = arith.constant 2.0 : f32 + %a = tensor.empty() : tensor<128xf32> + %2 = linalg.fill ins(%c2 : f32) outs(%a : tensor<128xf32>) -> tensor<128xf32> + return %2 : tensor<128xf32> +} + +func.func @main() { + %result = call @aaa() : ()-> tensor<128xf32> + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c128 step %c1 { + %4 = tensor.extract %result[%iv] : tensor<128xf32> + parallelcpu.printf "%f\n" %4 : f32 + } + return +} +// CHECK-COUNT-128: 2.000000 +} \ No newline at end of file diff --git a/test/gc/Transforms/Pipeline/tensor_args.mlir b/test/gc/Transforms/Pipeline/tensor_args.mlir new file mode 100644 index 000000000..73d916d04 --- /dev/null +++ b/test/gc/Transforms/Pipeline/tensor_args.mlir @@ -0,0 +1,13 @@ +// RUN: gc-opt %s --gc-cpu-pipeline | FileCheck %s + +module { +// CHECK: aaa +// check that the func returns void +// CHECK-NOT: ) -> !llvm.struct< +func.func @aaa(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + // CHECK: memcpy + return %out : tensor<128xf32> +} +} \ No newline at end of file From 475faf8052309cfd9e170f61e2622d5c4cd7a5ad Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 15 May 2024 15:46:46 +0800 Subject: [PATCH 06/32] update --- lib/gc/Transforms/Pipeline.cpp | 9 ++++++--- test/gc/Transforms/Pipeline/run.mlir | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index ed50925f6..8b74df1b9 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -21,6 +21,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" @@ -65,7 +66,9 @@ void populateBufferizationPasses(mlir::PassManager &pm) { pm.addPass(bufferization::createOneShotBufferizePass(options)); pm.addPass(createCSEPass()); pm.addPass(mlir::func::createFuncBufferizePass()); - pm.addPass(bufferization::createBufferResultsToOutParamsPass()); + bufferization::BufferResultsToOutParamsOpts opt{}; + // opt.hoistStaticAllocs = true; + pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); pm.addNestedPass( bufferization::createBufferizationBufferizePass()); pm.addNestedPass( @@ -98,14 +101,14 @@ void populateMicroKernelPasses(mlir::PassManager &pm) { void populateCPURuntimePasses(mlir::PassManager &pm) { // + flatten nested parallel pass, down-stream pass, to support coarse-grain // fusion - // pm.addNestedPass(parallelcpu::createParallelCPUAtExitToOmp()); + pm.addNestedPass(cpuruntime::createCPURuntimeAtExitToOmp()); // remove this pass after we add FlattenNestedParallel pm.addPass(createConvertSCFToOpenMPPass()); } void populateLoweringToLLVMPasses(mlir::PassManager &pm) { pm.addPass(createConvertSCFToCFPass()); - // pm.addPass(parallelcpu::createParallelCPUToLLVM()); + pm.addPass(cpuruntime::createCPURuntimeToLLVM()); pm.addPass(createConvertOpenMPToLLVMPass()); pm.addNestedPass(createConvertMathToLLVMPass()); pm.addPass(createConvertMathToLibmPass()); diff --git a/test/gc/Transforms/Pipeline/run.mlir b/test/gc/Transforms/Pipeline/run.mlir index 799935006..71feb0843 100644 --- a/test/gc/Transforms/Pipeline/run.mlir +++ b/test/gc/Transforms/Pipeline/run.mlir @@ -15,7 +15,7 @@ func.func @main() { %c1 = arith.constant 1 : index scf.for %iv = %c0 to %c128 step %c1 { %4 = tensor.extract %result[%iv] : tensor<128xf32> - parallelcpu.printf "%f\n" %4 : f32 + cpuruntime.printf "%f\n" %4 : f32 } return } From 0ac087deb28d4ec1efcf53519d95e1700690467f Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 15 May 2024 17:02:27 +0800 Subject: [PATCH 07/32] fix --- src/gc-opt/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index 74eb4b28e..36ace6847 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -15,7 +15,9 @@ endif() set(gc_opt_libs ${dialect_libs} ${conversion_libs} - MLIROptLib) + ${MLIR_LINK_COMPONENTS} + GCPasses) + if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") endif() From 74b0d342fba8e5896a9ca07bf57b449f34911155 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 11:00:23 +0800 Subject: [PATCH 08/32] remove at exit --- .../gc/Dialect/CPURuntime/IR/CPURuntimeOps.td | 41 +--------- .../CPURuntime/Transforms/CPURuntimePasses.td | 35 --------- .../Dialect/CPURuntime/IR/CPURuntimeOps.cpp | 34 -------- .../CPURuntime/Transforms/CMakeLists.txt | 1 - .../Transforms/CPURuntimePasses.cpp | 78 ------------------- .../CPURuntime/cpuruntime-atexit-to-omp.mlir | 44 ----------- 6 files changed, 1 insertion(+), 232 deletions(-) delete mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp delete mode 100644 test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td index cc1b7c555..bd77ad997 100644 --- a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td @@ -10,46 +10,7 @@ #define CPURUNTIME_OPS include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td" -include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" - - -def CPURuntime_AtParallelExitOp : CPURuntime_Op<"at_parallel_exit", [ - ParentOneOf<["scf::ForallOp", "scf::ParallelOp", "omp::WsloopOp", "memref::AllocaScopeOp"]>, - SingleBlockImplicitTerminator<"ParallelExitReturnOp"> - ]> { - let summary = "Runs the block once in all threads at the exit of the parallel section"; - let description = [{ - It executes the block for each thread working in the parallel section for - once, at the exit of parallel section. - }]; - - let regions = (region SizedRegion<1>:$region); - - let hasCustomAssemblyFormat = 1; - - // The default builder does not add a region with an empty body, add our own. - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins)>, - ]; -} - -def CPURuntime_ParallelExitReturnOp : CPURuntime_Op<"parallel_exit.return", [ - Pure, - HasParent<"AtParallelExitOp">, - Terminator, ReturnLike - ]> { - let summary = "Terminates at_parallel_exit block"; - let description = [{ - at_parallel_exit should ends with parallel_exit.return - }]; - let assemblyFormat = - [{ attr-dict }]; -} def CPURuntime_PrintfOp : CPURuntime_Op<"printf", [MemoryEffects<[MemWrite]>]>, @@ -61,7 +22,7 @@ def CPURuntime_PrintfOp : CPURuntime_Op<"printf", [MemoryEffects<[MemWrite]>]>, scalar arguments that should be printed. The format string is a C-style printf string, subject to any restrictions - imposed by one's target platform. + imposed by the target platform. }]; let assemblyFormat = [{ $format attr-dict ($args^ `:` type($args))? diff --git a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td index 0685ce498..20c81e10a 100644 --- a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td +++ b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td @@ -11,41 +11,6 @@ include "mlir/Pass/PassBase.td" - -def CPURuntimeAtExitToOmp: Pass<"cpuruntime-atexit-to-omp", "::mlir::func::FuncOp"> { - let summary = "Lower at_parallel_exit to code in omp.parallel section"; - let description = [{ - Switches the name of a FuncOp named `bar` to `foo` and folds. - ``` - omp.parallel { - omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - cpuruntime.at_parallel_exit { - "your.op"() - cpuruntime.parallel_exit.return - } - } - omp.yield - } - omp.terminator - } - ``` - Will be changed into - ``` - omp.parallel { - omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - } - omp.yield - } - "your.op"() - omp.terminator - } - ``` - }]; -} - - def CPURuntimeToLLVM: Pass<"convert-cpuruntime-to-llvm"> { let summary = "Convert cpuruntime to LLVM dialect"; let description = [{ diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp index ca632e9db..ed7bc6581 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp @@ -15,42 +15,8 @@ #include namespace mlir { -using namespace bufferization; - namespace cpuruntime { -void AtParallelExitOp::build(OpBuilder &b, OperationState &result) { - OpBuilder::InsertionGuard g(b); - Region *bodyRegion = result.addRegion(); - b.createBlock(bodyRegion); -} - -void AtParallelExitOp::print(OpAsmPrinter &p) { - p << " "; - p.printRegion(getRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - p.printOptionalAttrDict(getOperation()->getAttrs()); -} - -ParseResult AtParallelExitOp::parse(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - - SmallVector regionOperands; - std::unique_ptr region = std::make_unique(); - if (parser.parseRegion(*region, regionOperands)) - return failure(); - - if (region->empty()) - OpBuilder(builder.getContext()).createBlock(region.get()); - result.addRegion(std::move(region)); - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} } // namespace cpuruntime } // namespace mlir \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt index ee6148aa4..3bc84f6c8 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms - CPURuntimePasses.cpp CPURuntimeToLLVM.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp deleted file mode 100644 index a8f74c079..000000000 --- a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp +++ /dev/null @@ -1,78 +0,0 @@ -//===- CPURuntimePasses.cpp - CPU Runtime Passes ----------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" - -namespace mlir::cpuruntime { -#define GEN_PASS_DEF_CPURUNTIMEATEXITTOOMP -#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" - -namespace { - -class CPURuntimeAtExitToOmpRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtParallelExitOp op, - PatternRewriter &rewriter) const final { - auto parent = op->getParentOp(); - omp::ParallelOp parallel; - while (parent) { - parallel = llvm::dyn_cast(parent); - if (parallel) { - break; - } - parent = parent->getParentOp(); - } - if (!parallel) { - return failure(); - } - auto &block = parallel.getRegion().front(); - auto itr = block.end(); - --itr; - rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), &block, - itr); - rewriter.eraseOp(op); - return success(); - } -}; - -class CPURuntimeExitReturnRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ParallelExitReturnOp op, - PatternRewriter &rewriter) const final { - rewriter.eraseOp(op); - return success(); - } -}; - -class CPURuntimeAtExitToOmp - : public impl::CPURuntimeAtExitToOmpBase { -public: - using impl::CPURuntimeAtExitToOmpBase< - CPURuntimeAtExitToOmp>::CPURuntimeAtExitToOmpBase; - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) - signalPassFailure(); - } -}; - -} // namespace -} // namespace mlir::cpuruntime diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir deleted file mode 100644 index 172777690..000000000 --- a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: gc-opt %s --cpuruntime-atexit-to-omp | FileCheck %s - -module { - func.func @parallel_insert_slice(%arg0: memref<512x512xf32>) -> memref<512x512xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<512x512xf32> - %c512 = arith.constant 512 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - memref.copy %arg0, %alloc : memref<512x512xf32> to memref<512x512xf32> - %0 = llvm.mlir.constant(1 : i64) : i64 - omp.parallel { - omp.wsloop { - omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> - %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.dealloc %alloc_0 : memref<512xf32> - cpuruntime.at_parallel_exit { - memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> - cpuruntime.parallel_exit.return - } - } - omp.yield - } - omp.terminator - } - memref.prefetch %alloc[%c0,%c0], read, locality<3>, data : memref<512x512xf32> - omp.terminator - } - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK: omp.parallel - // CHECK-NEXT: omp.wsloop - // CHECK: memref.alloca_scope - // CHECK-NOT: cpuruntime.at_parallel_exit - // CHECK: omp.yield - // CHECK: memref.prefetch {{%alloc}}[%[[C0]], %[[C0]]] - // CHECK-NEXT: memref.prefetch {{%alloc}}[%[[C1]], %[[C0]]] - // CHECK-NEXT: omp.terminator - return %alloc : memref<512x512xf32> - } -} From 2cebba99452bd68876269b77293838b3a263f112 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 11:01:58 +0800 Subject: [PATCH 09/32] fix lint --- lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp index ed7bc6581..460c421cc 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp @@ -15,8 +15,5 @@ #include namespace mlir { -namespace cpuruntime { - - -} // namespace cpuruntime +namespace cpuruntime {} // namespace cpuruntime } // namespace mlir \ No newline at end of file From 34d10ea127052b423b3384c2ada95b3c30a1eb36 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 13:29:48 +0800 Subject: [PATCH 10/32] Add kmp_* wrapper for gomp environment --- lib/gc/CMakeLists.txt | 3 +- lib/gc/ExecutionEngine/CMakeLists.txt | 1 + .../ExecutionEngine/CPURuntime/CMakeLists.txt | 15 ++ .../ExecutionEngine/CPURuntime/Parallel.cpp | 189 ++++++++++++++++++ src/gc-cpu-runner/CMakeLists.txt | 3 +- src/gc-cpu-runner/gc-cpu-runner.cpp | 4 + test/gc/cpu-runner/tid.mlir | 37 ++++ 7 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 lib/gc/ExecutionEngine/CMakeLists.txt create mode 100644 lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt create mode 100644 lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp create mode 100644 test/gc/cpu-runner/tid.mlir diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index fd78d6cab..921853a08 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -5,4 +5,5 @@ endif() include(functions) add_subdirectory(Dialect) -add_subdirectory(Transforms) \ No newline at end of file +add_subdirectory(Transforms) +add_subdirectory(ExecutionEngine) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt new file mode 100644 index 000000000..8aa223412 --- /dev/null +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(CPURuntime) diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..6be58e28f --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -0,0 +1,15 @@ +find_package(OpenMP REQUIRED) + +if ("iomp" IN_LIST OpenMP_C_LIB_NAMES OR "omp" IN_LIST OpenMP_C_LIB_NAMES OR "omp5" IN_LIST OpenMP_C_LIB_NAMES) +else() + add_definitions("-DGC_NEEDS_OMP_WRAPPER=1") +endif() + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +add_mlir_library(GCCpuRuntime + SHARED + Parallel.cpp + + EXCLUDE_FROM_LIBMLIR + ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp new file mode 100644 index 000000000..a71b8522c --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -0,0 +1,189 @@ +//===- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +#define WEAK_SYMBOL __attribute__((weak)) + +namespace { +struct barrier_t { + alignas(64) std::atomic pending_; + std::atomic rounds_; + uint64_t total_; + // pad barrier to size of cacheline to avoid false sharing + char padding_[64 - 4 * sizeof(int32_t)]; +}; + +typedef uint64_t (*barrier_idle_func)(std::atomic *remaining, + int32_t expected_remain, int32_t tid, + void *args); +} // namespace + +extern "C" { +int gc_runtime_keep_alive = 0; +void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, + void *idle_args) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + assert(cnt >= 0); + // int count = 0; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + if (idle_func) { + if (cur_round != b->rounds_.load()) { + return; + } + idle_func(&b->rounds_, cur_round + 1, -1, idle_args); + // count = ret & 0xffffffff; + } + while (cur_round == b->rounds_.load()) { + _mm_pause(); + } + } +} + +static_assert(sizeof(barrier_t) == 64, "size of barrier_t should be 64-byte"); + +void gc_init_barrier(barrier_t *b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} + +#if GC_NEEDS_OMP_WRAPPER +void WEAK_SYMBOL __kmpc_barrier(void *loc, int32_t global_tid) { +#pragma omp barrier +} + +int WEAK_SYMBOL __kmpc_global_thread_num(void *loc) { + return omp_get_thread_num(); +} + +void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, + int32_t schedtype, + int32_t *plastiter, uint64_t *plower, + uint64_t *pupper, int64_t *pstride, + int64_t incr, int64_t chunk) { + if (unlikely(schedtype != 34)) { + std::abort(); + } + const int32_t FALSE = 0; + const int32_t TRUE = 0; + using UT = uint64_t; + // using ST = int64_t; + /* this all has to be changed back to TID and such.. */ + uint32_t tid = gtid; + uint32_t nth = omp_get_num_threads(); + UT trip_count; + + /* special handling for zero-trip loops */ + if (incr > 0 ? (*pupper < *plower) : (*plower < *pupper)) { + if (plastiter != nullptr) + *plastiter = FALSE; + /* leave pupper and plower set to entire iteration space */ + *pstride = incr; /* value should never be used */ + return; + } + + if (nth == 1) { + if (plastiter != nullptr) + *plastiter = TRUE; + *pstride = + (incr > 0) ? (*pupper - *plower + 1) : (-(*plower - *pupper + 1)); + return; + } + + /* compute trip count */ + if (incr == 1) { + trip_count = *pupper - *plower + 1; + } else if (incr == -1) { + trip_count = *plower - *pupper + 1; + } else if (incr > 0) { + // upper-lower can exceed the limit of signed type + trip_count = (UT)(*pupper - *plower) / incr + 1; + } else { + trip_count = (UT)(*plower - *pupper) / (-incr) + 1; + } + if (trip_count < nth) { + if (tid < trip_count) { + *pupper = *plower = *plower + tid * incr; + } else { + // set bounds so non-active threads execute no iterations + *plower = *pupper + (incr > 0 ? 1 : -1); + } + if (plastiter != nullptr) + *plastiter = (tid == trip_count - 1); + } else { + UT small_chunk = trip_count / nth; + UT extras = trip_count % nth; + *plower += incr * (tid * small_chunk + (tid < extras ? tid : extras)); + *pupper = *plower + small_chunk * incr - (tid < extras ? 0 : incr); + if (plastiter != nullptr) + *plastiter = (tid == nth - 1); + } + *pstride = trip_count; +} + +void WEAK_SYMBOL __kmpc_for_static_fini(void *ptr, int32_t v) {} + +static thread_local int next_num_threads = 0; + +/*! +@ingroup PARALLEL +The type for a microtask which gets passed to @ref __kmpc_fork_call(). +The arguments to the outlined function are +@param global_tid the global thread identity of the thread executing the +function. +@param bound_tid the local identity of the thread executing the function +@param ... pointers to shared variables accessed by the function. +*/ +using kmpc_micro = void (*)(int32_t *global_tid, int32_t *bound_tid, ...); +void WEAK_SYMBOL __kmpc_fork_call(void *loc, int32_t argc, void *pfunc, ...) { + if (unlikely(argc != 1 && argc != 0)) { + std::abort(); + } + va_list ap; + va_start(ap, pfunc); + void *c = va_arg(ap, void *); + int32_t global_tid = 0; + if (unlikely(next_num_threads)) { +#pragma omp parallel num_threads(next_num_threads) + { + kmpc_micro func = (kmpc_micro)(pfunc); + func(&global_tid, nullptr, c); + } + next_num_threads = 0; + } else { +#pragma omp parallel + { + kmpc_micro func = (kmpc_micro)(pfunc); + func(&global_tid, nullptr, c); + } + } + va_end(ap); +} + +void WEAK_SYMBOL __kmpc_push_num_threads(void *loc, int32_t global_tid, + int32_t num_threads) { + next_num_threads = num_threads; +} +#endif +} diff --git a/src/gc-cpu-runner/CMakeLists.txt b/src/gc-cpu-runner/CMakeLists.txt index f3f768612..85dbb6995 100644 --- a/src/gc-cpu-runner/CMakeLists.txt +++ b/src/gc-cpu-runner/CMakeLists.txt @@ -36,7 +36,8 @@ endif() #LLVM_LINK_COMPONENTS is processed by LLVM cmake in add_llvm_executable set(gc_cpu_runner_libs - ${MLIR_LINK_COMPONENTS}) + ${MLIR_LINK_COMPONENTS} + GCCpuRuntime) add_mlir_tool(gc-cpu-runner gc-cpu-runner.cpp ) diff --git a/src/gc-cpu-runner/gc-cpu-runner.cpp b/src/gc-cpu-runner/gc-cpu-runner.cpp index 3ece8f2ff..353abffe9 100644 --- a/src/gc-cpu-runner/gc-cpu-runner.cpp +++ b/src/gc-cpu-runner/gc-cpu-runner.cpp @@ -27,7 +27,11 @@ #include "llvm/Support/TargetSelect.h" #include +extern int gc_runtime_keep_alive; + int main(int argc, char **argv) { + // keeps GCCPURuntime linked + gc_runtime_keep_alive = 0; llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); diff --git a/test/gc/cpu-runner/tid.mlir b/test/gc/cpu-runner/tid.mlir new file mode 100644 index 000000000..aedcc0a20 --- /dev/null +++ b/test/gc/cpu-runner/tid.mlir @@ -0,0 +1,37 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm --convert-openmp-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --reconcile-unrealized-casts | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s +module { + func.func private @omp_get_thread_num() -> i32 + + func.func @check_parallel() { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %0 = llvm.mlir.constant(1 : i64) : i64 + omp.parallel num_threads(%c8: index) { + omp.wsloop { + omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c64) step (%c1, %c1) { + cpuruntime.printf "ITR %zu\n" %arg2 : index + omp.yield + } + omp.terminator + } + %tid = func.call @omp_get_thread_num() : () -> i32 + cpuruntime.printf "EXIT %d\n" %tid : i32 + omp.terminator + } + return + } + + func.func @main() { + %0 = func.call @omp_get_thread_num() : () -> i32 + cpuruntime.printf "TID %d\n" %0 : i32 + call @check_parallel() : ()->() + return + } + // CHECK: TID 0 + // CHECK-COUNT-64: ITR {{[0-9]+}} + // CHECK-NOT: ITR + // CHECK-COUNT-8: EXIT {{[0-9]+}} + // CHECK-NOT: EXIT +} \ No newline at end of file From 80a597f0887889f1721c83e502de97e0c05dac97 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 14:18:55 +0800 Subject: [PATCH 11/32] fix --- lib/gc/Transforms/Pipeline.cpp | 7 +++---- test/gc/Transforms/Pipeline/tensor_args.mlir | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 8b74df1b9..b336684a1 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -66,13 +66,13 @@ void populateBufferizationPasses(mlir::PassManager &pm) { pm.addPass(bufferization::createOneShotBufferizePass(options)); pm.addPass(createCSEPass()); pm.addPass(mlir::func::createFuncBufferizePass()); - bufferization::BufferResultsToOutParamsOpts opt{}; - // opt.hoistStaticAllocs = true; - pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); pm.addNestedPass( bufferization::createBufferizationBufferizePass()); pm.addNestedPass( bufferization::createFinalizingBufferizePass()); + bufferization::BufferResultsToOutParamsOpts opt{}; + opt.hoistStaticAllocs = true; + pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); // + buffer schedule pass, down-stream pass, to migrate buffer reschedule pass // from GC V1. pm.addNestedPass( @@ -101,7 +101,6 @@ void populateMicroKernelPasses(mlir::PassManager &pm) { void populateCPURuntimePasses(mlir::PassManager &pm) { // + flatten nested parallel pass, down-stream pass, to support coarse-grain // fusion - pm.addNestedPass(cpuruntime::createCPURuntimeAtExitToOmp()); // remove this pass after we add FlattenNestedParallel pm.addPass(createConvertSCFToOpenMPPass()); } diff --git a/test/gc/Transforms/Pipeline/tensor_args.mlir b/test/gc/Transforms/Pipeline/tensor_args.mlir index 73d916d04..adcfb3bd8 100644 --- a/test/gc/Transforms/Pipeline/tensor_args.mlir +++ b/test/gc/Transforms/Pipeline/tensor_args.mlir @@ -7,7 +7,7 @@ module { func.func @aaa(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - // CHECK: memcpy - return %out : tensor<128xf32> + // CHECK-NOT: memcpy + return %2 : tensor<128xf32> } } \ No newline at end of file From 0b4332b2fe3a1e0c481fa1bcfb67a86130241ac1 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 17:06:44 +0800 Subject: [PATCH 12/32] fix --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index a71b8522c..4a25a5ee0 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -40,7 +40,6 @@ void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, auto cur_round = b->rounds_.load(std::memory_order_acquire); auto cnt = --b->pending_; assert(cnt >= 0); - // int count = 0; if (cnt == 0) { b->pending_.store(b->total_); b->rounds_.store(cur_round + 1); @@ -50,7 +49,6 @@ void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, return; } idle_func(&b->rounds_, cur_round + 1, -1, idle_args); - // count = ret & 0xffffffff; } while (cur_round == b->rounds_.load()) { _mm_pause(); @@ -86,7 +84,7 @@ void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, std::abort(); } const int32_t FALSE = 0; - const int32_t TRUE = 0; + const int32_t TRUE = 1; using UT = uint64_t; // using ST = int64_t; /* this all has to be changed back to TID and such.. */ From b1c79a292428f7a34670c2b9b2c43d4529b98acd Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:45:07 +0800 Subject: [PATCH 13/32] add wraper --- .../gc/ExecutionEngine/JitWrapper/Module.hpp | 44 +++++++++++ lib/gc/ExecutionEngine/CMakeLists.txt | 1 + .../ExecutionEngine/JitWrapper/CMakeLists.txt | 41 ++++++++++ lib/gc/ExecutionEngine/JitWrapper/Module.cpp | 79 +++++++++++++++++++ lib/gc/Transforms/Pipeline.cpp | 1 + unittests/CMakeLists.txt | 1 + unittests/ExecutionEngine/CMakeLists.txt | 7 ++ unittests/ExecutionEngine/JitWrapper.cpp | 69 ++++++++++++++++ 8 files changed, 243 insertions(+) create mode 100644 include/gc/ExecutionEngine/JitWrapper/Module.hpp create mode 100644 lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt create mode 100644 lib/gc/ExecutionEngine/JitWrapper/Module.cpp create mode 100644 unittests/ExecutionEngine/CMakeLists.txt create mode 100644 unittests/ExecutionEngine/JitWrapper.cpp diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/JitWrapper/Module.hpp new file mode 100644 index 000000000..ec259680f --- /dev/null +++ b/include/gc/ExecutionEngine/JitWrapper/Module.hpp @@ -0,0 +1,44 @@ +//===- Module.h - Jit module and Execution engine wrapper -------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_EXECUTIONENGINE_JITWRAPPER_H +#define GC_EXECUTIONENGINE_JITWRAPPER_H + +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include +#include + +namespace mlir { +class DialectRegistry; +namespace gc { + +const DialectRegistry &initAndGetDialects(); + +using JitModuleFuncT = void (*)(void **); + +class JitModule : public std::enable_shared_from_this { +public: + static llvm::Expected> + create(Operation *op, bool transform, llvm::StringRef entry_name = {}, + const ExecutionEngineOptions &options = {}, + std::unique_ptr tm = nullptr); + + void call(void **args) { entry(args); } + + JitModule(std::unique_ptr engine, JitModuleFuncT entry); + ~JitModule(); + +private: + std::unique_ptr engine; + JitModuleFuncT entry; +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt index 8aa223412..a1592d74d 100644 --- a/lib/gc/ExecutionEngine/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(CPURuntime) +add_subdirectory(JitWrapper) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt b/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt new file mode 100644 index 000000000..3186b018d --- /dev/null +++ b/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt @@ -0,0 +1,41 @@ +if(GC_DEV_LINK_LLVM_DYLIB) + set(LLVM_LINK_COMPONENTS + LLVM + ) + get_property(dialect_libs GLOBAL PROPERTY GC_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY GC_PASS_LIBS) + set(MLIR_LINK_COMPONENTS + MLIR + MLIRExecutionEngineShared + ) +else() + set(LLVM_LINK_COMPONENTS + Core + Support + nativecodegen + native + ) + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + set(MLIR_LINK_COMPONENTS + MLIRBuiltinToLLVMIRTranslation + MLIRExecutionEngine + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRToLLVMIRTranslationRegistration + ) +endif() + +add_mlir_library(GCJitWrapper + Module.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} + ${dialect_libs} + ${conversion_libs} + GCPasses + ) + diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp new file mode 100644 index 000000000..7c80cc993 --- /dev/null +++ b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp @@ -0,0 +1,79 @@ +//===- Module.cpp -----------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Passes.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" + +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" + +namespace mlir { +namespace gc { + +static DialectRegistry initDialects() { + mlir::registerAllPasses(); + mlir::gc::registerGraphCompilerPasses(); + mlir::cpuruntime::registerCPURuntimePasses(); + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerAllDialects(registry); + mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + registry.insert(); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + mlir::registerAllToLLVMIRTranslations(registry); + return registry; +} + +const DialectRegistry &initAndGetDialects() { + static DialectRegistry reg = initDialects(); + return reg; +} + +static const char defaultEntryName[] = "_mlir_ciface_main_entry"; +llvm::Expected> +JitModule::create(Operation *op, bool transform, llvm::StringRef entry_name, + const ExecutionEngineOptions &options, + std::unique_ptr tm) { + if (transform) { + mlir::PassManager pm{op->getContext()}; + populateCPUPipeline(pm); + if (auto result = pm.run(op); failed(result)) { + return llvm::make_error( + "MLIR pass error", llvm::inconvertibleErrorCode()); + } + } + auto exec = ExecutionEngine::create(op, options, std::move(tm)); + if (!exec) { + return exec.takeError(); + } + auto &engine = *exec; + if (entry_name.empty()) { + entry_name = defaultEntryName; + } + auto mainEntry = engine->lookupPacked(entry_name); + if (!mainEntry) { + return mainEntry.takeError(); + } + return std::make_shared(std::move(engine), *mainEntry); +} + +JitModule::JitModule(std::unique_ptr engine, + JitModuleFuncT entry) + : engine{std::move(engine)}, entry{entry} {} +JitModule::~JitModule() = default; + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index b336684a1..e14f86589 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index c93735c63..a9bf31c38 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -18,4 +18,5 @@ function(add_mlir_unittest test_dirname) endfunction() add_subdirectory(Example) +add_subdirectory(ExecutionEngine) diff --git a/unittests/ExecutionEngine/CMakeLists.txt b/unittests/ExecutionEngine/CMakeLists.txt new file mode 100644 index 000000000..0e7315a0f --- /dev/null +++ b/unittests/ExecutionEngine/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(GCExecutionEngineTests + JitWrapper.cpp +) +target_link_libraries(GCExecutionEngineTests + PRIVATE + GCJitWrapper + GCCpuRuntime) diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp new file mode 100644 index 000000000..0653c97ee --- /dev/null +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -0,0 +1,69 @@ +//===- JitWrapper.cpp - Wrapper of JIT ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/ExecutionEngine/MemRefUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace mlir; + +static const char code1[] = R"mlir( +module { +func.func @main_entry(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + return %2 : tensor<128xf32> +} +} +)mlir"; + +extern "C" { +extern int gc_runtime_keep_alive; +} + +TEST(ExecutionEngine, JitWrapper) { + gc_runtime_keep_alive = 0; + MLIRContext ctx{gc::initAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code1); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get(), true); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{{128}, {128}}; + void *args[] = {&*bufA, &*bufB, &*bufC}; + void *pargs[] = {&args[0], &args[1], &args[2]}; + jited.get()->call(pargs); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufC[{i}], 1.0f + i); + } +} From 382171bf38ec16e7b9f3edeb0d07834800aba98d Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:49:08 +0800 Subject: [PATCH 14/32] fix lint --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 7 +++---- src/gc-cpu-runner/CMakeLists.txt | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index 4a25a5ee0..6efb38142 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,5 +1,4 @@ //===- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// -//-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -28,9 +27,9 @@ struct barrier_t { char padding_[64 - 4 * sizeof(int32_t)]; }; -typedef uint64_t (*barrier_idle_func)(std::atomic *remaining, - int32_t expected_remain, int32_t tid, - void *args); +using barrier_idle_func = uint64_t (*)(std::atomic *remaining, + int32_t expected_remain, int32_t tid, + void *args); } // namespace extern "C" { diff --git a/src/gc-cpu-runner/CMakeLists.txt b/src/gc-cpu-runner/CMakeLists.txt index 85dbb6995..2599eef84 100644 --- a/src/gc-cpu-runner/CMakeLists.txt +++ b/src/gc-cpu-runner/CMakeLists.txt @@ -1,3 +1,20 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + if(GC_DEV_LINK_LLVM_DYLIB) set(LLVM_LINK_COMPONENTS LLVM From f1fd0ae69fff9b17c8d7bf354e3e160a5d3be0aa Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:51:25 +0800 Subject: [PATCH 15/32] fix --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index 6efb38142..d81e8d3a9 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,4 +1,4 @@ -//===- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// +//===-- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From a773ea6bcbcfebc1370aa78552a703278ee502da Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:52:29 +0800 Subject: [PATCH 16/32] f --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index d81e8d3a9..ea7641417 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,9 +1,9 @@ -//===-- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// -// +//===-- Parallel.cpp - parallel ---------------------------------*- C++ -*-===// +// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// +// //===----------------------------------------------------------------------===// #include From 84933c21b7f2c73361953166add0e1143238906d Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:57:14 +0800 Subject: [PATCH 17/32] fix --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index ea7641417..5591dc3af 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,9 +1,9 @@ //===-- Parallel.cpp - parallel ---------------------------------*- C++ -*-===// -// +// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// +// //===----------------------------------------------------------------------===// #include From 4cca4dfe7bbcf9dbe4ddab9ff66c0c4ff59ffa86 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 17:11:32 +0800 Subject: [PATCH 18/32] add reference --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index 5591dc3af..3a5b4c2c1 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -74,6 +74,8 @@ int WEAK_SYMBOL __kmpc_global_thread_num(void *loc) { return omp_get_thread_num(); } +// The implementation was extracted and simplified from LLVM libomp +// at openmp/runtime/src/kmp_sched.cpp void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, int32_t schedtype, int32_t *plastiter, uint64_t *plower, From 678cef9885605e7d94199c0d1e0cecf6527fc6c9 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Fri, 24 May 2024 16:27:35 +0800 Subject: [PATCH 19/32] enable const cache --- .../CPURuntime/ConstantCache.hpp | 137 ++++++++++++ .../gc/ExecutionEngine/JitWrapper/Module.hpp | 44 +++- .../ExecutionEngine/CPURuntime/CMakeLists.txt | 1 + .../CPURuntime/ConstantCache.cpp | 40 ++++ lib/gc/ExecutionEngine/JitWrapper/Module.cpp | 209 ++++++++++++++++-- unittests/ExecutionEngine/JitWrapper.cpp | 114 +++++++++- 6 files changed, 518 insertions(+), 27 deletions(-) create mode 100644 include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp create mode 100644 lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp new file mode 100644 index 000000000..120d1d69f --- /dev/null +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp @@ -0,0 +1,137 @@ +#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include +#include +#include + +namespace mlir { +namespace gc { +/** + * The helper class to manage ref count manually for an object allocated with + * shared ptr. It holds an additional shared ptr reference to the object and + * contains an additional self-managed refcount. The refcount will be set to 1 + * when the object is initialized (see init()). When the refcount counts down to + * 0, the additional shared ptr is reset. + */ +struct ref_count_managed { + ref_count_managed() = default; + ref_count_managed(const std::shared_ptr &keep_alive) { + init(keep_alive); + } + void init(const std::shared_ptr &keep_alive) { + keep_alive_ = keep_alive; + ref_count_.store(1); + } + + void ref() { ++ref_count_; } + void deref() { + auto newv = --ref_count_; + if (newv == 0) { + keep_alive_ = nullptr; + } + } + + // atomically check if ref_count_ > 0. if so, ref() the object and return + // true. Otherwise (if ref_count_==0), return false + bool check_alive_and_ref() { + auto oldv = ref_count_.load(); + for (;;) { + if (oldv <= 0) { + return false; + } + if (ref_count_.compare_exchange_strong(oldv, oldv + 1)) { + return true; + } + // CAS failed, oldv has now the newest known value of ref_count_ + } + } + + bool is_alive() const { return ref_count_ > 0; } + void *unsafe_get_ptr() const { return keep_alive_.get(); } + +private: + std::shared_ptr keep_alive_; + std::atomic ref_count_{0}; +}; + +/** + * The proxy for the constant cache of Graph API. It holds a shared ptr pointing + * to the cache item in the cache manager (keep_alive) to extend the lifetime by + * refcount, @see ref_count_managed. To access the memory buffer of the const + * cache, use sc_acquire_const_cache and sc_release_const_cache functions. They + * will ref/deref the const_cache_proxy to make sure the cache is alive after + * calling sc_acquire_const_cache and before sc_release_const_cache. The cache + * manager of Graph API may evict the cache item by dereferenceing this + * ref_count_managed object. sc_{acquire,release}_const_cache functions will + * find out that the cache has been invalidated and they will then use the + * memory allocator in the runtime::stream_t to re-allocate the buffer. Usually + * we expect JIT modules to hold shared ptr to const_cache_proxy via + * cached_const_graph_tensor. + * If is_lazy_ == true, the cache item's lifetime will be managed by the cache + * manager of Graph API and it is filled with data after the first execution of + * the computation. Otherwise, the cache item is always alive as long as the + * jit_module of the kernel is alive. + */ +struct const_cache_proxy : ref_count_managed { + const_cache_proxy(const std::shared_ptr &keep_alive, void *buffer, + size_t size, bool is_lazy) + : ref_count_managed(keep_alive), size_(size), is_lazy_(is_lazy), + buffer_(buffer) {} + ~const_cache_proxy(); + + // get the buffer and increment the refcount. If the buffer is evicted, + // returns null + void *acquire(int32_t *inited) { + if (check_alive_and_ref()) { + *inited = *inited && initialized_; + return buffer_; + } + return nullptr; + } + // decrement the refcount + bool release() { + if (is_alive()) { + deref(); + initialized_ = 1; + return true; + } + return false; + } + + // return the buffer. Do not directly use the buffer because it may be already + // release! To access the buffer, always acquire() before using it. + void *get_buffer_unsafe() const { return buffer_; } + + size_t size_; + // if the buffer is lazy-initialized. If false, it should be filled before + // computation + bool is_lazy_; + +private: + // raw pointer to the buffer + void *buffer_; + // if the buffer has been initialized. calling release() will set this to 1 + int32_t initialized_ = 0; +}; + +struct cached_graph_tensor { + std::shared_ptr base; + size_t offset; + cached_graph_tensor(const std::shared_ptr &base, + size_t offset); + friend class JitModule; + +private: + StridedMemRefType ref; +}; + +std::shared_ptr query_cached_tensor(uint64_t key); +bool reg_cached_tensor(uint64_t key, + const std::shared_ptr &base, + size_t offset); + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/JitWrapper/Module.hpp index ec259680f..afa9fb541 100644 --- a/include/gc/ExecutionEngine/JitWrapper/Module.hpp +++ b/include/gc/ExecutionEngine/JitWrapper/Module.hpp @@ -9,6 +9,8 @@ #ifndef GC_EXECUTIONENGINE_JITWRAPPER_H #define GC_EXECUTIONENGINE_JITWRAPPER_H +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include #include @@ -19,23 +21,53 @@ namespace gc { const DialectRegistry &initAndGetDialects(); +// the pointers to XXXMemRefType +using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); class JitModule : public std::enable_shared_from_this { public: static llvm::Expected> - create(Operation *op, bool transform, llvm::StringRef entry_name = {}, - const ExecutionEngineOptions &options = {}, - std::unique_ptr tm = nullptr); + create(Operation *op, const ExecutionEngineOptions &options = {}, + std::unique_ptr tm = nullptr, + bool transform = true); - void call(void **args) { entry(args); } + // args should be an array of XXXMemrefType* + void call(GeneralMemrefPtr *args); - JitModule(std::unique_ptr engine, JitModuleFuncT entry); + JitModule( + std::unique_ptr engine, JitModuleFuncT compute, + JitModuleFuncT fold, size_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive = {}); ~JitModule(); private: std::unique_ptr engine; - JitModuleFuncT entry; + JitModuleFuncT compute; + JitModuleFuncT fold; + size_t numOrigArgs; + // `keepAlive` has the ownership of the objects pointed by this vector + llvm::SmallVector cacheBases; + struct CacheBufferInfo { + // index in cacheBases + size_t baseIdx; + size_t offset; + }; + // the info for each folded cached buffer + llvm::SmallVector cacheInfo; + // holding the pointers to StridedMemRefType of folded cache + // `keepAlive` holds the the ownership of the pointers + llvm::SmallVector fastFoldBuffers; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs; + + std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 6be58e28f..97f039834 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") add_mlir_library(GCCpuRuntime SHARED Parallel.cpp + ConstantCache.cpp EXCLUDE_FROM_LIBMLIR ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp new file mode 100644 index 000000000..3be4b7dce --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -0,0 +1,40 @@ +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include +#include + +namespace mlir::gc { + +const_cache_proxy::~const_cache_proxy() = default; + + +cached_graph_tensor::cached_graph_tensor( + const std::shared_ptr &base, size_t offset) + : base{base}, offset{offset} { + // todo: fill in real values + ref.basePtr = (char *)base->get_buffer_unsafe() + offset; + ref.data = ref.basePtr; + ref.offset = 0; + memset(ref.sizes, 0, sizeof(ref.sizes)); + memset(ref.strides, 0, sizeof(ref.strides)); +} + +static std::unordered_map> cache; + +std::shared_ptr query_cached_tensor(uint64_t key) { + auto itr = cache.find(key); + if (itr != cache.end()) { + return itr->second; + } + return nullptr; +} + +bool reg_cached_tensor(uint64_t key, + const std::shared_ptr &base, + size_t offset) { + if (query_cached_tensor(key)) { + return false; + } + cache[key] = std::make_shared(base, offset); + return true; +} +} // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp index 7c80cc993..fa4a8151b 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp +++ b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp @@ -7,17 +7,20 @@ //===----------------------------------------------------------------------===// #include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" -#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "mlir/Target/LLVMIR/Dialect/All.h" - +#include "string.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + namespace mlir { namespace gc { @@ -38,15 +41,15 @@ static DialectRegistry initDialects() { } const DialectRegistry &initAndGetDialects() { - static DialectRegistry reg = initDialects(); - return reg; + static DialectRegistry reg = initDialects(); + return reg; } -static const char defaultEntryName[] = "_mlir_ciface_main_entry"; +static const char defaultComputeName[] = "_mlir_ciface_compute"; +static const char defaultFoldName[] = "_mlir_ciface_fold"; llvm::Expected> -JitModule::create(Operation *op, bool transform, llvm::StringRef entry_name, - const ExecutionEngineOptions &options, - std::unique_ptr tm) { +JitModule::create(Operation *op, const ExecutionEngineOptions &options, + std::unique_ptr tm, bool transform) { if (transform) { mlir::PassManager pm{op->getContext()}; populateCPUPipeline(pm); @@ -60,20 +63,192 @@ JitModule::create(Operation *op, bool transform, llvm::StringRef entry_name, return exec.takeError(); } auto &engine = *exec; - if (entry_name.empty()) { - entry_name = defaultEntryName; + uint32_t numOrigArgs; + { + auto expectArgs = engine->lookup("__num_orig_num_args"); + if (!expectArgs) { + return expectArgs.takeError(); + } + numOrigArgs = *reinterpret_cast(*expectArgs); } - auto mainEntry = engine->lookupPacked(entry_name); - if (!mainEntry) { - return mainEntry.takeError(); + JitModuleFuncT compute; + { + auto expectCompute = engine->lookupPacked(defaultComputeName); + if (!expectCompute) { + return expectCompute.takeError(); + } + compute = *expectCompute; + } + llvm::ArrayRef foldBufferIds; + JitModuleFuncT fold = nullptr; + llvm::ArrayRef computeArgs; + llvm::ArrayRef foldArgs; + do { + auto expectBufferIds = engine->lookup("__fold_buffer_ids"); + if (!expectBufferIds) { + // nothing to fold, It is OK. + llvm::consumeError(expectBufferIds.takeError()); + // break out of the scope, don't need to lookup "fold" function + break; + } else { + auto raw = reinterpret_cast(*expectBufferIds); + foldBufferIds = llvm::ArrayRef{raw + 1, raw[0]}; + } + + // find "fold" func + { + auto expectFold = engine->lookupPacked(defaultFoldName); + if (!expectFold) { + return expectFold.takeError(); + } + fold = *expectFold; + } + + // find "foldArgs" + { + auto expectFold = engine->lookup("__fold_args"); + if (!expectFold) { + return expectFold.takeError(); + } + auto raw = reinterpret_cast(*expectFold); + foldArgs = llvm::ArrayRef{raw + 1, raw[0]}; + } + + // find "computeArgs" + { + auto expect = engine->lookup("__compute_args"); + if (!expect) { + return expect.takeError(); + } + auto raw = reinterpret_cast(*expect); + computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; + } + } while (0); + + std::vector> foldInfo; + foldInfo.reserve(foldBufferIds.size()); + for (auto bufId : foldBufferIds) { + auto ret = query_cached_tensor(bufId); + if (!ret) { + return llvm::make_error( + "Failed to query the folded cached tensor", + llvm::inconvertibleErrorCode()); + } + foldInfo.emplace_back(std::move(ret)); } - return std::make_shared(std::move(engine), *mainEntry); + + return std::make_shared(std::move(engine), compute, fold, + numOrigArgs, computeArgs, foldArgs, + std::move(foldInfo)); } -JitModule::JitModule(std::unique_ptr engine, - JitModuleFuncT entry) - : engine{std::move(engine)}, entry{entry} {} +JitModule::JitModule( + std::unique_ptr engine, JitModuleFuncT compute, + JitModuleFuncT fold, size_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive) + : engine{std::move(engine)}, compute{compute}, fold{fold}, + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, computeArgs{computeArgs}, + keepAlive{std::move(cachekeepAlive)} { + for (const auto &cache : keepAlive) { + auto currentItr = + std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); + if (currentItr == cacheBases.end()) { + cacheBases.push_back(cache->base.get()); + currentItr = cacheBases.end() - 1; + } + cacheInfo.emplace_back(CacheBufferInfo{ + static_cast(currentItr - cacheBases.begin()), cache->offset}); + fastFoldBuffers.push_back(&cache->ref); + } +} JitModule::~JitModule() = default; +static void prepareCallArgs(llvm::SmallVector &realargs, + GeneralMemrefPtr *origargs, size_t numOrigArgs, + GeneralMemrefPtr *foldedCache, + llvm::ArrayRef realArgIdx) { + realargs.reserve(realArgIdx.size()); + for (auto argIdx : realArgIdx) { + if (argIdx < numOrigArgs) { + realargs.push_back(&origargs[argIdx]); + } else { + realargs.push_back(&foldedCache[argIdx - numOrigArgs]); + } + } +} + +void JitModule::call(GeneralMemrefPtr *args) { + if (unlikely(cacheInfo.empty())) { + // fast path, no folded cached buffers + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numOrigArgs); + for (size_t i = 0; i < numOrigArgs; i++) { + realargs.push_back(&args[i]); + } + compute(realargs.data()); + return; + } + + // stage 1, acquire the foldBasePtr + llvm::SmallVector foldBasePtr; + int32_t inited = 1; + for (auto b : cacheBases) { + auto ptr = b->acquire(&inited); + if (unlikely(!ptr)) { + ptr = std::aligned_alloc(/*alignment*/ 64, b->size_); + inited = 0; + } + foldBasePtr.push_back((char *)ptr); + } + + // stage 2, run fold() if necessary + GeneralMemrefPtr *foldedCache; + // only used when !inited + std::vector slowFold; + std::vector> slowFoldObj; + if (likely(inited)) { + foldedCache = fastFoldBuffers.data(); + } else { + slowFold.reserve(cacheInfo.size()); + slowFoldObj.reserve(cacheInfo.size()); + for (auto &info : cacheInfo) { + slowFoldObj.emplace_back(); + auto &obj = slowFoldObj.back(); + obj.basePtr = foldBasePtr[info.baseIdx] + info.offset; + obj.data = obj.basePtr; + memset(obj.sizes, 0, sizeof(obj.sizes)); + memset(obj.strides, 0, sizeof(obj.strides)); + slowFold.push_back(&obj); + } + foldedCache = slowFold.data(); + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numOrigArgs, foldedCache, foldArgs); + fold(realargs.data()); + } + + // stage 3, call compute + { + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numOrigArgs, foldedCache, computeArgs); + compute(realargs.data()); + } + + // stage 4, cleanup + for (size_t i = 0; i < cacheBases.size(); i++) { + auto b = cacheBases[i]; + if (unlikely(!b->release())) { + // if the cached buffer is already free'd, foldBasePtr[i] is allocated via + // std::aligned_alloc by us, free it + std::free(foldBasePtr[i]); + } + } +} + } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 0653c97ee..d1a07c86f 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -19,12 +19,14 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "gtest/gtest.h" +#include using namespace mlir; static const char code1[] = R"mlir( module { -func.func @main_entry(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { +llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32 +func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> return %2 : tensor<128xf32> @@ -47,7 +49,7 @@ TEST(ExecutionEngine, JitWrapper) { mlir::OwningOpRef module = mlir::parseSourceFile(sourceMgr, &ctx); ASSERT_TRUE(module); - auto jited = gc::JitModule::create(module.get(), true); + auto jited = gc::JitModule::create(module.get()); bool jit_success = static_cast(jited); if (!jit_success) { auto err = jited.takeError(); @@ -61,9 +63,113 @@ TEST(ExecutionEngine, JitWrapper) { {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; OwningMemRef bufC{{128}, {128}}; void *args[] = {&*bufA, &*bufB, &*bufC}; - void *pargs[] = {&args[0], &args[1], &args[2]}; - jited.get()->call(pargs); + jited.get()->call(args); for (int i = 0; i < 128; i++) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } + +// compute d = (a+a) + (b+b) + c, where a,b is marked constant +// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 +static const char code2[] = R"mlir( +module { +llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 +llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> +// a,b, foldedA,foldedB +llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> +// foldedA, foldedB, c, d +llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> + +func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { + %c0 = arith.constant 0 : index + cpuruntime.printf "HI%zu\n" %c0 : index + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %out2 = tensor.empty() : tensor<128xf32> + %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> + return %2, %3 : tensor<128xf32>, tensor<128xf32> +} + +func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> + return %d : tensor<128xf32> +} +} +)mlir"; + +TEST(ExecutionEngine, JitWrapperCached) { + MLIRContext ctx{gc::initAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code2); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + + // foldedA and foldedB uses this buffer + auto ret = std::shared_ptr(new float[128 * 2]); + auto proxy = std::make_shared( + ret, ret.get(), 128 * 2 * sizeof(float), true); + + ASSERT_TRUE(gc::reg_cached_tensor(114514, proxy, 0)); + ASSERT_TRUE(gc::reg_cached_tensor(1919810, proxy, 128 * sizeof(float))); + + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get()); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{ + {128}, {128}, [](float &ptr, ArrayRef idx) { + ptr = -idx[0] * 3; + }}; + OwningMemRef bufD{{128}, {128}}; + void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; + + // first call, should run fold() + { + testing::internal::CaptureStdout(); + // first call, should run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_EQ(output, "HI0\n"); + } + + { + testing::internal::CaptureStdout(); + // second call, should not run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_TRUE(output.empty()); + } + + // the cache is evicted + proxy->deref(); + { + testing::internal::CaptureStdout(); + // third call, should run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_EQ(output, "HI0\n"); + } +} From c12156c02c77f0e4c7a1923f829b340acb7afc5e Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Fri, 24 May 2024 16:36:54 +0800 Subject: [PATCH 20/32] reduce size --- lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt | 4 +++- lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt index 3a1d63d3d..8d92804d6 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect) + add_mlir_dialect_library(MLIRCPURuntimeDialect CPURuntimeDialect.cpp CPURuntimeOps.cpp @@ -10,5 +12,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIRFuncDialect + ${MLIR_LINK_COMPONENTS} ) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt index 3bc84f6c8..52c0d7441 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -1,3 +1,5 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect) + add_mlir_dialect_library(MLIRCPURuntimeTransforms CPURuntimeToLLVM.cpp @@ -8,7 +10,7 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIRFuncDialect + ${MLIR_LINK_COMPONENTS} MLIRCPURuntimeDialect ) From e24b1df24a4f5b60557d756010c8217104e2c012 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 10:36:42 +0800 Subject: [PATCH 21/32] rename --- .../CPURuntime/ConstantCache.hpp | 22 +++++++++---------- .../gc/ExecutionEngine/JitWrapper/Module.hpp | 6 ++--- .../CPURuntime/ConstantCache.cpp | 18 +++++++-------- lib/gc/ExecutionEngine/JitWrapper/Module.cpp | 6 ++--- unittests/ExecutionEngine/JitWrapper.cpp | 6 ++--- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp index 120d1d69f..4000f53dc 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp @@ -60,25 +60,25 @@ struct ref_count_managed { * to the cache item in the cache manager (keep_alive) to extend the lifetime by * refcount, @see ref_count_managed. To access the memory buffer of the const * cache, use sc_acquire_const_cache and sc_release_const_cache functions. They - * will ref/deref the const_cache_proxy to make sure the cache is alive after + * will ref/deref the ConstCacheProxy to make sure the cache is alive after * calling sc_acquire_const_cache and before sc_release_const_cache. The cache * manager of Graph API may evict the cache item by dereferenceing this * ref_count_managed object. sc_{acquire,release}_const_cache functions will * find out that the cache has been invalidated and they will then use the * memory allocator in the runtime::stream_t to re-allocate the buffer. Usually - * we expect JIT modules to hold shared ptr to const_cache_proxy via + * we expect JIT modules to hold shared ptr to ConstCacheProxy via * cached_const_graph_tensor. * If is_lazy_ == true, the cache item's lifetime will be managed by the cache * manager of Graph API and it is filled with data after the first execution of * the computation. Otherwise, the cache item is always alive as long as the * jit_module of the kernel is alive. */ -struct const_cache_proxy : ref_count_managed { - const_cache_proxy(const std::shared_ptr &keep_alive, void *buffer, +struct ConstCacheProxy : ref_count_managed { + ConstCacheProxy(const std::shared_ptr &keep_alive, void *buffer, size_t size, bool is_lazy) : ref_count_managed(keep_alive), size_(size), is_lazy_(is_lazy), buffer_(buffer) {} - ~const_cache_proxy(); + ~ConstCacheProxy(); // get the buffer and increment the refcount. If the buffer is evicted, // returns null @@ -115,10 +115,10 @@ struct const_cache_proxy : ref_count_managed { int32_t initialized_ = 0; }; -struct cached_graph_tensor { - std::shared_ptr base; +struct CachedGraphTensor { + std::shared_ptr base; size_t offset; - cached_graph_tensor(const std::shared_ptr &base, + CachedGraphTensor(const std::shared_ptr &base, size_t offset); friend class JitModule; @@ -126,9 +126,9 @@ struct cached_graph_tensor { StridedMemRefType ref; }; -std::shared_ptr query_cached_tensor(uint64_t key); -bool reg_cached_tensor(uint64_t key, - const std::shared_ptr &base, +std::shared_ptr queryCacheTensor(uint64_t key); +bool regCachedTensor(uint64_t key, + const std::shared_ptr &base, size_t offset); } // namespace gc diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/JitWrapper/Module.hpp index afa9fb541..926aba3f6 100644 --- a/include/gc/ExecutionEngine/JitWrapper/Module.hpp +++ b/include/gc/ExecutionEngine/JitWrapper/Module.hpp @@ -42,7 +42,7 @@ class JitModule : public std::enable_shared_from_this { llvm::ArrayRef computeArgs, // The code inside `engine` has the ownership of the buffer llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive = {}); + std::vector> &&cachekeepAlive = {}); ~JitModule(); private: @@ -51,7 +51,7 @@ class JitModule : public std::enable_shared_from_this { JitModuleFuncT fold; size_t numOrigArgs; // `keepAlive` has the ownership of the objects pointed by this vector - llvm::SmallVector cacheBases; + llvm::SmallVector cacheBases; struct CacheBufferInfo { // index in cacheBases size_t baseIdx; @@ -67,7 +67,7 @@ class JitModule : public std::enable_shared_from_this { // The code inside `engine` has the ownership of the buffer llvm::ArrayRef computeArgs; - std::vector> keepAlive; + std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index 3be4b7dce..133cc362f 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -4,11 +4,11 @@ namespace mlir::gc { -const_cache_proxy::~const_cache_proxy() = default; +ConstCacheProxy::~ConstCacheProxy() = default; -cached_graph_tensor::cached_graph_tensor( - const std::shared_ptr &base, size_t offset) +CachedGraphTensor::CachedGraphTensor( + const std::shared_ptr &base, size_t offset) : base{base}, offset{offset} { // todo: fill in real values ref.basePtr = (char *)base->get_buffer_unsafe() + offset; @@ -18,9 +18,9 @@ cached_graph_tensor::cached_graph_tensor( memset(ref.strides, 0, sizeof(ref.strides)); } -static std::unordered_map> cache; +static std::unordered_map> cache; -std::shared_ptr query_cached_tensor(uint64_t key) { +std::shared_ptr queryCacheTensor(uint64_t key) { auto itr = cache.find(key); if (itr != cache.end()) { return itr->second; @@ -28,13 +28,13 @@ std::shared_ptr query_cached_tensor(uint64_t key) { return nullptr; } -bool reg_cached_tensor(uint64_t key, - const std::shared_ptr &base, +bool regCachedTensor(uint64_t key, + const std::shared_ptr &base, size_t offset) { - if (query_cached_tensor(key)) { + if (queryCacheTensor(key)) { return false; } - cache[key] = std::make_shared(base, offset); + cache[key] = std::make_shared(base, offset); return true; } } // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp index fa4a8151b..06f9c02b6 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp +++ b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp @@ -125,10 +125,10 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, } } while (0); - std::vector> foldInfo; + std::vector> foldInfo; foldInfo.reserve(foldBufferIds.size()); for (auto bufId : foldBufferIds) { - auto ret = query_cached_tensor(bufId); + auto ret = queryCacheTensor(bufId); if (!ret) { return llvm::make_error( "Failed to query the folded cached tensor", @@ -149,7 +149,7 @@ JitModule::JitModule( llvm::ArrayRef computeArgs, // The code inside `engine` has the ownership of the buffer llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive) + std::vector> &&cachekeepAlive) : engine{std::move(engine)}, compute{compute}, fold{fold}, numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index d1a07c86f..222c73116 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -111,11 +111,11 @@ TEST(ExecutionEngine, JitWrapperCached) { // foldedA and foldedB uses this buffer auto ret = std::shared_ptr(new float[128 * 2]); - auto proxy = std::make_shared( + auto proxy = std::make_shared( ret, ret.get(), 128 * 2 * sizeof(float), true); - ASSERT_TRUE(gc::reg_cached_tensor(114514, proxy, 0)); - ASSERT_TRUE(gc::reg_cached_tensor(1919810, proxy, 128 * sizeof(float))); + ASSERT_TRUE(gc::regCachedTensor(114514, proxy, 0)); + ASSERT_TRUE(gc::regCachedTensor(1919810, proxy, 128 * sizeof(float))); ASSERT_TRUE(module); auto jited = gc::JitModule::create(module.get()); From 1e06c9844c7753400c6efc254898475c52b2c38b Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 10:42:03 +0800 Subject: [PATCH 22/32] fix license.py --- scripts/license.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/license.py b/scripts/license.py index 49f28eaa8..3ed2ce521 100644 --- a/scripts/license.py +++ b/scripts/license.py @@ -15,10 +15,10 @@ # SPDX-License-Identifier: Apache-2.0 import datetime, sys, re, argparse -from typing import Dict, Set +from typing import Dict, Set, List WIDTH: int = 80 -intel_license: list[str] = [ +intel_license: List[str] = [ 'Copyright \\(C\\) (\\d\\d\\d\\d-)?$YEAR Intel Corporation', '', 'Licensed under the Apache License, Version 2.0 (the "License");', @@ -35,7 +35,7 @@ 'SPDX-License-Identifier: Apache-2.0', ] -llvm_license: list[str] = [ +llvm_license: List[str] = [ "===-{1,2} $FILE - .* -*\\*- $LANG -\\*-===", '', 'This file is licensed under the Apache License v2.0 with LLVM Exceptions.', @@ -45,7 +45,7 @@ "===-*===", ] -def check_license(filepath: str, license: list[str], var: Dict[str, str], re_line: Set[int]): +def check_license(filepath: str, license: List[str], var: Dict[str, str], re_line: Set[int]): with open(filepath, 'r') as f: idx: int = 0 for line in f.readlines(): @@ -117,7 +117,7 @@ def use_llvm_license(path: str) -> bool: var: Dict[str, str] = {} re_line: Set[int] = set() - lic = list[str] + lic = List[str] if filepath.startswith("test/") or filepath.startswith("./test/"): continue From 7c32bc5c899872e2a5ed1f011b2446d808b5f32a Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 10:58:21 +0800 Subject: [PATCH 23/32] fix --- .../CPURuntime/ConstantCache.hpp | 88 +++++++++---------- .../Module.hpp => Driver/Driver.hpp} | 6 +- lib/gc/ExecutionEngine/CMakeLists.txt | 2 +- .../CPURuntime/ConstantCache.cpp | 10 ++- .../{JitWrapper => Driver}/CMakeLists.txt | 2 +- .../Module.cpp => Driver/Driver.cpp} | 4 +- unittests/ExecutionEngine/JitWrapper.cpp | 6 +- 7 files changed, 63 insertions(+), 55 deletions(-) rename include/gc/ExecutionEngine/{JitWrapper/Module.hpp => Driver/Driver.hpp} (94%) rename lib/gc/ExecutionEngine/{JitWrapper => Driver}/CMakeLists.txt (98%) rename lib/gc/ExecutionEngine/{JitWrapper/Module.cpp => Driver/Driver.cpp} (98%) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp index 4000f53dc..0d96ae6b3 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp @@ -1,3 +1,10 @@ +//===-- ConstantCache.hpp - Constant cache interfaces -----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// #ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H #define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H #include "mlir/ExecutionEngine/CRunnerUtils.h" @@ -14,76 +21,70 @@ namespace gc { * when the object is initialized (see init()). When the refcount counts down to * 0, the additional shared ptr is reset. */ -struct ref_count_managed { - ref_count_managed() = default; - ref_count_managed(const std::shared_ptr &keep_alive) { - init(keep_alive); - } +struct RefCountManaged { + RefCountManaged() = default; + RefCountManaged(const std::shared_ptr &keep_alive) { init(keep_alive); } void init(const std::shared_ptr &keep_alive) { - keep_alive_ = keep_alive; - ref_count_.store(1); + keepAlive = keep_alive; + refCount.store(1); } - void ref() { ++ref_count_; } + void ref() { ++refCount; } void deref() { - auto newv = --ref_count_; + auto newv = --refCount; if (newv == 0) { - keep_alive_ = nullptr; + keepAlive = nullptr; } } - // atomically check if ref_count_ > 0. if so, ref() the object and return - // true. Otherwise (if ref_count_==0), return false - bool check_alive_and_ref() { - auto oldv = ref_count_.load(); + // atomically check if refCount > 0. if so, ref() the object and return + // true. Otherwise (if refCount==0), return false + bool checkAliveAndRef() { + auto oldv = refCount.load(); for (;;) { if (oldv <= 0) { return false; } - if (ref_count_.compare_exchange_strong(oldv, oldv + 1)) { + if (refCount.compare_exchange_strong(oldv, oldv + 1)) { return true; } - // CAS failed, oldv has now the newest known value of ref_count_ + // CAS failed, oldv has now the newest known value of refCount } } - bool is_alive() const { return ref_count_ > 0; } - void *unsafe_get_ptr() const { return keep_alive_.get(); } + bool isAlive() const { return refCount > 0; } + void *getPtrUnsafe() const { return keepAlive.get(); } private: - std::shared_ptr keep_alive_; - std::atomic ref_count_{0}; + std::shared_ptr keepAlive; + std::atomic refCount{0}; }; /** * The proxy for the constant cache of Graph API. It holds a shared ptr pointing * to the cache item in the cache manager (keep_alive) to extend the lifetime by - * refcount, @see ref_count_managed. To access the memory buffer of the const - * cache, use sc_acquire_const_cache and sc_release_const_cache functions. They - * will ref/deref the ConstCacheProxy to make sure the cache is alive after - * calling sc_acquire_const_cache and before sc_release_const_cache. The cache - * manager of Graph API may evict the cache item by dereferenceing this - * ref_count_managed object. sc_{acquire,release}_const_cache functions will - * find out that the cache has been invalidated and they will then use the - * memory allocator in the runtime::stream_t to re-allocate the buffer. Usually - * we expect JIT modules to hold shared ptr to ConstCacheProxy via - * cached_const_graph_tensor. - * If is_lazy_ == true, the cache item's lifetime will be managed by the cache - * manager of Graph API and it is filled with data after the first execution of - * the computation. Otherwise, the cache item is always alive as long as the - * jit_module of the kernel is alive. + * refcount, @see RefCountManaged. To access the memory buffer of the const + * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy + * to make sure the cache is alive after calling acauire and before release. The + * cache manager of Graph API may evict the cache item by dereferenceing this + * RefCountManaged object. {acquire,release} functions will find out that the + * cache has been invalidated. Usually we expect JIT modules to hold shared ptr + * to ConstCacheProxy via CachedGraphTensor. If is_lazy_ == true, the cache + * item's lifetime will be managed by the cache manager of Graph API and it is + * filled with data after the first execution of the computation. Otherwise, the + * cache item is always alive as long as the jit_module of the kernel is alive. */ -struct ConstCacheProxy : ref_count_managed { +struct ConstCacheProxy : RefCountManaged { ConstCacheProxy(const std::shared_ptr &keep_alive, void *buffer, - size_t size, bool is_lazy) - : ref_count_managed(keep_alive), size_(size), is_lazy_(is_lazy), + size_t size, bool is_lazy) + : RefCountManaged(keep_alive), size_(size), is_lazy_(is_lazy), buffer_(buffer) {} ~ConstCacheProxy(); // get the buffer and increment the refcount. If the buffer is evicted, // returns null void *acquire(int32_t *inited) { - if (check_alive_and_ref()) { + if (checkAliveAndRef()) { *inited = *inited && initialized_; return buffer_; } @@ -91,7 +92,7 @@ struct ConstCacheProxy : ref_count_managed { } // decrement the refcount bool release() { - if (is_alive()) { + if (isAlive()) { deref(); initialized_ = 1; return true; @@ -101,7 +102,7 @@ struct ConstCacheProxy : ref_count_managed { // return the buffer. Do not directly use the buffer because it may be already // release! To access the buffer, always acquire() before using it. - void *get_buffer_unsafe() const { return buffer_; } + void *getBufferUnsafe() const { return buffer_; } size_t size_; // if the buffer is lazy-initialized. If false, it should be filled before @@ -119,7 +120,7 @@ struct CachedGraphTensor { std::shared_ptr base; size_t offset; CachedGraphTensor(const std::shared_ptr &base, - size_t offset); + size_t offset); friend class JitModule; private: @@ -127,9 +128,8 @@ struct CachedGraphTensor { }; std::shared_ptr queryCacheTensor(uint64_t key); -bool regCachedTensor(uint64_t key, - const std::shared_ptr &base, - size_t offset); +bool regCachedTensor(uint64_t key, const std::shared_ptr &base, + size_t offset); } // namespace gc } // namespace mlir diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/Driver/Driver.hpp similarity index 94% rename from include/gc/ExecutionEngine/JitWrapper/Module.hpp rename to include/gc/ExecutionEngine/Driver/Driver.hpp index 926aba3f6..0a34514fa 100644 --- a/include/gc/ExecutionEngine/JitWrapper/Module.hpp +++ b/include/gc/ExecutionEngine/Driver/Driver.hpp @@ -1,4 +1,4 @@ -//===- Module.h - Jit module and Execution engine wrapper -------*- C++ -*-===// +//===-- Driver.hpp - The top-level MLIR compiler driver ---------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef GC_EXECUTIONENGINE_JITWRAPPER_H -#define GC_EXECUTIONENGINE_JITWRAPPER_H +#ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H +#define GC_EXECUTIONENGINE_DRIVER_DRIVER_H #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" #include "mlir/ExecutionEngine/CRunnerUtils.h" diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt index a1592d74d..ae0c1c8df 100644 --- a/lib/gc/ExecutionEngine/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -1,2 +1,2 @@ add_subdirectory(CPURuntime) -add_subdirectory(JitWrapper) \ No newline at end of file +add_subdirectory(Driver) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index 133cc362f..245f2ca89 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -1,3 +1,11 @@ +//===-- ConstantCache.cpp - Constant cache ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" #include #include @@ -11,7 +19,7 @@ CachedGraphTensor::CachedGraphTensor( const std::shared_ptr &base, size_t offset) : base{base}, offset{offset} { // todo: fill in real values - ref.basePtr = (char *)base->get_buffer_unsafe() + offset; + ref.basePtr = (char *)base->getBufferUnsafe() + offset; ref.data = ref.basePtr; ref.offset = 0; memset(ref.sizes, 0, sizeof(ref.sizes)); diff --git a/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt similarity index 98% rename from lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt rename to lib/gc/ExecutionEngine/Driver/CMakeLists.txt index 3186b018d..8bda0a16c 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -27,7 +27,7 @@ else() endif() add_mlir_library(GCJitWrapper - Module.cpp + Driver.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp similarity index 98% rename from lib/gc/ExecutionEngine/JitWrapper/Module.cpp rename to lib/gc/ExecutionEngine/Driver/Driver.cpp index 06f9c02b6..67b2ced6a 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -1,4 +1,4 @@ -//===- Module.cpp -----------------------------------------------*- C++ -*-===// +//===-- Driver.cpp - Top-level MLIR compiler driver -------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/ExecutionEngine/Driver/Driver.hpp" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 222c73116..d902541bd 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -1,12 +1,12 @@ -//===- JitWrapper.cpp - Wrapper of JIT ------------------------------------===// +//===-- JitWrapper.cpp - Wrapper for JIT ------------------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/ExecutionEngine/Driver/Driver.hpp" #include "mlir/AsmParser/AsmParser.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/IR/AsmState.h" From 4540fb64834a2ebcc4f1297672ff41a080d72555 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 14:12:32 +0800 Subject: [PATCH 24/32] fix lint --- lib/gc/ExecutionEngine/Driver/Driver.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 67b2ced6a..0b3c19113 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -123,7 +123,7 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, auto raw = reinterpret_cast(*expect); computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; } - } while (0); + } while (false); std::vector> foldInfo; foldInfo.reserve(foldBufferIds.size()); From 381677a27f0c796dd94e208a4c2f1378994b1f59 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 14:32:18 +0800 Subject: [PATCH 25/32] fix comments --- lib/gc/Transforms/Pipeline.cpp | 64 ++++++++++++++-------------------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index b336684a1..81fae6877 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -27,40 +27,35 @@ namespace mlir::gc { +// linalg + linalgX + tensor void populateFrontendPasses(mlir::PassManager &pm) { // pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg()); } -// linalg + linalgX + tensor ==> GC V1 GIR +// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack void populateTensorPasses(mlir::PassManager &pm) { - // + padding propagation pass, upstream-able 127x127 -> tilling size:32 - // ->padding to 128x128 - // + layout propagation pass, upstream-able 4x32x4x32 -> - // tensor.pack/tensor.unpack - // + tensor constant propagation pass, down-stream pass, designed to support - // oneDNN graph spec - // + linalg.matmul lowering to (scf.loop + linalg.brgemm) pass, upstream-able - // + fine-grain fusion pass, upstream-able -> scf.for + linalgx.mask - // + lower linalg to arith/math on virtual vector pass, up-streamable + // todo: padding propagation pass + // todo: layout propagation pass + // todo: tensor constant propagation pass + // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass + // todo: fine-grain fusion pass + // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); } -// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack ==> -// GC V1 TIR +// scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::PassManager &pm) { - // + bf16 promotion pass, down-stream pass, device dependent pass, maybe can - // upstream - // + bf16 cast elimilation pass, down-stream pass, fast-math kind pass, - // designed to support oneDNN graph spec + // todo: bf16 promotion pass, device dependent pass + // todo: bf16 cast elimilation pass, fast-math kind pass, designed to support + // oneDNN graph spec pm.addNestedPass(arith::createArithExpandOpsPass()); - // + lower to physical vector pass, down-stream pass, device dependent pass, - // maybe can upstream + // todo: lower to physical vector pass, device dependent pass } -// scf + arith + math + vector + tensor + linalg.brgemm +// scf + arith + math + vector + memref + linalg.brgemm void populateBufferizationPasses(mlir::PassManager &pm) { bufferization::OneShotBufferizationOptions options; pm.addPass(bufferization::createOneShotBufferizePass(options)); @@ -73,34 +68,27 @@ void populateBufferizationPasses(mlir::PassManager &pm) { bufferization::BufferResultsToOutParamsOpts opt{}; opt.hoistStaticAllocs = true; pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); - // + buffer schedule pass, down-stream pass, to migrate buffer reschedule pass - // from GC V1. - pm.addNestedPass( - bufferization::createBufferHoistingPass()); // Need to improve this pass - // to support thread-local - // allocator. + // todo: buffer schedule pass + // todo: Need to improve this pass to support nested parallel. + pm.addNestedPass(bufferization::createBufferHoistingPass()); pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addPass(createBufferizationToMemRefPass()); } -// scf + arith + math + vector + memref + linalg.brgemm +// scf + arith + math + vector + memref + func/microkernel void populateMicroKernelPasses(mlir::PassManager &pm) { - // + ConvertLinalgToMicrokernel pass, upstream-able, - // + CleanupInvalidMicrokernel pass, upstream-able - // + InvariantMicrokernelMotion pass, upstream-able - // + ConvertMicrokernelToDnnlFunc, down-stream pass, to lower brgemm to dnnl - // call - // + ConvertMicrokernelToXsmm, down-stream pass, to lower brgemm to libxsmm - // call - // + LowerMicrokernel pass, upstream-able - // + DispatchMicrokernel, down-stream pass + // todo: ConvertLinalgToMicrokernel pass + // todo: CleanupInvalidMicrokernel pass + // todo: InvariantMicrokernelMotion pass + // todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call + // todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call + // todo: LowerMicrokernel pass + // todo: DispatchMicrokernel } -// scf + arith + math + vector + memref + func/microkernel void populateCPURuntimePasses(mlir::PassManager &pm) { - // + flatten nested parallel pass, down-stream pass, to support coarse-grain - // fusion + // todo: flatten nested parallel pass to support coarse-grain usion // remove this pass after we add FlattenNestedParallel pm.addPass(createConvertSCFToOpenMPPass()); } From b54b310af2ebd75cd928c4edf13bc842e00eae34 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 29 May 2024 11:36:12 +0800 Subject: [PATCH 26/32] fix --- .../{ConstantCache.hpp => ConstantCache.h} | 34 +++++++++---------- .../Driver/{Driver.hpp => Driver.h} | 4 +-- .../CPURuntime/ConstantCache.cpp | 2 +- lib/gc/ExecutionEngine/Driver/Driver.cpp | 4 +-- unittests/ExecutionEngine/JitWrapper.cpp | 2 +- 5 files changed, 23 insertions(+), 23 deletions(-) rename include/gc/ExecutionEngine/CPURuntime/{ConstantCache.hpp => ConstantCache.h} (83%) rename include/gc/ExecutionEngine/Driver/{Driver.hpp => Driver.h} (95%) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h similarity index 83% rename from include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp rename to include/gc/ExecutionEngine/CPURuntime/ConstantCache.h index 0d96ae6b3..f41cb09e8 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -1,4 +1,4 @@ -//===-- ConstantCache.hpp - Constant cache interfaces -----------*- C++ -*-===// +//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -23,9 +23,9 @@ namespace gc { */ struct RefCountManaged { RefCountManaged() = default; - RefCountManaged(const std::shared_ptr &keep_alive) { init(keep_alive); } - void init(const std::shared_ptr &keep_alive) { - keepAlive = keep_alive; + RefCountManaged(const std::shared_ptr &vkeepAlive) { init(vkeepAlive); } + void init(const std::shared_ptr &vkeepAlive) { + keepAlive = vkeepAlive; refCount.store(1); } @@ -62,31 +62,31 @@ struct RefCountManaged { /** * The proxy for the constant cache of Graph API. It holds a shared ptr pointing - * to the cache item in the cache manager (keep_alive) to extend the lifetime by + * to the cache item in the cache manager (keepAlive) to extend the lifetime by * refcount, @see RefCountManaged. To access the memory buffer of the const * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy * to make sure the cache is alive after calling acauire and before release. The * cache manager of Graph API may evict the cache item by dereferenceing this * RefCountManaged object. {acquire,release} functions will find out that the * cache has been invalidated. Usually we expect JIT modules to hold shared ptr - * to ConstCacheProxy via CachedGraphTensor. If is_lazy_ == true, the cache + * to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache * item's lifetime will be managed by the cache manager of Graph API and it is * filled with data after the first execution of the computation. Otherwise, the * cache item is always alive as long as the jit_module of the kernel is alive. */ struct ConstCacheProxy : RefCountManaged { - ConstCacheProxy(const std::shared_ptr &keep_alive, void *buffer, + ConstCacheProxy(const std::shared_ptr &vkeepAlive, void *buffer, size_t size, bool is_lazy) - : RefCountManaged(keep_alive), size_(size), is_lazy_(is_lazy), - buffer_(buffer) {} + : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), + buffer(buffer) {} ~ConstCacheProxy(); // get the buffer and increment the refcount. If the buffer is evicted, // returns null void *acquire(int32_t *inited) { if (checkAliveAndRef()) { - *inited = *inited && initialized_; - return buffer_; + *inited = *inited && initialized; + return buffer; } return nullptr; } @@ -94,7 +94,7 @@ struct ConstCacheProxy : RefCountManaged { bool release() { if (isAlive()) { deref(); - initialized_ = 1; + initialized = 1; return true; } return false; @@ -102,18 +102,18 @@ struct ConstCacheProxy : RefCountManaged { // return the buffer. Do not directly use the buffer because it may be already // release! To access the buffer, always acquire() before using it. - void *getBufferUnsafe() const { return buffer_; } + void *getBufferUnsafe() const { return buffer; } - size_t size_; + size_t size; // if the buffer is lazy-initialized. If false, it should be filled before // computation - bool is_lazy_; + bool isLazy; private: // raw pointer to the buffer - void *buffer_; + void *buffer; // if the buffer has been initialized. calling release() will set this to 1 - int32_t initialized_ = 0; + int32_t initialized = 0; }; struct CachedGraphTensor { diff --git a/include/gc/ExecutionEngine/Driver/Driver.hpp b/include/gc/ExecutionEngine/Driver/Driver.h similarity index 95% rename from include/gc/ExecutionEngine/Driver/Driver.hpp rename to include/gc/ExecutionEngine/Driver/Driver.h index 0a34514fa..1ca5aa9f4 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.hpp +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -1,4 +1,4 @@ -//===-- Driver.hpp - The top-level MLIR compiler driver ---------*- C++ -*-===// +//===-- Driver.h - The top-level MLIR compiler driver -----------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,7 +9,7 @@ #ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H #define GC_EXECUTIONENGINE_DRIVER_DRIVER_H -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index 245f2ca89..ea1c10364 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include #include diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 0b3c19113..ca1bda930 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/Driver/Driver.hpp" +#include "gc/ExecutionEngine/Driver/Driver.h" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" @@ -201,7 +201,7 @@ void JitModule::call(GeneralMemrefPtr *args) { for (auto b : cacheBases) { auto ptr = b->acquire(&inited); if (unlikely(!ptr)) { - ptr = std::aligned_alloc(/*alignment*/ 64, b->size_); + ptr = std::aligned_alloc(/*alignment*/ 64, b->size); inited = 0; } foldBasePtr.push_back((char *)ptr); diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index d902541bd..71a73bf8b 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/Driver/Driver.hpp" +#include "gc/ExecutionEngine/Driver/Driver.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/IR/AsmState.h" From 9d04cd22b3f9218ed06fb38448f834b48df3401c Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 29 May 2024 11:44:11 +0800 Subject: [PATCH 27/32] format --- lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp | 6 ++---- lib/gc/ExecutionEngine/Driver/Driver.cpp | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index ea1c10364..ff45cd180 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -14,7 +14,6 @@ namespace mlir::gc { ConstCacheProxy::~ConstCacheProxy() = default; - CachedGraphTensor::CachedGraphTensor( const std::shared_ptr &base, size_t offset) : base{base}, offset{offset} { @@ -36,9 +35,8 @@ std::shared_ptr queryCacheTensor(uint64_t key) { return nullptr; } -bool regCachedTensor(uint64_t key, - const std::shared_ptr &base, - size_t offset) { +bool regCachedTensor(uint64_t key, const std::shared_ptr &base, + size_t offset) { if (queryCacheTensor(key)) { return false; } diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index ca1bda930..4b9b6c42b 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -151,8 +151,8 @@ JitModule::JitModule( llvm::ArrayRef foldArgs, std::vector> &&cachekeepAlive) : engine{std::move(engine)}, compute{compute}, fold{fold}, - numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, computeArgs{computeArgs}, - keepAlive{std::move(cachekeepAlive)} { + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, + computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { for (const auto &cache : keepAlive) { auto currentItr = std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); From 206c3f3df11b10a8384ff98d15feaf129a127c4e Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 30 May 2024 11:24:14 +0800 Subject: [PATCH 28/32] cleanup --- .../CPURuntime/ConstantCache.h | 137 ------------- include/gc/ExecutionEngine/Driver/Driver.h | 51 ++--- .../ExecutionEngine/CPURuntime/CMakeLists.txt | 1 - .../CPURuntime/ConstantCache.cpp | 46 ----- lib/gc/ExecutionEngine/Driver/Driver.cpp | 180 +----------------- unittests/ExecutionEngine/JitWrapper.cpp | 107 +---------- 6 files changed, 24 insertions(+), 498 deletions(-) delete mode 100644 include/gc/ExecutionEngine/CPURuntime/ConstantCache.h delete mode 100644 lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h deleted file mode 100644 index f41cb09e8..000000000 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h +++ /dev/null @@ -1,137 +0,0 @@ -//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H -#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H -#include "mlir/ExecutionEngine/CRunnerUtils.h" -#include -#include -#include - -namespace mlir { -namespace gc { -/** - * The helper class to manage ref count manually for an object allocated with - * shared ptr. It holds an additional shared ptr reference to the object and - * contains an additional self-managed refcount. The refcount will be set to 1 - * when the object is initialized (see init()). When the refcount counts down to - * 0, the additional shared ptr is reset. - */ -struct RefCountManaged { - RefCountManaged() = default; - RefCountManaged(const std::shared_ptr &vkeepAlive) { init(vkeepAlive); } - void init(const std::shared_ptr &vkeepAlive) { - keepAlive = vkeepAlive; - refCount.store(1); - } - - void ref() { ++refCount; } - void deref() { - auto newv = --refCount; - if (newv == 0) { - keepAlive = nullptr; - } - } - - // atomically check if refCount > 0. if so, ref() the object and return - // true. Otherwise (if refCount==0), return false - bool checkAliveAndRef() { - auto oldv = refCount.load(); - for (;;) { - if (oldv <= 0) { - return false; - } - if (refCount.compare_exchange_strong(oldv, oldv + 1)) { - return true; - } - // CAS failed, oldv has now the newest known value of refCount - } - } - - bool isAlive() const { return refCount > 0; } - void *getPtrUnsafe() const { return keepAlive.get(); } - -private: - std::shared_ptr keepAlive; - std::atomic refCount{0}; -}; - -/** - * The proxy for the constant cache of Graph API. It holds a shared ptr pointing - * to the cache item in the cache manager (keepAlive) to extend the lifetime by - * refcount, @see RefCountManaged. To access the memory buffer of the const - * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy - * to make sure the cache is alive after calling acauire and before release. The - * cache manager of Graph API may evict the cache item by dereferenceing this - * RefCountManaged object. {acquire,release} functions will find out that the - * cache has been invalidated. Usually we expect JIT modules to hold shared ptr - * to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache - * item's lifetime will be managed by the cache manager of Graph API and it is - * filled with data after the first execution of the computation. Otherwise, the - * cache item is always alive as long as the jit_module of the kernel is alive. - */ -struct ConstCacheProxy : RefCountManaged { - ConstCacheProxy(const std::shared_ptr &vkeepAlive, void *buffer, - size_t size, bool is_lazy) - : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), - buffer(buffer) {} - ~ConstCacheProxy(); - - // get the buffer and increment the refcount. If the buffer is evicted, - // returns null - void *acquire(int32_t *inited) { - if (checkAliveAndRef()) { - *inited = *inited && initialized; - return buffer; - } - return nullptr; - } - // decrement the refcount - bool release() { - if (isAlive()) { - deref(); - initialized = 1; - return true; - } - return false; - } - - // return the buffer. Do not directly use the buffer because it may be already - // release! To access the buffer, always acquire() before using it. - void *getBufferUnsafe() const { return buffer; } - - size_t size; - // if the buffer is lazy-initialized. If false, it should be filled before - // computation - bool isLazy; - -private: - // raw pointer to the buffer - void *buffer; - // if the buffer has been initialized. calling release() will set this to 1 - int32_t initialized = 0; -}; - -struct CachedGraphTensor { - std::shared_ptr base; - size_t offset; - CachedGraphTensor(const std::shared_ptr &base, - size_t offset); - friend class JitModule; - -private: - StridedMemRefType ref; -}; - -std::shared_ptr queryCacheTensor(uint64_t key); -bool regCachedTensor(uint64_t key, const std::shared_ptr &base, - size_t offset); - -} // namespace gc -} // namespace mlir - -#endif \ No newline at end of file diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 1ca5aa9f4..694cbbe5e 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -9,7 +9,6 @@ #ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H #define GC_EXECUTIONENGINE_DRIVER_DRIVER_H -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include @@ -25,7 +24,7 @@ const DialectRegistry &initAndGetDialects(); using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); -class JitModule : public std::enable_shared_from_this { +class JitModule { public: static llvm::Expected> create(Operation *op, const ExecutionEngineOptions &options = {}, @@ -33,41 +32,29 @@ class JitModule : public std::enable_shared_from_this { bool transform = true); // args should be an array of XXXMemrefType* - void call(GeneralMemrefPtr *args); - - JitModule( - std::unique_ptr engine, JitModuleFuncT compute, - JitModuleFuncT fold, size_t numOrigArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive = {}); + void call(GeneralMemrefPtr *args, std::size_t numArgs) { + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numArgs); + for (size_t i = 0; i < numArgs; i++) { + realargs.push_back(&args[i]); + } + compute(realargs.data()); + } + + // directly call compute(). args should be an array of void*. args[i] should + // be a pointer to the real data. For passing memref, users need to 1) create + // a pointer to XXXMemrefType 2) store the pointer to pointer to XXXMemrefType + // in args[i] + void callRaw(void **args) { compute(args); } + + JitModule(std::unique_ptr engine, JitModuleFuncT compute); ~JitModule(); private: std::unique_ptr engine; JitModuleFuncT compute; - JitModuleFuncT fold; - size_t numOrigArgs; - // `keepAlive` has the ownership of the objects pointed by this vector - llvm::SmallVector cacheBases; - struct CacheBufferInfo { - // index in cacheBases - size_t baseIdx; - size_t offset; - }; - // the info for each folded cached buffer - llvm::SmallVector cacheInfo; - // holding the pointers to StridedMemRefType of folded cache - // `keepAlive` holds the the ownership of the pointers - llvm::SmallVector fastFoldBuffers; - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs; - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs; - - std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 97f039834..6be58e28f 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -10,7 +10,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") add_mlir_library(GCCpuRuntime SHARED Parallel.cpp - ConstantCache.cpp EXCLUDE_FROM_LIBMLIR ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp deleted file mode 100644 index ff45cd180..000000000 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ /dev/null @@ -1,46 +0,0 @@ -//===-- ConstantCache.cpp - Constant cache ----------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" -#include -#include - -namespace mlir::gc { - -ConstCacheProxy::~ConstCacheProxy() = default; - -CachedGraphTensor::CachedGraphTensor( - const std::shared_ptr &base, size_t offset) - : base{base}, offset{offset} { - // todo: fill in real values - ref.basePtr = (char *)base->getBufferUnsafe() + offset; - ref.data = ref.basePtr; - ref.offset = 0; - memset(ref.sizes, 0, sizeof(ref.sizes)); - memset(ref.strides, 0, sizeof(ref.strides)); -} - -static std::unordered_map> cache; - -std::shared_ptr queryCacheTensor(uint64_t key) { - auto itr = cache.find(key); - if (itr != cache.end()) { - return itr->second; - } - return nullptr; -} - -bool regCachedTensor(uint64_t key, const std::shared_ptr &base, - size_t offset) { - if (queryCacheTensor(key)) { - return false; - } - cache[key] = std::make_shared(base, offset); - return true; -} -} // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 4b9b6c42b..4f5c33867 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -18,9 +18,6 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" -#define likely(x) __builtin_expect(!!(x), 1) -#define unlikely(x) __builtin_expect(!!(x), 0) - namespace mlir { namespace gc { @@ -46,7 +43,7 @@ const DialectRegistry &initAndGetDialects() { } static const char defaultComputeName[] = "_mlir_ciface_compute"; -static const char defaultFoldName[] = "_mlir_ciface_fold"; + llvm::Expected> JitModule::create(Operation *op, const ExecutionEngineOptions &options, std::unique_ptr tm, bool transform) { @@ -63,14 +60,6 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, return exec.takeError(); } auto &engine = *exec; - uint32_t numOrigArgs; - { - auto expectArgs = engine->lookup("__num_orig_num_args"); - if (!expectArgs) { - return expectArgs.takeError(); - } - numOrigArgs = *reinterpret_cast(*expectArgs); - } JitModuleFuncT compute; { auto expectCompute = engine->lookupPacked(defaultComputeName); @@ -79,176 +68,15 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, } compute = *expectCompute; } - llvm::ArrayRef foldBufferIds; - JitModuleFuncT fold = nullptr; - llvm::ArrayRef computeArgs; - llvm::ArrayRef foldArgs; - do { - auto expectBufferIds = engine->lookup("__fold_buffer_ids"); - if (!expectBufferIds) { - // nothing to fold, It is OK. - llvm::consumeError(expectBufferIds.takeError()); - // break out of the scope, don't need to lookup "fold" function - break; - } else { - auto raw = reinterpret_cast(*expectBufferIds); - foldBufferIds = llvm::ArrayRef{raw + 1, raw[0]}; - } - - // find "fold" func - { - auto expectFold = engine->lookupPacked(defaultFoldName); - if (!expectFold) { - return expectFold.takeError(); - } - fold = *expectFold; - } - - // find "foldArgs" - { - auto expectFold = engine->lookup("__fold_args"); - if (!expectFold) { - return expectFold.takeError(); - } - auto raw = reinterpret_cast(*expectFold); - foldArgs = llvm::ArrayRef{raw + 1, raw[0]}; - } - - // find "computeArgs" - { - auto expect = engine->lookup("__compute_args"); - if (!expect) { - return expect.takeError(); - } - auto raw = reinterpret_cast(*expect); - computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; - } - } while (false); - - std::vector> foldInfo; - foldInfo.reserve(foldBufferIds.size()); - for (auto bufId : foldBufferIds) { - auto ret = queryCacheTensor(bufId); - if (!ret) { - return llvm::make_error( - "Failed to query the folded cached tensor", - llvm::inconvertibleErrorCode()); - } - foldInfo.emplace_back(std::move(ret)); - } - - return std::make_shared(std::move(engine), compute, fold, - numOrigArgs, computeArgs, foldArgs, - std::move(foldInfo)); + return std::make_shared(std::move(engine), compute); } JitModule::JitModule( - std::unique_ptr engine, JitModuleFuncT compute, - JitModuleFuncT fold, size_t numOrigArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive) - : engine{std::move(engine)}, compute{compute}, fold{fold}, - numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, - computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { - for (const auto &cache : keepAlive) { - auto currentItr = - std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); - if (currentItr == cacheBases.end()) { - cacheBases.push_back(cache->base.get()); - currentItr = cacheBases.end() - 1; - } - cacheInfo.emplace_back(CacheBufferInfo{ - static_cast(currentItr - cacheBases.begin()), cache->offset}); - fastFoldBuffers.push_back(&cache->ref); - } + std::unique_ptr engine, JitModuleFuncT compute) + : engine{std::move(engine)}, compute{compute} { } JitModule::~JitModule() = default; -static void prepareCallArgs(llvm::SmallVector &realargs, - GeneralMemrefPtr *origargs, size_t numOrigArgs, - GeneralMemrefPtr *foldedCache, - llvm::ArrayRef realArgIdx) { - realargs.reserve(realArgIdx.size()); - for (auto argIdx : realArgIdx) { - if (argIdx < numOrigArgs) { - realargs.push_back(&origargs[argIdx]); - } else { - realargs.push_back(&foldedCache[argIdx - numOrigArgs]); - } - } -} - -void JitModule::call(GeneralMemrefPtr *args) { - if (unlikely(cacheInfo.empty())) { - // fast path, no folded cached buffers - // Silly code, MLIR execution engine requires pointers of real args as - // inputs - llvm::SmallVector realargs; - realargs.reserve(numOrigArgs); - for (size_t i = 0; i < numOrigArgs; i++) { - realargs.push_back(&args[i]); - } - compute(realargs.data()); - return; - } - - // stage 1, acquire the foldBasePtr - llvm::SmallVector foldBasePtr; - int32_t inited = 1; - for (auto b : cacheBases) { - auto ptr = b->acquire(&inited); - if (unlikely(!ptr)) { - ptr = std::aligned_alloc(/*alignment*/ 64, b->size); - inited = 0; - } - foldBasePtr.push_back((char *)ptr); - } - - // stage 2, run fold() if necessary - GeneralMemrefPtr *foldedCache; - // only used when !inited - std::vector slowFold; - std::vector> slowFoldObj; - if (likely(inited)) { - foldedCache = fastFoldBuffers.data(); - } else { - slowFold.reserve(cacheInfo.size()); - slowFoldObj.reserve(cacheInfo.size()); - for (auto &info : cacheInfo) { - slowFoldObj.emplace_back(); - auto &obj = slowFoldObj.back(); - obj.basePtr = foldBasePtr[info.baseIdx] + info.offset; - obj.data = obj.basePtr; - memset(obj.sizes, 0, sizeof(obj.sizes)); - memset(obj.strides, 0, sizeof(obj.strides)); - slowFold.push_back(&obj); - } - foldedCache = slowFold.data(); - llvm::SmallVector realargs; - prepareCallArgs(realargs, args, numOrigArgs, foldedCache, foldArgs); - fold(realargs.data()); - } - - // stage 3, call compute - { - llvm::SmallVector realargs; - prepareCallArgs(realargs, args, numOrigArgs, foldedCache, computeArgs); - compute(realargs.data()); - } - - // stage 4, cleanup - for (size_t i = 0; i < cacheBases.size(); i++) { - auto b = cacheBases[i]; - if (unlikely(!b->release())) { - // if the cached buffer is already free'd, foldBasePtr[i] is allocated via - // std::aligned_alloc by us, free it - std::free(foldBasePtr[i]); - } - } -} } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 71a73bf8b..84748bb81 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -63,113 +63,8 @@ TEST(ExecutionEngine, JitWrapper) { {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; OwningMemRef bufC{{128}, {128}}; void *args[] = {&*bufA, &*bufB, &*bufC}; - jited.get()->call(args); + jited.get()->call(args, 3); for (int i = 0; i < 128; i++) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } - -// compute d = (a+a) + (b+b) + c, where a,b is marked constant -// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 -static const char code2[] = R"mlir( -module { -llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 -llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> -// a,b, foldedA,foldedB -llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> -// foldedA, foldedB, c, d -llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> - -func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { - %c0 = arith.constant 0 : index - cpuruntime.printf "HI%zu\n" %c0 : index - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - %out2 = tensor.empty() : tensor<128xf32> - %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> - return %2, %3 : tensor<128xf32>, tensor<128xf32> -} - -func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> - return %d : tensor<128xf32> -} -} -)mlir"; - -TEST(ExecutionEngine, JitWrapperCached) { - MLIRContext ctx{gc::initAndGetDialects()}; - std::unique_ptr ir_buffer = - llvm::MemoryBuffer::getMemBuffer(code2); - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); - mlir::OwningOpRef module = - mlir::parseSourceFile(sourceMgr, &ctx); - - // foldedA and foldedB uses this buffer - auto ret = std::shared_ptr(new float[128 * 2]); - auto proxy = std::make_shared( - ret, ret.get(), 128 * 2 * sizeof(float), true); - - ASSERT_TRUE(gc::regCachedTensor(114514, proxy, 0)); - ASSERT_TRUE(gc::regCachedTensor(1919810, proxy, 128 * sizeof(float))); - - ASSERT_TRUE(module); - auto jited = gc::JitModule::create(module.get()); - bool jit_success = static_cast(jited); - if (!jit_success) { - auto err = jited.takeError(); - llvm::errs() << err; - llvm::consumeError(std::move(err)); - } - ASSERT_TRUE(jit_success); - OwningMemRef bufA{ - {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; - OwningMemRef bufB{ - {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; - OwningMemRef bufC{ - {128}, {128}, [](float &ptr, ArrayRef idx) { - ptr = -idx[0] * 3; - }}; - OwningMemRef bufD{{128}, {128}}; - void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; - - // first call, should run fold() - { - testing::internal::CaptureStdout(); - // first call, should run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_EQ(output, "HI0\n"); - } - - { - testing::internal::CaptureStdout(); - // second call, should not run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_TRUE(output.empty()); - } - - // the cache is evicted - proxy->deref(); - { - testing::internal::CaptureStdout(); - // third call, should run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_EQ(output, "HI0\n"); - } -} From bc9a7ad97751996e73092045ec49395c410b53aa Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 30 May 2024 11:36:55 +0800 Subject: [PATCH 29/32] refine options --- include/gc/ExecutionEngine/Driver/Driver.h | 14 ++++++++++---- lib/gc/ExecutionEngine/Driver/Driver.cpp | 12 +++++++----- unittests/ExecutionEngine/JitWrapper.cpp | 2 +- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 694cbbe5e..2ce9531bd 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -18,18 +18,24 @@ namespace mlir { class DialectRegistry; namespace gc { -const DialectRegistry &initAndGetDialects(); +const DialectRegistry &initCompilerAndGetDialects(); // the pointers to XXXMemRefType using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); +struct DriverOptions { + // the optimization level for the LLVM-JIT + llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; + // whether to run the MLIR transformation passes + bool runTransforms = true; + // todo: target machine, etc. +}; + class JitModule { public: static llvm::Expected> - create(Operation *op, const ExecutionEngineOptions &options = {}, - std::unique_ptr tm = nullptr, - bool transform = true); + create(Operation *op, const DriverOptions &options = {}); // args should be an array of XXXMemrefType* void call(GeneralMemrefPtr *args, std::size_t numArgs) { diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 4f5c33867..6fc8025be 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -37,7 +37,7 @@ static DialectRegistry initDialects() { return registry; } -const DialectRegistry &initAndGetDialects() { +const DialectRegistry &initCompilerAndGetDialects() { static DialectRegistry reg = initDialects(); return reg; } @@ -45,9 +45,8 @@ const DialectRegistry &initAndGetDialects() { static const char defaultComputeName[] = "_mlir_ciface_compute"; llvm::Expected> -JitModule::create(Operation *op, const ExecutionEngineOptions &options, - std::unique_ptr tm, bool transform) { - if (transform) { +JitModule::create(Operation *op, const DriverOptions &options) { + if (options.runTransforms) { mlir::PassManager pm{op->getContext()}; populateCPUPipeline(pm); if (auto result = pm.run(op); failed(result)) { @@ -55,7 +54,10 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, "MLIR pass error", llvm::inconvertibleErrorCode()); } } - auto exec = ExecutionEngine::create(op, options, std::move(tm)); + ExecutionEngineOptions exeOptions; + exeOptions.jitCodeGenOptLevel = options.jitCodeGenOptLevel; + std::unique_ptr tm = nullptr; + auto exec = ExecutionEngine::create(op, exeOptions, std::move(tm)); if (!exec) { return exec.takeError(); } diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 84748bb81..f7b93eaa6 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -40,7 +40,7 @@ extern int gc_runtime_keep_alive; TEST(ExecutionEngine, JitWrapper) { gc_runtime_keep_alive = 0; - MLIRContext ctx{gc::initAndGetDialects()}; + MLIRContext ctx{gc::initCompilerAndGetDialects()}; std::unique_ptr ir_buffer = llvm::MemoryBuffer::getMemBuffer(code1); // Parse the input mlir. From bc5c9de317f78457a622c59384da4dcf7c8a249e Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 30 May 2024 11:45:34 +0800 Subject: [PATCH 30/32] fmt --- lib/gc/ExecutionEngine/Driver/Driver.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 6fc8025be..0f2361924 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -73,12 +73,10 @@ JitModule::create(Operation *op, const DriverOptions &options) { return std::make_shared(std::move(engine), compute); } -JitModule::JitModule( - std::unique_ptr engine, JitModuleFuncT compute) - : engine{std::move(engine)}, compute{compute} { -} +JitModule::JitModule(std::unique_ptr engine, + JitModuleFuncT compute) + : engine{std::move(engine)}, compute{compute} {} JitModule::~JitModule() = default; - } // namespace gc } // namespace mlir \ No newline at end of file From 529c403dcd21a374d23f6baeaa83343fb9736084 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 11 Jun 2024 10:51:05 +0800 Subject: [PATCH 31/32] rebase --- {unittests => test/mlir/unittests}/ExecutionEngine/CMakeLists.txt | 0 {unittests => test/mlir/unittests}/ExecutionEngine/JitWrapper.cpp | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {unittests => test/mlir/unittests}/ExecutionEngine/CMakeLists.txt (100%) rename {unittests => test/mlir/unittests}/ExecutionEngine/JitWrapper.cpp (100%) diff --git a/unittests/ExecutionEngine/CMakeLists.txt b/test/mlir/unittests/ExecutionEngine/CMakeLists.txt similarity index 100% rename from unittests/ExecutionEngine/CMakeLists.txt rename to test/mlir/unittests/ExecutionEngine/CMakeLists.txt diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp similarity index 100% rename from unittests/ExecutionEngine/JitWrapper.cpp rename to test/mlir/unittests/ExecutionEngine/JitWrapper.cpp From 58d6639e5798f1ad91b269448a902d3014b527db Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 11 Jun 2024 11:39:08 +0800 Subject: [PATCH 32/32] fix comments --- include/gc/ExecutionEngine/Driver/Driver.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 2ce9531bd..ee8630b53 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -25,11 +25,11 @@ using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); struct DriverOptions { - // the optimization level for the LLVM-JIT + /// the optimization level for the LLVM-JIT llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; - // whether to run the MLIR transformation passes + /// whether to run the MLIR transformation passes bool runTransforms = true; - // todo: target machine, etc. + /// todo: target machine, etc. }; class JitModule { @@ -37,7 +37,7 @@ class JitModule { static llvm::Expected> create(Operation *op, const DriverOptions &options = {}); - // args should be an array of XXXMemrefType* + /// args should be an array of XXXMemrefType* void call(GeneralMemrefPtr *args, std::size_t numArgs) { // Silly code, MLIR execution engine requires pointers of real args as // inputs @@ -49,10 +49,10 @@ class JitModule { compute(realargs.data()); } - // directly call compute(). args should be an array of void*. args[i] should - // be a pointer to the real data. For passing memref, users need to 1) create - // a pointer to XXXMemrefType 2) store the pointer to pointer to XXXMemrefType - // in args[i] + /// directly call compute(). args should be an array of void*. args[i] should + /// be a pointer to the real data. For passing memref, users need to 1) create + /// a pointer to XXXMemrefType 2) store the pointer to pointer to + /// XXXMemrefType in args[i] void callRaw(void **args) { compute(args); } JitModule(std::unique_ptr engine, JitModuleFuncT compute);