From 13faa3333b395e8bea8ae45cee32005ea90b2392 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 15:07:57 +0800 Subject: [PATCH 01/64] add cpuruntime dialect --- include/gc/Dialect/CMakeLists.txt | 1 + include/gc/Dialect/CPURuntime/CMakeLists.txt | 2 + .../gc/Dialect/CPURuntime/IR/CMakeLists.txt | 1 + .../Dialect/CPURuntime/IR/CPURuntimeDialect.h | 18 ++ .../CPURuntime/IR/CPURuntimeDialect.td | 34 ++++ .../gc/Dialect/CPURuntime/IR/CPURuntimeOps.h | 26 +++ .../gc/Dialect/CPURuntime/IR/CPURuntimeOps.td | 72 ++++++++ .../CPURuntime/Transforms/CMakeLists.txt | 5 + .../CPURuntime/Transforms/CPURuntimePasses.h | 29 ++++ .../CPURuntime/Transforms/CPURuntimePasses.td | 57 +++++++ lib/gc/Dialect/CMakeLists.txt | 1 + lib/gc/Dialect/CPURuntime/CMakeLists.txt | 2 + lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt | 16 ++ .../CPURuntime/IR/CPURuntimeDialect.cpp | 26 +++ .../Dialect/CPURuntime/IR/CPURuntimeOps.cpp | 56 ++++++ .../CPURuntime/Transforms/CMakeLists.txt | 16 ++ .../Transforms/CPURuntimePasses.cpp | 77 +++++++++ .../Transforms/CPURuntimeToLLVM.cpp | 159 ++++++++++++++++++ lib/gc/Transforms/CMakeLists.txt | 2 + src/CMakeLists.txt | 1 + src/gc-opt/CMakeLists.txt | 2 +- src/gc-opt/gc-opt.cpp | 5 +- .../Dialect/CPURuntime/cpu-runner/printf.mlir | 17 ++ .../CPURuntime/cpuruntime-atexit-to-omp.mlir | 41 +++++ .../CPURuntime/cpuruntime-to-llvm.mlir | 19 +++ 25 files changed, 683 insertions(+), 2 deletions(-) create mode 100644 include/gc/Dialect/CPURuntime/CMakeLists.txt create mode 100644 include/gc/Dialect/CPURuntime/IR/CMakeLists.txt create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h create mode 100644 include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td create mode 100644 include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt create mode 100644 include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h create mode 100644 include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td create mode 100644 lib/gc/Dialect/CPURuntime/CMakeLists.txt create mode 100644 lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt create mode 100644 lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp create mode 100644 lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp create mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt create mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp create mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp create mode 100644 test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir create mode 100644 test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir create mode 100644 test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir diff --git a/include/gc/Dialect/CMakeLists.txt b/include/gc/Dialect/CMakeLists.txt index ffeda0aa7..a23f3f9f1 100644 --- a/include/gc/Dialect/CMakeLists.txt +++ b/include/gc/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(CPURuntime) add_subdirectory(OnednnGraph) add_subdirectory(Microkernel) add_subdirectory(Linalgx) \ No newline at end of file diff --git a/include/gc/Dialect/CPURuntime/CMakeLists.txt b/include/gc/Dialect/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/include/gc/Dialect/CPURuntime/IR/CMakeLists.txt new file mode 100644 index 000000000..fb73ae02b --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(CPURuntimeOps cpuruntime) diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h new file mode 100644 index 000000000..757182964 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h @@ -0,0 +1,18 @@ +//===- CPURuntimeDialect.h - CPU Runtime dialect ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_CPURUNTIMEDIALECT_H +#define CPURUNTIME_CPURUNTIMEDIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOpsDialect.h.inc" + +#endif // CPURUNTIME_CPURUNTIMEDIALECT_H diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td new file mode 100644 index 000000000..06f3af526 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td @@ -0,0 +1,34 @@ +//===- CPURuntimeDialect.td - CPU Runtime Dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPUPARALLEL_DIALECT +#define CPUPARALLEL_DIALECT + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// CPURuntime dialect definition. +//===----------------------------------------------------------------------===// + +def CPURuntime_Dialect : Dialect { + let name = "cpuruntime"; + let summary = "A dialect for CPU parallel primitives."; + let description = [{ + This dialect contains primitives for CPU runtime. + }]; + let cppNamespace = "::mlir::cpuruntime"; +} + +//===----------------------------------------------------------------------===// +// Base cpuruntime operation definition. +//===----------------------------------------------------------------------===// + +class CPURuntime_Op traits = []> : + Op; + +#endif // CPUPARALLEL_DIALECT diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h new file mode 100644 index 000000000..5ce667a91 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.h @@ -0,0 +1,26 @@ +//===- CPURuntimeOps.h - CPU Runtime Ops ====--------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_CPURUNTIMEOPS_H +#define CPURUNTIME_CPURUNTIMEOPS_H + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h.inc" + +#endif // CPURUNTIME_CPURUNTIMEOPS_H diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td new file mode 100644 index 000000000..cc1b7c555 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td @@ -0,0 +1,72 @@ +//===- CPURuntimeOps.td - CPU Runtime Ops -----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_OPS +#define CPURUNTIME_OPS + +include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" + + +def CPURuntime_AtParallelExitOp : CPURuntime_Op<"at_parallel_exit", [ + ParentOneOf<["scf::ForallOp", "scf::ParallelOp", "omp::WsloopOp", "memref::AllocaScopeOp"]>, + SingleBlockImplicitTerminator<"ParallelExitReturnOp"> + ]> { + let summary = "Runs the block once in all threads at the exit of the parallel section"; + let description = [{ + It executes the block for each thread working in the parallel section for + once, at the exit of parallel section. + }]; + + let regions = (region SizedRegion<1>:$region); + + let hasCustomAssemblyFormat = 1; + + // The default builder does not add a region with an empty body, add our own. + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins)>, + ]; +} + +def CPURuntime_ParallelExitReturnOp : CPURuntime_Op<"parallel_exit.return", [ + Pure, + HasParent<"AtParallelExitOp">, + Terminator, ReturnLike + ]> { + let summary = "Terminates at_parallel_exit block"; + let description = [{ + at_parallel_exit should ends with parallel_exit.return + }]; + let assemblyFormat = + [{ attr-dict }]; +} + + +def CPURuntime_PrintfOp : CPURuntime_Op<"printf", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$format, + Variadic>:$args)> { + let summary = "C-style printf"; + let description = [{ + `cpuruntime.printf` takes a literal format string `format` and an arbitrary number of + scalar arguments that should be printed. + + The format string is a C-style printf string, subject to any restrictions + imposed by one's target platform. + }]; + let assemblyFormat = [{ + $format attr-dict ($args^ `:` type($args))? + }]; +} + + +#endif // CPURUNTIME_OPS diff --git a/include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt new file mode 100644 index 000000000..763ffab86 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS CPURuntimePasses.td) +mlir_tablegen(CPURuntimePasses.h.inc --gen-pass-decls -name CPURuntime) +mlir_tablegen(CPURuntimePasses.capi.h.inc -gen-pass-capi-header --prefix CPURuntime) +mlir_tablegen(CPURuntimePasses.capi.cpp.inc -gen-pass-capi-impl --prefix CPURuntime) +add_public_tablegen_target(MLIRCPURuntimePassesIncGen) diff --git a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h new file mode 100644 index 000000000..8fde8f4fd --- /dev/null +++ b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h @@ -0,0 +1,29 @@ +//===- CPURuntimePasses.h - CPU Runtime Passes ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_CPURUNTIMEPASSES_H +#define CPURUNTIME_CPURUNTIMEPASSES_H + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h" +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace cpuruntime { +void registerConvertCPURuntimeToLLVMInterface(DialectRegistry ®istry); + +#define GEN_PASS_DECL +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" +} // namespace cpuruntime +} // namespace mlir + +#endif diff --git a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td new file mode 100644 index 000000000..0685ce498 --- /dev/null +++ b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td @@ -0,0 +1,57 @@ +//===- CPURuntimePasses.td - CPU Runtime Passes -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CPURUNTIME_PASS +#define CPURUNTIME_PASS + +include "mlir/Pass/PassBase.td" + + +def CPURuntimeAtExitToOmp: Pass<"cpuruntime-atexit-to-omp", "::mlir::func::FuncOp"> { + let summary = "Lower at_parallel_exit to code in omp.parallel section"; + let description = [{ + Switches the name of a FuncOp named `bar` to `foo` and folds. + ``` + omp.parallel { + omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + cpuruntime.at_parallel_exit { + "your.op"() + cpuruntime.parallel_exit.return + } + } + omp.yield + } + omp.terminator + } + ``` + Will be changed into + ``` + omp.parallel { + omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + } + omp.yield + } + "your.op"() + omp.terminator + } + ``` + }]; +} + + +def CPURuntimeToLLVM: Pass<"convert-cpuruntime-to-llvm"> { + let summary = "Convert cpuruntime to LLVM dialect"; + let description = [{ + This pass converts supported cpuruntime ops to LLVM dialect instructions. + }]; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + +#endif // CPURUNTIME_PASS diff --git a/lib/gc/Dialect/CMakeLists.txt b/lib/gc/Dialect/CMakeLists.txt index a880ff2ed..8720bd8e6 100644 --- a/lib/gc/Dialect/CMakeLists.txt +++ b/lib/gc/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(CPURuntime) add_subdirectory(Linalgx) add_subdirectory(Microkernel) add_subdirectory(OnednnGraph) diff --git a/lib/gc/Dialect/CPURuntime/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt new file mode 100644 index 000000000..e349da72c --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRCPURuntimeDialect + CPURuntimeDialect.cpp + CPURuntimeOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + MLIRCPURuntimeOpsIncGen + MLIRCPURuntimePassesIncGen + + LINK_LIBS PUBLIC + MLIR + # MLIRInferTypeOpInterface + # MLIRFuncDialect + ) diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp new file mode 100644 index 000000000..9f3e97b57 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeDialect.cpp @@ -0,0 +1,26 @@ +//===- CPURuntimeDialect.cpp - CPU Runtime Dialect --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h" +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h" + +using namespace mlir; +using namespace mlir::cpuruntime; + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// CPURuntime dialect. +//===----------------------------------------------------------------------===// + +void CPURuntimeDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp.inc" + >(); +} diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp new file mode 100644 index 000000000..ca632e9db --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp @@ -0,0 +1,56 @@ +//===- CPURuntimeOps.cpp - CPU Runtime Ops ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.h" +#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h" + +#define GET_OP_CLASSES +#include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp.inc" + +#include + +namespace mlir { +using namespace bufferization; + +namespace cpuruntime { + +void AtParallelExitOp::build(OpBuilder &b, OperationState &result) { + OpBuilder::InsertionGuard g(b); + Region *bodyRegion = result.addRegion(); + b.createBlock(bodyRegion); +} + +void AtParallelExitOp::print(OpAsmPrinter &p) { + p << " "; + p.printRegion(getRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + p.printOptionalAttrDict(getOperation()->getAttrs()); +} + +ParseResult AtParallelExitOp::parse(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); + if (parser.parseRegion(*region, regionOperands)) + return failure(); + + if (region->empty()) + OpBuilder(builder.getContext()).createBlock(region.get()); + result.addRegion(std::move(region)); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +} // namespace cpuruntime +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt new file mode 100644 index 000000000..ee6148aa4 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRCPURuntimeTransforms + CPURuntimePasses.cpp + CPURuntimeToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + MLIRCPURuntimePassesIncGen + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRCPURuntimeDialect + ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS MLIRCPURuntimeTransforms) \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp new file mode 100644 index 000000000..f2a098fcf --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp @@ -0,0 +1,77 @@ +//===- CPURuntimePasses.cpp - CPU Runtime Passes ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" + +namespace mlir::cpuruntime { +#define GEN_PASS_DEF_CPURUNTIMEATEXITTOOMP +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" + +namespace { + +class CPURuntimeAtExitToOmpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtParallelExitOp op, + PatternRewriter &rewriter) const final { + auto parent = op->getParentOp(); + Operation *secondLast = nullptr; + while (parent && (llvm::isa(parent) || + llvm::isa(parent))) { + secondLast = parent; + parent = parent->getParentOp(); + } + auto parallel = llvm::dyn_cast(parent); + if (!parallel) { + return failure(); + } + assert(secondLast->getBlock()); + auto itr = secondLast->getBlock()->end(); + --itr; + rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), + secondLast->getBlock(), itr); + rewriter.eraseOp(op); + return success(); + } +}; + +class CPURuntimeExitReturnRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ParallelExitReturnOp op, + PatternRewriter &rewriter) const final { + rewriter.eraseOp(op); + return success(); + } +}; + +class CPURuntimeAtExitToOmp + : public impl::CPURuntimeAtExitToOmpBase { +public: + using impl::CPURuntimeAtExitToOmpBase< + CPURuntimeAtExitToOmp>::CPURuntimeAtExitToOmpBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::cpuruntime diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp new file mode 100644 index 000000000..73cf14a84 --- /dev/null +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp @@ -0,0 +1,159 @@ +//===- CPURuntimeToLLVM.cpp - CPU Runtime To LLVM ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" + +namespace mlir::cpuruntime { + +void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +#define GEN_PASS_DEF_CPURUNTIMETOLLVM +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" + +namespace { +static const char formatStringPrefix[] = "cpuprintfFormat_"; + +static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, + const Location loc, + ConversionPatternRewriter &rewriter, + StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.template lookupSymbol(name))) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + ret = rewriter.create(loc, name, type, + LLVM::Linkage::External); + } + return ret; +} + + +class PrintfRewriter : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(PrintfOp op, PrintfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto moduleOp = op->getParentOfType(); + auto loc = op->getLoc(); + mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); + mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); + mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); + mlir::Type i8Ptr = LLVM::LLVMPointerType::get(op.getContext()); + auto printfFunc = getOrDefineFunction( + moduleOp, loc, rewriter, "printf", + LLVM::LLVMFunctionType::get(llvmI32, {i8Ptr}, /*isVarArg*/ true)); + + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<20> formatString(adaptor.getFormat()); + formatString.push_back('\0'); // Null terminate for C + size_t formatStringSize = formatString.size_in_bytes(); + + auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize); + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString)); + } + Value globalPtr = rewriter.create( + loc, + LLVM::LLVMPointerType::get(rewriter.getContext(), + global.getAddrSpace()), + global.getSymNameAttr()); + Value stringStart = rewriter.create( + loc, i8Ptr, globalType, globalPtr, ArrayRef{0, 0}); + SmallVector appendFormatArgs = {stringStart}; + for (auto arg : adaptor.getArgs()) { + if (auto floatType = dyn_cast(arg.getType())) { + if (!floatType.isF64()) + arg = rewriter.create( + loc, typeConverter->convertType(rewriter.getF64Type()), arg); + } + if (arg.getType().getIntOrFloatBitWidth() != 64) + arg = rewriter.create(loc, llvmI64, arg); + appendFormatArgs.push_back(arg); + } + rewriter.create(loc, printfFunc, appendFormatArgs); + rewriter.eraseOp(op); + return success(); + } +}; + +class CPURuntimeToLLVM + : public impl::CPURuntimeToLLVMBase { +public: + using Base::Base; + void runOnOperation() final { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions options(&getContext()); + LLVMTypeConverter converter(&getContext(), options); + populateCPURuntimeToLLVMConversionPatterns(converter, patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +/// Implement the interface to convert MemRef to LLVM. +struct CPURuntimeToDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateCPURuntimeToLLVMConversionPatterns(typeConverter, patterns); + } +}; + +} // namespace + +void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter); +} + +void registerConvertCPURuntimeToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + dialect->addInterfaces(); + }); +} + +} // namespace mlir::cpuruntime diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index df8a14d01..e87106946 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -14,3 +14,5 @@ add_mlir_library(GCPasses MLIRBufferizationToMemRef MLIRBufferizationPipelines ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCPasses) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f6298c270..e71fe30a0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(gc_pass_libs GLOBAL PROPERTY GC_PASS_LIBS) add_subdirectory(dnnl) add_subdirectory(gc-cpu-runner) diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index ff33375de..0deb242a0 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -2,7 +2,7 @@ set(gc_opt_libs ${dialect_libs} ${conversion_libs} MLIROptLib - GCPasses) + ${gc_pass_libs}) if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") endif() diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 72a25abf5..e1996c050 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -18,6 +18,7 @@ */ #include "gc/Transforms/Passes.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -25,9 +26,11 @@ int main(int argc, char *argv[]) { mlir::registerAllPasses(); mlir::gc::registerGraphCompilerPasses(); - + mlir::cpuruntime::registerCPURuntimePasses(); mlir::DialectRegistry registry; + registry.insert(); mlir::registerAllDialects(registry); + mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir b/test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir new file mode 100644 index 000000000..e95471d50 --- /dev/null +++ b/test/gc/Dialect/CPURuntime/cpu-runner/printf.mlir @@ -0,0 +1,17 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm | gc-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | FileCheck %s + +module { + func.func @doprint(%t: f32, %t2: i32, %t3: i64) { + cpuruntime.printf "Hello world %f %d %lld\n" %t, %t2, %t3 : f32, i32, i64 + return + } + + func.func @main() { + %c2 = arith.constant 2.0 : f32 + %c32i = arith.constant 2000000 : i32 + %c64i = arith.constant 2000000 : i64 + call @doprint(%c2, %c32i, %c64i) : (f32, i32, i64) -> () + return + } + // CHECK: Hello world 2.000000 2000000 2000000 +} \ No newline at end of file diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir new file mode 100644 index 000000000..401de95cc --- /dev/null +++ b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir @@ -0,0 +1,41 @@ +// RUN: gc-opt %s --cpuruntime-atexit-to-omp | FileCheck %s + +module { + func.func @parallel_insert_slice(%arg0: memref<512x512xf32>) -> memref<512x512xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<512x512xf32> + %c512 = arith.constant 512 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + memref.copy %arg0, %alloc : memref<512x512xf32> to memref<512x512xf32> + %0 = llvm.mlir.constant(1 : i64) : i64 + omp.parallel { + omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> + %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.dealloc %alloc_0 : memref<512xf32> + cpuruntime.at_parallel_exit { + memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> + cpuruntime.parallel_exit.return + } + } + omp.yield + } + memref.prefetch %alloc[%c0,%c0], read, locality<3>, data : memref<512x512xf32> + omp.terminator + } + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK: omp.parallel + // CHECK-NEXT: omp.wsloop + // CHECK-NEXT: memref.alloca_scope + // CHECK-NOT: cpuruntime.at_parallel_exit + // CHECK: omp.yield + // CHECK: memref.prefetch {{%alloc}}[%[[C0]], %[[C0]]] + // CHECK-NEXT: memref.prefetch {{%alloc}}[%[[C1]], %[[C0]]] + // CHECK-NEXT: omp.terminator + return %alloc : memref<512x512xf32> + } +} diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir new file mode 100644 index 000000000..fb8d748f1 --- /dev/null +++ b/test/gc/Dialect/CPURuntime/cpuruntime-to-llvm.mlir @@ -0,0 +1,19 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm | FileCheck %s + +module { + // CHECK: llvm.mlir.global internal constant @cpuprintfFormat_0("Hello world %f %d %lld\0A\00") {addr_space = 0 : i32} + // CHECK: llvm.func @printf(!llvm.ptr, + // CHECK-NEXT: func.func @doprint(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i64) + func.func @doprint(%t: f32, %t2: i32, %t3: i64) { + // CHECK-NEXT: llvm.mlir.addressof + // CHECK-DAG: %[[C1:.*]] = llvm.getelementptr + // CHECK-SAME: !llvm.ptr, !llvm.array<24 x i8> + // CHECK: %[[C2:.*]] = llvm.fpext %[[ARG0]] + // CHECK: %[[C3:.*]] = llvm.zext %[[ARG1]] + // CHECK-NOT: cpuruntime.printf + // CHECK-NEXT: llvm.call @printf(%[[C1]], %[[C2]], %[[C3]], %[[ARG2]]) + cpuruntime.printf "Hello world %f %d %lld\n" %t, %t2, %t3 : f32, i32, i64 + return + } + +} \ No newline at end of file From 161848e3e442ac1f6007de5c664e3b70f5a20b02 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 15:11:57 +0800 Subject: [PATCH 02/64] format --- lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp | 8 +++----- src/gc-opt/gc-opt.cpp | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp index 73cf14a84..c56621d45 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimeToLLVM.cpp @@ -25,7 +25,7 @@ namespace mlir::cpuruntime { void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns); #define GEN_PASS_DEF_CPURUNTIMETOLLVM #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" @@ -48,7 +48,6 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, return ret; } - class PrintfRewriter : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -110,8 +109,7 @@ class PrintfRewriter : public ConvertOpToLLVMPattern { } }; -class CPURuntimeToLLVM - : public impl::CPURuntimeToLLVMBase { +class CPURuntimeToLLVM : public impl::CPURuntimeToLLVMBase { public: using Base::Base; void runOnOperation() final { @@ -146,7 +144,7 @@ struct CPURuntimeToDialectInterface : public ConvertToLLVMPatternInterface { } // namespace void populateCPURuntimeToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns) { patterns.add(converter); } diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index e1996c050..9b06ecf1f 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -17,8 +17,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "gc/Transforms/Passes.h" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" +#include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" From 447ef129fdbd73e15533c1bbc9a8f1e6f1273413 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 15:18:43 +0800 Subject: [PATCH 03/64] add dependency --- lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt index e349da72c..3a1d63d3d 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -10,7 +10,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIR - # MLIRInferTypeOpInterface - # MLIRFuncDialect + MLIRFuncDialect ) From a73dcc12e1023de4b59d4303835c2489b9f955b7 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 14 May 2024 16:03:52 +0800 Subject: [PATCH 04/64] fix new MLIR --- .../Transforms/CPURuntimePasses.cpp | 19 +++++++------- .../CPURuntime/cpuruntime-atexit-to-omp.mlir | 25 +++++++++++-------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp index f2a098fcf..a8f74c079 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp +++ b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp @@ -27,21 +27,22 @@ class CPURuntimeAtExitToOmpRewriter LogicalResult matchAndRewrite(AtParallelExitOp op, PatternRewriter &rewriter) const final { auto parent = op->getParentOp(); - Operation *secondLast = nullptr; - while (parent && (llvm::isa(parent) || - llvm::isa(parent))) { - secondLast = parent; + omp::ParallelOp parallel; + while (parent) { + parallel = llvm::dyn_cast(parent); + if (parallel) { + break; + } parent = parent->getParentOp(); } - auto parallel = llvm::dyn_cast(parent); if (!parallel) { return failure(); } - assert(secondLast->getBlock()); - auto itr = secondLast->getBlock()->end(); + auto &block = parallel.getRegion().front(); + auto itr = block.end(); --itr; - rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), - secondLast->getBlock(), itr); + rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), &block, + itr); rewriter.eraseOp(op); return success(); } diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir index 401de95cc..172777690 100644 --- a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir +++ b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir @@ -10,18 +10,21 @@ module { memref.copy %arg0, %alloc : memref<512x512xf32> to memref<512x512xf32> %0 = llvm.mlir.constant(1 : i64) : i64 omp.parallel { - omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> - %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.dealloc %alloc_0 : memref<512xf32> - cpuruntime.at_parallel_exit { - memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> - cpuruntime.parallel_exit.return + omp.wsloop { + omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { + memref.alloca_scope { + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> + %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> + memref.dealloc %alloc_0 : memref<512xf32> + cpuruntime.at_parallel_exit { + memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> + cpuruntime.parallel_exit.return + } } + omp.yield } - omp.yield + omp.terminator } memref.prefetch %alloc[%c0,%c0], read, locality<3>, data : memref<512x512xf32> omp.terminator @@ -30,7 +33,7 @@ module { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 // CHECK: omp.parallel // CHECK-NEXT: omp.wsloop - // CHECK-NEXT: memref.alloca_scope + // CHECK: memref.alloca_scope // CHECK-NOT: cpuruntime.at_parallel_exit // CHECK: omp.yield // CHECK: memref.prefetch {{%alloc}}[%[[C0]], %[[C0]]] From 1cfede8e24231618e69ef692c3b1a0120d5a8857 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 15 May 2024 15:24:41 +0800 Subject: [PATCH 05/64] add --- include/gc/Transforms/Passes.h | 26 +++ include/gc/Transforms/Passes.td | 13 ++ lib/gc/Transforms/CMakeLists.txt | 1 + lib/gc/Transforms/Pipeline.cpp | 164 +++++++++++++++++++ test/gc/Transforms/Pipeline/run.mlir | 23 +++ test/gc/Transforms/Pipeline/tensor_args.mlir | 13 ++ 6 files changed, 240 insertions(+) create mode 100644 lib/gc/Transforms/Pipeline.cpp create mode 100644 test/gc/Transforms/Pipeline/run.mlir create mode 100644 test/gc/Transforms/Pipeline/tensor_args.mlir diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 243a6f4f6..e5e4aee33 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -12,8 +12,34 @@ #include "mlir/Pass/Pass.h" namespace mlir { + +namespace LLVM { +class LLVMDialect; +} + +namespace scf { +class SCFDialect; +} + +namespace openmp { +class OpenMPDialect; +} + +namespace linalg { +class LinalgDialect; +} + +namespace MemRef { +class MemRefDialect; +} + +class PassManager; + namespace gc { +void populateFrontendPasses(mlir::PassManager &); +void populateCPUPipeline(mlir::PassManager &); + #define GEN_PASS_DECL #include "gc/Transforms/Passes.h.inc" diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index d31baa5a7..ff0cd8a90 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -31,4 +31,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { ]; } +def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { + let summary = "All-in-one pipeline for GC for CPU"; + let dependentDialects = ["onednn_graph::OneDNNGraphDialect", + "tensor::TensorDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "LLVM::LLVMDialect", + "scf::SCFDialect", + "bufferization::BufferizationDialect", + "omp::OpenMPDialect", + "vector::VectorDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index e7e97ea26..f3fa43e04 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp + Pipeline.cpp TileNamed.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp new file mode 100644 index 000000000..ed50925f6 --- /dev/null +++ b/lib/gc/Transforms/Pipeline.cpp @@ -0,0 +1,164 @@ +//===- Pipeline.cpp - Graph Compiler all-in-one pipeline --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Passes.h" + +namespace mlir::gc { + +void populateFrontendPasses(mlir::PassManager &pm) { + // pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg()); +} +// linalg + linalgX + tensor ==> GC V1 GIR + +void populateTensorPasses(mlir::PassManager &pm) { + // + padding propagation pass, upstream-able 127x127 -> tilling size:32 + // ->padding to 128x128 + // + layout propagation pass, upstream-able 4x32x4x32 -> + // tensor.pack/tensor.unpack + // + tensor constant propagation pass, down-stream pass, designed to support + // oneDNN graph spec + // + linalg.matmul lowering to (scf.loop + linalg.brgemm) pass, upstream-able + // + fine-grain fusion pass, upstream-able -> scf.for + linalgx.mask + // + lower linalg to arith/math on virtual vector pass, up-streamable + + // REMOVE this pass after the above passes are added. Currently we add this + // pass to make the pipeline work properly + pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); +} +// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack ==> +// GC V1 TIR + +void populateVectorPasses(mlir::PassManager &pm) { + // + bf16 promotion pass, down-stream pass, device dependent pass, maybe can + // upstream + // + bf16 cast elimilation pass, down-stream pass, fast-math kind pass, + // designed to support oneDNN graph spec + pm.addNestedPass(arith::createArithExpandOpsPass()); + // + lower to physical vector pass, down-stream pass, device dependent pass, + // maybe can upstream +} +// scf + arith + math + vector + tensor + linalg.brgemm + +void populateBufferizationPasses(mlir::PassManager &pm) { + bufferization::OneShotBufferizationOptions options; + pm.addPass(bufferization::createOneShotBufferizePass(options)); + pm.addPass(createCSEPass()); + pm.addPass(mlir::func::createFuncBufferizePass()); + pm.addPass(bufferization::createBufferResultsToOutParamsPass()); + pm.addNestedPass( + bufferization::createBufferizationBufferizePass()); + pm.addNestedPass( + bufferization::createFinalizingBufferizePass()); + // + buffer schedule pass, down-stream pass, to migrate buffer reschedule pass + // from GC V1. + pm.addNestedPass( + bufferization::createBufferHoistingPass()); // Need to improve this pass + // to support thread-local + // allocator. + pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); + pm.addNestedPass(bufferization::createBufferDeallocationPass()); + pm.addPass(createBufferizationToMemRefPass()); +} +// scf + arith + math + vector + memref + linalg.brgemm + +void populateMicroKernelPasses(mlir::PassManager &pm) { + // + ConvertLinalgToMicrokernel pass, upstream-able, + // + CleanupInvalidMicrokernel pass, upstream-able + // + InvariantMicrokernelMotion pass, upstream-able + // + ConvertMicrokernelToDnnlFunc, down-stream pass, to lower brgemm to dnnl + // call + // + ConvertMicrokernelToXsmm, down-stream pass, to lower brgemm to libxsmm + // call + // + LowerMicrokernel pass, upstream-able + // + DispatchMicrokernel, down-stream pass +} +// scf + arith + math + vector + memref + func/microkernel + +void populateCPURuntimePasses(mlir::PassManager &pm) { + // + flatten nested parallel pass, down-stream pass, to support coarse-grain + // fusion + // pm.addNestedPass(parallelcpu::createParallelCPUAtExitToOmp()); + // remove this pass after we add FlattenNestedParallel + pm.addPass(createConvertSCFToOpenMPPass()); +} + +void populateLoweringToLLVMPasses(mlir::PassManager &pm) { + pm.addPass(createConvertSCFToCFPass()); + // pm.addPass(parallelcpu::createParallelCPUToLLVM()); + pm.addPass(createConvertOpenMPToLLVMPass()); + pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addPass(createConvertMathToLibmPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addNestedPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addPass(createSymbolDCEPass()); +} + +void populateLLVMPasses(mlir::PassManager &pm) { + pm.addPass(memref::createExpandOpsPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + populateLoweringToLLVMPasses(pm); +} + +void populateCPUPipeline(mlir::PassManager &pm) { + // front-end, oneDNN graph dialect + populateFrontendPasses(pm); + // middle-end, LinalgX/Linalg/tensor dialects + populateTensorPasses(pm); + // middle-end, arith/math/vector dialects + populateVectorPasses(pm); + // back-end, arith/math/vector/memref dialects + populateBufferizationPasses(pm); + // REMOVE this pass after the TensorPasses are added. Currently we add this + // pass to make the pipeline work properly + pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); + populateMicroKernelPasses(pm); + populateCPURuntimePasses(pm); + // // back-end, llvm dialect + populateLLVMPasses(pm); +} + +#define GEN_PASS_DEF_GCCPUPIPELINE +#include "gc/Transforms/Passes.h.inc" +namespace { + +class GCCPUPipeline : public impl::GCCPUPipelineBase { +public: + friend struct PassHelper; + using impl::GCCPUPipelineBase::GCCPUPipelineBase; + void runOnOperation() final { + auto op = getOperation(); + PassManager pm{op->getContext()}; + populateCPUPipeline(pm); + if (failed(pm.run(op))) + signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::gc diff --git a/test/gc/Transforms/Pipeline/run.mlir b/test/gc/Transforms/Pipeline/run.mlir new file mode 100644 index 000000000..799935006 --- /dev/null +++ b/test/gc/Transforms/Pipeline/run.mlir @@ -0,0 +1,23 @@ +// RUN: gc-opt %s --gc-cpu-pipeline | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s + +module { +func.func @aaa() -> tensor<128xf32> { + %c2 = arith.constant 2.0 : f32 + %a = tensor.empty() : tensor<128xf32> + %2 = linalg.fill ins(%c2 : f32) outs(%a : tensor<128xf32>) -> tensor<128xf32> + return %2 : tensor<128xf32> +} + +func.func @main() { + %result = call @aaa() : ()-> tensor<128xf32> + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c128 step %c1 { + %4 = tensor.extract %result[%iv] : tensor<128xf32> + parallelcpu.printf "%f\n" %4 : f32 + } + return +} +// CHECK-COUNT-128: 2.000000 +} \ No newline at end of file diff --git a/test/gc/Transforms/Pipeline/tensor_args.mlir b/test/gc/Transforms/Pipeline/tensor_args.mlir new file mode 100644 index 000000000..73d916d04 --- /dev/null +++ b/test/gc/Transforms/Pipeline/tensor_args.mlir @@ -0,0 +1,13 @@ +// RUN: gc-opt %s --gc-cpu-pipeline | FileCheck %s + +module { +// CHECK: aaa +// check that the func returns void +// CHECK-NOT: ) -> !llvm.struct< +func.func @aaa(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + // CHECK: memcpy + return %out : tensor<128xf32> +} +} \ No newline at end of file From 3d3308c43b194bc37a01fa0a6ad94273f3c5368b Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Wed, 15 May 2024 10:07:35 +0800 Subject: [PATCH 06/64] move codes from dnn-compiler --- .../DataFlow/ConstantSubgraphAnalysis.h | 127 +++++ .../Dialect/OnednnGraph/OnednnGraphDialect.td | 1 + include/gc/Transforms/Passes.h | 5 + include/gc/Transforms/Passes.td | 16 + lib/gc/Analysis/CMakeLists.txt | 16 + .../DataFlow/ConstantSubgraphAnalysis.cpp | 180 +++++++ lib/gc/CMakeLists.txt | 1 + .../OnednnGraph/OnednnGraphDialect.cpp | 6 + lib/gc/Transforms/CMakeLists.txt | 2 + lib/gc/Transforms/CSA.cpp | 51 ++ lib/gc/Transforms/CST.cpp | 496 ++++++++++++++++++ src/gc-opt/CMakeLists.txt | 3 +- .../test_constant_weights_folding.mlir | 76 +++ 13 files changed, 979 insertions(+), 1 deletion(-) create mode 100644 include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h create mode 100644 lib/gc/Analysis/CMakeLists.txt create mode 100644 lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp create mode 100644 lib/gc/Transforms/CSA.cpp create mode 100644 lib/gc/Transforms/CST.cpp create mode 100644 test/gc/Transforms/test_constant_weights_folding.mlir diff --git a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h new file mode 100644 index 000000000..fcb2939d8 --- /dev/null +++ b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h @@ -0,0 +1,127 @@ +//===- ConstantSubgraphAnalysis.h - Constant subgraph analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements constant subgraph analysis. In this file are: +// 1. the lattice value class that represents operations with constant inputs +// and outputs in the program, and +// 2. a sparse constant subgraph analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSIS_H +#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include + +namespace mlir { +namespace dataflow { + +//===----------------------------------------------------------------------===// +// InConstantSubgraph +//===----------------------------------------------------------------------===// + +/// This lattice represents a boolean integer indicating if an operation is with +/// constant inputs and constant outputs and hence in constant subgraph. +class InConstantSubgraph { +public: + /// Construct as uninitialized. + explicit InConstantSubgraph() = default; + + /// Construct with a known state. + explicit InConstantSubgraph(bool initialized, bool inConstantSubgraph) + : initialized(initialized), inConstantSubgraph(inConstantSubgraph) {} + + /// Get the state. Returns null if no value was determined. + bool getInConstantSubgraph() const { + assert(!isUninitialized()); + return inConstantSubgraph; + } + + /// Compare. + bool operator==(const InConstantSubgraph &rhs) const { + return initialized == rhs.initialized && + inConstantSubgraph == rhs.inConstantSubgraph; + } + + void print(raw_ostream &os) const; + + /// Get uninitialized state. This happens when the + /// state hasn't been set during the analysis. + static InConstantSubgraph getUninitialized() { return InConstantSubgraph{}; } + + /// Whether the state is uninitialized. + bool isUninitialized() const { return !initialized; } + + /// Get unknown state. + static InConstantSubgraph getUnknown() { + return InConstantSubgraph{/*initialized=*/false, + /*inConstantSubgraph=*/false}; + } + + // Join two states. + static InConstantSubgraph join(const InConstantSubgraph &lhs, + const InConstantSubgraph &rhs) { + // if one is uninitialized, use another + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + + // both are initialized, intersect them + if (!lhs.isUninitialized() && !rhs.isUninitialized()) { + return InConstantSubgraph(true, lhs.getInConstantSubgraph() && + rhs.getInConstantSubgraph()); + } + return getUninitialized(); + } + +private: + bool initialized = false; + bool inConstantSubgraph = false; +}; + +//===----------------------------------------------------------------------===// +// ConstantSubgraphAnalysis +//===----------------------------------------------------------------------===// + +class ConstantSubgraphAnalysis + : public SparseForwardDataFlowAnalysis> { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + + void setToEntryState(Lattice *lattice) override; +}; + +//===----------------------------------------------------------------------===// +// RunConstantSubgraphAnalysis +//===----------------------------------------------------------------------===// + +/// Runs constant subgraph analysis on the IR defined by `op`. +struct RunConstantSubgraphAnalysis { +public: + RunConstantSubgraphAnalysis(); + + void run(Operation *op); + + bool getInConstantSubgraph(Value val); + +private: + /// Stores the result of the analysis. + DataFlowSolver solver; + + void getConstantSubgraph(DataFlowSolver &solver, Operation *topFunc); +}; +} // end namespace dataflow +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSIS_H diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td b/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td index 16615a4d3..1f6fbe77b 100644 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td +++ b/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td @@ -24,6 +24,7 @@ def OnednnGraphDialect : Dialect { let cppNamespace = "::mlir::onednn_graph"; let useDefaultTypePrinterParser = 1; + let hasOperationAttrVerify = 1; } #endif // ONEDNNGRAPH_DIALECT diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 243a6f4f6..34d2fd487 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -15,8 +15,13 @@ namespace mlir { namespace gc { #define GEN_PASS_DECL +#define GEN_PASS_DECL_CSA +#define GEN_PASS_DECL_CST #include "gc/Transforms/Passes.h.inc" +std::unique_ptr createCSAPass(); +std::unique_ptr createCSTPass(); + #define GEN_PASS_REGISTRATION #include "gc/Transforms/Passes.h.inc" } // namespace gc diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 7274534b7..23593ded3 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -17,4 +17,20 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> { ["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"]; } +def CSA : Pass<"csa"> { + let summary = "Constant Subgraph Analysis"; + let description = [{ + This pass implements a constant subgraph analysis. + }]; + let constructor = "mlir::gc::createCSAPass()"; +} + +def CST : Pass<"cst"> { + let summary = "Constant Subgraph Transform"; + let description = [{ + This pass implements a constant subgraph transform. + }]; + let constructor = "mlir::gc::createCSTPass()"; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt new file mode 100644 index 000000000..42c3d5541 --- /dev/null +++ b/lib/gc/Analysis/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_library(GCAnalysis + DataFlow/ConstantSubgraphAnalysis.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + GraphCompilerPassIncGen + + LINK_LIBS PUBLIC + ${mlir_dialect_libs} + MLIRIR + MLIRSupport + MLIRBufferizationToMemRef + MLIRBufferizationPipelines + ) diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp new file mode 100644 index 000000000..2de9e5b4a --- /dev/null +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp @@ -0,0 +1,180 @@ +//===- ConstantSubgraphAnalysis.cpp - Constant subgraph analysis ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "in-constant-subgraph" + +using namespace mlir; +using namespace mlir::dataflow; + +//===----------------------------------------------------------------------===// +// InConstantSubgraph +//===----------------------------------------------------------------------===// + +void InConstantSubgraph::print(raw_ostream &os) const { + if (isUninitialized()) { + os << ""; + return; + } + os << getInConstantSubgraph(); + return; +} + +//===----------------------------------------------------------------------===// +// ConstantSubgraphAnalysis +//===----------------------------------------------------------------------===// + +void ConstantSubgraphAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + LLVM_DEBUG(llvm::dbgs() << "ConstantSubgraphAnalysis: Visiting operation:\n" + << *op << "\n"); + + bool in = true; + if (op->hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "Curr op is a Constant op\n"); + in = true; + } else if (operands.size() == 0) { // For example, tensor.empty() + LLVM_DEBUG(llvm::dbgs() << "Curr op has 0 operand, constant\n"); + in = true; + } else { + LLVM_DEBUG(llvm::dbgs() << "Curr op has " << operands.size() + << " operands, check if constant\n"); + for (auto *operandLattice : operands) { + auto operandState = operandLattice->getValue().getInConstantSubgraph(); + LLVM_DEBUG(llvm::dbgs() << "Operand: " << operandLattice->getPoint() + << ", lattice value: " << operandState << "\n"); + if (!operandState) { + in = false; + break; + } + } + } + + // lattice in results should be in unintialized state. + if (!in) { + LLVM_DEBUG(llvm::dbgs() << "Curr op not in constant subgraph\n"); + for (auto lattice : results) { + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph(true, false))); + } + } else { + LLVM_DEBUG(llvm::dbgs() << "Curr op in constant subgraph\n"); + for (auto lattice : results) { + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph(true, true))); + } + } +} + +void ConstantSubgraphAnalysis::setToEntryState( + Lattice *lattice) { + if (auto blockArg = cast(lattice->getPoint())) { + auto parent_op = blockArg.getParentBlock()->getParentOp(); + auto parent_op_attr = parent_op->getAttrDictionary(); + std::optional const_args = + parent_op_attr.getNamed("onednn_graph.const_args"); + if (const_args.has_value()) { + ArrayAttr const_args_indexes = + llvm::dyn_cast(const_args->getValue()); + for (auto id : const_args_indexes) { + auto idint = llvm::cast(id).getInt(); + if (blockArg.getArgNumber() == idint) { + LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg + << " is marked as constant\n"); + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph(true, true))); + return; + } + } + } + propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false))); + } else { + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph::getUninitialized())); + } +} + +//===----------------------------------------------------------------------===// +// RunConstantSubgraphAnalysis +//===----------------------------------------------------------------------===// + +/// Get the operations whose inputs and outputs are all constant values. +/// These operations will be put into a seperate subgraph. +void RunConstantSubgraphAnalysis::getConstantSubgraph(DataFlowSolver &solver, + Operation *topFunc) { + OpBuilder builder(topFunc->getContext()); + SmallVector constantOperations; + + Block &block = topFunc->getRegions().front().getBlocks().front(); + for (Operation &op : llvm::make_early_inc_range(block)) { + // If all the result values of a op are const, we mark this op as const. + bool resultsAllConstant = true; + if (op.getNumResults() == 0) { + continue; + } + for (Value res : op.getResults()) { + auto *lattice = solver.lookupState>(res); + if (!lattice || lattice->getValue().isUninitialized()) { + resultsAllConstant = false; + break; + } + const InConstantSubgraph &latticeValue = lattice->getValue(); + if (!latticeValue.getInConstantSubgraph()) { + resultsAllConstant = false; + break; + } + } + if (resultsAllConstant) { + op.setAttr("onednn_graph.in_const_subgraph", builder.getBoolAttr(true)); + constantOperations.push_back(&op); + } + } + + if (constantOperations.empty()) { + return; + } +} + +RunConstantSubgraphAnalysis::RunConstantSubgraphAnalysis() { + solver.load(); + solver.load(); +} + +void RunConstantSubgraphAnalysis::run(Operation *topFunc) { + if (failed(solver.initializeAndRun(topFunc))) { + return; + } + getConstantSubgraph(solver, topFunc); +} + +bool RunConstantSubgraphAnalysis::getInConstantSubgraph(Value val) { + auto *lattice = solver.lookupState>(val); + const InConstantSubgraph &latticeValue = lattice->getValue(); + return latticeValue.getInConstantSubgraph(); +} \ No newline at end of file diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index f5ed3a6e5..308db5f30 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -3,4 +3,5 @@ if(GC_MLIR_CXX_FLAGS) endif() add_subdirectory(Dialect) +add_subdirectory(Analysis) add_subdirectory(Transforms) \ No newline at end of file diff --git a/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp b/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp index 434fa8a57..6469cc12a 100644 --- a/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp +++ b/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp @@ -18,3 +18,9 @@ void OnednnGraphDialect::initialize() { #include "gc/Dialect/OnednnGraph/OnednnGraphOps.cpp.inc" >(); } + +LogicalResult +OnednnGraphDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + return success(); +} \ No newline at end of file diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index df8a14d01..6421d20c2 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_library(GCPasses TileNamed.cpp + CSA.cpp + CST.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/CSA.cpp b/lib/gc/Transforms/CSA.cpp new file mode 100644 index 000000000..5175be2f5 --- /dev/null +++ b/lib/gc/Transforms/CSA.cpp @@ -0,0 +1,51 @@ +//===- CSA.cpp - Constant Subgraph Analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass performs a constant subgraph analysis +// in MLIR. +// +//===----------------------------------------------------------------------===// +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_CSA +#include "gc/Transforms/Passes.h.inc" +} // namespace gc + +using namespace mlir; +using namespace mlir::dataflow; + +namespace gc { + +struct CSA : public impl::CSABase { + void runOnOperation() override; +}; + +void CSA::runOnOperation() { + Operation *op = getOperation(); + auto &func = + op->getRegions().front().getBlocks().front().getOperations().front(); + + // Hard-code: set the #1 argument to be constant. + // OpBuilder builder(op->getContext()); + // func.setAttr("onednn_graph.const_args", + // builder.getI32ArrayAttr({1,2,3,4})); + + RunConstantSubgraphAnalysis csa; + (void)csa.run(&func); +} + +std::unique_ptr createCSAPass() { return std::make_unique(); } + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp new file mode 100644 index 000000000..2dac0d860 --- /dev/null +++ b/lib/gc/Transforms/CST.cpp @@ -0,0 +1,496 @@ +//===- CST.cpp - Constant Subgraph Transform -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass performs a constant subgraph transform in MLIR. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_CST +#include "gc/Transforms/Passes.h.inc" +} // namespace gc + +using namespace mlir; + +namespace gc { + +struct CST : public impl::CSTBase { + void runOnOperation() override; +}; + +bool isInConstantSubgraph(Operation *op) { + auto opNamespace = op->getDialect()->getNamespace(); + if (opNamespace == linalg::LinalgDialect::getDialectNamespace() || + opNamespace == tensor::TensorDialect::getDialectNamespace() || + opNamespace == arith::ArithDialect::getDialectNamespace()) { + if (op->getAttr("onednn_graph.in_const_subgraph")) { + return true; + } + } + return false; +} + +int64_t getTensorSize(TensorType t) { + Type eleType = t.getElementType(); + unsigned bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes + ArrayRef shape = t.getShape(); + int64_t size = bitWidth; + for (auto s : shape) { + size *= s; + } + return size; +} + +bool canMoveBefore(Operation *op) { + if (op->getDialect()->getNamespace() == + arith::ArithDialect::getDialectNamespace()) { + return true; + } + + if (op->getDialect()->getNamespace() != + linalg::LinalgDialect::getDialectNamespace()) { + return false; + } + + auto linalgOp = dyn_cast(op); + + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + for (auto &affineMap : indexingMaps) { + if (!affineMap.isIdentity()) { + return false; + } + } + + SmallVector iterTypes = linalgOp.getIteratorTypesArray(); + for (auto &iterType : iterTypes) { + if (iterType != utils::IteratorType::parallel) { + return false; + } + } + + if (op->getNumOperands() > 1) { + // int64_t numInputs = linalgOp.getNumDpsInputs(); + int64_t numInits = linalgOp.getNumDpsInits(); + // definingOp of init should be tensor.empty() + for (int64_t i = 0; i < numInits; ++i) { + OpOperand *outOperand = linalgOp.getDpsInitOperand(i); + auto parentOp = outOperand->get().getDefiningOp(); + if (!isa(parentOp)) { + return false; + } + } + } + + return true; +} + +void postponeBroadcast(Block &block) { + // auto bcOps = block.getOps(); + // for (linalg::BroadcastOp bcOp : bcOps) {} + SmallVector constBcOps; + for (Operation &op : block.getOperations()) { + if (isa(&op)) { + Operation *bcOp = &op; + if (isInConstantSubgraph(bcOp)) { + constBcOps.push_back(bcOp); + } + } + } + + for (auto bcOp : constBcOps) { + // For topo v -> pack -> bc -> mul -> matmul, we transform + // it to v -> pack -> mul -> bc -> matmul, so that we can fold + // v -> pack -> mul. Note that we require the topo to be sequential + // and all the Values have exactly one user. + + // go upwards to BlockArg + SmallVector prevOps; + Operation *currOp = bcOp; + while (true) { + if (currOp->getNumOperands() != 1) { + break; + } + Value operand = currOp->getOperand(0); + if (isa(operand)) { + break; + } else { + currOp = operand.getDefiningOp(); + prevOps.push_back(currOp); + } + } + + // go downwards to the last constant op + SmallVector postOps; + currOp = bcOp; + while (true) { + if (currOp->getNumResults() != 1 || !currOp->hasOneUse()) { + break; + } + Value input = currOp->getResult(0); + currOp = *(input.getUsers().begin()); + Value output = currOp->getResult(0); + // NOTE: we require that input shape and output shape of curr op to be + // same. Operations from tensor dialect, like + // pack/unpack/concat/collapse_shape/expand_shape/reshape/pad, are not + // supported. So we simply restrict that currOp to be from arith or + // linalg. + if (!isa(input.getType()) || + !isa(output.getType()) || + dyn_cast(input.getType()).getShape() != + dyn_cast(output.getType()).getShape() || + !canMoveBefore(currOp)) { + break; + } + if (!isInConstantSubgraph(currOp)) { + break; + } else { + postOps.push_back(currOp); + } + } + if (postOps.empty()) { + continue; + } + + // move bcOp after the last constant op + SmallVector newPostOps; + Value operand = static_cast(bcOp->getOperand(0)); + ArrayRef shapeBeforeBc = + dyn_cast(operand.getType()).getShape(); + size_t postOpId = 0; + for (Operation *postOp : postOps) { + SmallVector newOperandTypes; + for (auto oriType : postOp->getOperandTypes()) { + TensorType tt = dyn_cast(oriType); + newOperandTypes.push_back( + tt.cloneWith(shapeBeforeBc, tt.getElementType())); + } + SmallVector newResultTypes; + for (auto oriType : postOp->getResultTypes()) { + TensorType tt = dyn_cast(oriType); + newResultTypes.push_back( + tt.cloneWith(shapeBeforeBc, tt.getElementType())); + } + auto *newPostOp = + Operation::create(postOp->getLoc(), postOp->getName(), newResultTypes, + postOp->getOperands(), + /*postOp->getAttrDictionary()*/ std::nullopt, + /*postOp->getPropertiesStorage()*/ nullptr, + postOp->getSuccessors(), postOp->getNumRegions()); + for (auto [oldRegion, newRegion] : + llvm::zip(postOp->getRegions(), newPostOp->getRegions())) { + newRegion.takeBody(oldRegion); + } + + if (postOpId == 0) { + // Only the first post op needs to replace its operand. Others only + // needs to call postOp->replaceAllUsesWith(newPostOp->getResults()). + newPostOp->getOperand(0).replaceAllUsesWith(operand); + } + ++postOpId; + + newPostOp->setAttr("onednn_graph.in_const_subgraph", + postOp->getAttr("onednn_graph.in_const_subgraph")); + if (postOp->getDialect()->getNamespace() == + linalg::LinalgDialect::getDialectNamespace()) { + newPostOp->setAttr("operandSegmentSizes", + postOp->getAttr("operandSegmentSizes")); + + OpBuilder builder(postOp->getContext()); + size_t indexingMapsSize = + dyn_cast(postOp).getIndexingMapsArray().size(); + unsigned rank = shapeBeforeBc.size(); + SmallVector indexingMaps( + indexingMapsSize, builder.getMultiDimIdentityMap(rank)); + auto indexingMapsAttr = builder.getAffineMapArrayAttr(indexingMaps); + newPostOp->setAttr("indexing_maps", indexingMapsAttr); + + SmallVector iterTypes = + dyn_cast(postOp).getIteratorTypesArray(); + iterTypes.resize(rank); + auto iterTypesAttr = + builder.getArrayAttr(llvm::to_vector(llvm::map_range( + iterTypes, [&](utils::IteratorType iter) -> mlir::Attribute { + return linalg::IteratorTypeAttr::get(builder.getContext(), + iter); + }))); + newPostOp->setAttr("iterator_types", iterTypesAttr); + } else { + // Ops from other dialects. + } + + // Modify the outputOperands of postOp. Here we simply assume that the + // value is from tensor.empty(). + if (postOp->getNumOperands() > 0) { + for (size_t i = 1; i < postOp->getNumOperands(); ++i) { + auto outOperand = postOp->getOperand(i); + outOperand.setType(newOperandTypes.front()); + } + } + + block.getOperations().push_back(newPostOp); + newPostOp->moveAfter(postOp); + newPostOps.push_back(newPostOp); + postOp->replaceAllUsesWith(newPostOp->getResults()); + + operand = static_cast(newPostOp->getResult(0)); + } + + auto nextOp = *(newPostOps.back()->getUsers().begin()); + nextOp->getOperand(0).replaceAllUsesWith(bcOp->getResult(0)); + bcOp->moveAfter(newPostOps.back()); + bcOp->getOperand(0).replaceUsesWithIf(operand, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op == bcOp; + }); + + for (auto it = postOps.rbegin(); it != postOps.rend(); ++it) { + (*it)->erase(); + } + } +} + +static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; + +// Operate on tensors. Create fold() and compute() on module. The +// folded weights and first-run flag is maintained by upper-level runtime. +void CST::runOnOperation() { + Operation *topOp = getOperation(); + MLIRContext *context = topOp->getContext(); + // A ModuleOp contains a single region, which contains a single block. + auto moduleOp = dyn_cast(topOp); + SymbolTable symbolTable(moduleOp); + auto &topFunc = + topOp->getRegions().front().getBlocks().front().getOperations().front(); + OpBuilder builder(context); + + auto topFuncAttr = topFunc.getAttrDictionary(); + std::optional constArgs = + topFuncAttr.getNamed("onednn_graph.const_args"); + std::unordered_set constArgsIndexes; + if (constArgs.has_value()) { + ArrayAttr constArgsArray = llvm::dyn_cast(constArgs->getValue()); + for (auto id : constArgsArray) { + constArgsIndexes.insert(llvm::cast(id).getInt()); + } + } else { + return; + } + if (constArgsIndexes.empty()) { + return; + } + + Region ®ion = topFunc.getRegions().front(); + Block &block = region.getBlocks().front(); + + postponeBroadcast(block); + + SmallVector constOps; + for (Operation &op : llvm::make_early_inc_range(block)) { + if (isInConstantSubgraph(&op)) { + constOps.push_back(&op); + } + } + + std::string funcName("fold"); + SmallVector inputTypes; // types of constant weights + // values of constant weights in original block + SmallVector inputValues; + SmallVector outputTypes; // types of folded constant weights + // values of folded constant weights in original block + SmallVector outputValues; + Value v; + // TODO: solve complicated topology. Currently we only handle simple topology + // where one constant weight input will and only will produce one constant + // output and each constant weight only contributes to one constant output. + for (size_t id = 0; id < block.getNumArguments(); ++id) { + if (constArgsIndexes.count(id) == 1) { + auto arg = block.getArgument(id); + if (!isa(arg.getType())) { + continue; + } + inputTypes.push_back(arg.getType()); + v = dyn_cast(arg); + inputValues.push_back(v); + SmallVector valuesOnTheWay = {v}; // the constant tensors + // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2 + while (!v.getUsers().empty()) { + // v.getUsers().size() should be 1 + Operation *user = *(v.getUsers().begin()); + if (!isInConstantSubgraph(user)) { + outputTypes.push_back(v.getType()); + outputValues.push_back(v); + break; + } + // user should has only 1 output value + OpResult result = *(user->result_begin()); + v = dyn_cast(result); + valuesOnTheWay.push_back(v); + } + + // If data size of outputValue is too greater than size of inputValue, do + // not fold it. Compare data size changes during traverse to find the last + // op that satisfies this condition. + int64_t initSize = + getTensorSize(dyn_cast(valuesOnTheWay[0].getType())); + if (!isa(outputTypes.back()) || + initSize * DATA_SIZE_EXPANDING_THRESHOLD < + getTensorSize(dyn_cast(outputTypes.back()))) { + size_t lastIdx = 0; + for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { + int64_t size = + getTensorSize(dyn_cast(valuesOnTheWay[i].getType())); + if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { + lastIdx = i; + } + } + if (lastIdx == 0) { // no suitable value found + inputTypes.pop_back(); + outputTypes.pop_back(); + inputValues.pop_back(); + outputValues.pop_back(); + constArgsIndexes.erase(id); + } else { + outputTypes.back() = valuesOnTheWay[lastIdx].getType(); + outputValues.back() = valuesOnTheWay[lastIdx]; + } + } + } + } + if (inputTypes.size() != outputTypes.size()) { + return; + } + + FunctionType foldFuncType = + FunctionType::get(context, inputTypes, outputTypes); + auto foldFunc = + builder.create(topFunc.getLoc(), funcName, foldFuncType); + Block *foldBlock = foldFunc.addEntryBlock(); + // values of folded constant weights in foldBlock + SmallVector outputValuesInFold; + IRMapping mapper; + for (Operation *op : constOps) { + foldBlock->getOperations().push_back(op->clone(mapper)); + } + // the order of outputValuesInFold is according to the order of corresponding + // inputValues + for (auto &v : outputValues) { + auto foldedV = mapper.lookupOrNull(v); + outputValuesInFold.push_back(foldedV); + v.replaceUsesWithIf(foldedV, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == foldBlock; + }); + } + + auto returnOp = + builder.create(topOp->getLoc(), outputValuesInFold); + foldBlock->getOperations().push_back(returnOp); + for (size_t i = 0; i < inputValues.size(); ++i) { + inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i), + [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == foldBlock; + }); + } + + foldFunc.setVisibility(SymbolTable::Visibility::Public); + moduleOp.push_back(foldFunc); + symbolTable.insert(foldFunc); + + // modify the BlockArguments of block + size_t oriNumArgs = block.getNumArguments(); + size_t argIdx = 0; + for (size_t id = 0; id < oriNumArgs; ++id) { + if (constArgsIndexes.count(id) == 1) { + auto loc = block.getArgument(id).getLoc(); + BlockArgument foldArg = + block.insertArgument(id, outputTypes[argIdx], loc); + outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == █ + }); + + std::deque dq; + SmallVector opsToErase; + dq.push_back(block.getArgument(id + 1)); + while (!dq.empty()) { + Value v = dq.front(); + dq.pop_front(); + for (Operation *op : v.getUsers()) { + for (auto res : op->getResults()) { + dq.push_back(res); + } + opsToErase.push_back(op); + } + } + + for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) { + (*it)->erase(); + } + block.eraseArgument(id + 1); + ++argIdx; + } + } + + // modify the compute func signature + func::FuncOp computeFunc = cast(topFunc); + FunctionType computeFuncType = computeFunc.getFunctionType(); + computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(), + computeFuncType.getResults())); + + // Delete dead operations by dialects' canonicalizer + RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + + ArrayRef disabledPatterns, enabledPatterns; + std::shared_ptr patterns = + std::make_shared( + std::move(owningPatterns), disabledPatterns, enabledPatterns); + GreedyRewriteConfig config; + LogicalResult converged = + applyPatternsAndFoldGreedily(topOp, *patterns, config); + (void)converged; + + // clean up the constant-related attrs on ops + for (auto &op : block.getOperations()) { + if (op.getAttr("onednn_graph.in_const_subgraph")) { + op.removeAttr("onednn_graph.in_const_subgraph"); + } + } + for (auto &op : foldBlock->getOperations()) { + if (op.getAttr("onednn_graph.in_const_subgraph")) { + op.removeAttr("onednn_graph.in_const_subgraph"); + } + } +} + +std::unique_ptr createCSTPass() { return std::make_unique(); } + +} // namespace gc +} // namespace mlir diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index ff33375de..b07c4dfe6 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -2,7 +2,8 @@ set(gc_opt_libs ${dialect_libs} ${conversion_libs} MLIROptLib - GCPasses) + GCPasses + GCAnalysis) if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") endif() diff --git a/test/gc/Transforms/test_constant_weights_folding.mlir b/test/gc/Transforms/test_constant_weights_folding.mlir new file mode 100644 index 000000000..52885ae7d --- /dev/null +++ b/test/gc/Transforms/test_constant_weights_folding.mlir @@ -0,0 +1,76 @@ +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(csa,cst)" %s | FileCheck %s + +// CHECK-LABEL: func.func @entry +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +module { + // COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear. + // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %1 = tensor.empty() : tensor<2x16x32x32xbf16> + %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> + %2 = tensor.empty() : tensor<8x16x32x32xbf16> + %packed_arg1 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<512x256xbf16> -> tensor<8x16x32x32xbf16> + %3 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %packed_packed_arg1 = tensor.pack %packed_arg1 inner_dims_pos = [2] inner_tiles = [2] into %3 : tensor<8x16x32x32xbf16> -> tensor<8x16x16x32x2xbf16> + %4 = tensor.empty() : tensor<2x8x32x32xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %5 = linalg.fill ins(%cst_0 : bf16) outs(%4 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%packed_arg0, %packed_packed_arg1 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%5 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x8x32x32xbf16> + %15 = tensor.empty() : tensor<8x32xbf16> + %packed_arg2 = tensor.pack %arg2 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %15 : tensor<256xbf16> -> tensor<8x32xbf16> + %bc_arg2_init = tensor.empty() : tensor<2x8x32x32xbf16> + %bc_arg2 = linalg.broadcast ins(%packed_arg2 : tensor<8x32xbf16>) outs(%bc_arg2_init : tensor<2x8x32x32xbf16>) dimensions = [0, 2] + %extf32 = arith.extf %bc_arg2 : tensor<2x8x32x32xbf16> to tensor<2x8x32x32xf32> + %cst_2 = arith.constant 2.000000e+00 : f32 + %extf32_mul2_init = tensor.empty() : tensor<2x8x32x32xf32> + %extf32_mul2 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extf32 : tensor<2x8x32x32xf32>) outs(%extf32_mul2_init : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.mulf %in, %cst_2 : f32 + linalg.yield %8 : f32 + } -> tensor<2x8x32x32xf32> + %truncbf16 = arith.truncf %extf32_mul2 : tensor<2x8x32x32xf32> to tensor<2x8x32x32xbf16> + %7 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%truncbf16 : tensor<2x8x32x32xbf16>) outs(%6 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %45 = arith.addf %in, %out : bf16 + linalg.yield %45 : bf16 + } -> tensor<2x8x32x32xbf16> + %8 = tensor.empty() : tensor<32x8x32x32xbf16> + %packed_arg3 = tensor.pack %arg3 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %8 : tensor<256x1024xbf16> -> tensor<32x8x32x32xbf16> + %9 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %packed_packed_arg3 = tensor.pack %packed_arg3 inner_dims_pos = [2] inner_tiles = [2] into %9 : tensor<32x8x32x32xbf16> -> tensor<32x8x16x32x2xbf16> + %10 = tensor.empty() : tensor<2x32x32x32xbf16> + %11 = linalg.fill ins(%cst_0 : bf16) outs(%10 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %12 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%7, %packed_packed_arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%11 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %46 = arith.mulf %in, %in_0 : bf16 + %56 = arith.addf %out, %46 : bf16 + linalg.yield %56 : bf16 + } -> tensor<2x32x32x32xbf16> + %16 = tensor.empty() : tensor<32x32xbf16> + %packed_arg4 = tensor.pack %arg4 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %16 : tensor<1024xbf16> -> tensor<32x32xbf16> + %13 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%packed_arg4 : tensor<32x32xbf16>) outs(%12 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %47 = arith.addf %in, %out : bf16 + linalg.yield %47 : bf16 + } -> tensor<2x32x32x32xbf16> + %14 = tensor.empty() : tensor<64x1024xbf16> + %unpack = tensor.unpack %13 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %14 : tensor<2x32x32x32xbf16> -> tensor<64x1024xbf16> + return %unpack : tensor<64x1024xbf16> + } +} +// CHECK: linalg.broadcast +// CHECK: func.func @fold +// CHECK: arith.extf +// CHECK: arith.truncf +// COM: expected output: +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> +// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) From 475faf8052309cfd9e170f61e2622d5c4cd7a5ad Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 15 May 2024 15:46:46 +0800 Subject: [PATCH 07/64] update --- lib/gc/Transforms/Pipeline.cpp | 9 ++++++--- test/gc/Transforms/Pipeline/run.mlir | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index ed50925f6..8b74df1b9 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -21,6 +21,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" @@ -65,7 +66,9 @@ void populateBufferizationPasses(mlir::PassManager &pm) { pm.addPass(bufferization::createOneShotBufferizePass(options)); pm.addPass(createCSEPass()); pm.addPass(mlir::func::createFuncBufferizePass()); - pm.addPass(bufferization::createBufferResultsToOutParamsPass()); + bufferization::BufferResultsToOutParamsOpts opt{}; + // opt.hoistStaticAllocs = true; + pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); pm.addNestedPass( bufferization::createBufferizationBufferizePass()); pm.addNestedPass( @@ -98,14 +101,14 @@ void populateMicroKernelPasses(mlir::PassManager &pm) { void populateCPURuntimePasses(mlir::PassManager &pm) { // + flatten nested parallel pass, down-stream pass, to support coarse-grain // fusion - // pm.addNestedPass(parallelcpu::createParallelCPUAtExitToOmp()); + pm.addNestedPass(cpuruntime::createCPURuntimeAtExitToOmp()); // remove this pass after we add FlattenNestedParallel pm.addPass(createConvertSCFToOpenMPPass()); } void populateLoweringToLLVMPasses(mlir::PassManager &pm) { pm.addPass(createConvertSCFToCFPass()); - // pm.addPass(parallelcpu::createParallelCPUToLLVM()); + pm.addPass(cpuruntime::createCPURuntimeToLLVM()); pm.addPass(createConvertOpenMPToLLVMPass()); pm.addNestedPass(createConvertMathToLLVMPass()); pm.addPass(createConvertMathToLibmPass()); diff --git a/test/gc/Transforms/Pipeline/run.mlir b/test/gc/Transforms/Pipeline/run.mlir index 799935006..71feb0843 100644 --- a/test/gc/Transforms/Pipeline/run.mlir +++ b/test/gc/Transforms/Pipeline/run.mlir @@ -15,7 +15,7 @@ func.func @main() { %c1 = arith.constant 1 : index scf.for %iv = %c0 to %c128 step %c1 { %4 = tensor.extract %result[%iv] : tensor<128xf32> - parallelcpu.printf "%f\n" %4 : f32 + cpuruntime.printf "%f\n" %4 : f32 } return } From 0ac087deb28d4ec1efcf53519d95e1700690467f Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 15 May 2024 17:02:27 +0800 Subject: [PATCH 08/64] fix --- src/gc-opt/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index 74eb4b28e..36ace6847 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -15,7 +15,9 @@ endif() set(gc_opt_libs ${dialect_libs} ${conversion_libs} - MLIROptLib) + ${MLIR_LINK_COMPONENTS} + GCPasses) + if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") endif() From 74b0d342fba8e5896a9ca07bf57b449f34911155 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 11:00:23 +0800 Subject: [PATCH 09/64] remove at exit --- .../gc/Dialect/CPURuntime/IR/CPURuntimeOps.td | 41 +--------- .../CPURuntime/Transforms/CPURuntimePasses.td | 35 --------- .../Dialect/CPURuntime/IR/CPURuntimeOps.cpp | 34 -------- .../CPURuntime/Transforms/CMakeLists.txt | 1 - .../Transforms/CPURuntimePasses.cpp | 78 ------------------- .../CPURuntime/cpuruntime-atexit-to-omp.mlir | 44 ----------- 6 files changed, 1 insertion(+), 232 deletions(-) delete mode 100644 lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp delete mode 100644 test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir diff --git a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td index cc1b7c555..bd77ad997 100644 --- a/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td +++ b/include/gc/Dialect/CPURuntime/IR/CPURuntimeOps.td @@ -10,46 +10,7 @@ #define CPURUNTIME_OPS include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.td" -include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" - - -def CPURuntime_AtParallelExitOp : CPURuntime_Op<"at_parallel_exit", [ - ParentOneOf<["scf::ForallOp", "scf::ParallelOp", "omp::WsloopOp", "memref::AllocaScopeOp"]>, - SingleBlockImplicitTerminator<"ParallelExitReturnOp"> - ]> { - let summary = "Runs the block once in all threads at the exit of the parallel section"; - let description = [{ - It executes the block for each thread working in the parallel section for - once, at the exit of parallel section. - }]; - - let regions = (region SizedRegion<1>:$region); - - let hasCustomAssemblyFormat = 1; - - // The default builder does not add a region with an empty body, add our own. - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins)>, - ]; -} - -def CPURuntime_ParallelExitReturnOp : CPURuntime_Op<"parallel_exit.return", [ - Pure, - HasParent<"AtParallelExitOp">, - Terminator, ReturnLike - ]> { - let summary = "Terminates at_parallel_exit block"; - let description = [{ - at_parallel_exit should ends with parallel_exit.return - }]; - let assemblyFormat = - [{ attr-dict }]; -} def CPURuntime_PrintfOp : CPURuntime_Op<"printf", [MemoryEffects<[MemWrite]>]>, @@ -61,7 +22,7 @@ def CPURuntime_PrintfOp : CPURuntime_Op<"printf", [MemoryEffects<[MemWrite]>]>, scalar arguments that should be printed. The format string is a C-style printf string, subject to any restrictions - imposed by one's target platform. + imposed by the target platform. }]; let assemblyFormat = [{ $format attr-dict ($args^ `:` type($args))? diff --git a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td index 0685ce498..20c81e10a 100644 --- a/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td +++ b/include/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.td @@ -11,41 +11,6 @@ include "mlir/Pass/PassBase.td" - -def CPURuntimeAtExitToOmp: Pass<"cpuruntime-atexit-to-omp", "::mlir::func::FuncOp"> { - let summary = "Lower at_parallel_exit to code in omp.parallel section"; - let description = [{ - Switches the name of a FuncOp named `bar` to `foo` and folds. - ``` - omp.parallel { - omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - cpuruntime.at_parallel_exit { - "your.op"() - cpuruntime.parallel_exit.return - } - } - omp.yield - } - omp.terminator - } - ``` - Will be changed into - ``` - omp.parallel { - omp.wsloop for (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - } - omp.yield - } - "your.op"() - omp.terminator - } - ``` - }]; -} - - def CPURuntimeToLLVM: Pass<"convert-cpuruntime-to-llvm"> { let summary = "Convert cpuruntime to LLVM dialect"; let description = [{ diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp index ca632e9db..ed7bc6581 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp @@ -15,42 +15,8 @@ #include namespace mlir { -using namespace bufferization; - namespace cpuruntime { -void AtParallelExitOp::build(OpBuilder &b, OperationState &result) { - OpBuilder::InsertionGuard g(b); - Region *bodyRegion = result.addRegion(); - b.createBlock(bodyRegion); -} - -void AtParallelExitOp::print(OpAsmPrinter &p) { - p << " "; - p.printRegion(getRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - p.printOptionalAttrDict(getOperation()->getAttrs()); -} - -ParseResult AtParallelExitOp::parse(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - - SmallVector regionOperands; - std::unique_ptr region = std::make_unique(); - if (parser.parseRegion(*region, regionOperands)) - return failure(); - - if (region->empty()) - OpBuilder(builder.getContext()).createBlock(region.get()); - result.addRegion(std::move(region)); - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} } // namespace cpuruntime } // namespace mlir \ No newline at end of file diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt index ee6148aa4..3bc84f6c8 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms - CPURuntimePasses.cpp CPURuntimeToLLVM.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp b/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp deleted file mode 100644 index a8f74c079..000000000 --- a/lib/gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.cpp +++ /dev/null @@ -1,78 +0,0 @@ -//===- CPURuntimePasses.cpp - CPU Runtime Passes ----------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" - -namespace mlir::cpuruntime { -#define GEN_PASS_DEF_CPURUNTIMEATEXITTOOMP -#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h.inc" - -namespace { - -class CPURuntimeAtExitToOmpRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtParallelExitOp op, - PatternRewriter &rewriter) const final { - auto parent = op->getParentOp(); - omp::ParallelOp parallel; - while (parent) { - parallel = llvm::dyn_cast(parent); - if (parallel) { - break; - } - parent = parent->getParentOp(); - } - if (!parallel) { - return failure(); - } - auto &block = parallel.getRegion().front(); - auto itr = block.end(); - --itr; - rewriter.inlineBlockBefore(&op->getRegion(0).getBlocks().front(), &block, - itr); - rewriter.eraseOp(op); - return success(); - } -}; - -class CPURuntimeExitReturnRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ParallelExitReturnOp op, - PatternRewriter &rewriter) const final { - rewriter.eraseOp(op); - return success(); - } -}; - -class CPURuntimeAtExitToOmp - : public impl::CPURuntimeAtExitToOmpBase { -public: - using impl::CPURuntimeAtExitToOmpBase< - CPURuntimeAtExitToOmp>::CPURuntimeAtExitToOmpBase; - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) - signalPassFailure(); - } -}; - -} // namespace -} // namespace mlir::cpuruntime diff --git a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir b/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir deleted file mode 100644 index 172777690..000000000 --- a/test/gc/Dialect/CPURuntime/cpuruntime-atexit-to-omp.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: gc-opt %s --cpuruntime-atexit-to-omp | FileCheck %s - -module { - func.func @parallel_insert_slice(%arg0: memref<512x512xf32>) -> memref<512x512xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<512x512xf32> - %c512 = arith.constant 512 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - memref.copy %arg0, %alloc : memref<512x512xf32> to memref<512x512xf32> - %0 = llvm.mlir.constant(1 : i64) : i64 - omp.parallel { - omp.wsloop { - omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c512) step (%c1, %c1) { - memref.alloca_scope { - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<512xf32> - %subview = memref.subview %alloc[%arg1, 0] [1, 512] [1, 1] : memref<512x512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.copy %alloc_0, %subview : memref<512xf32> to memref<512xf32, strided<[1], offset: ?>> - memref.dealloc %alloc_0 : memref<512xf32> - cpuruntime.at_parallel_exit { - memref.prefetch %alloc[%c1,%c0], read, locality<3>, data : memref<512x512xf32> - cpuruntime.parallel_exit.return - } - } - omp.yield - } - omp.terminator - } - memref.prefetch %alloc[%c0,%c0], read, locality<3>, data : memref<512x512xf32> - omp.terminator - } - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK: omp.parallel - // CHECK-NEXT: omp.wsloop - // CHECK: memref.alloca_scope - // CHECK-NOT: cpuruntime.at_parallel_exit - // CHECK: omp.yield - // CHECK: memref.prefetch {{%alloc}}[%[[C0]], %[[C0]]] - // CHECK-NEXT: memref.prefetch {{%alloc}}[%[[C1]], %[[C0]]] - // CHECK-NEXT: omp.terminator - return %alloc : memref<512x512xf32> - } -} From 2cebba99452bd68876269b77293838b3a263f112 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 11:01:58 +0800 Subject: [PATCH 10/64] fix lint --- lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp index ed7bc6581..460c421cc 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp +++ b/lib/gc/Dialect/CPURuntime/IR/CPURuntimeOps.cpp @@ -15,8 +15,5 @@ #include namespace mlir { -namespace cpuruntime { - - -} // namespace cpuruntime +namespace cpuruntime {} // namespace cpuruntime } // namespace mlir \ No newline at end of file From 34d10ea127052b423b3384c2ada95b3c30a1eb36 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 13:29:48 +0800 Subject: [PATCH 11/64] Add kmp_* wrapper for gomp environment --- lib/gc/CMakeLists.txt | 3 +- lib/gc/ExecutionEngine/CMakeLists.txt | 1 + .../ExecutionEngine/CPURuntime/CMakeLists.txt | 15 ++ .../ExecutionEngine/CPURuntime/Parallel.cpp | 189 ++++++++++++++++++ src/gc-cpu-runner/CMakeLists.txt | 3 +- src/gc-cpu-runner/gc-cpu-runner.cpp | 4 + test/gc/cpu-runner/tid.mlir | 37 ++++ 7 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 lib/gc/ExecutionEngine/CMakeLists.txt create mode 100644 lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt create mode 100644 lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp create mode 100644 test/gc/cpu-runner/tid.mlir diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index fd78d6cab..921853a08 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -5,4 +5,5 @@ endif() include(functions) add_subdirectory(Dialect) -add_subdirectory(Transforms) \ No newline at end of file +add_subdirectory(Transforms) +add_subdirectory(ExecutionEngine) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt new file mode 100644 index 000000000..8aa223412 --- /dev/null +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(CPURuntime) diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..6be58e28f --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -0,0 +1,15 @@ +find_package(OpenMP REQUIRED) + +if ("iomp" IN_LIST OpenMP_C_LIB_NAMES OR "omp" IN_LIST OpenMP_C_LIB_NAMES OR "omp5" IN_LIST OpenMP_C_LIB_NAMES) +else() + add_definitions("-DGC_NEEDS_OMP_WRAPPER=1") +endif() + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +add_mlir_library(GCCpuRuntime + SHARED + Parallel.cpp + + EXCLUDE_FROM_LIBMLIR + ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp new file mode 100644 index 000000000..a71b8522c --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -0,0 +1,189 @@ +//===- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +#define WEAK_SYMBOL __attribute__((weak)) + +namespace { +struct barrier_t { + alignas(64) std::atomic pending_; + std::atomic rounds_; + uint64_t total_; + // pad barrier to size of cacheline to avoid false sharing + char padding_[64 - 4 * sizeof(int32_t)]; +}; + +typedef uint64_t (*barrier_idle_func)(std::atomic *remaining, + int32_t expected_remain, int32_t tid, + void *args); +} // namespace + +extern "C" { +int gc_runtime_keep_alive = 0; +void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, + void *idle_args) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + assert(cnt >= 0); + // int count = 0; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + if (idle_func) { + if (cur_round != b->rounds_.load()) { + return; + } + idle_func(&b->rounds_, cur_round + 1, -1, idle_args); + // count = ret & 0xffffffff; + } + while (cur_round == b->rounds_.load()) { + _mm_pause(); + } + } +} + +static_assert(sizeof(barrier_t) == 64, "size of barrier_t should be 64-byte"); + +void gc_init_barrier(barrier_t *b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} + +#if GC_NEEDS_OMP_WRAPPER +void WEAK_SYMBOL __kmpc_barrier(void *loc, int32_t global_tid) { +#pragma omp barrier +} + +int WEAK_SYMBOL __kmpc_global_thread_num(void *loc) { + return omp_get_thread_num(); +} + +void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, + int32_t schedtype, + int32_t *plastiter, uint64_t *plower, + uint64_t *pupper, int64_t *pstride, + int64_t incr, int64_t chunk) { + if (unlikely(schedtype != 34)) { + std::abort(); + } + const int32_t FALSE = 0; + const int32_t TRUE = 0; + using UT = uint64_t; + // using ST = int64_t; + /* this all has to be changed back to TID and such.. */ + uint32_t tid = gtid; + uint32_t nth = omp_get_num_threads(); + UT trip_count; + + /* special handling for zero-trip loops */ + if (incr > 0 ? (*pupper < *plower) : (*plower < *pupper)) { + if (plastiter != nullptr) + *plastiter = FALSE; + /* leave pupper and plower set to entire iteration space */ + *pstride = incr; /* value should never be used */ + return; + } + + if (nth == 1) { + if (plastiter != nullptr) + *plastiter = TRUE; + *pstride = + (incr > 0) ? (*pupper - *plower + 1) : (-(*plower - *pupper + 1)); + return; + } + + /* compute trip count */ + if (incr == 1) { + trip_count = *pupper - *plower + 1; + } else if (incr == -1) { + trip_count = *plower - *pupper + 1; + } else if (incr > 0) { + // upper-lower can exceed the limit of signed type + trip_count = (UT)(*pupper - *plower) / incr + 1; + } else { + trip_count = (UT)(*plower - *pupper) / (-incr) + 1; + } + if (trip_count < nth) { + if (tid < trip_count) { + *pupper = *plower = *plower + tid * incr; + } else { + // set bounds so non-active threads execute no iterations + *plower = *pupper + (incr > 0 ? 1 : -1); + } + if (plastiter != nullptr) + *plastiter = (tid == trip_count - 1); + } else { + UT small_chunk = trip_count / nth; + UT extras = trip_count % nth; + *plower += incr * (tid * small_chunk + (tid < extras ? tid : extras)); + *pupper = *plower + small_chunk * incr - (tid < extras ? 0 : incr); + if (plastiter != nullptr) + *plastiter = (tid == nth - 1); + } + *pstride = trip_count; +} + +void WEAK_SYMBOL __kmpc_for_static_fini(void *ptr, int32_t v) {} + +static thread_local int next_num_threads = 0; + +/*! +@ingroup PARALLEL +The type for a microtask which gets passed to @ref __kmpc_fork_call(). +The arguments to the outlined function are +@param global_tid the global thread identity of the thread executing the +function. +@param bound_tid the local identity of the thread executing the function +@param ... pointers to shared variables accessed by the function. +*/ +using kmpc_micro = void (*)(int32_t *global_tid, int32_t *bound_tid, ...); +void WEAK_SYMBOL __kmpc_fork_call(void *loc, int32_t argc, void *pfunc, ...) { + if (unlikely(argc != 1 && argc != 0)) { + std::abort(); + } + va_list ap; + va_start(ap, pfunc); + void *c = va_arg(ap, void *); + int32_t global_tid = 0; + if (unlikely(next_num_threads)) { +#pragma omp parallel num_threads(next_num_threads) + { + kmpc_micro func = (kmpc_micro)(pfunc); + func(&global_tid, nullptr, c); + } + next_num_threads = 0; + } else { +#pragma omp parallel + { + kmpc_micro func = (kmpc_micro)(pfunc); + func(&global_tid, nullptr, c); + } + } + va_end(ap); +} + +void WEAK_SYMBOL __kmpc_push_num_threads(void *loc, int32_t global_tid, + int32_t num_threads) { + next_num_threads = num_threads; +} +#endif +} diff --git a/src/gc-cpu-runner/CMakeLists.txt b/src/gc-cpu-runner/CMakeLists.txt index f3f768612..85dbb6995 100644 --- a/src/gc-cpu-runner/CMakeLists.txt +++ b/src/gc-cpu-runner/CMakeLists.txt @@ -36,7 +36,8 @@ endif() #LLVM_LINK_COMPONENTS is processed by LLVM cmake in add_llvm_executable set(gc_cpu_runner_libs - ${MLIR_LINK_COMPONENTS}) + ${MLIR_LINK_COMPONENTS} + GCCpuRuntime) add_mlir_tool(gc-cpu-runner gc-cpu-runner.cpp ) diff --git a/src/gc-cpu-runner/gc-cpu-runner.cpp b/src/gc-cpu-runner/gc-cpu-runner.cpp index 3ece8f2ff..353abffe9 100644 --- a/src/gc-cpu-runner/gc-cpu-runner.cpp +++ b/src/gc-cpu-runner/gc-cpu-runner.cpp @@ -27,7 +27,11 @@ #include "llvm/Support/TargetSelect.h" #include +extern int gc_runtime_keep_alive; + int main(int argc, char **argv) { + // keeps GCCPURuntime linked + gc_runtime_keep_alive = 0; llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); diff --git a/test/gc/cpu-runner/tid.mlir b/test/gc/cpu-runner/tid.mlir new file mode 100644 index 000000000..aedcc0a20 --- /dev/null +++ b/test/gc/cpu-runner/tid.mlir @@ -0,0 +1,37 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm --convert-openmp-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --reconcile-unrealized-casts | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s +module { + func.func private @omp_get_thread_num() -> i32 + + func.func @check_parallel() { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %0 = llvm.mlir.constant(1 : i64) : i64 + omp.parallel num_threads(%c8: index) { + omp.wsloop { + omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c64) step (%c1, %c1) { + cpuruntime.printf "ITR %zu\n" %arg2 : index + omp.yield + } + omp.terminator + } + %tid = func.call @omp_get_thread_num() : () -> i32 + cpuruntime.printf "EXIT %d\n" %tid : i32 + omp.terminator + } + return + } + + func.func @main() { + %0 = func.call @omp_get_thread_num() : () -> i32 + cpuruntime.printf "TID %d\n" %0 : i32 + call @check_parallel() : ()->() + return + } + // CHECK: TID 0 + // CHECK-COUNT-64: ITR {{[0-9]+}} + // CHECK-NOT: ITR + // CHECK-COUNT-8: EXIT {{[0-9]+}} + // CHECK-NOT: EXIT +} \ No newline at end of file From 80a597f0887889f1721c83e502de97e0c05dac97 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 14:18:55 +0800 Subject: [PATCH 12/64] fix --- lib/gc/Transforms/Pipeline.cpp | 7 +++---- test/gc/Transforms/Pipeline/tensor_args.mlir | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 8b74df1b9..b336684a1 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -66,13 +66,13 @@ void populateBufferizationPasses(mlir::PassManager &pm) { pm.addPass(bufferization::createOneShotBufferizePass(options)); pm.addPass(createCSEPass()); pm.addPass(mlir::func::createFuncBufferizePass()); - bufferization::BufferResultsToOutParamsOpts opt{}; - // opt.hoistStaticAllocs = true; - pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); pm.addNestedPass( bufferization::createBufferizationBufferizePass()); pm.addNestedPass( bufferization::createFinalizingBufferizePass()); + bufferization::BufferResultsToOutParamsOpts opt{}; + opt.hoistStaticAllocs = true; + pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); // + buffer schedule pass, down-stream pass, to migrate buffer reschedule pass // from GC V1. pm.addNestedPass( @@ -101,7 +101,6 @@ void populateMicroKernelPasses(mlir::PassManager &pm) { void populateCPURuntimePasses(mlir::PassManager &pm) { // + flatten nested parallel pass, down-stream pass, to support coarse-grain // fusion - pm.addNestedPass(cpuruntime::createCPURuntimeAtExitToOmp()); // remove this pass after we add FlattenNestedParallel pm.addPass(createConvertSCFToOpenMPPass()); } diff --git a/test/gc/Transforms/Pipeline/tensor_args.mlir b/test/gc/Transforms/Pipeline/tensor_args.mlir index 73d916d04..adcfb3bd8 100644 --- a/test/gc/Transforms/Pipeline/tensor_args.mlir +++ b/test/gc/Transforms/Pipeline/tensor_args.mlir @@ -7,7 +7,7 @@ module { func.func @aaa(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - // CHECK: memcpy - return %out : tensor<128xf32> + // CHECK-NOT: memcpy + return %2 : tensor<128xf32> } } \ No newline at end of file From 0b4332b2fe3a1e0c481fa1bcfb67a86130241ac1 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 16 May 2024 17:06:44 +0800 Subject: [PATCH 13/64] fix --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index a71b8522c..4a25a5ee0 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -40,7 +40,6 @@ void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, auto cur_round = b->rounds_.load(std::memory_order_acquire); auto cnt = --b->pending_; assert(cnt >= 0); - // int count = 0; if (cnt == 0) { b->pending_.store(b->total_); b->rounds_.store(cur_round + 1); @@ -50,7 +49,6 @@ void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, return; } idle_func(&b->rounds_, cur_round + 1, -1, idle_args); - // count = ret & 0xffffffff; } while (cur_round == b->rounds_.load()) { _mm_pause(); @@ -86,7 +84,7 @@ void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, std::abort(); } const int32_t FALSE = 0; - const int32_t TRUE = 0; + const int32_t TRUE = 1; using UT = uint64_t; // using ST = int64_t; /* this all has to be changed back to TID and such.. */ From b1c79a292428f7a34670c2b9b2c43d4529b98acd Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:45:07 +0800 Subject: [PATCH 14/64] add wraper --- .../gc/ExecutionEngine/JitWrapper/Module.hpp | 44 +++++++++++ lib/gc/ExecutionEngine/CMakeLists.txt | 1 + .../ExecutionEngine/JitWrapper/CMakeLists.txt | 41 ++++++++++ lib/gc/ExecutionEngine/JitWrapper/Module.cpp | 79 +++++++++++++++++++ lib/gc/Transforms/Pipeline.cpp | 1 + unittests/CMakeLists.txt | 1 + unittests/ExecutionEngine/CMakeLists.txt | 7 ++ unittests/ExecutionEngine/JitWrapper.cpp | 69 ++++++++++++++++ 8 files changed, 243 insertions(+) create mode 100644 include/gc/ExecutionEngine/JitWrapper/Module.hpp create mode 100644 lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt create mode 100644 lib/gc/ExecutionEngine/JitWrapper/Module.cpp create mode 100644 unittests/ExecutionEngine/CMakeLists.txt create mode 100644 unittests/ExecutionEngine/JitWrapper.cpp diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/JitWrapper/Module.hpp new file mode 100644 index 000000000..ec259680f --- /dev/null +++ b/include/gc/ExecutionEngine/JitWrapper/Module.hpp @@ -0,0 +1,44 @@ +//===- Module.h - Jit module and Execution engine wrapper -------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_EXECUTIONENGINE_JITWRAPPER_H +#define GC_EXECUTIONENGINE_JITWRAPPER_H + +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include +#include + +namespace mlir { +class DialectRegistry; +namespace gc { + +const DialectRegistry &initAndGetDialects(); + +using JitModuleFuncT = void (*)(void **); + +class JitModule : public std::enable_shared_from_this { +public: + static llvm::Expected> + create(Operation *op, bool transform, llvm::StringRef entry_name = {}, + const ExecutionEngineOptions &options = {}, + std::unique_ptr tm = nullptr); + + void call(void **args) { entry(args); } + + JitModule(std::unique_ptr engine, JitModuleFuncT entry); + ~JitModule(); + +private: + std::unique_ptr engine; + JitModuleFuncT entry; +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt index 8aa223412..a1592d74d 100644 --- a/lib/gc/ExecutionEngine/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(CPURuntime) +add_subdirectory(JitWrapper) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt b/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt new file mode 100644 index 000000000..3186b018d --- /dev/null +++ b/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt @@ -0,0 +1,41 @@ +if(GC_DEV_LINK_LLVM_DYLIB) + set(LLVM_LINK_COMPONENTS + LLVM + ) + get_property(dialect_libs GLOBAL PROPERTY GC_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY GC_PASS_LIBS) + set(MLIR_LINK_COMPONENTS + MLIR + MLIRExecutionEngineShared + ) +else() + set(LLVM_LINK_COMPONENTS + Core + Support + nativecodegen + native + ) + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + set(MLIR_LINK_COMPONENTS + MLIRBuiltinToLLVMIRTranslation + MLIRExecutionEngine + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRToLLVMIRTranslationRegistration + ) +endif() + +add_mlir_library(GCJitWrapper + Module.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} + ${dialect_libs} + ${conversion_libs} + GCPasses + ) + diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp new file mode 100644 index 000000000..7c80cc993 --- /dev/null +++ b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp @@ -0,0 +1,79 @@ +//===- Module.cpp -----------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Passes.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" + +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" + +namespace mlir { +namespace gc { + +static DialectRegistry initDialects() { + mlir::registerAllPasses(); + mlir::gc::registerGraphCompilerPasses(); + mlir::cpuruntime::registerCPURuntimePasses(); + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerAllDialects(registry); + mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + registry.insert(); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + mlir::registerAllToLLVMIRTranslations(registry); + return registry; +} + +const DialectRegistry &initAndGetDialects() { + static DialectRegistry reg = initDialects(); + return reg; +} + +static const char defaultEntryName[] = "_mlir_ciface_main_entry"; +llvm::Expected> +JitModule::create(Operation *op, bool transform, llvm::StringRef entry_name, + const ExecutionEngineOptions &options, + std::unique_ptr tm) { + if (transform) { + mlir::PassManager pm{op->getContext()}; + populateCPUPipeline(pm); + if (auto result = pm.run(op); failed(result)) { + return llvm::make_error( + "MLIR pass error", llvm::inconvertibleErrorCode()); + } + } + auto exec = ExecutionEngine::create(op, options, std::move(tm)); + if (!exec) { + return exec.takeError(); + } + auto &engine = *exec; + if (entry_name.empty()) { + entry_name = defaultEntryName; + } + auto mainEntry = engine->lookupPacked(entry_name); + if (!mainEntry) { + return mainEntry.takeError(); + } + return std::make_shared(std::move(engine), *mainEntry); +} + +JitModule::JitModule(std::unique_ptr engine, + JitModuleFuncT entry) + : engine{std::move(engine)}, entry{entry} {} +JitModule::~JitModule() = default; + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index b336684a1..e14f86589 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index c93735c63..a9bf31c38 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -18,4 +18,5 @@ function(add_mlir_unittest test_dirname) endfunction() add_subdirectory(Example) +add_subdirectory(ExecutionEngine) diff --git a/unittests/ExecutionEngine/CMakeLists.txt b/unittests/ExecutionEngine/CMakeLists.txt new file mode 100644 index 000000000..0e7315a0f --- /dev/null +++ b/unittests/ExecutionEngine/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(GCExecutionEngineTests + JitWrapper.cpp +) +target_link_libraries(GCExecutionEngineTests + PRIVATE + GCJitWrapper + GCCpuRuntime) diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp new file mode 100644 index 000000000..0653c97ee --- /dev/null +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -0,0 +1,69 @@ +//===- JitWrapper.cpp - Wrapper of JIT ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/ExecutionEngine/MemRefUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace mlir; + +static const char code1[] = R"mlir( +module { +func.func @main_entry(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + return %2 : tensor<128xf32> +} +} +)mlir"; + +extern "C" { +extern int gc_runtime_keep_alive; +} + +TEST(ExecutionEngine, JitWrapper) { + gc_runtime_keep_alive = 0; + MLIRContext ctx{gc::initAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code1); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get(), true); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{{128}, {128}}; + void *args[] = {&*bufA, &*bufB, &*bufC}; + void *pargs[] = {&args[0], &args[1], &args[2]}; + jited.get()->call(pargs); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufC[{i}], 1.0f + i); + } +} From 382171bf38ec16e7b9f3edeb0d07834800aba98d Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:49:08 +0800 Subject: [PATCH 15/64] fix lint --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 7 +++---- src/gc-cpu-runner/CMakeLists.txt | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index 4a25a5ee0..6efb38142 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,5 +1,4 @@ //===- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// -//-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -28,9 +27,9 @@ struct barrier_t { char padding_[64 - 4 * sizeof(int32_t)]; }; -typedef uint64_t (*barrier_idle_func)(std::atomic *remaining, - int32_t expected_remain, int32_t tid, - void *args); +using barrier_idle_func = uint64_t (*)(std::atomic *remaining, + int32_t expected_remain, int32_t tid, + void *args); } // namespace extern "C" { diff --git a/src/gc-cpu-runner/CMakeLists.txt b/src/gc-cpu-runner/CMakeLists.txt index 85dbb6995..2599eef84 100644 --- a/src/gc-cpu-runner/CMakeLists.txt +++ b/src/gc-cpu-runner/CMakeLists.txt @@ -1,3 +1,20 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + if(GC_DEV_LINK_LLVM_DYLIB) set(LLVM_LINK_COMPONENTS LLVM From f1fd0ae69fff9b17c8d7bf354e3e160a5d3be0aa Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:51:25 +0800 Subject: [PATCH 16/64] fix --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index 6efb38142..d81e8d3a9 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,4 +1,4 @@ -//===- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// +//===-- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From a773ea6bcbcfebc1370aa78552a703278ee502da Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:52:29 +0800 Subject: [PATCH 17/64] f --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index d81e8d3a9..ea7641417 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,9 +1,9 @@ -//===-- Parallel.cpp - Definitions for parallel runtime -----------*- C++ -*-=// -// +//===-- Parallel.cpp - parallel ---------------------------------*- C++ -*-===// +// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// +// //===----------------------------------------------------------------------===// #include From 84933c21b7f2c73361953166add0e1143238906d Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 16:57:14 +0800 Subject: [PATCH 18/64] fix --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index ea7641417..5591dc3af 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -1,9 +1,9 @@ //===-- Parallel.cpp - parallel ---------------------------------*- C++ -*-===// -// +// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// +// //===----------------------------------------------------------------------===// #include From 4cca4dfe7bbcf9dbe4ddab9ff66c0c4ff59ffa86 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 23 May 2024 17:11:32 +0800 Subject: [PATCH 19/64] add reference --- lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp index 5591dc3af..3a5b4c2c1 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -74,6 +74,8 @@ int WEAK_SYMBOL __kmpc_global_thread_num(void *loc) { return omp_get_thread_num(); } +// The implementation was extracted and simplified from LLVM libomp +// at openmp/runtime/src/kmp_sched.cpp void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, int32_t schedtype, int32_t *plastiter, uint64_t *plower, From 678cef9885605e7d94199c0d1e0cecf6527fc6c9 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Fri, 24 May 2024 16:27:35 +0800 Subject: [PATCH 20/64] enable const cache --- .../CPURuntime/ConstantCache.hpp | 137 ++++++++++++ .../gc/ExecutionEngine/JitWrapper/Module.hpp | 44 +++- .../ExecutionEngine/CPURuntime/CMakeLists.txt | 1 + .../CPURuntime/ConstantCache.cpp | 40 ++++ lib/gc/ExecutionEngine/JitWrapper/Module.cpp | 209 ++++++++++++++++-- unittests/ExecutionEngine/JitWrapper.cpp | 114 +++++++++- 6 files changed, 518 insertions(+), 27 deletions(-) create mode 100644 include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp create mode 100644 lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp new file mode 100644 index 000000000..120d1d69f --- /dev/null +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp @@ -0,0 +1,137 @@ +#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include +#include +#include + +namespace mlir { +namespace gc { +/** + * The helper class to manage ref count manually for an object allocated with + * shared ptr. It holds an additional shared ptr reference to the object and + * contains an additional self-managed refcount. The refcount will be set to 1 + * when the object is initialized (see init()). When the refcount counts down to + * 0, the additional shared ptr is reset. + */ +struct ref_count_managed { + ref_count_managed() = default; + ref_count_managed(const std::shared_ptr &keep_alive) { + init(keep_alive); + } + void init(const std::shared_ptr &keep_alive) { + keep_alive_ = keep_alive; + ref_count_.store(1); + } + + void ref() { ++ref_count_; } + void deref() { + auto newv = --ref_count_; + if (newv == 0) { + keep_alive_ = nullptr; + } + } + + // atomically check if ref_count_ > 0. if so, ref() the object and return + // true. Otherwise (if ref_count_==0), return false + bool check_alive_and_ref() { + auto oldv = ref_count_.load(); + for (;;) { + if (oldv <= 0) { + return false; + } + if (ref_count_.compare_exchange_strong(oldv, oldv + 1)) { + return true; + } + // CAS failed, oldv has now the newest known value of ref_count_ + } + } + + bool is_alive() const { return ref_count_ > 0; } + void *unsafe_get_ptr() const { return keep_alive_.get(); } + +private: + std::shared_ptr keep_alive_; + std::atomic ref_count_{0}; +}; + +/** + * The proxy for the constant cache of Graph API. It holds a shared ptr pointing + * to the cache item in the cache manager (keep_alive) to extend the lifetime by + * refcount, @see ref_count_managed. To access the memory buffer of the const + * cache, use sc_acquire_const_cache and sc_release_const_cache functions. They + * will ref/deref the const_cache_proxy to make sure the cache is alive after + * calling sc_acquire_const_cache and before sc_release_const_cache. The cache + * manager of Graph API may evict the cache item by dereferenceing this + * ref_count_managed object. sc_{acquire,release}_const_cache functions will + * find out that the cache has been invalidated and they will then use the + * memory allocator in the runtime::stream_t to re-allocate the buffer. Usually + * we expect JIT modules to hold shared ptr to const_cache_proxy via + * cached_const_graph_tensor. + * If is_lazy_ == true, the cache item's lifetime will be managed by the cache + * manager of Graph API and it is filled with data after the first execution of + * the computation. Otherwise, the cache item is always alive as long as the + * jit_module of the kernel is alive. + */ +struct const_cache_proxy : ref_count_managed { + const_cache_proxy(const std::shared_ptr &keep_alive, void *buffer, + size_t size, bool is_lazy) + : ref_count_managed(keep_alive), size_(size), is_lazy_(is_lazy), + buffer_(buffer) {} + ~const_cache_proxy(); + + // get the buffer and increment the refcount. If the buffer is evicted, + // returns null + void *acquire(int32_t *inited) { + if (check_alive_and_ref()) { + *inited = *inited && initialized_; + return buffer_; + } + return nullptr; + } + // decrement the refcount + bool release() { + if (is_alive()) { + deref(); + initialized_ = 1; + return true; + } + return false; + } + + // return the buffer. Do not directly use the buffer because it may be already + // release! To access the buffer, always acquire() before using it. + void *get_buffer_unsafe() const { return buffer_; } + + size_t size_; + // if the buffer is lazy-initialized. If false, it should be filled before + // computation + bool is_lazy_; + +private: + // raw pointer to the buffer + void *buffer_; + // if the buffer has been initialized. calling release() will set this to 1 + int32_t initialized_ = 0; +}; + +struct cached_graph_tensor { + std::shared_ptr base; + size_t offset; + cached_graph_tensor(const std::shared_ptr &base, + size_t offset); + friend class JitModule; + +private: + StridedMemRefType ref; +}; + +std::shared_ptr query_cached_tensor(uint64_t key); +bool reg_cached_tensor(uint64_t key, + const std::shared_ptr &base, + size_t offset); + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/JitWrapper/Module.hpp index ec259680f..afa9fb541 100644 --- a/include/gc/ExecutionEngine/JitWrapper/Module.hpp +++ b/include/gc/ExecutionEngine/JitWrapper/Module.hpp @@ -9,6 +9,8 @@ #ifndef GC_EXECUTIONENGINE_JITWRAPPER_H #define GC_EXECUTIONENGINE_JITWRAPPER_H +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include #include @@ -19,23 +21,53 @@ namespace gc { const DialectRegistry &initAndGetDialects(); +// the pointers to XXXMemRefType +using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); class JitModule : public std::enable_shared_from_this { public: static llvm::Expected> - create(Operation *op, bool transform, llvm::StringRef entry_name = {}, - const ExecutionEngineOptions &options = {}, - std::unique_ptr tm = nullptr); + create(Operation *op, const ExecutionEngineOptions &options = {}, + std::unique_ptr tm = nullptr, + bool transform = true); - void call(void **args) { entry(args); } + // args should be an array of XXXMemrefType* + void call(GeneralMemrefPtr *args); - JitModule(std::unique_ptr engine, JitModuleFuncT entry); + JitModule( + std::unique_ptr engine, JitModuleFuncT compute, + JitModuleFuncT fold, size_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive = {}); ~JitModule(); private: std::unique_ptr engine; - JitModuleFuncT entry; + JitModuleFuncT compute; + JitModuleFuncT fold; + size_t numOrigArgs; + // `keepAlive` has the ownership of the objects pointed by this vector + llvm::SmallVector cacheBases; + struct CacheBufferInfo { + // index in cacheBases + size_t baseIdx; + size_t offset; + }; + // the info for each folded cached buffer + llvm::SmallVector cacheInfo; + // holding the pointers to StridedMemRefType of folded cache + // `keepAlive` holds the the ownership of the pointers + llvm::SmallVector fastFoldBuffers; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs; + + std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 6be58e28f..97f039834 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") add_mlir_library(GCCpuRuntime SHARED Parallel.cpp + ConstantCache.cpp EXCLUDE_FROM_LIBMLIR ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp new file mode 100644 index 000000000..3be4b7dce --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -0,0 +1,40 @@ +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include +#include + +namespace mlir::gc { + +const_cache_proxy::~const_cache_proxy() = default; + + +cached_graph_tensor::cached_graph_tensor( + const std::shared_ptr &base, size_t offset) + : base{base}, offset{offset} { + // todo: fill in real values + ref.basePtr = (char *)base->get_buffer_unsafe() + offset; + ref.data = ref.basePtr; + ref.offset = 0; + memset(ref.sizes, 0, sizeof(ref.sizes)); + memset(ref.strides, 0, sizeof(ref.strides)); +} + +static std::unordered_map> cache; + +std::shared_ptr query_cached_tensor(uint64_t key) { + auto itr = cache.find(key); + if (itr != cache.end()) { + return itr->second; + } + return nullptr; +} + +bool reg_cached_tensor(uint64_t key, + const std::shared_ptr &base, + size_t offset) { + if (query_cached_tensor(key)) { + return false; + } + cache[key] = std::make_shared(base, offset); + return true; +} +} // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp index 7c80cc993..fa4a8151b 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp +++ b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp @@ -7,17 +7,20 @@ //===----------------------------------------------------------------------===// #include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" -#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "mlir/Target/LLVMIR/Dialect/All.h" - +#include "string.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + namespace mlir { namespace gc { @@ -38,15 +41,15 @@ static DialectRegistry initDialects() { } const DialectRegistry &initAndGetDialects() { - static DialectRegistry reg = initDialects(); - return reg; + static DialectRegistry reg = initDialects(); + return reg; } -static const char defaultEntryName[] = "_mlir_ciface_main_entry"; +static const char defaultComputeName[] = "_mlir_ciface_compute"; +static const char defaultFoldName[] = "_mlir_ciface_fold"; llvm::Expected> -JitModule::create(Operation *op, bool transform, llvm::StringRef entry_name, - const ExecutionEngineOptions &options, - std::unique_ptr tm) { +JitModule::create(Operation *op, const ExecutionEngineOptions &options, + std::unique_ptr tm, bool transform) { if (transform) { mlir::PassManager pm{op->getContext()}; populateCPUPipeline(pm); @@ -60,20 +63,192 @@ JitModule::create(Operation *op, bool transform, llvm::StringRef entry_name, return exec.takeError(); } auto &engine = *exec; - if (entry_name.empty()) { - entry_name = defaultEntryName; + uint32_t numOrigArgs; + { + auto expectArgs = engine->lookup("__num_orig_num_args"); + if (!expectArgs) { + return expectArgs.takeError(); + } + numOrigArgs = *reinterpret_cast(*expectArgs); } - auto mainEntry = engine->lookupPacked(entry_name); - if (!mainEntry) { - return mainEntry.takeError(); + JitModuleFuncT compute; + { + auto expectCompute = engine->lookupPacked(defaultComputeName); + if (!expectCompute) { + return expectCompute.takeError(); + } + compute = *expectCompute; + } + llvm::ArrayRef foldBufferIds; + JitModuleFuncT fold = nullptr; + llvm::ArrayRef computeArgs; + llvm::ArrayRef foldArgs; + do { + auto expectBufferIds = engine->lookup("__fold_buffer_ids"); + if (!expectBufferIds) { + // nothing to fold, It is OK. + llvm::consumeError(expectBufferIds.takeError()); + // break out of the scope, don't need to lookup "fold" function + break; + } else { + auto raw = reinterpret_cast(*expectBufferIds); + foldBufferIds = llvm::ArrayRef{raw + 1, raw[0]}; + } + + // find "fold" func + { + auto expectFold = engine->lookupPacked(defaultFoldName); + if (!expectFold) { + return expectFold.takeError(); + } + fold = *expectFold; + } + + // find "foldArgs" + { + auto expectFold = engine->lookup("__fold_args"); + if (!expectFold) { + return expectFold.takeError(); + } + auto raw = reinterpret_cast(*expectFold); + foldArgs = llvm::ArrayRef{raw + 1, raw[0]}; + } + + // find "computeArgs" + { + auto expect = engine->lookup("__compute_args"); + if (!expect) { + return expect.takeError(); + } + auto raw = reinterpret_cast(*expect); + computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; + } + } while (0); + + std::vector> foldInfo; + foldInfo.reserve(foldBufferIds.size()); + for (auto bufId : foldBufferIds) { + auto ret = query_cached_tensor(bufId); + if (!ret) { + return llvm::make_error( + "Failed to query the folded cached tensor", + llvm::inconvertibleErrorCode()); + } + foldInfo.emplace_back(std::move(ret)); } - return std::make_shared(std::move(engine), *mainEntry); + + return std::make_shared(std::move(engine), compute, fold, + numOrigArgs, computeArgs, foldArgs, + std::move(foldInfo)); } -JitModule::JitModule(std::unique_ptr engine, - JitModuleFuncT entry) - : engine{std::move(engine)}, entry{entry} {} +JitModule::JitModule( + std::unique_ptr engine, JitModuleFuncT compute, + JitModuleFuncT fold, size_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive) + : engine{std::move(engine)}, compute{compute}, fold{fold}, + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, computeArgs{computeArgs}, + keepAlive{std::move(cachekeepAlive)} { + for (const auto &cache : keepAlive) { + auto currentItr = + std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); + if (currentItr == cacheBases.end()) { + cacheBases.push_back(cache->base.get()); + currentItr = cacheBases.end() - 1; + } + cacheInfo.emplace_back(CacheBufferInfo{ + static_cast(currentItr - cacheBases.begin()), cache->offset}); + fastFoldBuffers.push_back(&cache->ref); + } +} JitModule::~JitModule() = default; +static void prepareCallArgs(llvm::SmallVector &realargs, + GeneralMemrefPtr *origargs, size_t numOrigArgs, + GeneralMemrefPtr *foldedCache, + llvm::ArrayRef realArgIdx) { + realargs.reserve(realArgIdx.size()); + for (auto argIdx : realArgIdx) { + if (argIdx < numOrigArgs) { + realargs.push_back(&origargs[argIdx]); + } else { + realargs.push_back(&foldedCache[argIdx - numOrigArgs]); + } + } +} + +void JitModule::call(GeneralMemrefPtr *args) { + if (unlikely(cacheInfo.empty())) { + // fast path, no folded cached buffers + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numOrigArgs); + for (size_t i = 0; i < numOrigArgs; i++) { + realargs.push_back(&args[i]); + } + compute(realargs.data()); + return; + } + + // stage 1, acquire the foldBasePtr + llvm::SmallVector foldBasePtr; + int32_t inited = 1; + for (auto b : cacheBases) { + auto ptr = b->acquire(&inited); + if (unlikely(!ptr)) { + ptr = std::aligned_alloc(/*alignment*/ 64, b->size_); + inited = 0; + } + foldBasePtr.push_back((char *)ptr); + } + + // stage 2, run fold() if necessary + GeneralMemrefPtr *foldedCache; + // only used when !inited + std::vector slowFold; + std::vector> slowFoldObj; + if (likely(inited)) { + foldedCache = fastFoldBuffers.data(); + } else { + slowFold.reserve(cacheInfo.size()); + slowFoldObj.reserve(cacheInfo.size()); + for (auto &info : cacheInfo) { + slowFoldObj.emplace_back(); + auto &obj = slowFoldObj.back(); + obj.basePtr = foldBasePtr[info.baseIdx] + info.offset; + obj.data = obj.basePtr; + memset(obj.sizes, 0, sizeof(obj.sizes)); + memset(obj.strides, 0, sizeof(obj.strides)); + slowFold.push_back(&obj); + } + foldedCache = slowFold.data(); + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numOrigArgs, foldedCache, foldArgs); + fold(realargs.data()); + } + + // stage 3, call compute + { + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numOrigArgs, foldedCache, computeArgs); + compute(realargs.data()); + } + + // stage 4, cleanup + for (size_t i = 0; i < cacheBases.size(); i++) { + auto b = cacheBases[i]; + if (unlikely(!b->release())) { + // if the cached buffer is already free'd, foldBasePtr[i] is allocated via + // std::aligned_alloc by us, free it + std::free(foldBasePtr[i]); + } + } +} + } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 0653c97ee..d1a07c86f 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -19,12 +19,14 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "gtest/gtest.h" +#include using namespace mlir; static const char code1[] = R"mlir( module { -func.func @main_entry(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { +llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32 +func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> return %2 : tensor<128xf32> @@ -47,7 +49,7 @@ TEST(ExecutionEngine, JitWrapper) { mlir::OwningOpRef module = mlir::parseSourceFile(sourceMgr, &ctx); ASSERT_TRUE(module); - auto jited = gc::JitModule::create(module.get(), true); + auto jited = gc::JitModule::create(module.get()); bool jit_success = static_cast(jited); if (!jit_success) { auto err = jited.takeError(); @@ -61,9 +63,113 @@ TEST(ExecutionEngine, JitWrapper) { {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; OwningMemRef bufC{{128}, {128}}; void *args[] = {&*bufA, &*bufB, &*bufC}; - void *pargs[] = {&args[0], &args[1], &args[2]}; - jited.get()->call(pargs); + jited.get()->call(args); for (int i = 0; i < 128; i++) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } + +// compute d = (a+a) + (b+b) + c, where a,b is marked constant +// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 +static const char code2[] = R"mlir( +module { +llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 +llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> +// a,b, foldedA,foldedB +llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> +// foldedA, foldedB, c, d +llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> + +func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { + %c0 = arith.constant 0 : index + cpuruntime.printf "HI%zu\n" %c0 : index + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %out2 = tensor.empty() : tensor<128xf32> + %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> + return %2, %3 : tensor<128xf32>, tensor<128xf32> +} + +func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> + return %d : tensor<128xf32> +} +} +)mlir"; + +TEST(ExecutionEngine, JitWrapperCached) { + MLIRContext ctx{gc::initAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code2); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + + // foldedA and foldedB uses this buffer + auto ret = std::shared_ptr(new float[128 * 2]); + auto proxy = std::make_shared( + ret, ret.get(), 128 * 2 * sizeof(float), true); + + ASSERT_TRUE(gc::reg_cached_tensor(114514, proxy, 0)); + ASSERT_TRUE(gc::reg_cached_tensor(1919810, proxy, 128 * sizeof(float))); + + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get()); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{ + {128}, {128}, [](float &ptr, ArrayRef idx) { + ptr = -idx[0] * 3; + }}; + OwningMemRef bufD{{128}, {128}}; + void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; + + // first call, should run fold() + { + testing::internal::CaptureStdout(); + // first call, should run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_EQ(output, "HI0\n"); + } + + { + testing::internal::CaptureStdout(); + // second call, should not run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_TRUE(output.empty()); + } + + // the cache is evicted + proxy->deref(); + { + testing::internal::CaptureStdout(); + // third call, should run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_EQ(output, "HI0\n"); + } +} From c12156c02c77f0e4c7a1923f829b340acb7afc5e Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Fri, 24 May 2024 16:36:54 +0800 Subject: [PATCH 21/64] reduce size --- lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt | 4 +++- lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt index 3a1d63d3d..8d92804d6 100644 --- a/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect) + add_mlir_dialect_library(MLIRCPURuntimeDialect CPURuntimeDialect.cpp CPURuntimeOps.cpp @@ -10,5 +12,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIRFuncDialect + ${MLIR_LINK_COMPONENTS} ) diff --git a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt index 3bc84f6c8..52c0d7441 100644 --- a/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt +++ b/lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt @@ -1,3 +1,5 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect) + add_mlir_dialect_library(MLIRCPURuntimeTransforms CPURuntimeToLLVM.cpp @@ -8,7 +10,7 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms MLIRCPURuntimePassesIncGen LINK_LIBS PUBLIC - MLIRFuncDialect + ${MLIR_LINK_COMPONENTS} MLIRCPURuntimeDialect ) From 6219935a44713c214efc4c7bc7dd5c9d04e74cbe Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Mon, 27 May 2024 13:43:29 +0800 Subject: [PATCH 22/64] Add single operand check --- lib/gc/Transforms/CST.cpp | 25 +++++++++- .../test_constant_weights_folding-1.mlir | 50 +++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 test/gc/Transforms/test_constant_weights_folding-1.mlir diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index 2dac0d860..73dd075cc 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -66,6 +66,23 @@ int64_t getTensorSize(TensorType t) { return size; } +bool singleOperand(Operation *op) { + if (op->getNumOperands() > 1) { + Value firstOperand = op->getOperand(0); + for (int64_t i = 1; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (firstOperand == operand) { + continue; + } + auto parentOp = operand.getDefiningOp(); + if (parentOp && !isa(parentOp)) { + return false; + } + } + } + return true; +} + bool canMoveBefore(Operation *op) { if (op->getDialect()->getNamespace() == arith::ArithDialect::getDialectNamespace()) { @@ -341,7 +358,8 @@ void CST::runOnOperation() { while (!v.getUsers().empty()) { // v.getUsers().size() should be 1 Operation *user = *(v.getUsers().begin()); - if (!isInConstantSubgraph(user)) { + // If user is not const or user has multiple operand, we reach the end + if (!isInConstantSubgraph(user) || !singleOperand(user)) { outputTypes.push_back(v.getType()); outputValues.push_back(v); break; @@ -437,6 +455,7 @@ void CST::runOnOperation() { std::deque dq; SmallVector opsToErase; + std::unordered_set opsToEraseSet; dq.push_back(block.getArgument(id + 1)); while (!dq.empty()) { Value v = dq.front(); @@ -445,7 +464,11 @@ void CST::runOnOperation() { for (auto res : op->getResults()) { dq.push_back(res); } + if (opsToEraseSet.count(op)) { + break; + } opsToErase.push_back(op); + opsToEraseSet.insert(op); } } diff --git a/test/gc/Transforms/test_constant_weights_folding-1.mlir b/test/gc/Transforms/test_constant_weights_folding-1.mlir new file mode 100644 index 000000000..b446212c5 --- /dev/null +++ b/test/gc/Transforms/test_constant_weights_folding-1.mlir @@ -0,0 +1,50 @@ +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(csa,cst)" %s | FileCheck %s + +// CHECK-LABEL: func.func @entry +module { + func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } { + %c0 = arith.constant 0 : index + cpuruntime.printf "HI%zu\n" %c0 : index + %ax2 = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%ax2 : tensor<128xf32>) -> tensor<128xf32> + %bx2 = tensor.empty() : tensor<128xf32> + %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%bx2 : tensor<128xf32>) -> tensor<128xf32> + %ax2pbx2 = tensor.empty() : tensor<128xf32> + %4 = linalg.add ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2 : tensor<128xf32>) -> tensor<128xf32> + %ax2pbx2pc = tensor.empty() : tensor<128xf32> + %d = linalg.add ins(%4, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2pc : tensor<128xf32>) -> tensor<128xf32> + return %d : tensor<128xf32> + } +} + +// CHECK: cpuruntime.printf +// CHECK: linalg.add +// CHECK: linalg.add +// CHECK: func.func @fold +// CHECK: linalg.add +// CHECK: linalg.add + +// COM: expected output: +// COM: module { +// COM: llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 +// COM: llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> +// COM: // a,b, foldedA,foldedB +// COM: llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> +// COM: // foldedA, foldedB, c, d +// COM: llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> +// COM: func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { +// COM: %c0 = arith.constant 0 : index +// COM: cpuruntime.printf "HI%zu\n" %c0 : index +// COM: %out = tensor.empty() : tensor<128xf32> +// COM: %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> +// COM: %out2 = tensor.empty() : tensor<128xf32> +// COM: %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> +// COM: return %2, %3 : tensor<128xf32>, tensor<128xf32> +// COM: } +// COM: func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { +// COM: %out = tensor.empty() : tensor<128xf32> +// COM: %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> +// COM: %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> +// COM: return %d : tensor<128xf32> +// COM: } +// COM: } \ No newline at end of file From 5eb0ac014cd3fdf48a27eee05c71bb8aed719274 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Mon, 27 May 2024 17:01:17 +0800 Subject: [PATCH 23/64] Add cache manager --- lib/gc/Transforms/CST.cpp | 136 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index 73dd075cc..b8ceb6a2c 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -29,6 +30,8 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" +// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" + namespace mlir { namespace gc { #define GEN_PASS_DEF_CST @@ -293,6 +296,101 @@ void postponeBroadcast(Block &block) { static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; +// get from dnnl_graph_compiler_context +// void *allocator(size_t size) { return std::aligned_alloc(64, size); } +// void deallocator(void *ptr) { std::free(ptr); } + +// std::shared_ptr create_const_cache_proxy(size_t size) { +// // simply allocate buffer and return +// std::shared_ptr base = +// std::shared_ptr{std::aligned_alloc(64, size), [](void *p) { +// std::free(p); }}; +// return std::make_shared(base, base.get(), size, true); +// } + +size_t divide_and_ceil(size_t x, size_t y) { return (x + y - 1) / y; } + +// Manager +struct const_graph_tensor_cache_manager { + // dnnl_graph_compiler_context *ctx; + + uint64_t cached_tensor_global_id = 0; + + // singleton + static std::shared_ptr get() { + static std::shared_ptr c = + std::make_shared(); + return c; + } + + // alloc and set the buf_base_ and offset_ attributes of cache + std::vector alloc(std::vector buffers_size) { + size_t total_size = 0; + for (size_t i = 0; i < buffers_size.size(); i++) { + total_size += divide_and_ceil(buffers_size[i], 64) * 64; + } + llvm::dbgs() << "Alloc total size: " << total_size << '\n'; + // auto base = create_const_cache_proxy(total_size); + std::vector global_ids(buffers_size.size()); + size_t offset = 0; + for (size_t i = 0; i < buffers_size.size(); i++) { + llvm::dbgs() << "Alloc offset: " << offset << '\n'; + // reg_cached_tensor(cached_tensor_global_id, base, offset); + global_ids[i] = cached_tensor_global_id; + ++cached_tensor_global_id; + offset += divide_and_ceil(buffers_size[i], 64) * 64; + } + return global_ids; + } +}; + +// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, +// StringRef name, int64_t value) { +// OpBuilder::InsertionGuard insertGuard(builder); +// builder.setInsertionPointToStart(module.getBody()); + +// auto type = IntegerType::get(builder.getContext(), 8); +// LLVM::GlobalOp global = builder.create( +// loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, +// builder.getIndexAttr(value), +// /*alignment=*/0); +// } + +// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, +// StringRef name, ArrayRef array) { +// OpBuilder::InsertionGuard insertGuard(builder); +// builder.setInsertionPointToStart(module.getBody()); + +// auto type = LLVM::LLVMArrayType::get( +// IntegerType::get(builder.getContext(), 8), array.size()); +// LLVM::GlobalOp global = builder.create( +// loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, +// builder.getIndexArrayAttr(array), +// /*alignment=*/0); +// } + +static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, + StringRef name, ArrayRef array) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + + MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType()); + IntegerAttr memrefAlignment = IntegerAttr(); + auto global = builder.create( + loc, name, + /*sym_visibility=*/builder.getStringAttr("public"), + /*type=*/type, + /*initial_value=*/builder.getIndexTensorAttr(array), + /*constant=*/true, + /*alignment=*/memrefAlignment); +} + +static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, + StringRef name, int64_t value) { + SmallVector array{value}; + addGlobalArray(module, loc, builder, name, array); +} + // Operate on tensors. Create fold() and compute() on module. The // folded weights and first-run flag is maintained by upper-level runtime. void CST::runOnOperation() { @@ -436,15 +534,38 @@ void CST::runOnOperation() { }); } + // Allocate buffer for outputValuesInFold + std::vector buffersSize; + for (Value &tensor : outputValuesInFold) { + llvm::dbgs() << "Allocate buffer for tensor: " << tensor << "\n"; + buffersSize.push_back( + getTensorSize(dyn_cast(tensor.getType()))); + } + auto manager = const_graph_tensor_cache_manager::get(); + SmallVector globalIndexes; + for (auto id : manager->alloc(buffersSize)) { + globalIndexes.push_back(id); + } + globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); + addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids", + globalIndexes); + foldFunc.setVisibility(SymbolTable::Visibility::Public); moduleOp.push_back(foldFunc); symbolTable.insert(foldFunc); + SmallVector foldArgs; + SmallVector foldIds; + SmallVector computeArgs; + // modify the BlockArguments of block size_t oriNumArgs = block.getNumArguments(); size_t argIdx = 0; for (size_t id = 0; id < oriNumArgs; ++id) { if (constArgsIndexes.count(id) == 1) { + foldArgs.push_back(id); + foldIds.push_back(argIdx + oriNumArgs); + computeArgs.push_back(argIdx + oriNumArgs); auto loc = block.getArgument(id).getLoc(); BlockArgument foldArg = block.insertArgument(id, outputTypes[argIdx], loc); @@ -477,9 +598,24 @@ void CST::runOnOperation() { } block.eraseArgument(id + 1); ++argIdx; + } else { + computeArgs.push_back(id); } } + for (auto id : foldIds) { + foldArgs.insert(foldArgs.end(), id); + } + foldArgs.insert(foldArgs.begin(), foldArgs.size()); + addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__fold_args", foldArgs); + + computeArgs.insert(computeArgs.begin(), computeArgs.size()); + addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__compute_args", + computeArgs); + + addGlobal(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args", + oriNumArgs); + // modify the compute func signature func::FuncOp computeFunc = cast(topFunc); FunctionType computeFuncType = computeFunc.getFunctionType(); From c3e186d7e79c286056ef8fa8c51f4452821e8c64 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 28 May 2024 10:21:52 +0800 Subject: [PATCH 24/64] Use llvm global [need to cowork with yijie/mainfunc_wrapper] --- lib/gc/Transforms/CST.cpp | 94 +++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index b8ceb6a2c..ecf7eff8f 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -30,7 +30,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" -// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" namespace mlir { namespace gc { @@ -300,13 +300,13 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; // void *allocator(size_t size) { return std::aligned_alloc(64, size); } // void deallocator(void *ptr) { std::free(ptr); } -// std::shared_ptr create_const_cache_proxy(size_t size) { -// // simply allocate buffer and return -// std::shared_ptr base = -// std::shared_ptr{std::aligned_alloc(64, size), [](void *p) { -// std::free(p); }}; -// return std::make_shared(base, base.get(), size, true); -// } +std::shared_ptr create_const_cache_proxy(size_t size) { + // simply allocate buffer and return + std::shared_ptr base = + std::shared_ptr{std::aligned_alloc(64, size), [](void *p) { + std::free(p); }}; + return std::make_shared(base, base.get(), size, true); +} size_t divide_and_ceil(size_t x, size_t y) { return (x + y - 1) / y; } @@ -330,12 +330,12 @@ struct const_graph_tensor_cache_manager { total_size += divide_and_ceil(buffers_size[i], 64) * 64; } llvm::dbgs() << "Alloc total size: " << total_size << '\n'; - // auto base = create_const_cache_proxy(total_size); + auto base = create_const_cache_proxy(total_size); std::vector global_ids(buffers_size.size()); size_t offset = 0; for (size_t i = 0; i < buffers_size.size(); i++) { llvm::dbgs() << "Alloc offset: " << offset << '\n'; - // reg_cached_tensor(cached_tensor_global_id, base, offset); + reg_cached_tensor(cached_tensor_global_id, base, offset); global_ids[i] = cached_tensor_global_id; ++cached_tensor_global_id; offset += divide_and_ceil(buffers_size[i], 64) * 64; @@ -344,52 +344,52 @@ struct const_graph_tensor_cache_manager { } }; -// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, -// StringRef name, int64_t value) { -// OpBuilder::InsertionGuard insertGuard(builder); -// builder.setInsertionPointToStart(module.getBody()); - -// auto type = IntegerType::get(builder.getContext(), 8); -// LLVM::GlobalOp global = builder.create( -// loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, -// builder.getIndexAttr(value), -// /*alignment=*/0); -// } - -// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, -// StringRef name, ArrayRef array) { -// OpBuilder::InsertionGuard insertGuard(builder); -// builder.setInsertionPointToStart(module.getBody()); +static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, + StringRef name, int64_t value) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); -// auto type = LLVM::LLVMArrayType::get( -// IntegerType::get(builder.getContext(), 8), array.size()); -// LLVM::GlobalOp global = builder.create( -// loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, -// builder.getIndexArrayAttr(array), -// /*alignment=*/0); -// } + auto type = IntegerType::get(builder.getContext(), 8); + LLVM::GlobalOp global = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, + builder.getIndexAttr(value), + /*alignment=*/0); +} static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, StringRef name, ArrayRef array) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); - MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType()); - IntegerAttr memrefAlignment = IntegerAttr(); - auto global = builder.create( - loc, name, - /*sym_visibility=*/builder.getStringAttr("public"), - /*type=*/type, - /*initial_value=*/builder.getIndexTensorAttr(array), - /*constant=*/true, - /*alignment=*/memrefAlignment); + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 8), array.size()); + LLVM::GlobalOp global = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, + builder.getIndexArrayAttr(array), + /*alignment=*/0); } -static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, - StringRef name, int64_t value) { - SmallVector array{value}; - addGlobalArray(module, loc, builder, name, array); -} +// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, +// StringRef name, ArrayRef array) { +// OpBuilder::InsertionGuard insertGuard(builder); +// builder.setInsertionPointToStart(module.getBody()); + +// MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType()); +// IntegerAttr memrefAlignment = IntegerAttr(); +// auto global = builder.create( +// loc, name, +// /*sym_visibility=*/builder.getStringAttr("public"), +// /*type=*/type, +// /*initial_value=*/builder.getIndexTensorAttr(array), +// /*constant=*/true, +// /*alignment=*/memrefAlignment); +// } + +// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, +// StringRef name, int64_t value) { +// SmallVector array{value}; +// addGlobalArray(module, loc, builder, name, array); +// } // Operate on tensors. Create fold() and compute() on module. The // folded weights and first-run flag is maintained by upper-level runtime. From e24b1df24a4f5b60557d756010c8217104e2c012 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 10:36:42 +0800 Subject: [PATCH 25/64] rename --- .../CPURuntime/ConstantCache.hpp | 22 +++++++++---------- .../gc/ExecutionEngine/JitWrapper/Module.hpp | 6 ++--- .../CPURuntime/ConstantCache.cpp | 18 +++++++-------- lib/gc/ExecutionEngine/JitWrapper/Module.cpp | 6 ++--- unittests/ExecutionEngine/JitWrapper.cpp | 6 ++--- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp index 120d1d69f..4000f53dc 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp @@ -60,25 +60,25 @@ struct ref_count_managed { * to the cache item in the cache manager (keep_alive) to extend the lifetime by * refcount, @see ref_count_managed. To access the memory buffer of the const * cache, use sc_acquire_const_cache and sc_release_const_cache functions. They - * will ref/deref the const_cache_proxy to make sure the cache is alive after + * will ref/deref the ConstCacheProxy to make sure the cache is alive after * calling sc_acquire_const_cache and before sc_release_const_cache. The cache * manager of Graph API may evict the cache item by dereferenceing this * ref_count_managed object. sc_{acquire,release}_const_cache functions will * find out that the cache has been invalidated and they will then use the * memory allocator in the runtime::stream_t to re-allocate the buffer. Usually - * we expect JIT modules to hold shared ptr to const_cache_proxy via + * we expect JIT modules to hold shared ptr to ConstCacheProxy via * cached_const_graph_tensor. * If is_lazy_ == true, the cache item's lifetime will be managed by the cache * manager of Graph API and it is filled with data after the first execution of * the computation. Otherwise, the cache item is always alive as long as the * jit_module of the kernel is alive. */ -struct const_cache_proxy : ref_count_managed { - const_cache_proxy(const std::shared_ptr &keep_alive, void *buffer, +struct ConstCacheProxy : ref_count_managed { + ConstCacheProxy(const std::shared_ptr &keep_alive, void *buffer, size_t size, bool is_lazy) : ref_count_managed(keep_alive), size_(size), is_lazy_(is_lazy), buffer_(buffer) {} - ~const_cache_proxy(); + ~ConstCacheProxy(); // get the buffer and increment the refcount. If the buffer is evicted, // returns null @@ -115,10 +115,10 @@ struct const_cache_proxy : ref_count_managed { int32_t initialized_ = 0; }; -struct cached_graph_tensor { - std::shared_ptr base; +struct CachedGraphTensor { + std::shared_ptr base; size_t offset; - cached_graph_tensor(const std::shared_ptr &base, + CachedGraphTensor(const std::shared_ptr &base, size_t offset); friend class JitModule; @@ -126,9 +126,9 @@ struct cached_graph_tensor { StridedMemRefType ref; }; -std::shared_ptr query_cached_tensor(uint64_t key); -bool reg_cached_tensor(uint64_t key, - const std::shared_ptr &base, +std::shared_ptr queryCacheTensor(uint64_t key); +bool regCachedTensor(uint64_t key, + const std::shared_ptr &base, size_t offset); } // namespace gc diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/JitWrapper/Module.hpp index afa9fb541..926aba3f6 100644 --- a/include/gc/ExecutionEngine/JitWrapper/Module.hpp +++ b/include/gc/ExecutionEngine/JitWrapper/Module.hpp @@ -42,7 +42,7 @@ class JitModule : public std::enable_shared_from_this { llvm::ArrayRef computeArgs, // The code inside `engine` has the ownership of the buffer llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive = {}); + std::vector> &&cachekeepAlive = {}); ~JitModule(); private: @@ -51,7 +51,7 @@ class JitModule : public std::enable_shared_from_this { JitModuleFuncT fold; size_t numOrigArgs; // `keepAlive` has the ownership of the objects pointed by this vector - llvm::SmallVector cacheBases; + llvm::SmallVector cacheBases; struct CacheBufferInfo { // index in cacheBases size_t baseIdx; @@ -67,7 +67,7 @@ class JitModule : public std::enable_shared_from_this { // The code inside `engine` has the ownership of the buffer llvm::ArrayRef computeArgs; - std::vector> keepAlive; + std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index 3be4b7dce..133cc362f 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -4,11 +4,11 @@ namespace mlir::gc { -const_cache_proxy::~const_cache_proxy() = default; +ConstCacheProxy::~ConstCacheProxy() = default; -cached_graph_tensor::cached_graph_tensor( - const std::shared_ptr &base, size_t offset) +CachedGraphTensor::CachedGraphTensor( + const std::shared_ptr &base, size_t offset) : base{base}, offset{offset} { // todo: fill in real values ref.basePtr = (char *)base->get_buffer_unsafe() + offset; @@ -18,9 +18,9 @@ cached_graph_tensor::cached_graph_tensor( memset(ref.strides, 0, sizeof(ref.strides)); } -static std::unordered_map> cache; +static std::unordered_map> cache; -std::shared_ptr query_cached_tensor(uint64_t key) { +std::shared_ptr queryCacheTensor(uint64_t key) { auto itr = cache.find(key); if (itr != cache.end()) { return itr->second; @@ -28,13 +28,13 @@ std::shared_ptr query_cached_tensor(uint64_t key) { return nullptr; } -bool reg_cached_tensor(uint64_t key, - const std::shared_ptr &base, +bool regCachedTensor(uint64_t key, + const std::shared_ptr &base, size_t offset) { - if (query_cached_tensor(key)) { + if (queryCacheTensor(key)) { return false; } - cache[key] = std::make_shared(base, offset); + cache[key] = std::make_shared(base, offset); return true; } } // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp index fa4a8151b..06f9c02b6 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp +++ b/lib/gc/ExecutionEngine/JitWrapper/Module.cpp @@ -125,10 +125,10 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, } } while (0); - std::vector> foldInfo; + std::vector> foldInfo; foldInfo.reserve(foldBufferIds.size()); for (auto bufId : foldBufferIds) { - auto ret = query_cached_tensor(bufId); + auto ret = queryCacheTensor(bufId); if (!ret) { return llvm::make_error( "Failed to query the folded cached tensor", @@ -149,7 +149,7 @@ JitModule::JitModule( llvm::ArrayRef computeArgs, // The code inside `engine` has the ownership of the buffer llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive) + std::vector> &&cachekeepAlive) : engine{std::move(engine)}, compute{compute}, fold{fold}, numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index d1a07c86f..222c73116 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -111,11 +111,11 @@ TEST(ExecutionEngine, JitWrapperCached) { // foldedA and foldedB uses this buffer auto ret = std::shared_ptr(new float[128 * 2]); - auto proxy = std::make_shared( + auto proxy = std::make_shared( ret, ret.get(), 128 * 2 * sizeof(float), true); - ASSERT_TRUE(gc::reg_cached_tensor(114514, proxy, 0)); - ASSERT_TRUE(gc::reg_cached_tensor(1919810, proxy, 128 * sizeof(float))); + ASSERT_TRUE(gc::regCachedTensor(114514, proxy, 0)); + ASSERT_TRUE(gc::regCachedTensor(1919810, proxy, 128 * sizeof(float))); ASSERT_TRUE(module); auto jited = gc::JitModule::create(module.get()); From 1e06c9844c7753400c6efc254898475c52b2c38b Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 10:42:03 +0800 Subject: [PATCH 26/64] fix license.py --- scripts/license.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/license.py b/scripts/license.py index 49f28eaa8..3ed2ce521 100644 --- a/scripts/license.py +++ b/scripts/license.py @@ -15,10 +15,10 @@ # SPDX-License-Identifier: Apache-2.0 import datetime, sys, re, argparse -from typing import Dict, Set +from typing import Dict, Set, List WIDTH: int = 80 -intel_license: list[str] = [ +intel_license: List[str] = [ 'Copyright \\(C\\) (\\d\\d\\d\\d-)?$YEAR Intel Corporation', '', 'Licensed under the Apache License, Version 2.0 (the "License");', @@ -35,7 +35,7 @@ 'SPDX-License-Identifier: Apache-2.0', ] -llvm_license: list[str] = [ +llvm_license: List[str] = [ "===-{1,2} $FILE - .* -*\\*- $LANG -\\*-===", '', 'This file is licensed under the Apache License v2.0 with LLVM Exceptions.', @@ -45,7 +45,7 @@ "===-*===", ] -def check_license(filepath: str, license: list[str], var: Dict[str, str], re_line: Set[int]): +def check_license(filepath: str, license: List[str], var: Dict[str, str], re_line: Set[int]): with open(filepath, 'r') as f: idx: int = 0 for line in f.readlines(): @@ -117,7 +117,7 @@ def use_llvm_license(path: str) -> bool: var: Dict[str, str] = {} re_line: Set[int] = set() - lic = list[str] + lic = List[str] if filepath.startswith("test/") or filepath.startswith("./test/"): continue From 7c32bc5c899872e2a5ed1f011b2446d808b5f32a Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 10:58:21 +0800 Subject: [PATCH 27/64] fix --- .../CPURuntime/ConstantCache.hpp | 88 +++++++++---------- .../Module.hpp => Driver/Driver.hpp} | 6 +- lib/gc/ExecutionEngine/CMakeLists.txt | 2 +- .../CPURuntime/ConstantCache.cpp | 10 ++- .../{JitWrapper => Driver}/CMakeLists.txt | 2 +- .../Module.cpp => Driver/Driver.cpp} | 4 +- unittests/ExecutionEngine/JitWrapper.cpp | 6 +- 7 files changed, 63 insertions(+), 55 deletions(-) rename include/gc/ExecutionEngine/{JitWrapper/Module.hpp => Driver/Driver.hpp} (94%) rename lib/gc/ExecutionEngine/{JitWrapper => Driver}/CMakeLists.txt (98%) rename lib/gc/ExecutionEngine/{JitWrapper/Module.cpp => Driver/Driver.cpp} (98%) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp index 4000f53dc..0d96ae6b3 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp @@ -1,3 +1,10 @@ +//===-- ConstantCache.hpp - Constant cache interfaces -----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// #ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H #define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H #include "mlir/ExecutionEngine/CRunnerUtils.h" @@ -14,76 +21,70 @@ namespace gc { * when the object is initialized (see init()). When the refcount counts down to * 0, the additional shared ptr is reset. */ -struct ref_count_managed { - ref_count_managed() = default; - ref_count_managed(const std::shared_ptr &keep_alive) { - init(keep_alive); - } +struct RefCountManaged { + RefCountManaged() = default; + RefCountManaged(const std::shared_ptr &keep_alive) { init(keep_alive); } void init(const std::shared_ptr &keep_alive) { - keep_alive_ = keep_alive; - ref_count_.store(1); + keepAlive = keep_alive; + refCount.store(1); } - void ref() { ++ref_count_; } + void ref() { ++refCount; } void deref() { - auto newv = --ref_count_; + auto newv = --refCount; if (newv == 0) { - keep_alive_ = nullptr; + keepAlive = nullptr; } } - // atomically check if ref_count_ > 0. if so, ref() the object and return - // true. Otherwise (if ref_count_==0), return false - bool check_alive_and_ref() { - auto oldv = ref_count_.load(); + // atomically check if refCount > 0. if so, ref() the object and return + // true. Otherwise (if refCount==0), return false + bool checkAliveAndRef() { + auto oldv = refCount.load(); for (;;) { if (oldv <= 0) { return false; } - if (ref_count_.compare_exchange_strong(oldv, oldv + 1)) { + if (refCount.compare_exchange_strong(oldv, oldv + 1)) { return true; } - // CAS failed, oldv has now the newest known value of ref_count_ + // CAS failed, oldv has now the newest known value of refCount } } - bool is_alive() const { return ref_count_ > 0; } - void *unsafe_get_ptr() const { return keep_alive_.get(); } + bool isAlive() const { return refCount > 0; } + void *getPtrUnsafe() const { return keepAlive.get(); } private: - std::shared_ptr keep_alive_; - std::atomic ref_count_{0}; + std::shared_ptr keepAlive; + std::atomic refCount{0}; }; /** * The proxy for the constant cache of Graph API. It holds a shared ptr pointing * to the cache item in the cache manager (keep_alive) to extend the lifetime by - * refcount, @see ref_count_managed. To access the memory buffer of the const - * cache, use sc_acquire_const_cache and sc_release_const_cache functions. They - * will ref/deref the ConstCacheProxy to make sure the cache is alive after - * calling sc_acquire_const_cache and before sc_release_const_cache. The cache - * manager of Graph API may evict the cache item by dereferenceing this - * ref_count_managed object. sc_{acquire,release}_const_cache functions will - * find out that the cache has been invalidated and they will then use the - * memory allocator in the runtime::stream_t to re-allocate the buffer. Usually - * we expect JIT modules to hold shared ptr to ConstCacheProxy via - * cached_const_graph_tensor. - * If is_lazy_ == true, the cache item's lifetime will be managed by the cache - * manager of Graph API and it is filled with data after the first execution of - * the computation. Otherwise, the cache item is always alive as long as the - * jit_module of the kernel is alive. + * refcount, @see RefCountManaged. To access the memory buffer of the const + * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy + * to make sure the cache is alive after calling acauire and before release. The + * cache manager of Graph API may evict the cache item by dereferenceing this + * RefCountManaged object. {acquire,release} functions will find out that the + * cache has been invalidated. Usually we expect JIT modules to hold shared ptr + * to ConstCacheProxy via CachedGraphTensor. If is_lazy_ == true, the cache + * item's lifetime will be managed by the cache manager of Graph API and it is + * filled with data after the first execution of the computation. Otherwise, the + * cache item is always alive as long as the jit_module of the kernel is alive. */ -struct ConstCacheProxy : ref_count_managed { +struct ConstCacheProxy : RefCountManaged { ConstCacheProxy(const std::shared_ptr &keep_alive, void *buffer, - size_t size, bool is_lazy) - : ref_count_managed(keep_alive), size_(size), is_lazy_(is_lazy), + size_t size, bool is_lazy) + : RefCountManaged(keep_alive), size_(size), is_lazy_(is_lazy), buffer_(buffer) {} ~ConstCacheProxy(); // get the buffer and increment the refcount. If the buffer is evicted, // returns null void *acquire(int32_t *inited) { - if (check_alive_and_ref()) { + if (checkAliveAndRef()) { *inited = *inited && initialized_; return buffer_; } @@ -91,7 +92,7 @@ struct ConstCacheProxy : ref_count_managed { } // decrement the refcount bool release() { - if (is_alive()) { + if (isAlive()) { deref(); initialized_ = 1; return true; @@ -101,7 +102,7 @@ struct ConstCacheProxy : ref_count_managed { // return the buffer. Do not directly use the buffer because it may be already // release! To access the buffer, always acquire() before using it. - void *get_buffer_unsafe() const { return buffer_; } + void *getBufferUnsafe() const { return buffer_; } size_t size_; // if the buffer is lazy-initialized. If false, it should be filled before @@ -119,7 +120,7 @@ struct CachedGraphTensor { std::shared_ptr base; size_t offset; CachedGraphTensor(const std::shared_ptr &base, - size_t offset); + size_t offset); friend class JitModule; private: @@ -127,9 +128,8 @@ struct CachedGraphTensor { }; std::shared_ptr queryCacheTensor(uint64_t key); -bool regCachedTensor(uint64_t key, - const std::shared_ptr &base, - size_t offset); +bool regCachedTensor(uint64_t key, const std::shared_ptr &base, + size_t offset); } // namespace gc } // namespace mlir diff --git a/include/gc/ExecutionEngine/JitWrapper/Module.hpp b/include/gc/ExecutionEngine/Driver/Driver.hpp similarity index 94% rename from include/gc/ExecutionEngine/JitWrapper/Module.hpp rename to include/gc/ExecutionEngine/Driver/Driver.hpp index 926aba3f6..0a34514fa 100644 --- a/include/gc/ExecutionEngine/JitWrapper/Module.hpp +++ b/include/gc/ExecutionEngine/Driver/Driver.hpp @@ -1,4 +1,4 @@ -//===- Module.h - Jit module and Execution engine wrapper -------*- C++ -*-===// +//===-- Driver.hpp - The top-level MLIR compiler driver ---------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef GC_EXECUTIONENGINE_JITWRAPPER_H -#define GC_EXECUTIONENGINE_JITWRAPPER_H +#ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H +#define GC_EXECUTIONENGINE_DRIVER_DRIVER_H #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" #include "mlir/ExecutionEngine/CRunnerUtils.h" diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt index a1592d74d..ae0c1c8df 100644 --- a/lib/gc/ExecutionEngine/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -1,2 +1,2 @@ add_subdirectory(CPURuntime) -add_subdirectory(JitWrapper) \ No newline at end of file +add_subdirectory(Driver) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index 133cc362f..245f2ca89 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -1,3 +1,11 @@ +//===-- ConstantCache.cpp - Constant cache ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" #include #include @@ -11,7 +19,7 @@ CachedGraphTensor::CachedGraphTensor( const std::shared_ptr &base, size_t offset) : base{base}, offset{offset} { // todo: fill in real values - ref.basePtr = (char *)base->get_buffer_unsafe() + offset; + ref.basePtr = (char *)base->getBufferUnsafe() + offset; ref.data = ref.basePtr; ref.offset = 0; memset(ref.sizes, 0, sizeof(ref.sizes)); diff --git a/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt similarity index 98% rename from lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt rename to lib/gc/ExecutionEngine/Driver/CMakeLists.txt index 3186b018d..8bda0a16c 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -27,7 +27,7 @@ else() endif() add_mlir_library(GCJitWrapper - Module.cpp + Driver.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp similarity index 98% rename from lib/gc/ExecutionEngine/JitWrapper/Module.cpp rename to lib/gc/ExecutionEngine/Driver/Driver.cpp index 06f9c02b6..67b2ced6a 100644 --- a/lib/gc/ExecutionEngine/JitWrapper/Module.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -1,4 +1,4 @@ -//===- Module.cpp -----------------------------------------------*- C++ -*-===// +//===-- Driver.cpp - Top-level MLIR compiler driver -------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/ExecutionEngine/Driver/Driver.hpp" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 222c73116..d902541bd 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -1,12 +1,12 @@ -//===- JitWrapper.cpp - Wrapper of JIT ------------------------------------===// +//===-- JitWrapper.cpp - Wrapper for JIT ------------------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/JitWrapper/Module.hpp" +#include "gc/ExecutionEngine/Driver/Driver.hpp" #include "mlir/AsmParser/AsmParser.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/IR/AsmState.h" From 4540fb64834a2ebcc4f1297672ff41a080d72555 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 14:12:32 +0800 Subject: [PATCH 28/64] fix lint --- lib/gc/ExecutionEngine/Driver/Driver.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 67b2ced6a..0b3c19113 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -123,7 +123,7 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, auto raw = reinterpret_cast(*expect); computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; } - } while (0); + } while (false); std::vector> foldInfo; foldInfo.reserve(foldBufferIds.size()); From 381677a27f0c796dd94e208a4c2f1378994b1f59 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Tue, 28 May 2024 14:32:18 +0800 Subject: [PATCH 29/64] fix comments --- lib/gc/Transforms/Pipeline.cpp | 64 ++++++++++++++-------------------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index b336684a1..81fae6877 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -27,40 +27,35 @@ namespace mlir::gc { +// linalg + linalgX + tensor void populateFrontendPasses(mlir::PassManager &pm) { // pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg()); } -// linalg + linalgX + tensor ==> GC V1 GIR +// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack void populateTensorPasses(mlir::PassManager &pm) { - // + padding propagation pass, upstream-able 127x127 -> tilling size:32 - // ->padding to 128x128 - // + layout propagation pass, upstream-able 4x32x4x32 -> - // tensor.pack/tensor.unpack - // + tensor constant propagation pass, down-stream pass, designed to support - // oneDNN graph spec - // + linalg.matmul lowering to (scf.loop + linalg.brgemm) pass, upstream-able - // + fine-grain fusion pass, upstream-able -> scf.for + linalgx.mask - // + lower linalg to arith/math on virtual vector pass, up-streamable + // todo: padding propagation pass + // todo: layout propagation pass + // todo: tensor constant propagation pass + // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass + // todo: fine-grain fusion pass + // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); } -// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack ==> -// GC V1 TIR +// scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::PassManager &pm) { - // + bf16 promotion pass, down-stream pass, device dependent pass, maybe can - // upstream - // + bf16 cast elimilation pass, down-stream pass, fast-math kind pass, - // designed to support oneDNN graph spec + // todo: bf16 promotion pass, device dependent pass + // todo: bf16 cast elimilation pass, fast-math kind pass, designed to support + // oneDNN graph spec pm.addNestedPass(arith::createArithExpandOpsPass()); - // + lower to physical vector pass, down-stream pass, device dependent pass, - // maybe can upstream + // todo: lower to physical vector pass, device dependent pass } -// scf + arith + math + vector + tensor + linalg.brgemm +// scf + arith + math + vector + memref + linalg.brgemm void populateBufferizationPasses(mlir::PassManager &pm) { bufferization::OneShotBufferizationOptions options; pm.addPass(bufferization::createOneShotBufferizePass(options)); @@ -73,34 +68,27 @@ void populateBufferizationPasses(mlir::PassManager &pm) { bufferization::BufferResultsToOutParamsOpts opt{}; opt.hoistStaticAllocs = true; pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); - // + buffer schedule pass, down-stream pass, to migrate buffer reschedule pass - // from GC V1. - pm.addNestedPass( - bufferization::createBufferHoistingPass()); // Need to improve this pass - // to support thread-local - // allocator. + // todo: buffer schedule pass + // todo: Need to improve this pass to support nested parallel. + pm.addNestedPass(bufferization::createBufferHoistingPass()); pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); pm.addNestedPass(bufferization::createBufferDeallocationPass()); pm.addPass(createBufferizationToMemRefPass()); } -// scf + arith + math + vector + memref + linalg.brgemm +// scf + arith + math + vector + memref + func/microkernel void populateMicroKernelPasses(mlir::PassManager &pm) { - // + ConvertLinalgToMicrokernel pass, upstream-able, - // + CleanupInvalidMicrokernel pass, upstream-able - // + InvariantMicrokernelMotion pass, upstream-able - // + ConvertMicrokernelToDnnlFunc, down-stream pass, to lower brgemm to dnnl - // call - // + ConvertMicrokernelToXsmm, down-stream pass, to lower brgemm to libxsmm - // call - // + LowerMicrokernel pass, upstream-able - // + DispatchMicrokernel, down-stream pass + // todo: ConvertLinalgToMicrokernel pass + // todo: CleanupInvalidMicrokernel pass + // todo: InvariantMicrokernelMotion pass + // todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call + // todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call + // todo: LowerMicrokernel pass + // todo: DispatchMicrokernel } -// scf + arith + math + vector + memref + func/microkernel void populateCPURuntimePasses(mlir::PassManager &pm) { - // + flatten nested parallel pass, down-stream pass, to support coarse-grain - // fusion + // todo: flatten nested parallel pass to support coarse-grain usion // remove this pass after we add FlattenNestedParallel pm.addPass(createConvertSCFToOpenMPPass()); } From 8c50b67e014dec09717ad1217703bb9132703ece Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 28 May 2024 16:09:00 +0800 Subject: [PATCH 30/64] Rename; Add llvm dependence --- include/gc/Transforms/Passes.td | 4 ++++ lib/gc/Transforms/CST.cpp | 30 +++++++++++++++--------------- src/gc-opt/CMakeLists.txt | 3 ++- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index f56e8b016..5fd0bd7a7 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -45,6 +45,10 @@ def CST : Pass<"cst"> { This pass implements a constant subgraph transform. }]; let constructor = "mlir::gc::createCSTPass()"; + let dependentDialects = [ + "tensor::TensorDialect", + "linalg::LinalgDialect", + "LLVM::LLVMDialect"]; } #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index ecf7eff8f..5fb48a676 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -300,26 +300,26 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; // void *allocator(size_t size) { return std::aligned_alloc(64, size); } // void deallocator(void *ptr) { std::free(ptr); } -std::shared_ptr create_const_cache_proxy(size_t size) { +std::shared_ptr createConstCacheProxy(size_t size) { // simply allocate buffer and return std::shared_ptr base = std::shared_ptr{std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; - return std::make_shared(base, base.get(), size, true); + return std::make_shared(base, base.get(), size, true); } -size_t divide_and_ceil(size_t x, size_t y) { return (x + y - 1) / y; } +size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } // Manager -struct const_graph_tensor_cache_manager { +struct constGraphTensorCacheManager { // dnnl_graph_compiler_context *ctx; - uint64_t cached_tensor_global_id = 0; + uint64_t cachedTensorGlobalId = 0; // singleton - static std::shared_ptr get() { - static std::shared_ptr c = - std::make_shared(); + static std::shared_ptr get() { + static std::shared_ptr c = + std::make_shared(); return c; } @@ -327,18 +327,18 @@ struct const_graph_tensor_cache_manager { std::vector alloc(std::vector buffers_size) { size_t total_size = 0; for (size_t i = 0; i < buffers_size.size(); i++) { - total_size += divide_and_ceil(buffers_size[i], 64) * 64; + total_size += divideAndCeil(buffers_size[i], 64) * 64; } llvm::dbgs() << "Alloc total size: " << total_size << '\n'; - auto base = create_const_cache_proxy(total_size); + auto base = createConstCacheProxy(total_size); std::vector global_ids(buffers_size.size()); size_t offset = 0; for (size_t i = 0; i < buffers_size.size(); i++) { llvm::dbgs() << "Alloc offset: " << offset << '\n'; - reg_cached_tensor(cached_tensor_global_id, base, offset); - global_ids[i] = cached_tensor_global_id; - ++cached_tensor_global_id; - offset += divide_and_ceil(buffers_size[i], 64) * 64; + regCachedTensor(cachedTensorGlobalId, base, offset); + global_ids[i] = cachedTensorGlobalId; + ++cachedTensorGlobalId; + offset += divideAndCeil(buffers_size[i], 64) * 64; } return global_ids; } @@ -541,7 +541,7 @@ void CST::runOnOperation() { buffersSize.push_back( getTensorSize(dyn_cast(tensor.getType()))); } - auto manager = const_graph_tensor_cache_manager::get(); + auto manager = constGraphTensorCacheManager::get(); SmallVector globalIndexes; for (auto id : manager->alloc(buffersSize)) { globalIndexes.push_back(id); diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index ac7ed4ead..6b8def4be 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -17,7 +17,8 @@ set(gc_opt_libs ${conversion_libs} ${MLIR_LINK_COMPONENTS} GCPasses - GCAnalysis) + GCAnalysis + GCCpuRuntime) if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") From 25f611eceb7d57fb3365d30d72aa38e782405228 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 28 May 2024 16:56:17 +0800 Subject: [PATCH 31/64] Change dtype --- lib/gc/Transforms/CST.cpp | 97 ++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index 5fb48a676..9beaca812 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -302,9 +302,8 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; std::shared_ptr createConstCacheProxy(size_t size) { // simply allocate buffer and return - std::shared_ptr base = - std::shared_ptr{std::aligned_alloc(64, size), [](void *p) { - std::free(p); }}; + std::shared_ptr base = std::shared_ptr{ + std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; return std::make_shared(base, base.get(), size, true); } @@ -324,72 +323,63 @@ struct constGraphTensorCacheManager { } // alloc and set the buf_base_ and offset_ attributes of cache - std::vector alloc(std::vector buffers_size) { - size_t total_size = 0; - for (size_t i = 0; i < buffers_size.size(); i++) { - total_size += divideAndCeil(buffers_size[i], 64) * 64; + std::vector alloc(std::vector buffersSize) { + size_t totalSize = 0; + for (size_t i = 0; i < buffersSize.size(); i++) { + totalSize += divideAndCeil(buffersSize[i], 64) * 64; } - llvm::dbgs() << "Alloc total size: " << total_size << '\n'; - auto base = createConstCacheProxy(total_size); - std::vector global_ids(buffers_size.size()); + llvm::dbgs() << "Alloc total size: " << totalSize << '\n'; + auto base = createConstCacheProxy(totalSize); + std::vector globalIds(buffersSize.size()); size_t offset = 0; - for (size_t i = 0; i < buffers_size.size(); i++) { + for (size_t i = 0; i < buffersSize.size(); i++) { llvm::dbgs() << "Alloc offset: " << offset << '\n'; regCachedTensor(cachedTensorGlobalId, base, offset); - global_ids[i] = cachedTensorGlobalId; + globalIds[i] = cachedTensorGlobalId; ++cachedTensorGlobalId; - offset += divideAndCeil(buffers_size[i], 64) * 64; + offset += divideAndCeil(buffersSize[i], 64) * 64; } - return global_ids; + return globalIds; } }; -static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, - StringRef name, int64_t value) { +static void addGlobalI32(ModuleOp module, Location loc, OpBuilder &builder, + StringRef name, int32_t value) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); - auto type = IntegerType::get(builder.getContext(), 8); + auto type = IntegerType::get(builder.getContext(), 32); LLVM::GlobalOp global = builder.create( loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, - builder.getIndexAttr(value), + builder.getI32IntegerAttr(value), /*alignment=*/0); } -static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, - StringRef name, ArrayRef array) { +static void addGlobalI64Array(ModuleOp module, Location loc, OpBuilder &builder, + StringRef name, ArrayRef array) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( - IntegerType::get(builder.getContext(), 8), array.size()); + IntegerType::get(builder.getContext(), 64), array.size()); LLVM::GlobalOp global = builder.create( loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, - builder.getIndexArrayAttr(array), + builder.getI64ArrayAttr(array), /*alignment=*/0); } -// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder, -// StringRef name, ArrayRef array) { -// OpBuilder::InsertionGuard insertGuard(builder); -// builder.setInsertionPointToStart(module.getBody()); - -// MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType()); -// IntegerAttr memrefAlignment = IntegerAttr(); -// auto global = builder.create( -// loc, name, -// /*sym_visibility=*/builder.getStringAttr("public"), -// /*type=*/type, -// /*initial_value=*/builder.getIndexTensorAttr(array), -// /*constant=*/true, -// /*alignment=*/memrefAlignment); -// } - -// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder, -// StringRef name, int64_t value) { -// SmallVector array{value}; -// addGlobalArray(module, loc, builder, name, array); -// } +static void addGlobalI32Array(ModuleOp module, Location loc, OpBuilder &builder, + StringRef name, ArrayRef array) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 32), array.size()); + LLVM::GlobalOp global = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, + builder.getI32ArrayAttr(array), + /*alignment=*/0); +} // Operate on tensors. Create fold() and compute() on module. The // folded weights and first-run flag is maintained by upper-level runtime. @@ -547,16 +537,16 @@ void CST::runOnOperation() { globalIndexes.push_back(id); } globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); - addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids", - globalIndexes); + addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids", + globalIndexes); foldFunc.setVisibility(SymbolTable::Visibility::Public); moduleOp.push_back(foldFunc); symbolTable.insert(foldFunc); - SmallVector foldArgs; - SmallVector foldIds; - SmallVector computeArgs; + SmallVector foldArgs; + SmallVector foldIds; + SmallVector computeArgs; // modify the BlockArguments of block size_t oriNumArgs = block.getNumArguments(); @@ -607,14 +597,15 @@ void CST::runOnOperation() { foldArgs.insert(foldArgs.end(), id); } foldArgs.insert(foldArgs.begin(), foldArgs.size()); - addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__fold_args", foldArgs); + addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args", + foldArgs); computeArgs.insert(computeArgs.begin(), computeArgs.size()); - addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__compute_args", - computeArgs); + addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args", + computeArgs); - addGlobal(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args", - oriNumArgs); + addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args", + oriNumArgs); // modify the compute func signature func::FuncOp computeFunc = cast(topFunc); From 43639154e9e5f231667795e30923aa7eb7fdfba6 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Wed, 29 May 2024 11:22:45 +0800 Subject: [PATCH 32/64] Fix visibility and type --- lib/gc/Transforms/CST.cpp | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index 9beaca812..dc5c332e5 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -343,41 +343,43 @@ struct constGraphTensorCacheManager { } }; -static void addGlobalI32(ModuleOp module, Location loc, OpBuilder &builder, +static void addGlobalI32(ModuleOp &module, Location loc, OpBuilder &builder, StringRef name, int32_t value) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = IntegerType::get(builder.getContext(), 32); LLVM::GlobalOp global = builder.create( - loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, + loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, builder.getI32IntegerAttr(value), /*alignment=*/0); } -static void addGlobalI64Array(ModuleOp module, Location loc, OpBuilder &builder, - StringRef name, ArrayRef array) { +static void addGlobalI64Array(ModuleOp &module, Location loc, + OpBuilder &builder, StringRef name, + ArrayRef array) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( IntegerType::get(builder.getContext(), 64), array.size()); LLVM::GlobalOp global = builder.create( - loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, - builder.getI64ArrayAttr(array), + loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, + builder.getI64TensorAttr(array), /*alignment=*/0); } -static void addGlobalI32Array(ModuleOp module, Location loc, OpBuilder &builder, - StringRef name, ArrayRef array) { +static void addGlobalI32Array(ModuleOp &module, Location loc, + OpBuilder &builder, StringRef name, + ArrayRef array) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( IntegerType::get(builder.getContext(), 32), array.size()); LLVM::GlobalOp global = builder.create( - loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, - builder.getI32ArrayAttr(array), + loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, + builder.getI32TensorAttr(array), /*alignment=*/0); } @@ -493,7 +495,7 @@ void CST::runOnOperation() { FunctionType foldFuncType = FunctionType::get(context, inputTypes, outputTypes); - auto foldFunc = + func::FuncOp foldFunc = builder.create(topFunc.getLoc(), funcName, foldFuncType); Block *foldBlock = foldFunc.addEntryBlock(); // values of folded constant weights in foldBlock @@ -541,6 +543,8 @@ void CST::runOnOperation() { globalIndexes); foldFunc.setVisibility(SymbolTable::Visibility::Public); + foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), + UnitAttr::get(context)); moduleOp.push_back(foldFunc); symbolTable.insert(foldFunc); From b54b310af2ebd75cd928c4edf13bc842e00eae34 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 29 May 2024 11:36:12 +0800 Subject: [PATCH 33/64] fix --- .../{ConstantCache.hpp => ConstantCache.h} | 34 +++++++++---------- .../Driver/{Driver.hpp => Driver.h} | 4 +-- .../CPURuntime/ConstantCache.cpp | 2 +- lib/gc/ExecutionEngine/Driver/Driver.cpp | 4 +-- unittests/ExecutionEngine/JitWrapper.cpp | 2 +- 5 files changed, 23 insertions(+), 23 deletions(-) rename include/gc/ExecutionEngine/CPURuntime/{ConstantCache.hpp => ConstantCache.h} (83%) rename include/gc/ExecutionEngine/Driver/{Driver.hpp => Driver.h} (95%) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h similarity index 83% rename from include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp rename to include/gc/ExecutionEngine/CPURuntime/ConstantCache.h index 0d96ae6b3..f41cb09e8 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.hpp +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -1,4 +1,4 @@ -//===-- ConstantCache.hpp - Constant cache interfaces -----------*- C++ -*-===// +//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -23,9 +23,9 @@ namespace gc { */ struct RefCountManaged { RefCountManaged() = default; - RefCountManaged(const std::shared_ptr &keep_alive) { init(keep_alive); } - void init(const std::shared_ptr &keep_alive) { - keepAlive = keep_alive; + RefCountManaged(const std::shared_ptr &vkeepAlive) { init(vkeepAlive); } + void init(const std::shared_ptr &vkeepAlive) { + keepAlive = vkeepAlive; refCount.store(1); } @@ -62,31 +62,31 @@ struct RefCountManaged { /** * The proxy for the constant cache of Graph API. It holds a shared ptr pointing - * to the cache item in the cache manager (keep_alive) to extend the lifetime by + * to the cache item in the cache manager (keepAlive) to extend the lifetime by * refcount, @see RefCountManaged. To access the memory buffer of the const * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy * to make sure the cache is alive after calling acauire and before release. The * cache manager of Graph API may evict the cache item by dereferenceing this * RefCountManaged object. {acquire,release} functions will find out that the * cache has been invalidated. Usually we expect JIT modules to hold shared ptr - * to ConstCacheProxy via CachedGraphTensor. If is_lazy_ == true, the cache + * to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache * item's lifetime will be managed by the cache manager of Graph API and it is * filled with data after the first execution of the computation. Otherwise, the * cache item is always alive as long as the jit_module of the kernel is alive. */ struct ConstCacheProxy : RefCountManaged { - ConstCacheProxy(const std::shared_ptr &keep_alive, void *buffer, + ConstCacheProxy(const std::shared_ptr &vkeepAlive, void *buffer, size_t size, bool is_lazy) - : RefCountManaged(keep_alive), size_(size), is_lazy_(is_lazy), - buffer_(buffer) {} + : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), + buffer(buffer) {} ~ConstCacheProxy(); // get the buffer and increment the refcount. If the buffer is evicted, // returns null void *acquire(int32_t *inited) { if (checkAliveAndRef()) { - *inited = *inited && initialized_; - return buffer_; + *inited = *inited && initialized; + return buffer; } return nullptr; } @@ -94,7 +94,7 @@ struct ConstCacheProxy : RefCountManaged { bool release() { if (isAlive()) { deref(); - initialized_ = 1; + initialized = 1; return true; } return false; @@ -102,18 +102,18 @@ struct ConstCacheProxy : RefCountManaged { // return the buffer. Do not directly use the buffer because it may be already // release! To access the buffer, always acquire() before using it. - void *getBufferUnsafe() const { return buffer_; } + void *getBufferUnsafe() const { return buffer; } - size_t size_; + size_t size; // if the buffer is lazy-initialized. If false, it should be filled before // computation - bool is_lazy_; + bool isLazy; private: // raw pointer to the buffer - void *buffer_; + void *buffer; // if the buffer has been initialized. calling release() will set this to 1 - int32_t initialized_ = 0; + int32_t initialized = 0; }; struct CachedGraphTensor { diff --git a/include/gc/ExecutionEngine/Driver/Driver.hpp b/include/gc/ExecutionEngine/Driver/Driver.h similarity index 95% rename from include/gc/ExecutionEngine/Driver/Driver.hpp rename to include/gc/ExecutionEngine/Driver/Driver.h index 0a34514fa..1ca5aa9f4 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.hpp +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -1,4 +1,4 @@ -//===-- Driver.hpp - The top-level MLIR compiler driver ---------*- C++ -*-===// +//===-- Driver.h - The top-level MLIR compiler driver -----------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,7 +9,7 @@ #ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H #define GC_EXECUTIONENGINE_DRIVER_DRIVER_H -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index 245f2ca89..ea1c10364 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include #include diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 0b3c19113..ca1bda930 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/Driver/Driver.hpp" +#include "gc/ExecutionEngine/Driver/Driver.h" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" @@ -201,7 +201,7 @@ void JitModule::call(GeneralMemrefPtr *args) { for (auto b : cacheBases) { auto ptr = b->acquire(&inited); if (unlikely(!ptr)) { - ptr = std::aligned_alloc(/*alignment*/ 64, b->size_); + ptr = std::aligned_alloc(/*alignment*/ 64, b->size); inited = 0; } foldBasePtr.push_back((char *)ptr); diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index d902541bd..71a73bf8b 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/ExecutionEngine/Driver/Driver.hpp" +#include "gc/ExecutionEngine/Driver/Driver.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/IR/AsmState.h" From 9d04cd22b3f9218ed06fb38448f834b48df3401c Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Wed, 29 May 2024 11:44:11 +0800 Subject: [PATCH 34/64] format --- lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp | 6 ++---- lib/gc/ExecutionEngine/Driver/Driver.cpp | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp index ea1c10364..ff45cd180 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -14,7 +14,6 @@ namespace mlir::gc { ConstCacheProxy::~ConstCacheProxy() = default; - CachedGraphTensor::CachedGraphTensor( const std::shared_ptr &base, size_t offset) : base{base}, offset{offset} { @@ -36,9 +35,8 @@ std::shared_ptr queryCacheTensor(uint64_t key) { return nullptr; } -bool regCachedTensor(uint64_t key, - const std::shared_ptr &base, - size_t offset) { +bool regCachedTensor(uint64_t key, const std::shared_ptr &base, + size_t offset) { if (queryCacheTensor(key)) { return false; } diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index ca1bda930..4b9b6c42b 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -151,8 +151,8 @@ JitModule::JitModule( llvm::ArrayRef foldArgs, std::vector> &&cachekeepAlive) : engine{std::move(engine)}, compute{compute}, fold{fold}, - numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, computeArgs{computeArgs}, - keepAlive{std::move(cachekeepAlive)} { + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, + computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { for (const auto &cache : keepAlive) { auto currentItr = std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); From 206c3f3df11b10a8384ff98d15feaf129a127c4e Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 30 May 2024 11:24:14 +0800 Subject: [PATCH 35/64] cleanup --- .../CPURuntime/ConstantCache.h | 137 ------------- include/gc/ExecutionEngine/Driver/Driver.h | 51 ++--- .../ExecutionEngine/CPURuntime/CMakeLists.txt | 1 - .../CPURuntime/ConstantCache.cpp | 46 ----- lib/gc/ExecutionEngine/Driver/Driver.cpp | 180 +----------------- unittests/ExecutionEngine/JitWrapper.cpp | 107 +---------- 6 files changed, 24 insertions(+), 498 deletions(-) delete mode 100644 include/gc/ExecutionEngine/CPURuntime/ConstantCache.h delete mode 100644 lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h deleted file mode 100644 index f41cb09e8..000000000 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h +++ /dev/null @@ -1,137 +0,0 @@ -//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H -#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H -#include "mlir/ExecutionEngine/CRunnerUtils.h" -#include -#include -#include - -namespace mlir { -namespace gc { -/** - * The helper class to manage ref count manually for an object allocated with - * shared ptr. It holds an additional shared ptr reference to the object and - * contains an additional self-managed refcount. The refcount will be set to 1 - * when the object is initialized (see init()). When the refcount counts down to - * 0, the additional shared ptr is reset. - */ -struct RefCountManaged { - RefCountManaged() = default; - RefCountManaged(const std::shared_ptr &vkeepAlive) { init(vkeepAlive); } - void init(const std::shared_ptr &vkeepAlive) { - keepAlive = vkeepAlive; - refCount.store(1); - } - - void ref() { ++refCount; } - void deref() { - auto newv = --refCount; - if (newv == 0) { - keepAlive = nullptr; - } - } - - // atomically check if refCount > 0. if so, ref() the object and return - // true. Otherwise (if refCount==0), return false - bool checkAliveAndRef() { - auto oldv = refCount.load(); - for (;;) { - if (oldv <= 0) { - return false; - } - if (refCount.compare_exchange_strong(oldv, oldv + 1)) { - return true; - } - // CAS failed, oldv has now the newest known value of refCount - } - } - - bool isAlive() const { return refCount > 0; } - void *getPtrUnsafe() const { return keepAlive.get(); } - -private: - std::shared_ptr keepAlive; - std::atomic refCount{0}; -}; - -/** - * The proxy for the constant cache of Graph API. It holds a shared ptr pointing - * to the cache item in the cache manager (keepAlive) to extend the lifetime by - * refcount, @see RefCountManaged. To access the memory buffer of the const - * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy - * to make sure the cache is alive after calling acauire and before release. The - * cache manager of Graph API may evict the cache item by dereferenceing this - * RefCountManaged object. {acquire,release} functions will find out that the - * cache has been invalidated. Usually we expect JIT modules to hold shared ptr - * to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache - * item's lifetime will be managed by the cache manager of Graph API and it is - * filled with data after the first execution of the computation. Otherwise, the - * cache item is always alive as long as the jit_module of the kernel is alive. - */ -struct ConstCacheProxy : RefCountManaged { - ConstCacheProxy(const std::shared_ptr &vkeepAlive, void *buffer, - size_t size, bool is_lazy) - : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), - buffer(buffer) {} - ~ConstCacheProxy(); - - // get the buffer and increment the refcount. If the buffer is evicted, - // returns null - void *acquire(int32_t *inited) { - if (checkAliveAndRef()) { - *inited = *inited && initialized; - return buffer; - } - return nullptr; - } - // decrement the refcount - bool release() { - if (isAlive()) { - deref(); - initialized = 1; - return true; - } - return false; - } - - // return the buffer. Do not directly use the buffer because it may be already - // release! To access the buffer, always acquire() before using it. - void *getBufferUnsafe() const { return buffer; } - - size_t size; - // if the buffer is lazy-initialized. If false, it should be filled before - // computation - bool isLazy; - -private: - // raw pointer to the buffer - void *buffer; - // if the buffer has been initialized. calling release() will set this to 1 - int32_t initialized = 0; -}; - -struct CachedGraphTensor { - std::shared_ptr base; - size_t offset; - CachedGraphTensor(const std::shared_ptr &base, - size_t offset); - friend class JitModule; - -private: - StridedMemRefType ref; -}; - -std::shared_ptr queryCacheTensor(uint64_t key); -bool regCachedTensor(uint64_t key, const std::shared_ptr &base, - size_t offset); - -} // namespace gc -} // namespace mlir - -#endif \ No newline at end of file diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 1ca5aa9f4..694cbbe5e 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -9,7 +9,6 @@ #ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H #define GC_EXECUTIONENGINE_DRIVER_DRIVER_H -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include @@ -25,7 +24,7 @@ const DialectRegistry &initAndGetDialects(); using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); -class JitModule : public std::enable_shared_from_this { +class JitModule { public: static llvm::Expected> create(Operation *op, const ExecutionEngineOptions &options = {}, @@ -33,41 +32,29 @@ class JitModule : public std::enable_shared_from_this { bool transform = true); // args should be an array of XXXMemrefType* - void call(GeneralMemrefPtr *args); - - JitModule( - std::unique_ptr engine, JitModuleFuncT compute, - JitModuleFuncT fold, size_t numOrigArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive = {}); + void call(GeneralMemrefPtr *args, std::size_t numArgs) { + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numArgs); + for (size_t i = 0; i < numArgs; i++) { + realargs.push_back(&args[i]); + } + compute(realargs.data()); + } + + // directly call compute(). args should be an array of void*. args[i] should + // be a pointer to the real data. For passing memref, users need to 1) create + // a pointer to XXXMemrefType 2) store the pointer to pointer to XXXMemrefType + // in args[i] + void callRaw(void **args) { compute(args); } + + JitModule(std::unique_ptr engine, JitModuleFuncT compute); ~JitModule(); private: std::unique_ptr engine; JitModuleFuncT compute; - JitModuleFuncT fold; - size_t numOrigArgs; - // `keepAlive` has the ownership of the objects pointed by this vector - llvm::SmallVector cacheBases; - struct CacheBufferInfo { - // index in cacheBases - size_t baseIdx; - size_t offset; - }; - // the info for each folded cached buffer - llvm::SmallVector cacheInfo; - // holding the pointers to StridedMemRefType of folded cache - // `keepAlive` holds the the ownership of the pointers - llvm::SmallVector fastFoldBuffers; - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs; - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs; - - std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 97f039834..6be58e28f 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -10,7 +10,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") add_mlir_library(GCCpuRuntime SHARED Parallel.cpp - ConstantCache.cpp EXCLUDE_FROM_LIBMLIR ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp deleted file mode 100644 index ff45cd180..000000000 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ /dev/null @@ -1,46 +0,0 @@ -//===-- ConstantCache.cpp - Constant cache ----------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" -#include -#include - -namespace mlir::gc { - -ConstCacheProxy::~ConstCacheProxy() = default; - -CachedGraphTensor::CachedGraphTensor( - const std::shared_ptr &base, size_t offset) - : base{base}, offset{offset} { - // todo: fill in real values - ref.basePtr = (char *)base->getBufferUnsafe() + offset; - ref.data = ref.basePtr; - ref.offset = 0; - memset(ref.sizes, 0, sizeof(ref.sizes)); - memset(ref.strides, 0, sizeof(ref.strides)); -} - -static std::unordered_map> cache; - -std::shared_ptr queryCacheTensor(uint64_t key) { - auto itr = cache.find(key); - if (itr != cache.end()) { - return itr->second; - } - return nullptr; -} - -bool regCachedTensor(uint64_t key, const std::shared_ptr &base, - size_t offset) { - if (queryCacheTensor(key)) { - return false; - } - cache[key] = std::make_shared(base, offset); - return true; -} -} // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 4b9b6c42b..4f5c33867 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -18,9 +18,6 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" -#define likely(x) __builtin_expect(!!(x), 1) -#define unlikely(x) __builtin_expect(!!(x), 0) - namespace mlir { namespace gc { @@ -46,7 +43,7 @@ const DialectRegistry &initAndGetDialects() { } static const char defaultComputeName[] = "_mlir_ciface_compute"; -static const char defaultFoldName[] = "_mlir_ciface_fold"; + llvm::Expected> JitModule::create(Operation *op, const ExecutionEngineOptions &options, std::unique_ptr tm, bool transform) { @@ -63,14 +60,6 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, return exec.takeError(); } auto &engine = *exec; - uint32_t numOrigArgs; - { - auto expectArgs = engine->lookup("__num_orig_num_args"); - if (!expectArgs) { - return expectArgs.takeError(); - } - numOrigArgs = *reinterpret_cast(*expectArgs); - } JitModuleFuncT compute; { auto expectCompute = engine->lookupPacked(defaultComputeName); @@ -79,176 +68,15 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, } compute = *expectCompute; } - llvm::ArrayRef foldBufferIds; - JitModuleFuncT fold = nullptr; - llvm::ArrayRef computeArgs; - llvm::ArrayRef foldArgs; - do { - auto expectBufferIds = engine->lookup("__fold_buffer_ids"); - if (!expectBufferIds) { - // nothing to fold, It is OK. - llvm::consumeError(expectBufferIds.takeError()); - // break out of the scope, don't need to lookup "fold" function - break; - } else { - auto raw = reinterpret_cast(*expectBufferIds); - foldBufferIds = llvm::ArrayRef{raw + 1, raw[0]}; - } - - // find "fold" func - { - auto expectFold = engine->lookupPacked(defaultFoldName); - if (!expectFold) { - return expectFold.takeError(); - } - fold = *expectFold; - } - - // find "foldArgs" - { - auto expectFold = engine->lookup("__fold_args"); - if (!expectFold) { - return expectFold.takeError(); - } - auto raw = reinterpret_cast(*expectFold); - foldArgs = llvm::ArrayRef{raw + 1, raw[0]}; - } - - // find "computeArgs" - { - auto expect = engine->lookup("__compute_args"); - if (!expect) { - return expect.takeError(); - } - auto raw = reinterpret_cast(*expect); - computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; - } - } while (false); - - std::vector> foldInfo; - foldInfo.reserve(foldBufferIds.size()); - for (auto bufId : foldBufferIds) { - auto ret = queryCacheTensor(bufId); - if (!ret) { - return llvm::make_error( - "Failed to query the folded cached tensor", - llvm::inconvertibleErrorCode()); - } - foldInfo.emplace_back(std::move(ret)); - } - - return std::make_shared(std::move(engine), compute, fold, - numOrigArgs, computeArgs, foldArgs, - std::move(foldInfo)); + return std::make_shared(std::move(engine), compute); } JitModule::JitModule( - std::unique_ptr engine, JitModuleFuncT compute, - JitModuleFuncT fold, size_t numOrigArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs, - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs, - std::vector> &&cachekeepAlive) - : engine{std::move(engine)}, compute{compute}, fold{fold}, - numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, - computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { - for (const auto &cache : keepAlive) { - auto currentItr = - std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); - if (currentItr == cacheBases.end()) { - cacheBases.push_back(cache->base.get()); - currentItr = cacheBases.end() - 1; - } - cacheInfo.emplace_back(CacheBufferInfo{ - static_cast(currentItr - cacheBases.begin()), cache->offset}); - fastFoldBuffers.push_back(&cache->ref); - } + std::unique_ptr engine, JitModuleFuncT compute) + : engine{std::move(engine)}, compute{compute} { } JitModule::~JitModule() = default; -static void prepareCallArgs(llvm::SmallVector &realargs, - GeneralMemrefPtr *origargs, size_t numOrigArgs, - GeneralMemrefPtr *foldedCache, - llvm::ArrayRef realArgIdx) { - realargs.reserve(realArgIdx.size()); - for (auto argIdx : realArgIdx) { - if (argIdx < numOrigArgs) { - realargs.push_back(&origargs[argIdx]); - } else { - realargs.push_back(&foldedCache[argIdx - numOrigArgs]); - } - } -} - -void JitModule::call(GeneralMemrefPtr *args) { - if (unlikely(cacheInfo.empty())) { - // fast path, no folded cached buffers - // Silly code, MLIR execution engine requires pointers of real args as - // inputs - llvm::SmallVector realargs; - realargs.reserve(numOrigArgs); - for (size_t i = 0; i < numOrigArgs; i++) { - realargs.push_back(&args[i]); - } - compute(realargs.data()); - return; - } - - // stage 1, acquire the foldBasePtr - llvm::SmallVector foldBasePtr; - int32_t inited = 1; - for (auto b : cacheBases) { - auto ptr = b->acquire(&inited); - if (unlikely(!ptr)) { - ptr = std::aligned_alloc(/*alignment*/ 64, b->size); - inited = 0; - } - foldBasePtr.push_back((char *)ptr); - } - - // stage 2, run fold() if necessary - GeneralMemrefPtr *foldedCache; - // only used when !inited - std::vector slowFold; - std::vector> slowFoldObj; - if (likely(inited)) { - foldedCache = fastFoldBuffers.data(); - } else { - slowFold.reserve(cacheInfo.size()); - slowFoldObj.reserve(cacheInfo.size()); - for (auto &info : cacheInfo) { - slowFoldObj.emplace_back(); - auto &obj = slowFoldObj.back(); - obj.basePtr = foldBasePtr[info.baseIdx] + info.offset; - obj.data = obj.basePtr; - memset(obj.sizes, 0, sizeof(obj.sizes)); - memset(obj.strides, 0, sizeof(obj.strides)); - slowFold.push_back(&obj); - } - foldedCache = slowFold.data(); - llvm::SmallVector realargs; - prepareCallArgs(realargs, args, numOrigArgs, foldedCache, foldArgs); - fold(realargs.data()); - } - - // stage 3, call compute - { - llvm::SmallVector realargs; - prepareCallArgs(realargs, args, numOrigArgs, foldedCache, computeArgs); - compute(realargs.data()); - } - - // stage 4, cleanup - for (size_t i = 0; i < cacheBases.size(); i++) { - auto b = cacheBases[i]; - if (unlikely(!b->release())) { - // if the cached buffer is already free'd, foldBasePtr[i] is allocated via - // std::aligned_alloc by us, free it - std::free(foldBasePtr[i]); - } - } -} } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 71a73bf8b..84748bb81 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -63,113 +63,8 @@ TEST(ExecutionEngine, JitWrapper) { {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; OwningMemRef bufC{{128}, {128}}; void *args[] = {&*bufA, &*bufB, &*bufC}; - jited.get()->call(args); + jited.get()->call(args, 3); for (int i = 0; i < 128; i++) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } - -// compute d = (a+a) + (b+b) + c, where a,b is marked constant -// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 -static const char code2[] = R"mlir( -module { -llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 -llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> -// a,b, foldedA,foldedB -llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> -// foldedA, foldedB, c, d -llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> - -func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { - %c0 = arith.constant 0 : index - cpuruntime.printf "HI%zu\n" %c0 : index - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - %out2 = tensor.empty() : tensor<128xf32> - %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> - return %2, %3 : tensor<128xf32>, tensor<128xf32> -} - -func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> - return %d : tensor<128xf32> -} -} -)mlir"; - -TEST(ExecutionEngine, JitWrapperCached) { - MLIRContext ctx{gc::initAndGetDialects()}; - std::unique_ptr ir_buffer = - llvm::MemoryBuffer::getMemBuffer(code2); - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); - mlir::OwningOpRef module = - mlir::parseSourceFile(sourceMgr, &ctx); - - // foldedA and foldedB uses this buffer - auto ret = std::shared_ptr(new float[128 * 2]); - auto proxy = std::make_shared( - ret, ret.get(), 128 * 2 * sizeof(float), true); - - ASSERT_TRUE(gc::regCachedTensor(114514, proxy, 0)); - ASSERT_TRUE(gc::regCachedTensor(1919810, proxy, 128 * sizeof(float))); - - ASSERT_TRUE(module); - auto jited = gc::JitModule::create(module.get()); - bool jit_success = static_cast(jited); - if (!jit_success) { - auto err = jited.takeError(); - llvm::errs() << err; - llvm::consumeError(std::move(err)); - } - ASSERT_TRUE(jit_success); - OwningMemRef bufA{ - {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; - OwningMemRef bufB{ - {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; - OwningMemRef bufC{ - {128}, {128}, [](float &ptr, ArrayRef idx) { - ptr = -idx[0] * 3; - }}; - OwningMemRef bufD{{128}, {128}}; - void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; - - // first call, should run fold() - { - testing::internal::CaptureStdout(); - // first call, should run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_EQ(output, "HI0\n"); - } - - { - testing::internal::CaptureStdout(); - // second call, should not run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_TRUE(output.empty()); - } - - // the cache is evicted - proxy->deref(); - { - testing::internal::CaptureStdout(); - // third call, should run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_EQ(output, "HI0\n"); - } -} From 824946b01440b0879a74779d41c922b600e91b02 Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 30 May 2024 11:25:08 +0800 Subject: [PATCH 36/64] Revert "cleanup" This reverts commit 206c3f3df11b10a8384ff98d15feaf129a127c4e. --- .../CPURuntime/ConstantCache.h | 137 +++++++++++++ include/gc/ExecutionEngine/Driver/Driver.h | 51 +++-- .../ExecutionEngine/CPURuntime/CMakeLists.txt | 1 + .../CPURuntime/ConstantCache.cpp | 46 +++++ lib/gc/ExecutionEngine/Driver/Driver.cpp | 180 +++++++++++++++++- unittests/ExecutionEngine/JitWrapper.cpp | 107 ++++++++++- 6 files changed, 498 insertions(+), 24 deletions(-) create mode 100644 include/gc/ExecutionEngine/CPURuntime/ConstantCache.h create mode 100644 lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h new file mode 100644 index 000000000..f41cb09e8 --- /dev/null +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -0,0 +1,137 @@ +//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include +#include +#include + +namespace mlir { +namespace gc { +/** + * The helper class to manage ref count manually for an object allocated with + * shared ptr. It holds an additional shared ptr reference to the object and + * contains an additional self-managed refcount. The refcount will be set to 1 + * when the object is initialized (see init()). When the refcount counts down to + * 0, the additional shared ptr is reset. + */ +struct RefCountManaged { + RefCountManaged() = default; + RefCountManaged(const std::shared_ptr &vkeepAlive) { init(vkeepAlive); } + void init(const std::shared_ptr &vkeepAlive) { + keepAlive = vkeepAlive; + refCount.store(1); + } + + void ref() { ++refCount; } + void deref() { + auto newv = --refCount; + if (newv == 0) { + keepAlive = nullptr; + } + } + + // atomically check if refCount > 0. if so, ref() the object and return + // true. Otherwise (if refCount==0), return false + bool checkAliveAndRef() { + auto oldv = refCount.load(); + for (;;) { + if (oldv <= 0) { + return false; + } + if (refCount.compare_exchange_strong(oldv, oldv + 1)) { + return true; + } + // CAS failed, oldv has now the newest known value of refCount + } + } + + bool isAlive() const { return refCount > 0; } + void *getPtrUnsafe() const { return keepAlive.get(); } + +private: + std::shared_ptr keepAlive; + std::atomic refCount{0}; +}; + +/** + * The proxy for the constant cache of Graph API. It holds a shared ptr pointing + * to the cache item in the cache manager (keepAlive) to extend the lifetime by + * refcount, @see RefCountManaged. To access the memory buffer of the const + * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy + * to make sure the cache is alive after calling acauire and before release. The + * cache manager of Graph API may evict the cache item by dereferenceing this + * RefCountManaged object. {acquire,release} functions will find out that the + * cache has been invalidated. Usually we expect JIT modules to hold shared ptr + * to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache + * item's lifetime will be managed by the cache manager of Graph API and it is + * filled with data after the first execution of the computation. Otherwise, the + * cache item is always alive as long as the jit_module of the kernel is alive. + */ +struct ConstCacheProxy : RefCountManaged { + ConstCacheProxy(const std::shared_ptr &vkeepAlive, void *buffer, + size_t size, bool is_lazy) + : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), + buffer(buffer) {} + ~ConstCacheProxy(); + + // get the buffer and increment the refcount. If the buffer is evicted, + // returns null + void *acquire(int32_t *inited) { + if (checkAliveAndRef()) { + *inited = *inited && initialized; + return buffer; + } + return nullptr; + } + // decrement the refcount + bool release() { + if (isAlive()) { + deref(); + initialized = 1; + return true; + } + return false; + } + + // return the buffer. Do not directly use the buffer because it may be already + // release! To access the buffer, always acquire() before using it. + void *getBufferUnsafe() const { return buffer; } + + size_t size; + // if the buffer is lazy-initialized. If false, it should be filled before + // computation + bool isLazy; + +private: + // raw pointer to the buffer + void *buffer; + // if the buffer has been initialized. calling release() will set this to 1 + int32_t initialized = 0; +}; + +struct CachedGraphTensor { + std::shared_ptr base; + size_t offset; + CachedGraphTensor(const std::shared_ptr &base, + size_t offset); + friend class JitModule; + +private: + StridedMemRefType ref; +}; + +std::shared_ptr queryCacheTensor(uint64_t key); +bool regCachedTensor(uint64_t key, const std::shared_ptr &base, + size_t offset); + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 694cbbe5e..1ca5aa9f4 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -9,6 +9,7 @@ #ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H #define GC_EXECUTIONENGINE_DRIVER_DRIVER_H +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include @@ -24,7 +25,7 @@ const DialectRegistry &initAndGetDialects(); using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); -class JitModule { +class JitModule : public std::enable_shared_from_this { public: static llvm::Expected> create(Operation *op, const ExecutionEngineOptions &options = {}, @@ -32,29 +33,41 @@ class JitModule { bool transform = true); // args should be an array of XXXMemrefType* - void call(GeneralMemrefPtr *args, std::size_t numArgs) { - // Silly code, MLIR execution engine requires pointers of real args as - // inputs - llvm::SmallVector realargs; - realargs.reserve(numArgs); - for (size_t i = 0; i < numArgs; i++) { - realargs.push_back(&args[i]); - } - compute(realargs.data()); - } - - // directly call compute(). args should be an array of void*. args[i] should - // be a pointer to the real data. For passing memref, users need to 1) create - // a pointer to XXXMemrefType 2) store the pointer to pointer to XXXMemrefType - // in args[i] - void callRaw(void **args) { compute(args); } - - JitModule(std::unique_ptr engine, JitModuleFuncT compute); + void call(GeneralMemrefPtr *args); + + JitModule( + std::unique_ptr engine, JitModuleFuncT compute, + JitModuleFuncT fold, size_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive = {}); ~JitModule(); private: std::unique_ptr engine; JitModuleFuncT compute; + JitModuleFuncT fold; + size_t numOrigArgs; + // `keepAlive` has the ownership of the objects pointed by this vector + llvm::SmallVector cacheBases; + struct CacheBufferInfo { + // index in cacheBases + size_t baseIdx; + size_t offset; + }; + // the info for each folded cached buffer + llvm::SmallVector cacheInfo; + // holding the pointers to StridedMemRefType of folded cache + // `keepAlive` holds the the ownership of the pointers + llvm::SmallVector fastFoldBuffers; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs; + + std::vector> keepAlive; }; } // namespace gc diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 6be58e28f..97f039834 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") add_mlir_library(GCCpuRuntime SHARED Parallel.cpp + ConstantCache.cpp EXCLUDE_FROM_LIBMLIR ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp new file mode 100644 index 000000000..ff45cd180 --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp @@ -0,0 +1,46 @@ +//===-- ConstantCache.cpp - Constant cache ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" +#include +#include + +namespace mlir::gc { + +ConstCacheProxy::~ConstCacheProxy() = default; + +CachedGraphTensor::CachedGraphTensor( + const std::shared_ptr &base, size_t offset) + : base{base}, offset{offset} { + // todo: fill in real values + ref.basePtr = (char *)base->getBufferUnsafe() + offset; + ref.data = ref.basePtr; + ref.offset = 0; + memset(ref.sizes, 0, sizeof(ref.sizes)); + memset(ref.strides, 0, sizeof(ref.strides)); +} + +static std::unordered_map> cache; + +std::shared_ptr queryCacheTensor(uint64_t key) { + auto itr = cache.find(key); + if (itr != cache.end()) { + return itr->second; + } + return nullptr; +} + +bool regCachedTensor(uint64_t key, const std::shared_ptr &base, + size_t offset) { + if (queryCacheTensor(key)) { + return false; + } + cache[key] = std::make_shared(base, offset); + return true; +} +} // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 4f5c33867..4b9b6c42b 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -18,6 +18,9 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + namespace mlir { namespace gc { @@ -43,7 +46,7 @@ const DialectRegistry &initAndGetDialects() { } static const char defaultComputeName[] = "_mlir_ciface_compute"; - +static const char defaultFoldName[] = "_mlir_ciface_fold"; llvm::Expected> JitModule::create(Operation *op, const ExecutionEngineOptions &options, std::unique_ptr tm, bool transform) { @@ -60,6 +63,14 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, return exec.takeError(); } auto &engine = *exec; + uint32_t numOrigArgs; + { + auto expectArgs = engine->lookup("__num_orig_num_args"); + if (!expectArgs) { + return expectArgs.takeError(); + } + numOrigArgs = *reinterpret_cast(*expectArgs); + } JitModuleFuncT compute; { auto expectCompute = engine->lookupPacked(defaultComputeName); @@ -68,15 +79,176 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, } compute = *expectCompute; } - return std::make_shared(std::move(engine), compute); + llvm::ArrayRef foldBufferIds; + JitModuleFuncT fold = nullptr; + llvm::ArrayRef computeArgs; + llvm::ArrayRef foldArgs; + do { + auto expectBufferIds = engine->lookup("__fold_buffer_ids"); + if (!expectBufferIds) { + // nothing to fold, It is OK. + llvm::consumeError(expectBufferIds.takeError()); + // break out of the scope, don't need to lookup "fold" function + break; + } else { + auto raw = reinterpret_cast(*expectBufferIds); + foldBufferIds = llvm::ArrayRef{raw + 1, raw[0]}; + } + + // find "fold" func + { + auto expectFold = engine->lookupPacked(defaultFoldName); + if (!expectFold) { + return expectFold.takeError(); + } + fold = *expectFold; + } + + // find "foldArgs" + { + auto expectFold = engine->lookup("__fold_args"); + if (!expectFold) { + return expectFold.takeError(); + } + auto raw = reinterpret_cast(*expectFold); + foldArgs = llvm::ArrayRef{raw + 1, raw[0]}; + } + + // find "computeArgs" + { + auto expect = engine->lookup("__compute_args"); + if (!expect) { + return expect.takeError(); + } + auto raw = reinterpret_cast(*expect); + computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; + } + } while (false); + + std::vector> foldInfo; + foldInfo.reserve(foldBufferIds.size()); + for (auto bufId : foldBufferIds) { + auto ret = queryCacheTensor(bufId); + if (!ret) { + return llvm::make_error( + "Failed to query the folded cached tensor", + llvm::inconvertibleErrorCode()); + } + foldInfo.emplace_back(std::move(ret)); + } + + return std::make_shared(std::move(engine), compute, fold, + numOrigArgs, computeArgs, foldArgs, + std::move(foldInfo)); } JitModule::JitModule( - std::unique_ptr engine, JitModuleFuncT compute) - : engine{std::move(engine)}, compute{compute} { + std::unique_ptr engine, JitModuleFuncT compute, + JitModuleFuncT fold, size_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef computeArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive) + : engine{std::move(engine)}, compute{compute}, fold{fold}, + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, + computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { + for (const auto &cache : keepAlive) { + auto currentItr = + std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); + if (currentItr == cacheBases.end()) { + cacheBases.push_back(cache->base.get()); + currentItr = cacheBases.end() - 1; + } + cacheInfo.emplace_back(CacheBufferInfo{ + static_cast(currentItr - cacheBases.begin()), cache->offset}); + fastFoldBuffers.push_back(&cache->ref); + } } JitModule::~JitModule() = default; +static void prepareCallArgs(llvm::SmallVector &realargs, + GeneralMemrefPtr *origargs, size_t numOrigArgs, + GeneralMemrefPtr *foldedCache, + llvm::ArrayRef realArgIdx) { + realargs.reserve(realArgIdx.size()); + for (auto argIdx : realArgIdx) { + if (argIdx < numOrigArgs) { + realargs.push_back(&origargs[argIdx]); + } else { + realargs.push_back(&foldedCache[argIdx - numOrigArgs]); + } + } +} + +void JitModule::call(GeneralMemrefPtr *args) { + if (unlikely(cacheInfo.empty())) { + // fast path, no folded cached buffers + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numOrigArgs); + for (size_t i = 0; i < numOrigArgs; i++) { + realargs.push_back(&args[i]); + } + compute(realargs.data()); + return; + } + + // stage 1, acquire the foldBasePtr + llvm::SmallVector foldBasePtr; + int32_t inited = 1; + for (auto b : cacheBases) { + auto ptr = b->acquire(&inited); + if (unlikely(!ptr)) { + ptr = std::aligned_alloc(/*alignment*/ 64, b->size); + inited = 0; + } + foldBasePtr.push_back((char *)ptr); + } + + // stage 2, run fold() if necessary + GeneralMemrefPtr *foldedCache; + // only used when !inited + std::vector slowFold; + std::vector> slowFoldObj; + if (likely(inited)) { + foldedCache = fastFoldBuffers.data(); + } else { + slowFold.reserve(cacheInfo.size()); + slowFoldObj.reserve(cacheInfo.size()); + for (auto &info : cacheInfo) { + slowFoldObj.emplace_back(); + auto &obj = slowFoldObj.back(); + obj.basePtr = foldBasePtr[info.baseIdx] + info.offset; + obj.data = obj.basePtr; + memset(obj.sizes, 0, sizeof(obj.sizes)); + memset(obj.strides, 0, sizeof(obj.strides)); + slowFold.push_back(&obj); + } + foldedCache = slowFold.data(); + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numOrigArgs, foldedCache, foldArgs); + fold(realargs.data()); + } + + // stage 3, call compute + { + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numOrigArgs, foldedCache, computeArgs); + compute(realargs.data()); + } + + // stage 4, cleanup + for (size_t i = 0; i < cacheBases.size(); i++) { + auto b = cacheBases[i]; + if (unlikely(!b->release())) { + // if the cached buffer is already free'd, foldBasePtr[i] is allocated via + // std::aligned_alloc by us, free it + std::free(foldBasePtr[i]); + } + } +} } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 84748bb81..71a73bf8b 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -63,8 +63,113 @@ TEST(ExecutionEngine, JitWrapper) { {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; OwningMemRef bufC{{128}, {128}}; void *args[] = {&*bufA, &*bufB, &*bufC}; - jited.get()->call(args, 3); + jited.get()->call(args); for (int i = 0; i < 128; i++) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } + +// compute d = (a+a) + (b+b) + c, where a,b is marked constant +// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 +static const char code2[] = R"mlir( +module { +llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 +llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> +// a,b, foldedA,foldedB +llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> +// foldedA, foldedB, c, d +llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> + +func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { + %c0 = arith.constant 0 : index + cpuruntime.printf "HI%zu\n" %c0 : index + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %out2 = tensor.empty() : tensor<128xf32> + %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> + return %2, %3 : tensor<128xf32>, tensor<128xf32> +} + +func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> + return %d : tensor<128xf32> +} +} +)mlir"; + +TEST(ExecutionEngine, JitWrapperCached) { + MLIRContext ctx{gc::initAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code2); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + + // foldedA and foldedB uses this buffer + auto ret = std::shared_ptr(new float[128 * 2]); + auto proxy = std::make_shared( + ret, ret.get(), 128 * 2 * sizeof(float), true); + + ASSERT_TRUE(gc::regCachedTensor(114514, proxy, 0)); + ASSERT_TRUE(gc::regCachedTensor(1919810, proxy, 128 * sizeof(float))); + + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get()); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{ + {128}, {128}, [](float &ptr, ArrayRef idx) { + ptr = -idx[0] * 3; + }}; + OwningMemRef bufD{{128}, {128}}; + void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; + + // first call, should run fold() + { + testing::internal::CaptureStdout(); + // first call, should run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_EQ(output, "HI0\n"); + } + + { + testing::internal::CaptureStdout(); + // second call, should not run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_TRUE(output.empty()); + } + + // the cache is evicted + proxy->deref(); + { + testing::internal::CaptureStdout(); + // third call, should run fold() + jited.get()->call(args); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + std::string output = testing::internal::GetCapturedStdout(); + ASSERT_EQ(output, "HI0\n"); + } +} From bc9a7ad97751996e73092045ec49395c410b53aa Mon Sep 17 00:00:00 2001 From: "Mei, Yijie" Date: Thu, 30 May 2024 11:36:55 +0800 Subject: [PATCH 37/64] refine options --- include/gc/ExecutionEngine/Driver/Driver.h | 14 ++++++++++---- lib/gc/ExecutionEngine/Driver/Driver.cpp | 12 +++++++----- unittests/ExecutionEngine/JitWrapper.cpp | 2 +- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 694cbbe5e..2ce9531bd 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -18,18 +18,24 @@ namespace mlir { class DialectRegistry; namespace gc { -const DialectRegistry &initAndGetDialects(); +const DialectRegistry &initCompilerAndGetDialects(); // the pointers to XXXMemRefType using GeneralMemrefPtr = void *; using JitModuleFuncT = void (*)(void **); +struct DriverOptions { + // the optimization level for the LLVM-JIT + llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; + // whether to run the MLIR transformation passes + bool runTransforms = true; + // todo: target machine, etc. +}; + class JitModule { public: static llvm::Expected> - create(Operation *op, const ExecutionEngineOptions &options = {}, - std::unique_ptr tm = nullptr, - bool transform = true); + create(Operation *op, const DriverOptions &options = {}); // args should be an array of XXXMemrefType* void call(GeneralMemrefPtr *args, std::size_t numArgs) { diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 4f5c33867..6fc8025be 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -37,7 +37,7 @@ static DialectRegistry initDialects() { return registry; } -const DialectRegistry &initAndGetDialects() { +const DialectRegistry &initCompilerAndGetDialects() { static DialectRegistry reg = initDialects(); return reg; } @@ -45,9 +45,8 @@ const DialectRegistry &initAndGetDialects() { static const char defaultComputeName[] = "_mlir_ciface_compute"; llvm::Expected> -JitModule::create(Operation *op, const ExecutionEngineOptions &options, - std::unique_ptr tm, bool transform) { - if (transform) { +JitModule::create(Operation *op, const DriverOptions &options) { + if (options.runTransforms) { mlir::PassManager pm{op->getContext()}; populateCPUPipeline(pm); if (auto result = pm.run(op); failed(result)) { @@ -55,7 +54,10 @@ JitModule::create(Operation *op, const ExecutionEngineOptions &options, "MLIR pass error", llvm::inconvertibleErrorCode()); } } - auto exec = ExecutionEngine::create(op, options, std::move(tm)); + ExecutionEngineOptions exeOptions; + exeOptions.jitCodeGenOptLevel = options.jitCodeGenOptLevel; + std::unique_ptr tm = nullptr; + auto exec = ExecutionEngine::create(op, exeOptions, std::move(tm)); if (!exec) { return exec.takeError(); } diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp index 84748bb81..f7b93eaa6 100644 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ b/unittests/ExecutionEngine/JitWrapper.cpp @@ -40,7 +40,7 @@ extern int gc_runtime_keep_alive; TEST(ExecutionEngine, JitWrapper) { gc_runtime_keep_alive = 0; - MLIRContext ctx{gc::initAndGetDialects()}; + MLIRContext ctx{gc::initCompilerAndGetDialects()}; std::unique_ptr ir_buffer = llvm::MemoryBuffer::getMemBuffer(code1); // Parse the input mlir. From 94f28137c2d71428deeb71e4948a7375a4be216c Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Thu, 30 May 2024 15:49:12 +0800 Subject: [PATCH 38/64] Support cpmplex topo --- lib/gc/Transforms/CST.cpp | 161 +++++++++++------- .../test_constant_weights_folding-1.mlir | 43 +++-- .../test_constant_weights_folding.mlir | 12 +- 3 files changed, 128 insertions(+), 88 deletions(-) diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index dc5c332e5..c60cea97e 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -30,7 +30,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" namespace mlir { namespace gc { @@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; // void *allocator(size_t size) { return std::aligned_alloc(64, size); } // void deallocator(void *ptr) { std::free(ptr); } -std::shared_ptr createConstCacheProxy(size_t size) { - // simply allocate buffer and return - std::shared_ptr base = std::shared_ptr{ - std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; - return std::make_shared(base, base.get(), size, true); -} +// std::shared_ptr createConstCacheProxy(size_t size) { +// // simply allocate buffer and return +// std::shared_ptr base = std::shared_ptr{ +// std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; +// return std::make_shared(base, base.get(), size, true); +// } size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } @@ -329,12 +329,12 @@ struct constGraphTensorCacheManager { totalSize += divideAndCeil(buffersSize[i], 64) * 64; } llvm::dbgs() << "Alloc total size: " << totalSize << '\n'; - auto base = createConstCacheProxy(totalSize); + // auto base = createConstCacheProxy(totalSize); std::vector globalIds(buffersSize.size()); size_t offset = 0; for (size_t i = 0; i < buffersSize.size(); i++) { llvm::dbgs() << "Alloc offset: " << offset << '\n'; - regCachedTensor(cachedTensorGlobalId, base, offset); + // regCachedTensor(cachedTensorGlobalId, base, offset); globalIds[i] = cachedTensorGlobalId; ++cachedTensorGlobalId; offset += divideAndCeil(buffersSize[i], 64) * 64; @@ -431,11 +431,11 @@ void CST::runOnOperation() { // values of folded constant weights in original block SmallVector outputValues; Value v; - // TODO: solve complicated topology. Currently we only handle simple topology - // where one constant weight input will and only will produce one constant - // output and each constant weight only contributes to one constant output. + // Support complicated topology. for (size_t id = 0; id < block.getNumArguments(); ++id) { if (constArgsIndexes.count(id) == 1) { + // The constant ops are all single-input single-output. + bool simpleTopo = true; auto arg = block.getArgument(id); if (!isa(arg.getType())) { continue; @@ -444,54 +444,72 @@ void CST::runOnOperation() { v = dyn_cast(arg); inputValues.push_back(v); SmallVector valuesOnTheWay = {v}; // the constant tensors + std::deque dq; + dq.push_back(v); // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2 - while (!v.getUsers().empty()) { - // v.getUsers().size() should be 1 - Operation *user = *(v.getUsers().begin()); - // If user is not const or user has multiple operand, we reach the end - if (!isInConstantSubgraph(user) || !singleOperand(user)) { - outputTypes.push_back(v.getType()); - outputValues.push_back(v); - break; + while (!dq.empty()) { + v = dq.front(); + dq.pop_front(); + // if the children ops of v are not all constant, we end at v + if (std::any_of(v.getUsers().begin(), v.getUsers().end(), + [](Operation *child) { + return !isInConstantSubgraph(child); + })) { + if (std::find(outputValues.begin(), outputValues.end(), v) == + outputValues.end()) { + outputTypes.push_back(v.getType()); + outputValues.push_back(v); + } + continue; + } + if (!v.hasOneUse()) { + simpleTopo = false; + } + // the children ops of v are all constant, we push their results to + // queue + for (Operation *child : v.getUsers()) { + if (!singleOperand(child) || child->getResults().size() > 1) { + simpleTopo = false; + } + for (OpResult result : child->getResults()) { + auto r = dyn_cast(result); + dq.push_back(r); + valuesOnTheWay.push_back(r); + } } - // user should has only 1 output value - OpResult result = *(user->result_begin()); - v = dyn_cast(result); - valuesOnTheWay.push_back(v); } // If data size of outputValue is too greater than size of inputValue, do // not fold it. Compare data size changes during traverse to find the last // op that satisfies this condition. - int64_t initSize = - getTensorSize(dyn_cast(valuesOnTheWay[0].getType())); - if (!isa(outputTypes.back()) || - initSize * DATA_SIZE_EXPANDING_THRESHOLD < - getTensorSize(dyn_cast(outputTypes.back()))) { - size_t lastIdx = 0; - for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { - int64_t size = - getTensorSize(dyn_cast(valuesOnTheWay[i].getType())); - if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { - lastIdx = i; + if (simpleTopo) { + int64_t initSize = + getTensorSize(dyn_cast(valuesOnTheWay[0].getType())); + if (!isa(outputTypes.back()) || + initSize * DATA_SIZE_EXPANDING_THRESHOLD < + getTensorSize(dyn_cast(outputTypes.back()))) { + size_t lastIdx = 0; + for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { + int64_t size = getTensorSize( + dyn_cast(valuesOnTheWay[i].getType())); + if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { + lastIdx = i; + } + } + if (lastIdx == 0) { // no suitable value found + inputTypes.pop_back(); + outputTypes.pop_back(); + inputValues.pop_back(); + outputValues.pop_back(); + constArgsIndexes.erase(id); + } else { + outputTypes.back() = valuesOnTheWay[lastIdx].getType(); + outputValues.back() = valuesOnTheWay[lastIdx]; } - } - if (lastIdx == 0) { // no suitable value found - inputTypes.pop_back(); - outputTypes.pop_back(); - inputValues.pop_back(); - outputValues.pop_back(); - constArgsIndexes.erase(id); - } else { - outputTypes.back() = valuesOnTheWay[lastIdx].getType(); - outputValues.back() = valuesOnTheWay[lastIdx]; } } } } - if (inputTypes.size() != outputTypes.size()) { - return; - } FunctionType foldFuncType = FunctionType::get(context, inputTypes, outputTypes); @@ -548,30 +566,34 @@ void CST::runOnOperation() { moduleOp.push_back(foldFunc); symbolTable.insert(foldFunc); + // the indexes of args to the folding func. SmallVector foldArgs; + // the indexes of folded args. SmallVector foldIds; + // the indexes of args to the computing func. SmallVector computeArgs; // modify the BlockArguments of block size_t oriNumArgs = block.getNumArguments(); - size_t argIdx = 0; + // Add the folded args to the end of BlockArguments list + for (size_t id = 0; id < outputValues.size(); ++id) { + auto loc = block.getArgument(id).getLoc(); + BlockArgument foldArg = + block.insertArgument(oriNumArgs + id, outputTypes[id], loc); + outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == █ + }); + foldIds.push_back(id + oriNumArgs); + } + // Erase the operations on constant args for (size_t id = 0; id < oriNumArgs; ++id) { if (constArgsIndexes.count(id) == 1) { foldArgs.push_back(id); - foldIds.push_back(argIdx + oriNumArgs); - computeArgs.push_back(argIdx + oriNumArgs); - auto loc = block.getArgument(id).getLoc(); - BlockArgument foldArg = - block.insertArgument(id, outputTypes[argIdx], loc); - outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) { - Operation *op = val.getOwner(); - return op->getBlock() == █ - }); - std::deque dq; SmallVector opsToErase; std::unordered_set opsToEraseSet; - dq.push_back(block.getArgument(id + 1)); + dq.push_back(block.getArgument(id)); while (!dq.empty()) { Value v = dq.front(); dq.pop_front(); @@ -586,16 +608,26 @@ void CST::runOnOperation() { opsToEraseSet.insert(op); } } - for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) { (*it)->erase(); } - block.eraseArgument(id + 1); - ++argIdx; } else { computeArgs.push_back(id); } } + // Erase the constant args in BlockArguments list + llvm::BitVector argsToErase; + for (size_t id = 0; id < oriNumArgs; ++id) { + if (constArgsIndexes.count(id) == 1) { + argsToErase.push_back(true); + } else { + argsToErase.push_back(false); + } + } + for (size_t id = 0; id < outputValues.size(); ++id) { + argsToErase.push_back(false); + } + block.eraseArguments(argsToErase); for (auto id : foldIds) { foldArgs.insert(foldArgs.end(), id); @@ -604,6 +636,9 @@ void CST::runOnOperation() { addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args", foldArgs); + for (auto id : foldIds) { + computeArgs.insert(computeArgs.end(), id); + } computeArgs.insert(computeArgs.begin(), computeArgs.size()); addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args", computeArgs); diff --git a/test/gc/Transforms/test_constant_weights_folding-1.mlir b/test/gc/Transforms/test_constant_weights_folding-1.mlir index b446212c5..940255f60 100644 --- a/test/gc/Transforms/test_constant_weights_folding-1.mlir +++ b/test/gc/Transforms/test_constant_weights_folding-1.mlir @@ -19,32 +19,31 @@ module { // CHECK: cpuruntime.printf // CHECK: linalg.add -// CHECK: linalg.add // CHECK: func.func @fold // CHECK: linalg.add // CHECK: linalg.add +// CHECK: linalg.add // COM: expected output: // COM: module { -// COM: llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 -// COM: llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> -// COM: // a,b, foldedA,foldedB -// COM: llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> -// COM: // foldedA, foldedB, c, d -// COM: llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> -// COM: func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { -// COM: %c0 = arith.constant 0 : index -// COM: cpuruntime.printf "HI%zu\n" %c0 : index -// COM: %out = tensor.empty() : tensor<128xf32> -// COM: %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> -// COM: %out2 = tensor.empty() : tensor<128xf32> -// COM: %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> -// COM: return %2, %3 : tensor<128xf32>, tensor<128xf32> -// COM: } -// COM: func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { -// COM: %out = tensor.empty() : tensor<128xf32> -// COM: %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> -// COM: %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> -// COM: return %d : tensor<128xf32> -// COM: } +// COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32 +// COM: llvm.mlir.global external constant @__compute_args(dense<[2, 2, 3]> : tensor<3xi32>) {addr_space = 0 : i32} : !llvm.array<3 x i32> +// COM: llvm.mlir.global external constant @__fold_args(dense<[3, 0, 1, 3]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> +// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[1, 0]> : tensor<2xi64>) {addr_space = 0 : i32} : !llvm.array<2 x i64> +// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} { +// COM: %c0 = arith.constant 0 : index +// COM: cpuruntime.printf "HI%zu\0A" %c0 : index +// COM: %0 = tensor.empty() : tensor<128xf32> +// COM: %1 = linalg.add ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> +// COM: return %1 : tensor<128xf32> +// COM: } +// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface} { +// COM: %0 = tensor.empty() : tensor<128xf32> +// COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> +// COM: %2 = tensor.empty() : tensor<128xf32> +// COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> +// COM: %4 = tensor.empty() : tensor<128xf32> +// COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32> +// COM: return %5 : tensor<128xf32> +// COM: } // COM: } \ No newline at end of file diff --git a/test/gc/Transforms/test_constant_weights_folding.mlir b/test/gc/Transforms/test_constant_weights_folding.mlir index 52885ae7d..485c11e4f 100644 --- a/test/gc/Transforms/test_constant_weights_folding.mlir +++ b/test/gc/Transforms/test_constant_weights_folding.mlir @@ -9,7 +9,7 @@ module { // COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear. // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { %1 = tensor.empty() : tensor<2x16x32x32xbf16> %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> %2 = tensor.empty() : tensor<8x16x32x32xbf16> @@ -71,6 +71,12 @@ module { // CHECK: func.func @fold // CHECK: arith.extf // CHECK: arith.truncf + // COM: expected output: -// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> -// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) +// COM: module { +// COM: llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32 +// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> +// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> +// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} +// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} From 0f67f75deb7874bf7cb8c0f2f9d4fede92ad3d47 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Mon, 3 Jun 2024 16:18:09 +0800 Subject: [PATCH 39/64] Rename --- ...hAnalysis.h => ConstantSubgraphAnalyser.h} | 18 +++++++------- include/gc/Transforms/Passes.h | 8 +++---- include/gc/Transforms/Passes.td | 12 +++++----- lib/gc/Analysis/CMakeLists.txt | 2 +- ...lysis.cpp => ConstantSubgraphAnalyser.cpp} | 24 +++++++++---------- lib/gc/Transforms/CMakeLists.txt | 4 ++-- .../{CSA.cpp => ConstantSubgraphAnalysis.cpp} | 20 +++++++++------- .../{CST.cpp => ConstantTensorFolding.cpp} | 14 +++++++---- src/gc-opt/CMakeLists.txt | 3 +-- ...ir => test_constant_tensor_folding-1.mlir} | 2 +- ...mlir => test_constant_tensor_folding.mlir} | 2 +- 11 files changed, 58 insertions(+), 51 deletions(-) rename include/gc/Analysis/DataFlow/{ConstantSubgraphAnalysis.h => ConstantSubgraphAnalyser.h} (90%) rename lib/gc/Analysis/DataFlow/{ConstantSubgraphAnalysis.cpp => ConstantSubgraphAnalyser.cpp} (90%) rename lib/gc/Transforms/{CSA.cpp => ConstantSubgraphAnalysis.cpp} (67%) rename lib/gc/Transforms/{CST.cpp => ConstantTensorFolding.cpp} (98%) rename test/gc/Transforms/{test_constant_weights_folding-1.mlir => test_constant_tensor_folding-1.mlir} (97%) rename test/gc/Transforms/{test_constant_weights_folding.mlir => test_constant_tensor_folding.mlir} (98%) diff --git a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h similarity index 90% rename from include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h rename to include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h index fcb2939d8..a5a199914 100644 --- a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h +++ b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h @@ -1,4 +1,4 @@ -//===- ConstantSubgraphAnalysis.h - Constant subgraph analysis ------===// +//===- ConstantSubgraphAnalyser.h - Constant subgraph analysis ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -13,8 +13,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSIS_H -#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSIS_H +#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H +#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include @@ -87,10 +87,10 @@ class InConstantSubgraph { }; //===----------------------------------------------------------------------===// -// ConstantSubgraphAnalysis +// ConstantSubgraphAnalyser //===----------------------------------------------------------------------===// -class ConstantSubgraphAnalysis +class ConstantSubgraphAnalyser : public SparseForwardDataFlowAnalysis> { public: using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; @@ -103,13 +103,13 @@ class ConstantSubgraphAnalysis }; //===----------------------------------------------------------------------===// -// RunConstantSubgraphAnalysis +// RunConstantSubgraphAnalyser //===----------------------------------------------------------------------===// /// Runs constant subgraph analysis on the IR defined by `op`. -struct RunConstantSubgraphAnalysis { +struct RunConstantSubgraphAnalyser { public: - RunConstantSubgraphAnalysis(); + RunConstantSubgraphAnalyser(); void run(Operation *op); @@ -124,4 +124,4 @@ struct RunConstantSubgraphAnalysis { } // end namespace dataflow } // end namespace mlir -#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSIS_H +#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 34d2fd487..84096279f 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -15,12 +15,12 @@ namespace mlir { namespace gc { #define GEN_PASS_DECL -#define GEN_PASS_DECL_CSA -#define GEN_PASS_DECL_CST +#define GEN_PASS_DECL_CONSTANTSUBGRAPHANALYSIS +#define GEN_PASS_DECL_CONSTANTTENSORFOLDING #include "gc/Transforms/Passes.h.inc" -std::unique_ptr createCSAPass(); -std::unique_ptr createCSTPass(); +std::unique_ptr createConstantSubgraphAnalysisPass(); +std::unique_ptr createConstantTensorFoldingPass(); #define GEN_PASS_REGISTRATION #include "gc/Transforms/Passes.h.inc" diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 5fd0bd7a7..bba7ea0d7 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -31,20 +31,20 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { ]; } -def CSA : Pass<"csa"> { +def ConstantSubgraphAnalysis : Pass<"constant-subgraph-analysis"> { let summary = "Constant Subgraph Analysis"; let description = [{ This pass implements a constant subgraph analysis. }]; - let constructor = "mlir::gc::createCSAPass()"; + let constructor = "mlir::gc::createConstantSubgraphAnalysisPass()"; } -def CST : Pass<"cst"> { - let summary = "Constant Subgraph Transform"; +def ConstantTensorFolding : Pass<"constant-tensor-folding"> { + let summary = "Constant Tensor Folding Transform"; let description = [{ - This pass implements a constant subgraph transform. + This pass implements a constant tensor folding transform. }]; - let constructor = "mlir::gc::createCSTPass()"; + let constructor = "mlir::gc::createConstantTensorFoldingPass()"; let dependentDialects = [ "tensor::TensorDialect", "linalg::LinalgDialect", diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index 42c3d5541..9b5994f3d 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_library(GCAnalysis - DataFlow/ConstantSubgraphAnalysis.cpp + DataFlow/ConstantSubgraphAnalyser.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp similarity index 90% rename from lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp rename to lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index 2de9e5b4a..741af4697 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalysis.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -1,11 +1,11 @@ -//===- ConstantSubgraphAnalysis.cpp - Constant subgraph analysis ----===// +//===- ConstantSubgraphAnalyser.cpp - Constant subgraph analysis ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h" +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -46,13 +46,13 @@ void InConstantSubgraph::print(raw_ostream &os) const { } //===----------------------------------------------------------------------===// -// ConstantSubgraphAnalysis +// ConstantSubgraphAnalyser //===----------------------------------------------------------------------===// -void ConstantSubgraphAnalysis::visitOperation( +void ConstantSubgraphAnalyser::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { - LLVM_DEBUG(llvm::dbgs() << "ConstantSubgraphAnalysis: Visiting operation:\n" + LLVM_DEBUG(llvm::dbgs() << "ConstantSubgraphAnalyser: Visiting operation:\n" << *op << "\n"); bool in = true; @@ -92,7 +92,7 @@ void ConstantSubgraphAnalysis::visitOperation( } } -void ConstantSubgraphAnalysis::setToEntryState( +void ConstantSubgraphAnalyser::setToEntryState( Lattice *lattice) { if (auto blockArg = cast(lattice->getPoint())) { auto parent_op = blockArg.getParentBlock()->getParentOp(); @@ -121,12 +121,12 @@ void ConstantSubgraphAnalysis::setToEntryState( } //===----------------------------------------------------------------------===// -// RunConstantSubgraphAnalysis +// RunConstantSubgraphAnalyser //===----------------------------------------------------------------------===// /// Get the operations whose inputs and outputs are all constant values. /// These operations will be put into a seperate subgraph. -void RunConstantSubgraphAnalysis::getConstantSubgraph(DataFlowSolver &solver, +void RunConstantSubgraphAnalyser::getConstantSubgraph(DataFlowSolver &solver, Operation *topFunc) { OpBuilder builder(topFunc->getContext()); SmallVector constantOperations; @@ -161,19 +161,19 @@ void RunConstantSubgraphAnalysis::getConstantSubgraph(DataFlowSolver &solver, } } -RunConstantSubgraphAnalysis::RunConstantSubgraphAnalysis() { +RunConstantSubgraphAnalyser::RunConstantSubgraphAnalyser() { solver.load(); - solver.load(); + solver.load(); } -void RunConstantSubgraphAnalysis::run(Operation *topFunc) { +void RunConstantSubgraphAnalyser::run(Operation *topFunc) { if (failed(solver.initializeAndRun(topFunc))) { return; } getConstantSubgraph(solver, topFunc); } -bool RunConstantSubgraphAnalysis::getInConstantSubgraph(Value val) { +bool RunConstantSubgraphAnalyser::getInConstantSubgraph(Value val) { auto *lattice = solver.lookupState>(val); const InConstantSubgraph &latticeValue = lattice->getValue(); return latticeValue.getInConstantSubgraph(); diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 86a58b407..205538e63 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -7,8 +7,8 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp TileNamed.cpp - CSA.cpp - CST.cpp + ConstantSubgraphAnalysis.cpp + ConstantTensorFolding.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/CSA.cpp b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp similarity index 67% rename from lib/gc/Transforms/CSA.cpp rename to lib/gc/Transforms/ConstantSubgraphAnalysis.cpp index 5175be2f5..b78ecd956 100644 --- a/lib/gc/Transforms/CSA.cpp +++ b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp @@ -1,4 +1,5 @@ -//===- CSA.cpp - Constant Subgraph Analysis -----------------===// +//===- ConstantSubgraphAnalysis.cpp - Constant Subgraph Analysis +//-----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,7 +11,7 @@ // in MLIR. // //===----------------------------------------------------------------------===// -#include "gc/Analysis/DataFlow/ConstantSubgraphAnalysis.h" +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" @@ -18,7 +19,7 @@ namespace mlir { namespace gc { -#define GEN_PASS_DEF_CSA +#define GEN_PASS_DEF_CONSTANTSUBGRAPHANALYSIS #include "gc/Transforms/Passes.h.inc" } // namespace gc @@ -27,11 +28,12 @@ using namespace mlir::dataflow; namespace gc { -struct CSA : public impl::CSABase { +struct ConstantSubgraphAnalysis + : public impl::ConstantSubgraphAnalysisBase { void runOnOperation() override; }; -void CSA::runOnOperation() { +void ConstantSubgraphAnalysis::runOnOperation() { Operation *op = getOperation(); auto &func = op->getRegions().front().getBlocks().front().getOperations().front(); @@ -41,11 +43,13 @@ void CSA::runOnOperation() { // func.setAttr("onednn_graph.const_args", // builder.getI32ArrayAttr({1,2,3,4})); - RunConstantSubgraphAnalysis csa; - (void)csa.run(&func); + RunConstantSubgraphAnalyser runAnalyser; + (void)runAnalyser.run(&func); } -std::unique_ptr createCSAPass() { return std::make_unique(); } +std::unique_ptr createConstantSubgraphAnalysisPass() { + return std::make_unique(); +} } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp similarity index 98% rename from lib/gc/Transforms/CST.cpp rename to lib/gc/Transforms/ConstantTensorFolding.cpp index c60cea97e..49e69f7d8 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -1,4 +1,5 @@ -//===- CST.cpp - Constant Subgraph Transform -----------------===// +//===- ConstantTensorFolding.cpp - Constant Subgraph Transform +//-----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -34,7 +35,7 @@ namespace mlir { namespace gc { -#define GEN_PASS_DEF_CST +#define GEN_PASS_DEF_CONSTANTTENSORFOLDING #include "gc/Transforms/Passes.h.inc" } // namespace gc @@ -42,7 +43,8 @@ using namespace mlir; namespace gc { -struct CST : public impl::CSTBase { +struct ConstantTensorFolding + : public impl::ConstantTensorFoldingBase { void runOnOperation() override; }; @@ -385,7 +387,7 @@ static void addGlobalI32Array(ModuleOp &module, Location loc, // Operate on tensors. Create fold() and compute() on module. The // folded weights and first-run flag is maintained by upper-level runtime. -void CST::runOnOperation() { +void ConstantTensorFolding::runOnOperation() { Operation *topOp = getOperation(); MLIRContext *context = topOp->getContext(); // A ModuleOp contains a single region, which contains a single block. @@ -679,7 +681,9 @@ void CST::runOnOperation() { } } -std::unique_ptr createCSTPass() { return std::make_unique(); } +std::unique_ptr createConstantTensorFoldingPass() { + return std::make_unique(); +} } // namespace gc } // namespace mlir diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index 6b8def4be..ac7ed4ead 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -17,8 +17,7 @@ set(gc_opt_libs ${conversion_libs} ${MLIR_LINK_COMPONENTS} GCPasses - GCAnalysis - GCCpuRuntime) + GCAnalysis) if(GC_MLIR_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GC_MLIR_CXX_FLAGS}") diff --git a/test/gc/Transforms/test_constant_weights_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir similarity index 97% rename from test/gc/Transforms/test_constant_weights_folding-1.mlir rename to test/gc/Transforms/test_constant_tensor_folding-1.mlir index 940255f60..d54b56bad 100644 --- a/test/gc/Transforms/test_constant_weights_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(csa,cst)" %s | FileCheck %s +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s // CHECK-LABEL: func.func @entry module { diff --git a/test/gc/Transforms/test_constant_weights_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir similarity index 98% rename from test/gc/Transforms/test_constant_weights_folding.mlir rename to test/gc/Transforms/test_constant_tensor_folding.mlir index 485c11e4f..1256c52cf 100644 --- a/test/gc/Transforms/test_constant_weights_folding.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(csa,cst)" %s | FileCheck %s +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s // CHECK-LABEL: func.func @entry #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> From d7663a51a9f435e2203d9ec15bc9fa6b316dde54 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 4 Jun 2024 09:48:58 +0800 Subject: [PATCH 40/64] Split into short functions --- lib/gc/Transforms/ConstantTensorFolding.cpp | 164 ++++++++++++-------- 1 file changed, 99 insertions(+), 65 deletions(-) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 49e69f7d8..59a2c75f5 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -312,15 +312,15 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } // Manager -struct constGraphTensorCacheManager { +struct ConstGraphTensorCacheManager { // dnnl_graph_compiler_context *ctx; uint64_t cachedTensorGlobalId = 0; // singleton - static std::shared_ptr get() { - static std::shared_ptr c = - std::make_shared(); + static std::shared_ptr get() { + static std::shared_ptr c = + std::make_shared(); return c; } @@ -385,18 +385,7 @@ static void addGlobalI32Array(ModuleOp &module, Location loc, /*alignment=*/0); } -// Operate on tensors. Create fold() and compute() on module. The -// folded weights and first-run flag is maintained by upper-level runtime. -void ConstantTensorFolding::runOnOperation() { - Operation *topOp = getOperation(); - MLIRContext *context = topOp->getContext(); - // A ModuleOp contains a single region, which contains a single block. - auto moduleOp = dyn_cast(topOp); - SymbolTable symbolTable(moduleOp); - auto &topFunc = - topOp->getRegions().front().getBlocks().front().getOperations().front(); - OpBuilder builder(context); - +std::unordered_set getConstArgsIndexes(Operation &topFunc) { auto topFuncAttr = topFunc.getAttrDictionary(); std::optional constArgs = topFuncAttr.getNamed("onednn_graph.const_args"); @@ -406,32 +395,16 @@ void ConstantTensorFolding::runOnOperation() { for (auto id : constArgsArray) { constArgsIndexes.insert(llvm::cast(id).getInt()); } - } else { - return; - } - if (constArgsIndexes.empty()) { - return; - } - - Region ®ion = topFunc.getRegions().front(); - Block &block = region.getBlocks().front(); - - postponeBroadcast(block); - - SmallVector constOps; - for (Operation &op : llvm::make_early_inc_range(block)) { - if (isInConstantSubgraph(&op)) { - constOps.push_back(&op); - } } + return constArgsIndexes; +} - std::string funcName("fold"); - SmallVector inputTypes; // types of constant weights - // values of constant weights in original block - SmallVector inputValues; - SmallVector outputTypes; // types of folded constant weights - // values of folded constant weights in original block - SmallVector outputValues; +void getInputsAndOutputs(Block &block, + std::unordered_set &constArgsIndexes, + SmallVector &inputTypes, + SmallVector &inputValues, + SmallVector &outputTypes, + SmallVector &outputValues) { Value v; // Support complicated topology. for (size_t id = 0; id < block.getNumArguments(); ++id) { @@ -512,11 +485,19 @@ void ConstantTensorFolding::runOnOperation() { } } } +} +func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, + Operation *topOp, SmallVector constOps, + SmallVector &inputTypes, + SmallVector &inputValues, + SmallVector &outputTypes, + SmallVector &outputValues) { + std::string funcName("fold"); FunctionType foldFuncType = FunctionType::get(context, inputTypes, outputTypes); func::FuncOp foldFunc = - builder.create(topFunc.getLoc(), funcName, foldFuncType); + builder.create(topOp->getLoc(), funcName, foldFuncType); Block *foldBlock = foldFunc.addEntryBlock(); // values of folded constant weights in foldBlock SmallVector outputValuesInFold; @@ -535,17 +516,6 @@ void ConstantTensorFolding::runOnOperation() { }); } - auto returnOp = - builder.create(topOp->getLoc(), outputValuesInFold); - foldBlock->getOperations().push_back(returnOp); - for (size_t i = 0; i < inputValues.size(); ++i) { - inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i), - [&](OpOperand &val) { - Operation *op = val.getOwner(); - return op->getBlock() == foldBlock; - }); - } - // Allocate buffer for outputValuesInFold std::vector buffersSize; for (Value &tensor : outputValuesInFold) { @@ -553,21 +523,43 @@ void ConstantTensorFolding::runOnOperation() { buffersSize.push_back( getTensorSize(dyn_cast(tensor.getType()))); } - auto manager = constGraphTensorCacheManager::get(); + auto manager = ConstGraphTensorCacheManager::get(); SmallVector globalIndexes; for (auto id : manager->alloc(buffersSize)) { globalIndexes.push_back(id); } globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); + auto moduleOp = dyn_cast(topOp); addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids", globalIndexes); + auto returnOp = + builder.create(topOp->getLoc(), outputValuesInFold); + foldBlock->getOperations().push_back(returnOp); + for (size_t i = 0; i < inputValues.size(); ++i) { + inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i), + [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == foldBlock; + }); + } + foldFunc.setVisibility(SymbolTable::Visibility::Public); foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), UnitAttr::get(context)); + moduleOp.push_back(foldFunc); + SymbolTable symbolTable(moduleOp); symbolTable.insert(foldFunc); + return foldFunc; +} + +void modifyComputeFunc(MLIRContext *context, OpBuilder &builder, + Operation *topOp, Operation &func, Block &block, + std::unordered_set &constArgsIndexes, + SmallVector &outputTypes, + SmallVector &outputValues) { // the indexes of args to the folding func. SmallVector foldArgs; // the indexes of folded args. @@ -631,6 +623,13 @@ void ConstantTensorFolding::runOnOperation() { } block.eraseArguments(argsToErase); + // modify the compute func signature + func::FuncOp computeFunc = cast(func); + FunctionType computeFuncType = computeFunc.getFunctionType(); + computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(), + computeFuncType.getResults())); + + auto moduleOp = dyn_cast(topOp); for (auto id : foldIds) { foldArgs.insert(foldArgs.end(), id); } @@ -647,13 +646,9 @@ void ConstantTensorFolding::runOnOperation() { addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args", oriNumArgs); +} - // modify the compute func signature - func::FuncOp computeFunc = cast(topFunc); - FunctionType computeFuncType = computeFunc.getFunctionType(); - computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(), - computeFuncType.getResults())); - +void canonicalizeAndClean(MLIRContext *context, Operation *topOp) { // Delete dead operations by dialects' canonicalizer RewritePatternSet owningPatterns(context); for (auto *dialect : context->getLoadedDialects()) @@ -669,16 +664,55 @@ void ConstantTensorFolding::runOnOperation() { (void)converged; // clean up the constant-related attrs on ops - for (auto &op : block.getOperations()) { - if (op.getAttr("onednn_graph.in_const_subgraph")) { - op.removeAttr("onednn_graph.in_const_subgraph"); + topOp->walk([&](Operation *op) { + if (op->getAttr("onednn_graph.in_const_subgraph")) { + op->removeAttr("onednn_graph.in_const_subgraph"); } + }); +} + +// Operate on tensors. Create fold() and compute() on module. The +// folded weights and first-run flag is maintained by upper-level runtime. +void ConstantTensorFolding::runOnOperation() { + Operation *topOp = getOperation(); + MLIRContext *context = topOp->getContext(); + auto &topFunc = + topOp->getRegions().front().getBlocks().front().getOperations().front(); + OpBuilder builder(context); + Region ®ion = topFunc.getRegions().front(); + Block &block = region.getBlocks().front(); + + std::unordered_set constArgsIndexes = getConstArgsIndexes(topFunc); + if (constArgsIndexes.empty()) { + return; } - for (auto &op : foldBlock->getOperations()) { - if (op.getAttr("onednn_graph.in_const_subgraph")) { - op.removeAttr("onednn_graph.in_const_subgraph"); + + postponeBroadcast(block); + + SmallVector constOps; + for (Operation &op : llvm::make_early_inc_range(block)) { + if (isInConstantSubgraph(&op)) { + constOps.push_back(&op); } } + + SmallVector inputTypes; // types of constant weights + // values of constant weights in original block + SmallVector inputValues; + SmallVector outputTypes; // types of folded constant weights + // values of folded constant weights in original block + SmallVector outputValues; + getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues, + outputTypes, outputValues); + + func::FuncOp foldFunc = + buildFoldFunc(context, builder, topOp, constOps, inputTypes, inputValues, + outputTypes, outputValues); + + modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, + outputTypes, outputValues); + + canonicalizeAndClean(context, topOp); } std::unique_ptr createConstantTensorFoldingPass() { From 3f34e971f8f72e1dec4d2d48d428965e82ddbdd2 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Wed, 5 Jun 2024 11:13:31 +0800 Subject: [PATCH 41/64] Add a test --- .../test_constant_tensor_folding-1.mlir | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index d54b56bad..8324c9aae 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @entry module { - func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } { + func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } { %c0 = arith.constant 0 : index cpuruntime.printf "HI%zu\n" %c0 : index %ax2 = tensor.empty() : tensor<128xf32> @@ -11,39 +11,49 @@ module { %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%bx2 : tensor<128xf32>) -> tensor<128xf32> %ax2pbx2 = tensor.empty() : tensor<128xf32> %4 = linalg.add ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2 : tensor<128xf32>) -> tensor<128xf32> + %ax2mbx2 = tensor.empty() : tensor<128xf32> + %5 = linalg.mul ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2mbx2 : tensor<128xf32>) -> tensor<128xf32> %ax2pbx2pc = tensor.empty() : tensor<128xf32> - %d = linalg.add ins(%4, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2pc : tensor<128xf32>) -> tensor<128xf32> - return %d : tensor<128xf32> + %6 = linalg.add ins(%4, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2pc : tensor<128xf32>) -> tensor<128xf32> + %ax2mbx2mc = tensor.empty() : tensor<128xf32> + %7 = linalg.mul ins(%5, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2mbx2mc : tensor<128xf32>) -> tensor<128xf32> + return %6, %7 : tensor<128xf32>, tensor<128xf32> } } // CHECK: cpuruntime.printf // CHECK: linalg.add +// CHECK: linalg.mul // CHECK: func.func @fold // CHECK: linalg.add // CHECK: linalg.add // CHECK: linalg.add +// CHECK: linalg.mul // COM: expected output: // COM: module { // COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32 -// COM: llvm.mlir.global external constant @__compute_args(dense<[2, 2, 3]> : tensor<3xi32>) {addr_space = 0 : i32} : !llvm.array<3 x i32> -// COM: llvm.mlir.global external constant @__fold_args(dense<[3, 0, 1, 3]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> -// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[1, 0]> : tensor<2xi64>) {addr_space = 0 : i32} : !llvm.array<2 x i64> -// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} { +// COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> +// COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32> +// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64> +// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} { // COM: %c0 = arith.constant 0 : index // COM: cpuruntime.printf "HI%zu\0A" %c0 : index // COM: %0 = tensor.empty() : tensor<128xf32> -// COM: %1 = linalg.add ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> -// COM: return %1 : tensor<128xf32> +// COM: %1 = linalg.add ins(%arg2, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> +// COM: %2 = tensor.empty() : tensor<128xf32> +// COM: %3 = linalg.mul ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> +// COM: return %1, %3 : tensor<128xf32>, tensor<128xf32> // COM: } -// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface} { +// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface} { // COM: %0 = tensor.empty() : tensor<128xf32> // COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> // COM: %2 = tensor.empty() : tensor<128xf32> // COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> // COM: %4 = tensor.empty() : tensor<128xf32> // COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32> -// COM: return %5 : tensor<128xf32> +// COM: %6 = tensor.empty() : tensor<128xf32> +// COM: %7 = linalg.mul ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%6 : tensor<128xf32>) -> tensor<128xf32> +// COM: return %7, %5 : tensor<128xf32>, tensor<128xf32> // COM: } // COM: } \ No newline at end of file From 22c3d76a69f5745cb06a12e0bb9bcb50dea25e4e Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 11 Jun 2024 10:22:58 +0800 Subject: [PATCH 42/64] Adapt to constant PropertyType --- .../DataFlow/ConstantSubgraphAnalyser.cpp | 47 ++++++++++++------- lib/gc/Transforms/ConstantTensorFolding.cpp | 27 +++++++---- .../test_constant_tensor_folding-1.mlir | 4 +- .../test_constant_tensor_folding.mlir | 6 +-- 4 files changed, 55 insertions(+), 29 deletions(-) diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index 741af4697..e01190bc3 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -6,6 +6,11 @@ // //===----------------------------------------------------------------------===// #include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" +#include "gc/Dialect/OneDNNGraph/Utils/Utils.h" + #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -96,22 +101,32 @@ void ConstantSubgraphAnalyser::setToEntryState( Lattice *lattice) { if (auto blockArg = cast(lattice->getPoint())) { auto parent_op = blockArg.getParentBlock()->getParentOp(); - auto parent_op_attr = parent_op->getAttrDictionary(); - std::optional const_args = - parent_op_attr.getNamed("onednn_graph.const_args"); - if (const_args.has_value()) { - ArrayAttr const_args_indexes = - llvm::dyn_cast(const_args->getValue()); - for (auto id : const_args_indexes) { - auto idint = llvm::cast(id).getInt(); - if (blockArg.getArgNumber() == idint) { - LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg - << " is marked as constant\n"); - propagateIfChanged(lattice, - lattice->join(InConstantSubgraph(true, true))); - return; - } - } + // auto parent_op_attr = parent_op->getAttrDictionary(); + // std::optional const_args = + // parent_op_attr.getNamed("onednn_graph.const_args"); + // if (const_args.has_value()) { + // ArrayAttr const_args_indexes = + // llvm::dyn_cast(const_args->getValue()); + // for (auto id : const_args_indexes) { + // auto idint = llvm::cast(id).getInt(); + // if (blockArg.getArgNumber() == idint) { + // LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg + // << " is marked as constant\n"); + // propagateIfChanged(lattice, + // lattice->join(InConstantSubgraph(true, true))); + // return; + // } + // } + // } + auto funcOp = cast(parent_op); + mlir::onednn_graph::LogicalTensorInfo info(funcOp); + if (info.queryPropertyType(blockArg) == + mlir::onednn_graph::PropertyType::constant) { + LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg + << " is marked as constant\n"); + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph(true, true))); + return; } propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false))); } else { diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 59a2c75f5..f0ed58449 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -14,8 +14,11 @@ #include #include -#include "mlir/Transforms/Passes.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" +#include "gc/Dialect/OneDNNGraph/Utils/Utils.h" +#include "mlir/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -386,14 +389,22 @@ static void addGlobalI32Array(ModuleOp &module, Location loc, } std::unordered_set getConstArgsIndexes(Operation &topFunc) { - auto topFuncAttr = topFunc.getAttrDictionary(); - std::optional constArgs = - topFuncAttr.getNamed("onednn_graph.const_args"); std::unordered_set constArgsIndexes; - if (constArgs.has_value()) { - ArrayAttr constArgsArray = llvm::dyn_cast(constArgs->getValue()); - for (auto id : constArgsArray) { - constArgsIndexes.insert(llvm::cast(id).getInt()); + // auto topFuncAttr = topFunc.getAttrDictionary(); + // std::optional constArgs = + // topFuncAttr.getNamed("onednn_graph.const_args"); + // if (constArgs.has_value()) { + // ArrayAttr constArgsArray = llvm::dyn_cast(constArgs->getValue()); + // for (auto id : constArgsArray) { + // constArgsIndexes.insert(llvm::cast(id).getInt()); + // } + // } + auto funcOp = cast(topFunc); + mlir::onednn_graph::LogicalTensorInfo info(funcOp); + for (int i = 0; i < funcOp.getArguments().size(); ++i) { + if (info.queryPropertyType(funcOp.getArguments()[i]) == + mlir::onednn_graph::PropertyType::constant) { + constArgsIndexes.insert(i); } } return constArgsIndexes; diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index 8324c9aae..ec84937dc 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @entry module { - func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } { + func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32], onednn_graph.property_types = [#onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type] } { %c0 = arith.constant 0 : index cpuruntime.printf "HI%zu\n" %c0 : index %ax2 = tensor.empty() : tensor<128xf32> @@ -36,7 +36,7 @@ module { // COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64> -// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} { +// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32], onednn_graph.property_types = [#onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type]} { // COM: %c0 = arith.constant 0 : index // COM: cpuruntime.printf "HI%zu\0A" %c0 : index // COM: %0 = tensor.empty() : tensor<128xf32> diff --git a/test/gc/Transforms/test_constant_tensor_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir index 1256c52cf..2c82b3e67 100644 --- a/test/gc/Transforms/test_constant_tensor_folding.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding.mlir @@ -9,7 +9,7 @@ module { // COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear. // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32], onednn_graph.property_types = [#onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type]} { %1 = tensor.empty() : tensor<2x16x32x32xbf16> %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> %2 = tensor.empty() : tensor<8x16x32x32xbf16> @@ -78,5 +78,5 @@ module { // COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> -// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} -// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> +// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) From 9218762cf8f1a6ea4c8981e0ad6504348a4693d7 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 23 Jul 2024 20:22:26 -0700 Subject: [PATCH 43/64] Revert "Adapt to constant PropertyType" This reverts commit 22c3d76a69f5745cb06a12e0bb9bcb50dea25e4e. --- .../DataFlow/ConstantSubgraphAnalyser.cpp | 47 +++++++------------ lib/gc/Transforms/ConstantTensorFolding.cpp | 27 ++++------- .../test_constant_tensor_folding-1.mlir | 4 +- .../test_constant_tensor_folding.mlir | 6 +-- 4 files changed, 29 insertions(+), 55 deletions(-) diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index e01190bc3..741af4697 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -6,11 +6,6 @@ // //===----------------------------------------------------------------------===// #include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" - -#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" -#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" -#include "gc/Dialect/OneDNNGraph/Utils/Utils.h" - #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -101,32 +96,22 @@ void ConstantSubgraphAnalyser::setToEntryState( Lattice *lattice) { if (auto blockArg = cast(lattice->getPoint())) { auto parent_op = blockArg.getParentBlock()->getParentOp(); - // auto parent_op_attr = parent_op->getAttrDictionary(); - // std::optional const_args = - // parent_op_attr.getNamed("onednn_graph.const_args"); - // if (const_args.has_value()) { - // ArrayAttr const_args_indexes = - // llvm::dyn_cast(const_args->getValue()); - // for (auto id : const_args_indexes) { - // auto idint = llvm::cast(id).getInt(); - // if (blockArg.getArgNumber() == idint) { - // LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg - // << " is marked as constant\n"); - // propagateIfChanged(lattice, - // lattice->join(InConstantSubgraph(true, true))); - // return; - // } - // } - // } - auto funcOp = cast(parent_op); - mlir::onednn_graph::LogicalTensorInfo info(funcOp); - if (info.queryPropertyType(blockArg) == - mlir::onednn_graph::PropertyType::constant) { - LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg - << " is marked as constant\n"); - propagateIfChanged(lattice, - lattice->join(InConstantSubgraph(true, true))); - return; + auto parent_op_attr = parent_op->getAttrDictionary(); + std::optional const_args = + parent_op_attr.getNamed("onednn_graph.const_args"); + if (const_args.has_value()) { + ArrayAttr const_args_indexes = + llvm::dyn_cast(const_args->getValue()); + for (auto id : const_args_indexes) { + auto idint = llvm::cast(id).getInt(); + if (blockArg.getArgNumber() == idint) { + LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg + << " is marked as constant\n"); + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph(true, true))); + return; + } + } } propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false))); } else { diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index f0ed58449..59a2c75f5 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -14,11 +14,8 @@ #include #include -#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" -#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" -#include "gc/Dialect/OneDNNGraph/Utils/Utils.h" - #include "mlir/Transforms/Passes.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -389,22 +386,14 @@ static void addGlobalI32Array(ModuleOp &module, Location loc, } std::unordered_set getConstArgsIndexes(Operation &topFunc) { + auto topFuncAttr = topFunc.getAttrDictionary(); + std::optional constArgs = + topFuncAttr.getNamed("onednn_graph.const_args"); std::unordered_set constArgsIndexes; - // auto topFuncAttr = topFunc.getAttrDictionary(); - // std::optional constArgs = - // topFuncAttr.getNamed("onednn_graph.const_args"); - // if (constArgs.has_value()) { - // ArrayAttr constArgsArray = llvm::dyn_cast(constArgs->getValue()); - // for (auto id : constArgsArray) { - // constArgsIndexes.insert(llvm::cast(id).getInt()); - // } - // } - auto funcOp = cast(topFunc); - mlir::onednn_graph::LogicalTensorInfo info(funcOp); - for (int i = 0; i < funcOp.getArguments().size(); ++i) { - if (info.queryPropertyType(funcOp.getArguments()[i]) == - mlir::onednn_graph::PropertyType::constant) { - constArgsIndexes.insert(i); + if (constArgs.has_value()) { + ArrayAttr constArgsArray = llvm::dyn_cast(constArgs->getValue()); + for (auto id : constArgsArray) { + constArgsIndexes.insert(llvm::cast(id).getInt()); } } return constArgsIndexes; diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index ec84937dc..8324c9aae 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @entry module { - func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32], onednn_graph.property_types = [#onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type] } { + func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } { %c0 = arith.constant 0 : index cpuruntime.printf "HI%zu\n" %c0 : index %ax2 = tensor.empty() : tensor<128xf32> @@ -36,7 +36,7 @@ module { // COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64> -// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32], onednn_graph.property_types = [#onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type]} { +// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} { // COM: %c0 = arith.constant 0 : index // COM: cpuruntime.printf "HI%zu\0A" %c0 : index // COM: %0 = tensor.empty() : tensor<128xf32> diff --git a/test/gc/Transforms/test_constant_tensor_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir index 2c82b3e67..1256c52cf 100644 --- a/test/gc/Transforms/test_constant_tensor_folding.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding.mlir @@ -9,7 +9,7 @@ module { // COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear. // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32], onednn_graph.property_types = [#onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type, #onednn_graph.property_type]} { + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { %1 = tensor.empty() : tensor<2x16x32x32xbf16> %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> %2 = tensor.empty() : tensor<8x16x32x32xbf16> @@ -78,5 +78,5 @@ module { // COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> -// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> -// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} +// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} From 4e447dd52e67899deba8d17321667ea62beda70d Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 23 Jul 2024 22:17:07 -0700 Subject: [PATCH 44/64] Fix link --- lib/gc/ExecutionEngine/Driver/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt index d04dbbb4e..688607b56 100644 --- a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -37,6 +37,7 @@ add_mlir_library(GCJitWrapper ${dialect_libs} ${conversion_libs} GCPasses + GCAnalysis GCGPUPasses ) From d4d81a62b030e205f04274cc51fec57d54fa15bf Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Wed, 24 Jul 2024 23:08:43 -0700 Subject: [PATCH 45/64] Fold arith.constant --- .../DataFlow/ConstantSubgraphAnalyser.cpp | 45 ++++++----- .../Transforms/ConstantSubgraphAnalysis.cpp | 2 +- lib/gc/Transforms/ConstantTensorFolding.cpp | 76 ++++++++++++++++--- .../test_constant_tensor_folding-1.mlir | 4 +- .../test_constant_tensor_folding-2.mlir | 61 +++++++++++++++ .../test_constant_tensor_folding.mlir | 4 +- 6 files changed, 160 insertions(+), 32 deletions(-) create mode 100644 test/gc/Transforms/test_constant_tensor_folding-2.mlir diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index 741af4697..584c7e8ce 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -5,6 +5,9 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include +#include + #include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" @@ -25,7 +28,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include #define DEBUG_TYPE "in-constant-subgraph" @@ -95,24 +97,33 @@ void ConstantSubgraphAnalyser::visitOperation( void ConstantSubgraphAnalyser::setToEntryState( Lattice *lattice) { if (auto blockArg = cast(lattice->getPoint())) { - auto parent_op = blockArg.getParentBlock()->getParentOp(); - auto parent_op_attr = parent_op->getAttrDictionary(); - std::optional const_args = - parent_op_attr.getNamed("onednn_graph.const_args"); - if (const_args.has_value()) { - ArrayAttr const_args_indexes = - llvm::dyn_cast(const_args->getValue()); - for (auto id : const_args_indexes) { - auto idint = llvm::cast(id).getInt(); - if (blockArg.getArgNumber() == idint) { - LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg - << " is marked as constant\n"); - propagateIfChanged(lattice, - lattice->join(InConstantSubgraph(true, true))); - return; - } + auto parentOp = blockArg.getParentBlock()->getParentOp(); + auto parentOpAttr = parentOp->getAttrDictionary(); + + std::unordered_set constArgsIndexes; + std::optional compiletimeConstArgs = + parentOpAttr.getNamed("compiletime_const_args_index"); + if (compiletimeConstArgs.has_value()) { + for (auto id : + llvm::dyn_cast(compiletimeConstArgs->getValue())) { + constArgsIndexes.insert(llvm::cast(id).getInt()); + } + } + std::optional runtimeConstArgs = + parentOpAttr.getNamed("runtime_const_args_index"); + if (runtimeConstArgs.has_value()) { + for (auto id : llvm::dyn_cast(runtimeConstArgs->getValue())) { + constArgsIndexes.insert(llvm::cast(id).getInt()); } } + + if (constArgsIndexes.count(blockArg.getArgNumber())) { + LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg + << " is marked as constant\n"); + propagateIfChanged(lattice, + lattice->join(InConstantSubgraph(true, true))); + return; + } propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false))); } else { propagateIfChanged(lattice, diff --git a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp index b78ecd956..d4f183326 100644 --- a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp +++ b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp @@ -40,7 +40,7 @@ void ConstantSubgraphAnalysis::runOnOperation() { // Hard-code: set the #1 argument to be constant. // OpBuilder builder(op->getContext()); - // func.setAttr("onednn_graph.const_args", + // func.setAttr("runtime_const_args_index", // builder.getI32ArrayAttr({1,2,3,4})); RunConstantSubgraphAnalyser runAnalyser; diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 59a2c75f5..2df13adcc 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -355,6 +355,7 @@ static void addGlobalI32(ModuleOp &module, Location loc, OpBuilder &builder, loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, builder.getI32IntegerAttr(value), /*alignment=*/0); + (void)global; } static void addGlobalI64Array(ModuleOp &module, Location loc, @@ -369,6 +370,7 @@ static void addGlobalI64Array(ModuleOp &module, Location loc, loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, builder.getI64TensorAttr(array), /*alignment=*/0); + (void)global; } static void addGlobalI32Array(ModuleOp &module, Location loc, @@ -383,22 +385,74 @@ static void addGlobalI32Array(ModuleOp &module, Location loc, loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, builder.getI32TensorAttr(array), /*alignment=*/0); + (void)global; } std::unordered_set getConstArgsIndexes(Operation &topFunc) { auto topFuncAttr = topFunc.getAttrDictionary(); - std::optional constArgs = - topFuncAttr.getNamed("onednn_graph.const_args"); std::unordered_set constArgsIndexes; - if (constArgs.has_value()) { - ArrayAttr constArgsArray = llvm::dyn_cast(constArgs->getValue()); - for (auto id : constArgsArray) { + std::optional compiletimeConstArgs = + topFuncAttr.getNamed("compiletime_const_args_index"); + if (compiletimeConstArgs.has_value()) { + for (auto id : + llvm::dyn_cast(compiletimeConstArgs->getValue())) { + constArgsIndexes.insert(llvm::cast(id).getInt()); + } + } + std::optional runtimeConstArgs = + topFuncAttr.getNamed("runtime_const_args_index"); + if (runtimeConstArgs.has_value()) { + for (auto id : llvm::dyn_cast(runtimeConstArgs->getValue())) { constArgsIndexes.insert(llvm::cast(id).getInt()); } } return constArgsIndexes; } +void getArithConstantOutputs(Block &block, SmallVector &outputTypes, + SmallVector &outputValues) { + for (Operation &op : block.getOperations()) { + if (isa(&op)) { + Operation *constOp = &op; + auto constTensor = constOp->getResults().front(); + if (!isa(constTensor.getType())) { + continue; + } + auto v = dyn_cast(constTensor); + SmallVector valuesOnTheWay = {v}; // the constant tensors + std::deque dq; + dq.push_back(v); + // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2 + while (!dq.empty()) { + v = dq.front(); + dq.pop_front(); + // if the children ops of v are not all constant, we end at v + if (std::any_of(v.getUsers().begin(), v.getUsers().end(), + [](Operation *child) { + return !isInConstantSubgraph(child); + })) { + if (std::find(outputValues.begin(), outputValues.end(), v) == + outputValues.end()) { + outputTypes.push_back(v.getType()); + outputValues.push_back(v); + } + continue; + } + + // the children ops of v are all constant, we push their results to + // queue + for (Operation *child : v.getUsers()) { + for (OpResult result : child->getResults()) { + auto r = dyn_cast(result); + dq.push_back(r); + valuesOnTheWay.push_back(r); + } + } + } + } + } +} + void getInputsAndOutputs(Block &block, std::unordered_set &constArgsIndexes, SmallVector &inputTypes, @@ -499,7 +553,7 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, func::FuncOp foldFunc = builder.create(topOp->getLoc(), funcName, foldFuncType); Block *foldBlock = foldFunc.addEntryBlock(); - // values of folded constant weights in foldBlock + // values of folded constant tensors in foldBlock SmallVector outputValuesInFold; IRMapping mapper; for (Operation *op : constOps) { @@ -696,18 +750,20 @@ void ConstantTensorFolding::runOnOperation() { } } - SmallVector inputTypes; // types of constant weights - // values of constant weights in original block + SmallVector inputTypes; // types of constant tensors + // values of constant tensors in original block SmallVector inputValues; - SmallVector outputTypes; // types of folded constant weights - // values of folded constant weights in original block + SmallVector outputTypes; // types of folded constant tensors + // values of folded constant tensors in original block SmallVector outputValues; + getArithConstantOutputs(block, outputTypes, outputValues); getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues, outputTypes, outputValues); func::FuncOp foldFunc = buildFoldFunc(context, builder, topOp, constOps, inputTypes, inputValues, outputTypes, outputValues); + (void)foldFunc; modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, outputTypes, outputValues); diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index 8324c9aae..fa4fcb210 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @entry module { - func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } { + func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, runtime_const_args_index = [0 : i32, 1 : i32] } { %c0 = arith.constant 0 : index cpuruntime.printf "HI%zu\n" %c0 : index %ax2 = tensor.empty() : tensor<128xf32> @@ -36,7 +36,7 @@ module { // COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64> -// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} { +// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, runtime_const_args_index = [0 : i32, 1 : i32]} { // COM: %c0 = arith.constant 0 : index // COM: cpuruntime.printf "HI%zu\0A" %c0 : index // COM: %0 = tensor.empty() : tensor<128xf32> diff --git a/test/gc/Transforms/test_constant_tensor_folding-2.mlir b/test/gc/Transforms/test_constant_tensor_folding-2.mlir new file mode 100644 index 000000000..8d9e4ed53 --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding-2.mlir @@ -0,0 +1,61 @@ +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s + +// CHECK-LABEL: func.func @entry +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +module { + // COM: A three-layer mlp. %arg0: input feature. %arg1, %arg2, %arg3: weight of #1, #2 and #3 linear. + func.func @entry(%arg0: tensor<64x32xbf16>, %arg2: tensor<32x256xbf16>, %arg3: tensor<256x1024xbf16>) + -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, compiletime_const_args_index = [1 : i32], runtime_const_args_index = [2 : i32]} { + %1 = tensor.empty() : tensor<2x1x32x32xbf16> + %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x32xbf16> -> tensor<2x1x32x32xbf16> + + %arg1 = arith.constant dense<"0x99571CBE05BA1C3D926AFCBD782B34BE67A737BEBF181ABE3C4E253B5D32F73D1566963D9B6F313EB74C0FBD430B253AE8E2E23DB0CBC53C46C014BE2E0981BD4D313F3E833D37BEAB70E13D65B6CA3DB194983D1E60983D950B71BD1815FDBB32DF9A3DD106EBBDB4A8233E841EC3BDE8C7C13D3734073EFF067DBD070206BEF6AF633DB209843C2135C3BD4F85B83C43BD1CBE04841A3E3E78BD3DE9D0D0BCF660093ED074083E14D43E3ECDA735BE8C8C0E3E40C60FBE4F73C9BDB4358DBD263D323C64E61EBEE535D23D238F013C727EA73DBDBAA1BD79D53EBE392981BDC06B453D10E37D3D2D2B41BEE1FA6BBD410E513D05588BBD514AB0BB0624243E3D80993C8E6A113EE57CFD3D23FE37BE001573BD86AD143E7F052D3E97C07DBD19B4113D3E87F6BDB971E83DFEA12BBC5D51F9BD4F203A3ED454043E22775BBD2EE8313EB027D03D8FEFD7BD0E56B7BDBF963FBE5B64E93D9291FBBD027101BE573DFD3D0CD6EB3D809B863DA9E8263E9EF2A43D717AB73D3CF597BD9FB7243DC603003D61780E3E3992293D8B1B25BE6B0024BE806DCB3D5BAB91BD9A33AFBDD5BC3BBE6D920FBE0D90F53D4513383E2219A0BBE8B6FBBD341C42BD42F235BED91A1ABDC3AEB0BD5AC1383DE0EADC3D303D11BE850D263E8281163E5CB78A3D19EB34BE33150F3E84F8EE3D18FC823DB26CCBBD09AB06BED909FFBA605EFE3B9014B7BD1606DA3D75ACE13D0910753C33C6843DE9951CBECD220ABD0EF2BF3D14BB2E3C798718BD60A53A3E8B83E53D18663DBE4D07CABD37CE043EA6B18E3D3D0F303EE392073EC92A1ABED6900E3E72D3E73D8CEF803D1B4D3D3E997D283E210F923BC2D131BECEAF913DB981EFBDCBCCCCBA2B6711BE4E32FE3C5D5D33BD2F34313EB7EC48BC26CDFD3D07170B3E1CD816BE310DD2BD9E03023E1EA8F3BD8B99EEBBFC97433E047F8DBDDD6BA03DA3B2433E34D7C0BC7FDB89BA1980333EF3FC8D3DC05C203E9C7213BD8385403E2F971A3E4357CF3DB39BFBBC784FF8BC7DBD0C3E8301E23D77BF1ABB04F3243CFBA3B1BD5A46C6BD1745A8BDD6950ABD939CC5BDB4226EBCAC622EBD6748FBBDAFF9D53DF29D433E41991C3D4DD7353EE2EF8E3D21EF3B3DF679973D31DEFDBDF0AF303E8D34DFBB31B895BD6A633A3EACE125BEE94E95BDA58043BEC9F233BE915F03BD1B7C8F3DE1D367BDD7BBD63D6E990A3E23222F3D4B6CD73DB869C53D8697383E3A86853D973F2C3EFC3827BC4E87FA3DD5903BBE4BB8403E34A9A33D41C8843D4BC8FABD3CD5E8BD4946233D955052BDA5F841BC6C81AFBD5DD8883DB71A753CD0A1263D88690ABE35DAA73CA3557D3D8C09D23D5A27273DECEFDBBCD220023EE036ACBD6CD2443E8F630FBEBC43B73DF03AA4BDC709133E1B94E73D362CE4BCB15F33BE3139443E5FCF62BD0E3C1B3EE99DF93D9E1BB3BA70DB213E38EBDDBC47F10CBEF817293DAD3DEB3B730942BE535C87BD448D7B3B1C8094BD97962B3D5B0F3B3EA3F42A3E4ED46DBD6D72C33C687CC63DEA34C53D1CCC3EBEDCA640BE638ABCBD4B63AFBDA699063E92861E3E98219FBC8E0B233ED3ED573DC856B8BD13880F3EFA0763BD5A8C89BD194519BE89C6CF3D73A219BC5ECBD43D41EFA33D27D8493D756B1ABEC796C93D9A25133C6A5A363E13FB8DBD601755BD3935FABD14D6883D0EF2D33DB8E914BD527347397200433DE72A3F3B62C52F3ED164EF3CD8806FBD05528B3D89701EBE0A09C23DA19B103D05922EBE7A100E3E31C0503D8ED53BBE08463E3E5168013E55F3E53D782EC53DA8BBD93C1711223E05FDB2BDA740113EA27A20BD1685A23D7E35293E02BD8B3CC43F163E4AE6613DE4280F3EEEF20BBE965C1DBEFAAD233E75754E3D96C33BBCB6D7013E0D8E7ABD703C82BDEA0875BC6F57A6BCE83609BE8A8EB53DAB7D3C3E39A50ABEB878A33D9FCEA1BC124AD33C22C34A3DB5F338BE0307BF3C2F0881BD7E15E8BDBEE8C8BDBBFFA63C342F303E15B1CCBB2590153EEA05EF3DE778F2BCE9E1233ECEC244BDBF92D5BDECDEAE3C29750CBDD969FCBD7DC236BE571D1DBEC8FA7DBC243BAD3C38673D3ED15943BEFE4D913D5329273E18AB2EBE19AB5F3D30A62F3E94303CBE1421DABCBE6E133E355D073EEC76633DEB2AB83DA2BF16BC9A46C2BD4EB47EBC4C82343EC1D1E63D13D314BED232E3BD3E5CF1BDC78F9EBD6483233E7290293E514A163E255F0FBE1AEF7BBD5259173EF12524BEDF47793C886BE8BD57B408BE351980BD0FF71ABD24643ABEA79920BED2603A3EEB75393EC6D52B3E458B29BC22C45ABC02BB40BCED4BDEBCA6E9CABC11FB213EC4FB363E5AC2DCBDAD6B4F3CBB85B1BD8093343E487518BEDFA316BD7FFFAEBB9375963DF68A88BD6876013C9FA1C63D95CDB23C911721BE04B5F9BD1B7C8F3DE1D367BDD7BBD63D6E990A3E23222F3D4B6CD73DB869C53D8697383E3A86853D973F2C3EFC3827BC4E87FA3DD5903BBE4BB8403E34A9A33D41C8843D4BC8FABD3CD5E8BD4946233D955052BDA5F841BC6C81AFBD5DD8883DB71A753CD0A1263D88690ABE35DAA73CA3557D3D8C09D23D5A27273DECEFDBBCD220023EE036ACBD6CD2443E8F630FBEBC43B73DF03AA4BDC709133E1B94E73D362CE4BCB15F33BE3139443E5FCF62BD0E3C1B3EE99DF93D9E1BB3BA70DB213E38EBDDBC47F10CBEF817293DAD3997D283E210F923BC2D131BECEAF913DB981EFBDCBCCCCBA2B6711BE4E32FE3C5D5D33BD2F34313EB7EC48BC26CDFD3D07170B3E1CD816BE310DD2BD9E03023E1EA8F3BD8B99EEBBFC97433E047F8DBDDD6BA03DA3B2433E34D7C0BC7FDB89BA1980333EF3EB7EC48B383DE0E383DE0E383DE0E383DE0"> : tensor<32x32xbf16> + %2 = tensor.empty() : tensor<1x1x32x32xbf16> + %packed_arg1 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<32x32xbf16> -> tensor<1x1x32x32xbf16> + %3 = tensor.empty() : tensor<1x1x16x32x2xbf16> + %packed_packed_arg1 = tensor.pack %packed_arg1 inner_dims_pos = [2] inner_tiles = [2] into %3 : tensor<1x1x32x32xbf16> -> tensor<1x1x16x32x2xbf16> + + %4 = tensor.empty() : tensor<2x1x32x32xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %5 = linalg.fill ins(%cst_0 : bf16) outs(%4 : tensor<2x1x32x32xbf16>) -> tensor<2x1x32x32xbf16> + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%packed_arg0, %packed_packed_arg1 : tensor<2x1x32x32xbf16>, tensor<1x1x16x32x2xbf16>) outs(%5 : tensor<2x1x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x1x32x32xbf16> + + %7 = tensor.empty() : tensor<8x1x32x32xbf16> + %packed_arg2 = tensor.pack %arg2 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %7 : tensor<32x256xbf16> -> tensor<8x1x32x32xbf16> + %8 = tensor.empty() : tensor<8x1x16x32x2xbf16> + %packed_packed_arg2 = tensor.pack %packed_arg2 inner_dims_pos = [2] inner_tiles = [2] into %8 : tensor<8x1x32x32xbf16> -> tensor<8x1x16x32x2xbf16> + %9 = tensor.empty() : tensor<2x8x32x32xbf16> + %10 = linalg.fill ins(%cst_0 : bf16) outs(%9 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %11 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%6, %packed_packed_arg2 : tensor<2x1x32x32xbf16>, tensor<8x1x16x32x2xbf16>) outs(%10 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x8x32x32xbf16> + + %12 = tensor.empty() : tensor<32x8x32x32xbf16> + %packed_arg3 = tensor.pack %arg3 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %12 : tensor<256x1024xbf16> -> tensor<32x8x32x32xbf16> + %13 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %packed_packed_arg3 = tensor.pack %packed_arg3 inner_dims_pos = [2] inner_tiles = [2] into %13 : tensor<32x8x32x32xbf16> -> tensor<32x8x16x32x2xbf16> + + %14 = tensor.empty() : tensor<2x32x32x32xbf16> + %15 = linalg.fill ins(%cst_0 : bf16) outs(%14 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %16 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%11, %packed_packed_arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%15 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %46 = arith.mulf %in, %in_0 : bf16 + %56 = arith.addf %out, %46 : bf16 + linalg.yield %56 : bf16 + } -> tensor<2x32x32x32xbf16> + + %17 = tensor.empty() : tensor<64x1024xbf16> + %unpack = tensor.unpack %16 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %17 : tensor<2x32x32x32xbf16> -> tensor<64x1024xbf16> + return %unpack : tensor<64x1024xbf16> + } +} diff --git a/test/gc/Transforms/test_constant_tensor_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir index 1256c52cf..d55f42039 100644 --- a/test/gc/Transforms/test_constant_tensor_folding.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding.mlir @@ -9,7 +9,7 @@ module { // COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear. // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { %1 = tensor.empty() : tensor<2x16x32x32xbf16> %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> %2 = tensor.empty() : tensor<8x16x32x32xbf16> @@ -78,5 +78,5 @@ module { // COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> -// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} // COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} From afec52ae8f788972cef6a9573330aa15e0a60526 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Thu, 25 Jul 2024 01:40:16 -0700 Subject: [PATCH 46/64] Add compile_time_fold and runtime_fold. --- lib/gc/Transforms/ConstantTensorFolding.cpp | 97 +++++++++++++-------- 1 file changed, 62 insertions(+), 35 deletions(-) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 2df13adcc..ce75e4409 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -11,11 +11,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Transforms/Passes.h" #include #include -#include "mlir/Transforms/Passes.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -388,21 +387,15 @@ static void addGlobalI32Array(ModuleOp &module, Location loc, (void)global; } -std::unordered_set getConstArgsIndexes(Operation &topFunc) { +std::unordered_set getConstArgsIndexes(Operation &topFunc, + bool compiletime) { auto topFuncAttr = topFunc.getAttrDictionary(); std::unordered_set constArgsIndexes; - std::optional compiletimeConstArgs = - topFuncAttr.getNamed("compiletime_const_args_index"); - if (compiletimeConstArgs.has_value()) { - for (auto id : - llvm::dyn_cast(compiletimeConstArgs->getValue())) { - constArgsIndexes.insert(llvm::cast(id).getInt()); - } - } - std::optional runtimeConstArgs = - topFuncAttr.getNamed("runtime_const_args_index"); - if (runtimeConstArgs.has_value()) { - for (auto id : llvm::dyn_cast(runtimeConstArgs->getValue())) { + std::string attrName = + compiletime ? "compiletime_const_args_index" : "runtime_const_args_index"; + std::optional constArgs = topFuncAttr.getNamed(attrName); + if (constArgs.has_value()) { + for (auto id : llvm::dyn_cast(constArgs->getValue())) { constArgsIndexes.insert(llvm::cast(id).getInt()); } } @@ -542,16 +535,16 @@ void getInputsAndOutputs(Block &block, } func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, - Operation *topOp, SmallVector constOps, + Operation *topOp, std::string name, + SmallVector constOps, SmallVector &inputTypes, SmallVector &inputValues, SmallVector &outputTypes, SmallVector &outputValues) { - std::string funcName("fold"); FunctionType foldFuncType = FunctionType::get(context, inputTypes, outputTypes); func::FuncOp foldFunc = - builder.create(topOp->getLoc(), funcName, foldFuncType); + builder.create(topOp->getLoc(), name, foldFuncType); Block *foldBlock = foldFunc.addEntryBlock(); // values of folded constant tensors in foldBlock SmallVector outputValuesInFold; @@ -584,8 +577,8 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, } globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); auto moduleOp = dyn_cast(topOp); - addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids", - globalIndexes); + addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, + "__" + name + "_buffer_ids_", globalIndexes); auto returnOp = builder.create(topOp->getLoc(), outputValuesInFold); @@ -736,8 +729,11 @@ void ConstantTensorFolding::runOnOperation() { Region ®ion = topFunc.getRegions().front(); Block &block = region.getBlocks().front(); - std::unordered_set constArgsIndexes = getConstArgsIndexes(topFunc); - if (constArgsIndexes.empty()) { + std::unordered_set compiletimeConstArgsIndexes = + getConstArgsIndexes(topFunc, true); + std::unordered_set runtimeConstArgsIndexes = + getConstArgsIndexes(topFunc, false); + if (compiletimeConstArgsIndexes.empty() && runtimeConstArgsIndexes.empty()) { return; } @@ -750,21 +746,52 @@ void ConstantTensorFolding::runOnOperation() { } } - SmallVector inputTypes; // types of constant tensors + // ===== build compile time folding function ===== + SmallVector compiletimeInputTypes; // types of constant tensors // values of constant tensors in original block - SmallVector inputValues; - SmallVector outputTypes; // types of folded constant tensors + SmallVector compiletimeInputValues; + SmallVector compiletimeOutputTypes; // types of folded constant tensors // values of folded constant tensors in original block - SmallVector outputValues; - getArithConstantOutputs(block, outputTypes, outputValues); - getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues, - outputTypes, outputValues); - - func::FuncOp foldFunc = - buildFoldFunc(context, builder, topOp, constOps, inputTypes, inputValues, - outputTypes, outputValues); - (void)foldFunc; - + SmallVector compiletimeOutputValues; + getArithConstantOutputs(block, compiletimeOutputTypes, + compiletimeOutputValues); + getInputsAndOutputs(block, compiletimeConstArgsIndexes, compiletimeInputTypes, + compiletimeInputValues, compiletimeOutputTypes, + compiletimeOutputValues); + + func::FuncOp compiletimeFoldFunc = + buildFoldFunc(context, builder, topOp, "compiletime_fold", constOps, + compiletimeInputTypes, compiletimeInputValues, + compiletimeOutputTypes, compiletimeOutputValues); + (void)compiletimeFoldFunc; + canonicalizeAndClean(context, compiletimeFoldFunc.getOperation()); + + // ===== build runtime folding function ===== + SmallVector runtimeInputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector runtimeInputValues; + SmallVector runtimeOutputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector runtimeOutputValues; + getInputsAndOutputs(block, runtimeConstArgsIndexes, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, + runtimeOutputValues); + + func::FuncOp runtimeFoldFunc = buildFoldFunc( + context, builder, topOp, "runtime_fold", constOps, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); + (void)runtimeFoldFunc; + canonicalizeAndClean(context, runtimeFoldFunc.getOperation()); + + // ===== build computing function ===== + std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; + constArgsIndexes.merge(runtimeConstArgsIndexes); + SmallVector outputTypes = compiletimeOutputTypes; + outputTypes.insert(outputTypes.end(), runtimeOutputTypes.begin(), + runtimeOutputTypes.end()); + SmallVector outputValues = compiletimeOutputValues; + outputValues.insert(outputValues.end(), runtimeOutputValues.begin(), + runtimeOutputValues.end()); modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, outputTypes, outputValues); From 9c4fd70a0d1a1ce23a233f9b0d6a4d3481821bb4 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Thu, 25 Jul 2024 19:38:14 -0700 Subject: [PATCH 47/64] Fix license and tidy --- .../gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h | 4 ++-- .../Analysis/DataFlow/ConstantSubgraphAnalyser.cpp | 13 ++++++------- lib/gc/Transforms/ConstantSubgraphAnalysis.cpp | 5 ++--- lib/gc/Transforms/ConstantTensorFolding.cpp | 13 ++++++------- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h index a5a199914..d2dc4ffa4 100644 --- a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h +++ b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h @@ -1,6 +1,6 @@ -//===- ConstantSubgraphAnalyser.h - Constant subgraph analysis ------===// +//===-- ConstantSubgraphAnalyser.h - Constant subgraph ----------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index 584c7e8ce..640b3ef59 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -1,6 +1,6 @@ -//===- ConstantSubgraphAnalyser.cpp - Constant subgraph analysis ----===// +//===-- ConstantSubgraphAnalyser.cpp - Constant subgraph -------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // @@ -44,7 +44,6 @@ void InConstantSubgraph::print(raw_ostream &os) const { return; } os << getInConstantSubgraph(); - return; } //===----------------------------------------------------------------------===// @@ -61,7 +60,7 @@ void ConstantSubgraphAnalyser::visitOperation( if (op->hasTrait()) { LLVM_DEBUG(llvm::dbgs() << "Curr op is a Constant op\n"); in = true; - } else if (operands.size() == 0) { // For example, tensor.empty() + } else if (operands.empty()) { // For example, tensor.empty() LLVM_DEBUG(llvm::dbgs() << "Curr op has 0 operand, constant\n"); in = true; } else { @@ -177,11 +176,11 @@ RunConstantSubgraphAnalyser::RunConstantSubgraphAnalyser() { solver.load(); } -void RunConstantSubgraphAnalyser::run(Operation *topFunc) { - if (failed(solver.initializeAndRun(topFunc))) { +void RunConstantSubgraphAnalyser::run(Operation *op) { + if (failed(solver.initializeAndRun(op))) { return; } - getConstantSubgraph(solver, topFunc); + getConstantSubgraph(solver, op); } bool RunConstantSubgraphAnalyser::getInConstantSubgraph(Value val) { diff --git a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp index d4f183326..ed481720b 100644 --- a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp +++ b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp @@ -1,7 +1,6 @@ -//===- ConstantSubgraphAnalysis.cpp - Constant Subgraph Analysis -//-----------------===// +//===-- ConstantSubgraphAnalysis.cpp - Constant Subgraph --------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index ce75e4409..5b9e7a5b4 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -1,7 +1,6 @@ -//===- ConstantTensorFolding.cpp - Constant Subgraph Transform -//-----------------===// +//===-- ConstantTensorFolding.cpp - Constant Folding ------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // @@ -326,8 +325,8 @@ struct ConstGraphTensorCacheManager { // alloc and set the buf_base_ and offset_ attributes of cache std::vector alloc(std::vector buffersSize) { size_t totalSize = 0; - for (size_t i = 0; i < buffersSize.size(); i++) { - totalSize += divideAndCeil(buffersSize[i], 64) * 64; + for (size_t size : buffersSize) { + totalSize += divideAndCeil(size, 64) * 64; } llvm::dbgs() << "Alloc total size: " << totalSize << '\n'; // auto base = createConstCacheProxy(totalSize); @@ -535,8 +534,8 @@ void getInputsAndOutputs(Block &block, } func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, - Operation *topOp, std::string name, - SmallVector constOps, + Operation *topOp, const std::string &name, + const SmallVector &constOps, SmallVector &inputTypes, SmallVector &inputValues, SmallVector &outputTypes, From fad5f92f94f5681b62abc4f522871625b41b0d50 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Fri, 26 Jul 2024 01:54:29 -0700 Subject: [PATCH 48/64] Fix link --- CMakeLists.txt | 2 ++ lib/gc/Analysis/CMakeLists.txt | 6 ++++++ lib/gc/CAPI/CMakeLists.txt | 1 + .../Transforms/test_constant_tensor_folding-1.mlir | 2 +- .../Transforms/test_constant_tensor_folding-2.mlir | 14 ++++++++++++++ .../Transforms/test_constant_tensor_folding.mlir | 2 +- 6 files changed, 25 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 636b33ad2..07164d7da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,8 @@ endif() set(GC_LIB_LINKED_LIBS GCJitWrapper GCCpuRuntime + GCPasses + GCAnalysis ) add_mlir_library(graph_compiler SHARED ${GC_LIB_SOURCES}) target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES}) diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index 9b5994f3d..403748041 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -1,3 +1,7 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS + MLIRIR + MLIRSupport) + add_mlir_library(GCAnalysis DataFlow/ConstantSubgraphAnalyser.cpp @@ -14,3 +18,5 @@ add_mlir_library(GCAnalysis MLIRBufferizationToMemRef MLIRBufferizationPipelines ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis) diff --git a/lib/gc/CAPI/CMakeLists.txt b/lib/gc/CAPI/CMakeLists.txt index 1d2e7687e..aca399ad7 100644 --- a/lib/gc/CAPI/CMakeLists.txt +++ b/lib/gc/CAPI/CMakeLists.txt @@ -6,5 +6,6 @@ add_mlir_public_c_api_library(GcCAPI MLIRCPURuntimeDialect GCPasses GCGPUPasses + GCAnalysis MLIRCPURuntimeTransforms ) \ No newline at end of file diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index fa4fcb210..0664edafb 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -24,7 +24,7 @@ module { // CHECK: cpuruntime.printf // CHECK: linalg.add // CHECK: linalg.mul -// CHECK: func.func @fold +// CHECK: func.func @runtime_fold // CHECK: linalg.add // CHECK: linalg.add // CHECK: linalg.add diff --git a/test/gc/Transforms/test_constant_tensor_folding-2.mlir b/test/gc/Transforms/test_constant_tensor_folding-2.mlir index 8d9e4ed53..85208815e 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-2.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-2.mlir @@ -59,3 +59,17 @@ module { return %unpack : tensor<64x1024xbf16> } } + +// COM: 1 pack in entry for input feature, +// COM: 4 packs in compiletime_fold for 2 weights, +// COM: 2 packs in runtime_fold for 1 weights + +// CHECK: tensor.pack +// CHECK: func.func @compiletime_fold +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: func.func @runtime_fold +// CHECK: tensor.pack +// CHECK: tensor.pack diff --git a/test/gc/Transforms/test_constant_tensor_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir index d55f42039..71f475c00 100644 --- a/test/gc/Transforms/test_constant_tensor_folding.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding.mlir @@ -68,7 +68,7 @@ module { } } // CHECK: linalg.broadcast -// CHECK: func.func @fold +// CHECK: func.func @runtime_fold // CHECK: arith.extf // CHECK: arith.truncf From 57f887dbee1671337ab3d367ee8573568b6fbaaa Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sun, 28 Jul 2024 19:34:20 -0700 Subject: [PATCH 49/64] Only enable runtime folding --- lib/gc/Transforms/ConstantTensorFolding.cpp | 124 +++++++++++------- .../test_constant_tensor_folding-2.mlir | 16 ++- 2 files changed, 89 insertions(+), 51 deletions(-) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 5b9e7a5b4..3f38dda77 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -745,54 +745,82 @@ void ConstantTensorFolding::runOnOperation() { } } - // ===== build compile time folding function ===== - SmallVector compiletimeInputTypes; // types of constant tensors - // values of constant tensors in original block - SmallVector compiletimeInputValues; - SmallVector compiletimeOutputTypes; // types of folded constant tensors - // values of folded constant tensors in original block - SmallVector compiletimeOutputValues; - getArithConstantOutputs(block, compiletimeOutputTypes, - compiletimeOutputValues); - getInputsAndOutputs(block, compiletimeConstArgsIndexes, compiletimeInputTypes, - compiletimeInputValues, compiletimeOutputTypes, - compiletimeOutputValues); - - func::FuncOp compiletimeFoldFunc = - buildFoldFunc(context, builder, topOp, "compiletime_fold", constOps, - compiletimeInputTypes, compiletimeInputValues, - compiletimeOutputTypes, compiletimeOutputValues); - (void)compiletimeFoldFunc; - canonicalizeAndClean(context, compiletimeFoldFunc.getOperation()); - - // ===== build runtime folding function ===== - SmallVector runtimeInputTypes; // types of constant tensors - // values of constant tensors in original block - SmallVector runtimeInputValues; - SmallVector runtimeOutputTypes; // types of folded constant tensors - // values of folded constant tensors in original block - SmallVector runtimeOutputValues; - getInputsAndOutputs(block, runtimeConstArgsIndexes, runtimeInputTypes, - runtimeInputValues, runtimeOutputTypes, - runtimeOutputValues); - - func::FuncOp runtimeFoldFunc = buildFoldFunc( - context, builder, topOp, "runtime_fold", constOps, runtimeInputTypes, - runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); - (void)runtimeFoldFunc; - canonicalizeAndClean(context, runtimeFoldFunc.getOperation()); - - // ===== build computing function ===== - std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; - constArgsIndexes.merge(runtimeConstArgsIndexes); - SmallVector outputTypes = compiletimeOutputTypes; - outputTypes.insert(outputTypes.end(), runtimeOutputTypes.begin(), - runtimeOutputTypes.end()); - SmallVector outputValues = compiletimeOutputValues; - outputValues.insert(outputValues.end(), runtimeOutputValues.begin(), - runtimeOutputValues.end()); - modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, - outputTypes, outputValues); + bool enableCompiletimeFolding = false; + if (enableCompiletimeFolding) { + // ===== build compile time folding function ===== + SmallVector compiletimeInputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector compiletimeInputValues; + SmallVector + compiletimeOutputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector compiletimeOutputValues; + getArithConstantOutputs(block, compiletimeOutputTypes, + compiletimeOutputValues); + getInputsAndOutputs(block, compiletimeConstArgsIndexes, + compiletimeInputTypes, compiletimeInputValues, + compiletimeOutputTypes, compiletimeOutputValues); + + func::FuncOp compiletimeFoldFunc = + buildFoldFunc(context, builder, topOp, "compiletime_fold", constOps, + compiletimeInputTypes, compiletimeInputValues, + compiletimeOutputTypes, compiletimeOutputValues); + (void)compiletimeFoldFunc; + canonicalizeAndClean(context, compiletimeFoldFunc.getOperation()); + + // ===== build runtime folding function ===== + SmallVector runtimeInputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector runtimeInputValues; + SmallVector runtimeOutputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector runtimeOutputValues; + getInputsAndOutputs(block, runtimeConstArgsIndexes, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, + runtimeOutputValues); + + func::FuncOp runtimeFoldFunc = buildFoldFunc( + context, builder, topOp, "runtime_fold", constOps, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); + (void)runtimeFoldFunc; + canonicalizeAndClean(context, runtimeFoldFunc.getOperation()); + + // ===== build computing function ===== + std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; + constArgsIndexes.merge(runtimeConstArgsIndexes); + SmallVector outputTypes = compiletimeOutputTypes; + outputTypes.insert(outputTypes.end(), runtimeOutputTypes.begin(), + runtimeOutputTypes.end()); + SmallVector outputValues = compiletimeOutputValues; + outputValues.insert(outputValues.end(), runtimeOutputValues.begin(), + runtimeOutputValues.end()); + modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, + outputTypes, outputValues); + } else { + std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; + constArgsIndexes.merge(runtimeConstArgsIndexes); + + // ===== build runtime folding function ===== + SmallVector inputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector inputValues; + SmallVector outputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector outputValues; + getArithConstantOutputs(block, outputTypes, outputValues); + getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues, + outputTypes, outputValues); + + func::FuncOp foldFunc = + buildFoldFunc(context, builder, topOp, "runtime_fold", constOps, + inputTypes, inputValues, outputTypes, outputValues); + (void)foldFunc; + canonicalizeAndClean(context, foldFunc.getOperation()); + + // ===== build computing function ===== + modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, + outputTypes, outputValues); + } canonicalizeAndClean(context, topOp); } diff --git a/test/gc/Transforms/test_constant_tensor_folding-2.mlir b/test/gc/Transforms/test_constant_tensor_folding-2.mlir index 85208815e..a5e123085 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-2.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-2.mlir @@ -60,16 +60,26 @@ module { } } +// COM: If enable compile time folding, // COM: 1 pack in entry for input feature, // COM: 4 packs in compiletime_fold for 2 weights, -// COM: 2 packs in runtime_fold for 1 weights +// COM: 2 packs in runtime_fold for 1 weights: +// COM: CHECK: tensor.pack +// COM: CHECK: func.func @compiletime_fold +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: CHECK: func.func @runtime_fold +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: else, // CHECK: tensor.pack -// CHECK: func.func @compiletime_fold +// CHECK: func.func @runtime_fold // CHECK: tensor.pack // CHECK: tensor.pack // CHECK: tensor.pack // CHECK: tensor.pack -// CHECK: func.func @runtime_fold // CHECK: tensor.pack // CHECK: tensor.pack From 1fc3b9f2c28e4ce99d15956b56fd2794ea4362a0 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sun, 28 Jul 2024 22:33:05 -0700 Subject: [PATCH 50/64] Rename and polish --- .../DataFlow/ConstantSubgraphAnalyser.h | 66 +++++++++---------- .../DataFlow/ConstantSubgraphAnalyser.cpp | 41 ++++++------ .../Transforms/ConstantSubgraphAnalysis.cpp | 2 +- lib/gc/Transforms/ConstantTensorFolding.cpp | 16 +++-- 4 files changed, 64 insertions(+), 61 deletions(-) diff --git a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h index d2dc4ffa4..288ee74c4 100644 --- a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h +++ b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h @@ -5,68 +5,66 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file implements constant subgraph analysis. In this file are: -// 1. the lattice value class that represents operations with constant inputs -// and outputs in the program, and -// 2. a sparse constant subgraph analysis. -// -//===----------------------------------------------------------------------===// +/// +/// This file implements constant subgraph analysis. In this file are: +/// 1. the lattice value class that represents operations with constant inputs +/// and outputs in the program, and +/// 2. a sparse constant subgraph analysis. +/// +///===----------------------------------------------------------------------===// #ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H #define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H #include "mlir/Analysis/DataFlow/SparseAnalysis.h" -#include namespace mlir { namespace dataflow { //===----------------------------------------------------------------------===// -// InConstantSubgraph +// IsConstantTensor //===----------------------------------------------------------------------===// -/// This lattice represents a boolean integer indicating if an operation is with -/// constant inputs and constant outputs and hence in constant subgraph. -class InConstantSubgraph { +/// This lattice represents a boolean indicating if a value is constant. +class IsConstantTensor { public: /// Construct as uninitialized. - explicit InConstantSubgraph() = default; + explicit IsConstantTensor() = default; /// Construct with a known state. - explicit InConstantSubgraph(bool initialized, bool inConstantSubgraph) - : initialized(initialized), inConstantSubgraph(inConstantSubgraph) {} + explicit IsConstantTensor(bool initialized, bool isConstantTensor) + : initialized(initialized), isConstantTensor(isConstantTensor) {} - /// Get the state. Returns null if no value was determined. - bool getInConstantSubgraph() const { + /// Get the state. Must be initialized before. + bool getIsConstantTensor() const { assert(!isUninitialized()); - return inConstantSubgraph; + return isConstantTensor; } /// Compare. - bool operator==(const InConstantSubgraph &rhs) const { + bool operator==(const IsConstantTensor &rhs) const { return initialized == rhs.initialized && - inConstantSubgraph == rhs.inConstantSubgraph; + isConstantTensor == rhs.isConstantTensor; } void print(raw_ostream &os) const; /// Get uninitialized state. This happens when the /// state hasn't been set during the analysis. - static InConstantSubgraph getUninitialized() { return InConstantSubgraph{}; } + static IsConstantTensor getUninitialized() { return IsConstantTensor{}; } /// Whether the state is uninitialized. bool isUninitialized() const { return !initialized; } /// Get unknown state. - static InConstantSubgraph getUnknown() { - return InConstantSubgraph{/*initialized=*/false, - /*inConstantSubgraph=*/false}; + static IsConstantTensor getUnknown() { + return IsConstantTensor{/*initialized=*/false, + /*isConstantTensor*/ false}; } // Join two states. - static InConstantSubgraph join(const InConstantSubgraph &lhs, - const InConstantSubgraph &rhs) { + static IsConstantTensor join(const IsConstantTensor &lhs, + const IsConstantTensor &rhs) { // if one is uninitialized, use another if (lhs.isUninitialized()) return rhs; @@ -75,15 +73,15 @@ class InConstantSubgraph { // both are initialized, intersect them if (!lhs.isUninitialized() && !rhs.isUninitialized()) { - return InConstantSubgraph(true, lhs.getInConstantSubgraph() && - rhs.getInConstantSubgraph()); + return IsConstantTensor(true, lhs.getIsConstantTensor() && + rhs.getIsConstantTensor()); } return getUninitialized(); } private: bool initialized = false; - bool inConstantSubgraph = false; + bool isConstantTensor = false; }; //===----------------------------------------------------------------------===// @@ -91,15 +89,15 @@ class InConstantSubgraph { //===----------------------------------------------------------------------===// class ConstantSubgraphAnalyser - : public SparseForwardDataFlowAnalysis> { + : public SparseForwardDataFlowAnalysis> { public: using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; void visitOperation(Operation *op, - ArrayRef *> operands, - ArrayRef *> results) override; + ArrayRef *> operands, + ArrayRef *> results) override; - void setToEntryState(Lattice *lattice) override; + void setToEntryState(Lattice *lattice) override; }; //===----------------------------------------------------------------------===// @@ -113,7 +111,7 @@ struct RunConstantSubgraphAnalyser { void run(Operation *op); - bool getInConstantSubgraph(Value val); + bool getIsConstantTensor(Value val); private: /// Stores the result of the analysis. diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index 640b3ef59..ff291d6b0 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -35,15 +35,15 @@ using namespace mlir; using namespace mlir::dataflow; //===----------------------------------------------------------------------===// -// InConstantSubgraph +// IsConstantTensor //===----------------------------------------------------------------------===// -void InConstantSubgraph::print(raw_ostream &os) const { +void IsConstantTensor::print(raw_ostream &os) const { if (isUninitialized()) { os << ""; return; } - os << getInConstantSubgraph(); + os << getIsConstantTensor(); } //===----------------------------------------------------------------------===// @@ -51,8 +51,8 @@ void InConstantSubgraph::print(raw_ostream &os) const { //===----------------------------------------------------------------------===// void ConstantSubgraphAnalyser::visitOperation( - Operation *op, ArrayRef *> operands, - ArrayRef *> results) { + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { LLVM_DEBUG(llvm::dbgs() << "ConstantSubgraphAnalyser: Visiting operation:\n" << *op << "\n"); @@ -67,7 +67,7 @@ void ConstantSubgraphAnalyser::visitOperation( LLVM_DEBUG(llvm::dbgs() << "Curr op has " << operands.size() << " operands, check if constant\n"); for (auto *operandLattice : operands) { - auto operandState = operandLattice->getValue().getInConstantSubgraph(); + auto operandState = operandLattice->getValue().getIsConstantTensor(); LLVM_DEBUG(llvm::dbgs() << "Operand: " << operandLattice->getPoint() << ", lattice value: " << operandState << "\n"); if (!operandState) { @@ -81,20 +81,18 @@ void ConstantSubgraphAnalyser::visitOperation( if (!in) { LLVM_DEBUG(llvm::dbgs() << "Curr op not in constant subgraph\n"); for (auto lattice : results) { - propagateIfChanged(lattice, - lattice->join(InConstantSubgraph(true, false))); + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, false))); } } else { LLVM_DEBUG(llvm::dbgs() << "Curr op in constant subgraph\n"); for (auto lattice : results) { - propagateIfChanged(lattice, - lattice->join(InConstantSubgraph(true, true))); + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, true))); } } } void ConstantSubgraphAnalyser::setToEntryState( - Lattice *lattice) { + Lattice *lattice) { if (auto blockArg = cast(lattice->getPoint())) { auto parentOp = blockArg.getParentBlock()->getParentOp(); auto parentOpAttr = parentOp->getAttrDictionary(); @@ -119,14 +117,13 @@ void ConstantSubgraphAnalyser::setToEntryState( if (constArgsIndexes.count(blockArg.getArgNumber())) { LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg << " is marked as constant\n"); - propagateIfChanged(lattice, - lattice->join(InConstantSubgraph(true, true))); + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, true))); return; } - propagateIfChanged(lattice, lattice->join(InConstantSubgraph(true, false))); + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, false))); } else { propagateIfChanged(lattice, - lattice->join(InConstantSubgraph::getUninitialized())); + lattice->join(IsConstantTensor::getUninitialized())); } } @@ -149,13 +146,13 @@ void RunConstantSubgraphAnalyser::getConstantSubgraph(DataFlowSolver &solver, continue; } for (Value res : op.getResults()) { - auto *lattice = solver.lookupState>(res); + auto *lattice = solver.lookupState>(res); if (!lattice || lattice->getValue().isUninitialized()) { resultsAllConstant = false; break; } - const InConstantSubgraph &latticeValue = lattice->getValue(); - if (!latticeValue.getInConstantSubgraph()) { + const IsConstantTensor &latticeValue = lattice->getValue(); + if (!latticeValue.getIsConstantTensor()) { resultsAllConstant = false; break; } @@ -183,8 +180,8 @@ void RunConstantSubgraphAnalyser::run(Operation *op) { getConstantSubgraph(solver, op); } -bool RunConstantSubgraphAnalyser::getInConstantSubgraph(Value val) { - auto *lattice = solver.lookupState>(val); - const InConstantSubgraph &latticeValue = lattice->getValue(); - return latticeValue.getInConstantSubgraph(); +bool RunConstantSubgraphAnalyser::getIsConstantTensor(Value val) { + auto *lattice = solver.lookupState>(val); + const IsConstantTensor &latticeValue = lattice->getValue(); + return latticeValue.getIsConstantTensor(); } \ No newline at end of file diff --git a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp index ed481720b..511d76f21 100644 --- a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp +++ b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp @@ -37,7 +37,7 @@ void ConstantSubgraphAnalysis::runOnOperation() { auto &func = op->getRegions().front().getBlocks().front().getOperations().front(); - // Hard-code: set the #1 argument to be constant. + // Hard-code example: set some arguments to be constant. // OpBuilder builder(op->getContext()); // func.setAttr("runtime_const_args_index", // builder.getI32ArrayAttr({1,2,3,4})); diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 3f38dda77..d7174ec6e 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -10,10 +10,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Passes.h" #include #include +#include "mlir/Transforms/Passes.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -31,6 +32,8 @@ // #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#define DEBUG_TYPE "constant-tensor-folding" + namespace mlir { namespace gc { #define GEN_PASS_DEF_CONSTANTTENSORFOLDING @@ -69,6 +72,10 @@ int64_t getTensorSize(TensorType t) { return size; } +/// @brief op has only one operand, or operands of op are one same value, or +/// operands of op are one same value or from tensor.EmptyOp. +/// @param op +/// @return bool singleOperand(Operation *op) { if (op->getNumOperands() > 1) { Value firstOperand = op->getOperand(0); @@ -328,12 +335,12 @@ struct ConstGraphTensorCacheManager { for (size_t size : buffersSize) { totalSize += divideAndCeil(size, 64) * 64; } - llvm::dbgs() << "Alloc total size: " << totalSize << '\n'; + LLVM_DEBUG(llvm::dbgs() << "Alloc total size: " << totalSize << '\n'); // auto base = createConstCacheProxy(totalSize); std::vector globalIds(buffersSize.size()); size_t offset = 0; for (size_t i = 0; i < buffersSize.size(); i++) { - llvm::dbgs() << "Alloc offset: " << offset << '\n'; + LLVM_DEBUG(llvm::dbgs() << "Alloc offset: " << offset << '\n'); // regCachedTensor(cachedTensorGlobalId, base, offset); globalIds[i] = cachedTensorGlobalId; ++cachedTensorGlobalId; @@ -565,7 +572,8 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, // Allocate buffer for outputValuesInFold std::vector buffersSize; for (Value &tensor : outputValuesInFold) { - llvm::dbgs() << "Allocate buffer for tensor: " << tensor << "\n"; + LLVM_DEBUG(llvm::dbgs() + << "Allocate buffer for tensor: " << tensor << "\n"); buffersSize.push_back( getTensorSize(dyn_cast(tensor.getType()))); } From bfc12c71ce4c2a184a0b20206ba604b9d3efb524 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Tue, 6 Aug 2024 19:58:13 -0700 Subject: [PATCH 51/64] Add accuracy tests on mlp --- .../test_constant_tensor_folding_bf16_4D5D.py | 101 +++++++ ...constant_tensor_folding_bf16_two_layers.py | 258 ++++++++++++++++++ .../test_constant_tensor_folding_f32_4D4D.py | 96 +++++++ ..._constant_tensor_folding_f32_two_layers.py | 225 +++++++++++++++ 4 files changed, 680 insertions(+) create mode 100644 test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py create mode 100644 test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py create mode 100644 test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py create mode 100644 test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py diff --git a/test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py b/test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py new file mode 100644 index 000000000..0fafbd080 --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py @@ -0,0 +1,101 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +from enum import Flag +import os +import sys +import ml_dtypes +import numpy as np +from gc_mlir import ir +from gc_mlir.graph_compiler import GraphCompiler +from numpy.testing import assert_allclose + +project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_dir not in sys.path: + sys.path.insert(0, project_dir) + +import torch +# from bench import py_timeit_bench +from utils import get_mlir_args + +if __name__ == "__main__": + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + + M = 64 + N = 256 + K = 512 + MBlock = 32 + NBlock = 32 + KBlock = 32 + vnni_size = 2 + shapeA = [M // MBlock, K // KBlock, MBlock, KBlock] + shapeB = [N // NBlock, K // KBlock, KBlock // vnni_size, NBlock, vnni_size] + shapeC = [M // MBlock, N // NBlock, MBlock, NBlock] + + block_start = "{" + block_end = "}" + mlir_str = f''' +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +module {block_start} + func.func @entry(%arg0: tensor<{M // MBlock}x{K // KBlock}x{MBlock}x{KBlock}xbf16>, %cst: tensor<{N // NBlock}x{K // KBlock}x{KBlock // vnni_size}x{NBlock}x{vnni_size}xbf16>) -> tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> attributes {block_start}llvm.emit_c_interface{block_end} {block_start} + %cst_0 = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> + %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16>) -> tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> + %2 = linalg.generic {block_start}indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]{block_end} ins(%arg0, %cst : tensor<{M // MBlock}x{K // KBlock}x{MBlock}x{KBlock}xbf16>, tensor<{N // NBlock}x{K // KBlock}x{KBlock // vnni_size}x{NBlock}x{vnni_size}xbf16>) outs(%1 : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16>) {block_start} + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %3 = arith.mulf %in, %in_1 : bf16 + %4 = arith.addf %out, %3 : bf16 + linalg.yield %4 : bf16 + {block_end} -> tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> + return %2 : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> + {block_end} +{block_end} + ''' + print(mlir_str) + + # 4D x 5D, inputs transposed + module_in = ir.Module.parse(mlir_str) + + # entry(%transposed: tensor<2x16x32x32xbf16>, %transposed_5: tensor<8x16x16x32x2xbf16>) -> tensor<2x8x32x32xbf16> + torch_arg0 = torch.rand((M, K), dtype=torch.bfloat16) + torch_arg1 = torch.rand((K, N), dtype=torch.bfloat16) + ref_res = torch_arg0 @ torch_arg1 + + passes = "any(gc-cpu-pipeline)" + shared_libs = [ + os.environ["MLIR_C_RUNNER_UTILS"], + os.environ["MLIR_RUNNER_UTILS"], + ] + compiler = GraphCompiler(passes) + ctx.enable_multithreading(False) + + arg0 = torch_arg0.view(shapeA).permute([0, 2, 1, 3]).contiguous() # MK -> MKmk + np_arg0 = arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + arg1 = torch_arg1.view(shapeB).permute([3, 0, 1, 4, 2]).contiguous() # KN -> NKkn2k + np_arg1 = arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + gc_res = np.ones(shapeC, dtype=ml_dtypes.bfloat16) + + entry = "entry" + mlir_args = get_mlir_args(module_in, entry, [np_arg0, np_arg1, gc_res]) + engine_in = compiler.compile_and_jit(module_in, ir_printing=False) + engine_in.invoke(entry, *mlir_args) + gc_res = np.reshape(np.transpose(gc_res, (0, 2, 1, 3)), (M, N)) # MNmn -> MN + + assert_allclose(gc_res.astype(np.float32), ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) diff --git a/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py b/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py new file mode 100644 index 000000000..d444416e7 --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py @@ -0,0 +1,258 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import os +import sys + +import numpy as np +import ml_dtypes + +from gc_mlir import ir +from gc_mlir.graph_compiler import GraphCompiler +from numpy.testing import assert_allclose + +project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_dir not in sys.path: + sys.path.insert(0, project_dir) + +import torch +# from bench import py_timeit_bench +from utils import get_mlir_args + +if __name__ == "__main__": + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + # ctx.enable_multithreading = False + module_in = ir.Module.parse( + """ +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +module { + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %0 = tensor.empty() : tensor<2x16x32x32xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %padded = tensor.pad %arg0 low[0, 0] high[0, 0] { + ^bb0(%arg5: index, %arg6: index): + tensor.yield %cst : bf16 + } : tensor<64x512xbf16> to tensor<64x512xbf16> + %expanded = tensor.expand_shape %padded [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xbf16> into tensor<2x32x16x32xbf16> + %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xbf16>) outs(%0 : tensor<2x16x32x32xbf16>) permutation = [0, 2, 1, 3] + %1 = tensor.empty() : tensor<8x16x32x32xbf16> + %padded_0 = tensor.pad %arg1 low[0, 0] high[0, 0] { + ^bb0(%arg5: index, %arg6: index): + tensor.yield %cst : bf16 + } : tensor<512x256xbf16> to tensor<512x256xbf16> + %expanded_1 = tensor.expand_shape %padded_0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xbf16> into tensor<16x32x8x32xbf16> + %transposed_2 = linalg.transpose ins(%expanded_1 : tensor<16x32x8x32xbf16>) outs(%1 : tensor<8x16x32x32xbf16>) permutation = [2, 0, 1, 3] + %2 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %padded_3 = tensor.pad %transposed_2 low[0, 0, 0, 0] high[0, 0, 0, 0] { + ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : bf16 + } : tensor<8x16x32x32xbf16> to tensor<8x16x32x32xbf16> + %expanded_4 = tensor.expand_shape %padded_3 [[0], [1], [2, 3], [4]] output_shape [8, 16, 16, 2, 32] : tensor<8x16x32x32xbf16> into tensor<8x16x16x2x32xbf16> + %transposed_5 = linalg.transpose ins(%expanded_4 : tensor<8x16x16x2x32xbf16>) outs(%2 : tensor<8x16x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] + %3 = tensor.empty() : tensor<2x8x32x32xbf16> + %4 = linalg.fill ins(%cst : bf16) outs(%3 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %transposed_5 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%4 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_19: bf16, %out: bf16): + %17 = arith.mulf %in, %in_19 : bf16 + %18 = arith.addf %out, %17 : bf16 + linalg.yield %18 : bf16 + } -> tensor<2x8x32x32xbf16> + %6 = tensor.empty() : tensor<8x32xbf16> + %padded_6 = tensor.pad %arg2 low[0] high[0] { + ^bb0(%arg5: index): + tensor.yield %cst : bf16 + } : tensor<256xbf16> to tensor<256xbf16> + %expanded_7 = tensor.expand_shape %padded_6 [[0, 1]] output_shape [8, 32] : tensor<256xbf16> into tensor<8x32xbf16> + %transposed_8 = linalg.transpose ins(%expanded_7 : tensor<8x32xbf16>) outs(%6 : tensor<8x32xbf16>) permutation = [0, 1] + %broadcasted = linalg.broadcast ins(%transposed_8 : tensor<8x32xbf16>) outs(%3 : tensor<2x8x32x32xbf16>) dimensions = [0, 2] + %7 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xbf16>) outs(%5 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %17 = arith.addf %in, %out : bf16 + linalg.yield %17 : bf16 + } -> tensor<2x8x32x32xbf16> + %8 = tensor.empty() : tensor<32x8x32x32xbf16> + %padded_9 = tensor.pad %arg3 low[0, 0] high[0, 0] { + ^bb0(%arg5: index, %arg6: index): + tensor.yield %cst : bf16 + } : tensor<256x1024xbf16> to tensor<256x1024xbf16> + %expanded_10 = tensor.expand_shape %padded_9 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xbf16> into tensor<8x32x32x32xbf16> + %transposed_11 = linalg.transpose ins(%expanded_10 : tensor<8x32x32x32xbf16>) outs(%8 : tensor<32x8x32x32xbf16>) permutation = [2, 0, 1, 3] + %9 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %padded_12 = tensor.pad %transposed_11 low[0, 0, 0, 0] high[0, 0, 0, 0] { + ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : bf16 + } : tensor<32x8x32x32xbf16> to tensor<32x8x32x32xbf16> + %expanded_13 = tensor.expand_shape %padded_12 [[0], [1], [2, 3], [4]] output_shape [32, 8, 16, 2, 32] : tensor<32x8x32x32xbf16> into tensor<32x8x16x2x32xbf16> + %transposed_14 = linalg.transpose ins(%expanded_13 : tensor<32x8x16x2x32xbf16>) outs(%9 : tensor<32x8x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] + %10 = tensor.empty() : tensor<2x32x32x32xbf16> + %11 = linalg.fill ins(%cst : bf16) outs(%10 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %12 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%7, %transposed_14 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%11 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_19: bf16, %out: bf16): + %17 = arith.mulf %in, %in_19 : bf16 + %18 = arith.addf %out, %17 : bf16 + linalg.yield %18 : bf16 + } -> tensor<2x32x32x32xbf16> + %13 = tensor.empty() : tensor<32x32xbf16> + %padded_15 = tensor.pad %arg4 low[0] high[0] { + ^bb0(%arg5: index): + tensor.yield %cst : bf16 + } : tensor<1024xbf16> to tensor<1024xbf16> + %expanded_16 = tensor.expand_shape %padded_15 [[0, 1]] output_shape [32, 32] : tensor<1024xbf16> into tensor<32x32xbf16> + %transposed_17 = linalg.transpose ins(%expanded_16 : tensor<32x32xbf16>) outs(%13 : tensor<32x32xbf16>) permutation = [0, 1] + %14 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%transposed_17 : tensor<32x32xbf16>) outs(%12 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %17 = arith.addf %in, %out : bf16 + linalg.yield %17 : bf16 + } -> tensor<2x32x32x32xbf16> + %15 = tensor.empty() : tensor<64x1024xbf16> + %transposed_18 = linalg.transpose ins(%14 : tensor<2x32x32x32xbf16>) outs(%10 : tensor<2x32x32x32xbf16>) permutation = [0, 2, 1, 3] + %collapsed = tensor.collapse_shape %transposed_18 [[0, 1], [2, 3]] : tensor<2x32x32x32xbf16> into tensor<64x1024xbf16> + %extracted_slice = tensor.extract_slice %collapsed[0, 0] [64, 1024] [1, 1] : tensor<64x1024xbf16> to tensor<64x1024xbf16> + %16 = linalg.copy ins(%extracted_slice : tensor<64x1024xbf16>) outs(%15 : tensor<64x1024xbf16>) -> tensor<64x1024xbf16> + return %16 : tensor<64x1024xbf16> + } +} + """ + ) + module_out = ir.Module.parse( + """ +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +module { + llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32 + llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> + llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> + llvm.mlir.global external constant @__runtime_fold_buffer_ids_(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<2x16x32x32xbf16> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xbf16> into tensor<2x32x16x32xbf16> + %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xbf16>) outs(%0 : tensor<2x16x32x32xbf16>) permutation = [0, 2, 1, 3] + %1 = tensor.empty() : tensor<2x8x32x32xbf16> + %2 = linalg.fill ins(%cst : bf16) outs(%1 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %arg1 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%2 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %11 = arith.mulf %in, %in_1 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } -> tensor<2x8x32x32xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<8x32xbf16>) outs(%1 : tensor<2x8x32x32xbf16>) dimensions = [0, 2] + %4 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xbf16>) outs(%3 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %11 = arith.addf %in, %out : bf16 + linalg.yield %11 : bf16 + } -> tensor<2x8x32x32xbf16> + %5 = tensor.empty() : tensor<2x32x32x32xbf16> + %6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%4, %arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%6 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %11 = arith.mulf %in, %in_1 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } -> tensor<2x32x32x32xbf16> + %8 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<32x32xbf16>) outs(%7 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %11 = arith.addf %in, %out : bf16 + linalg.yield %11 : bf16 + } -> tensor<2x32x32x32xbf16> + %9 = tensor.empty() : tensor<64x1024xbf16> + %transposed_0 = linalg.transpose ins(%8 : tensor<2x32x32x32xbf16>) outs(%5 : tensor<2x32x32x32xbf16>) permutation = [0, 2, 1, 3] + %collapsed = tensor.collapse_shape %transposed_0 [[0, 1], [2, 3]] : tensor<2x32x32x32xbf16> into tensor<64x1024xbf16> + %10 = linalg.copy ins(%collapsed : tensor<64x1024xbf16>) outs(%9 : tensor<64x1024xbf16>) -> tensor<64x1024xbf16> + return %10 : tensor<64x1024xbf16> + } + func.func @runtime_fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<8x16x32x32xbf16> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xbf16> into tensor<16x32x8x32xbf16> + %transposed = linalg.transpose ins(%expanded : tensor<16x32x8x32xbf16>) outs(%0 : tensor<8x16x32x32xbf16>) permutation = [2, 0, 1, 3] + %1 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %expanded_0 = tensor.expand_shape %transposed [[0], [1], [2, 3], [4]] output_shape [8, 16, 16, 2, 32] : tensor<8x16x32x32xbf16> into tensor<8x16x16x2x32xbf16> + %transposed_1 = linalg.transpose ins(%expanded_0 : tensor<8x16x16x2x32xbf16>) outs(%1 : tensor<8x16x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] + %expanded_2 = tensor.expand_shape %arg1 [[0, 1]] output_shape [8, 32] : tensor<256xbf16> into tensor<8x32xbf16> + %2 = tensor.empty() : tensor<32x8x32x32xbf16> + %expanded_3 = tensor.expand_shape %arg2 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xbf16> into tensor<8x32x32x32xbf16> + %transposed_4 = linalg.transpose ins(%expanded_3 : tensor<8x32x32x32xbf16>) outs(%2 : tensor<32x8x32x32xbf16>) permutation = [2, 0, 1, 3] + %3 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %expanded_5 = tensor.expand_shape %transposed_4 [[0], [1], [2, 3], [4]] output_shape [32, 8, 16, 2, 32] : tensor<32x8x32x32xbf16> into tensor<32x8x16x2x32xbf16> + %transposed_6 = linalg.transpose ins(%expanded_5 : tensor<32x8x16x2x32xbf16>) outs(%3 : tensor<32x8x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] + %expanded_7 = tensor.expand_shape %arg3 [[0, 1]] output_shape [32, 32] : tensor<1024xbf16> into tensor<32x32xbf16> + return %transposed_1, %expanded_2, %transposed_6, %expanded_7 : tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16> + } +} + """ + ) + + # module_in entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> + torch_arg0 = torch.rand((64, 512), dtype=torch.bfloat16) + torch_arg1 = torch.rand((512, 256), dtype=torch.bfloat16) + torch_arg2 = torch.rand((256), dtype=torch.bfloat16) + torch_arg3 = torch.rand((256, 1024), dtype=torch.bfloat16) + torch_arg4 = torch.rand((1024), dtype=torch.bfloat16) + + ref_res = (torch_arg0 @ torch_arg1 + torch_arg2) @ torch_arg3 + torch_arg4 + + passes = "any(gc-cpu-pipeline)" + compiler = GraphCompiler(passes) + ctx.enable_multithreading(False) + + arg0 = torch_arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + arg1 = torch_arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + arg2 = torch_arg2.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + arg3 = torch_arg3.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + arg4 = torch_arg4.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + gc_res = np.ones((64, 1024), dtype=ml_dtypes.bfloat16) + + entry = "entry" + mlir_args = get_mlir_args(module_in, entry, [arg0, arg1, arg2, arg3, arg4, gc_res]) + engine_in = compiler.compile_and_jit(module_in, ir_printing=True) + engine_in.invoke(entry, *mlir_args) + + assert_allclose(gc_res.astype(np.float32), ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) + + + # module_out entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> + # module_out runtime_fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) + fold_arg0 = arg1 + fold_arg1 = arg2 + fold_arg2 = arg3 + fold_arg3 = arg4 + fold_res0 = np.zeros((8, 16, 16, 32, 2), dtype=ml_dtypes.bfloat16) + fold_res1 = np.zeros((8, 32), dtype=ml_dtypes.bfloat16) + fold_res2 = np.zeros((32, 8, 16, 32, 2), dtype=ml_dtypes.bfloat16) + fold_res3 = np.zeros((32, 32), dtype=ml_dtypes.bfloat16) + + runtime_fold = "runtime_fold" + fold_mlir_args = get_mlir_args(module_out, runtime_fold, [fold_arg0, fold_arg1, fold_arg2, fold_arg3, fold_res0, fold_res1, fold_res2, fold_res3]) + + gc_res_out = np.zeros((64, 1024), dtype=ml_dtypes.bfloat16) + entry = "entry" + mlir_args = get_mlir_args(module_out, entry, [arg0, fold_res0, fold_res1, fold_res2, fold_res3, gc_res_out]) + + engine_out = compiler.compile_and_jit(module_out, ir_printing=True) + engine_out.invoke(runtime_fold, *fold_mlir_args) + engine_out.invoke(entry, *mlir_args) + + assert_allclose(gc_res.astype(np.float32), gc_res_out.astype(np.float32), rtol=1e-5, atol=1e-5) + diff --git a/test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py b/test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py new file mode 100644 index 000000000..465d390fd --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py @@ -0,0 +1,96 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +from enum import Flag +import os +import sys + +import numpy as np +from gc_mlir import ir +from gc_mlir.graph_compiler import GraphCompiler +from numpy.testing import assert_allclose + +project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_dir not in sys.path: + sys.path.insert(0, project_dir) + +import torch +# from bench import py_timeit_bench +from utils import get_mlir_args + +if __name__ == "__main__": + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + + M = 64 + N = 256 + K = 512 + MBlock = 32 + NBlock = 32 + KBlock = 32 + vnni_size = 1 + shapeA = [M // MBlock, K // KBlock, MBlock, KBlock] + shapeB = [N // NBlock, K // KBlock, KBlock, NBlock] + shapeC = [M // MBlock, N // NBlock, MBlock, NBlock] + + # 4D x 4D, inputs transposed + mlir_str = """ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +module { + func.func @main_entry(%arg0: tensor<2x16x32x32xf32>, %arg1: tensor<8x16x32x32xf32>) -> tensor<2x8x32x32xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2x8x32x32xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x8x32x32xf32>) -> tensor<2x8x32x32xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%1 : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x8x32x32xf32> + return %2 : tensor<2x8x32x32xf32> + } +} + """ + module = ir.Module.parse(mlir_str) + + torch_arg0 = torch.rand((M, K), dtype=torch.float32) + torch_arg1 = torch.rand((K, N), dtype=torch.float32) + ref_res = torch.matmul(torch_arg0, torch_arg1) + + arg0_0 = torch_arg0.view([M // MBlock, MBlock, K // KBlock, KBlock]).permute([0, 2, 1, 3]).contiguous().numpy().view(np.dtype("float32")) + arg0_1 = np.transpose(np.reshape(torch_arg0.contiguous().numpy().view(np.dtype("float32")), (M // MBlock, MBlock, K // KBlock, KBlock)), (0, 2, 1, 3)) # MK -> MKmk + print("arg0_0 arg0_1 close: ", np.allclose(arg0_0, arg0_1, rtol=1e-5, atol=1e-5)) + + arg1 = torch_arg1.view([K // KBlock, KBlock, N // NBlock, NBlock]).permute([2, 0, 1, 3]).contiguous().numpy().view(np.dtype("float32")) + # arg1 = np.transpose(np.reshape(torch_arg1.contiguous().numpy(), (16, 32, 8, 32)), (2, 0, 1, 3)).view(np.dtype("float32")) # KN -> NKkn, 8x16x32x32 + + gc_res = np.ones(shapeC, dtype=np.dtype("float32")) + + entry = "main_entry" + mlir_args = get_mlir_args(module, entry, [arg0_1, arg1, gc_res]) + + passes = "any(gc-cpu-pipeline)" + compiler = GraphCompiler(passes) + engine_in = compiler.compile_and_jit(module) + engine_in.invoke(entry, *mlir_args) + gc_res = np.reshape(np.transpose(gc_res, (0, 2, 1, 3)), (64, 256)) # MNmn -> MN + + print("gc_res ref_res close: ", np.allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5)) + assert_allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) + diff --git a/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py b/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py new file mode 100644 index 000000000..377e28a36 --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py @@ -0,0 +1,225 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import os +import sys + +import numpy as np +from gc_mlir import ir +from gc_mlir.graph_compiler import GraphCompiler +from numpy.testing import assert_allclose + +project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_dir not in sys.path: + sys.path.insert(0, project_dir) + +import torch +# from bench import py_timeit_bench +from utils import get_mlir_args + +if __name__ == "__main__": + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + + # 4D x 4D, inputs plain, two layers + mlir_str_4D4D = """ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +module { + func.func @entry(%arg0: tensor<64x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256x1024xf32>, %arg4: tensor<1024xf32>) -> tensor<64x1024xf32> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %0 = tensor.empty() : tensor<2x16x32x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[0, 0] high[0, 0] { + ^bb0(%arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<64x512xf32> to tensor<64x512xf32> + %expanded = tensor.expand_shape %padded [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xf32> into tensor<2x32x16x32xf32> + %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xf32>) outs(%0 : tensor<2x16x32x32xf32>) permutation = [0, 2, 1, 3] + %1 = tensor.empty() : tensor<8x16x32x32xf32> + %padded_0 = tensor.pad %arg1 low[0, 0] high[0, 0] { + ^bb0(%arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<512x256xf32> to tensor<512x256xf32> + %expanded_1 = tensor.expand_shape %padded_0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xf32> into tensor<16x32x8x32xf32> + %transposed_2 = linalg.transpose ins(%expanded_1 : tensor<16x32x8x32xf32>) outs(%1 : tensor<8x16x32x32xf32>) permutation = [2, 0, 1, 3] + %2 = tensor.empty() : tensor<2x8x32x32xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x8x32x32xf32>) -> tensor<2x8x32x32xf32> + %4 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %transposed_2 : tensor<2x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%3 : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %in_8: f32, %out: f32): + %14 = arith.mulf %in, %in_8 : f32 + %15 = arith.addf %out, %14 : f32 + linalg.yield %15 : f32 + } -> tensor<2x8x32x32xf32> + %expanded_3 = tensor.expand_shape %arg2 [[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32> + %broadcasted = linalg.broadcast ins(%expanded_3 : tensor<8x32xf32>) outs(%2 : tensor<2x8x32x32xf32>) dimensions = [0, 2] + %5 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xf32>) outs(%4 : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %14 = arith.addf %in, %out : f32 + linalg.yield %14 : f32 + } -> tensor<2x8x32x32xf32> + %6 = tensor.empty() : tensor<32x8x32x32xf32> + %expanded_4 = tensor.expand_shape %arg3 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xf32> into tensor<8x32x32x32xf32> + %transposed_5 = linalg.transpose ins(%expanded_4 : tensor<8x32x32x32xf32>) outs(%6 : tensor<32x8x32x32xf32>) permutation = [2, 0, 1, 3] + %7 = tensor.empty() : tensor<2x32x32x32xf32> + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<2x32x32x32xf32>) -> tensor<2x32x32x32xf32> + %9 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%5, %transposed_5 : tensor<2x8x32x32xf32>, tensor<32x8x32x32xf32>) outs(%8 : tensor<2x32x32x32xf32>) { + ^bb0(%in: f32, %in_8: f32, %out: f32): + %14 = arith.mulf %in, %in_8 : f32 + %15 = arith.addf %out, %14 : f32 + linalg.yield %15 : f32 + } -> tensor<2x32x32x32xf32> + %expanded_6 = tensor.expand_shape %arg4 [[0, 1]] output_shape [32, 32] : tensor<1024xf32> into tensor<32x32xf32> + %10 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_6 : tensor<32x32xf32>) outs(%9 : tensor<2x32x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %14 = arith.addf %in, %out : f32 + linalg.yield %14 : f32 + } -> tensor<2x32x32x32xf32> + %11 = tensor.empty() : tensor<2x32x32x32xf32> + %transposed_7 = linalg.transpose ins(%10 : tensor<2x32x32x32xf32>) outs(%11 : tensor<2x32x32x32xf32>) permutation = [0, 2, 1, 3] + %collapsed = tensor.collapse_shape %transposed_7 [[0, 1], [2, 3]] : tensor<2x32x32x32xf32> into tensor<64x1024xf32> + %12 = tensor.empty() : tensor<64x1024xf32> + %13 = linalg.copy ins(%collapsed : tensor<64x1024xf32>) outs(%12 : tensor<64x1024xf32>) -> tensor<64x1024xf32> + return %13 : tensor<64x1024xf32> + } +} + """ + + module_in = ir.Module.parse(mlir_str_4D4D) + + + mlir_str_4D4D_out = """ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +module { + llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32 + llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> + llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> + llvm.mlir.global external constant @__runtime_fold_buffer_ids_(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> + func.func @entry(%arg0: tensor<64x512xf32>, %arg1: tensor<8x16x32x32xf32>, %arg2: tensor<8x32xf32>, %arg3: tensor<32x8x32x32xf32>, %arg4: tensor<32x32xf32>) -> tensor<64x1024xf32> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2x16x32x32xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xf32> into tensor<2x32x16x32xf32> + %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xf32>) outs(%0 : tensor<2x16x32x32xf32>) permutation = [0, 2, 1, 3] + %1 = tensor.empty() : tensor<2x8x32x32xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x8x32x32xf32>) -> tensor<2x8x32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %arg1 : tensor<2x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%2 : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %12 = arith.mulf %in, %in_1 : f32 + %13 = arith.addf %out, %12 : f32 + linalg.yield %13 : f32 + } -> tensor<2x8x32x32xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<8x32xf32>) outs(%1 : tensor<2x8x32x32xf32>) dimensions = [0, 2] + %4 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xf32>) outs(%3 : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %12 = arith.addf %in, %out : f32 + linalg.yield %12 : f32 + } -> tensor<2x8x32x32xf32> + %5 = tensor.empty() : tensor<2x32x32x32xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x32x32x32xf32>) -> tensor<2x32x32x32xf32> + %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%4, %arg3 : tensor<2x8x32x32xf32>, tensor<32x8x32x32xf32>) outs(%6 : tensor<2x32x32x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %12 = arith.mulf %in, %in_1 : f32 + %13 = arith.addf %out, %12 : f32 + linalg.yield %13 : f32 + } -> tensor<2x32x32x32xf32> + %8 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<32x32xf32>) outs(%7 : tensor<2x32x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %12 = arith.addf %in, %out : f32 + linalg.yield %12 : f32 + } -> tensor<2x32x32x32xf32> + %9 = tensor.empty() : tensor<2x32x32x32xf32> + %transposed_0 = linalg.transpose ins(%8 : tensor<2x32x32x32xf32>) outs(%9 : tensor<2x32x32x32xf32>) permutation = [0, 2, 1, 3] + %collapsed = tensor.collapse_shape %transposed_0 [[0, 1], [2, 3]] : tensor<2x32x32x32xf32> into tensor<64x1024xf32> + %10 = tensor.empty() : tensor<64x1024xf32> + %11 = linalg.copy ins(%collapsed : tensor<64x1024xf32>) outs(%10 : tensor<64x1024xf32>) -> tensor<64x1024xf32> + return %11 : tensor<64x1024xf32> + } + + func.func @runtime_fold(%arg0: tensor<512x256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256x1024xf32>, %arg3: tensor<1024xf32>) -> (tensor<8x16x32x32xf32>, tensor<8x32xf32>, tensor<32x8x32x32xf32>, tensor<32x32xf32>) attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<8x16x32x32xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xf32> into tensor<16x32x8x32xf32> + %transposed = linalg.transpose ins(%expanded : tensor<16x32x8x32xf32>) outs(%0 : tensor<8x16x32x32xf32>) permutation = [2, 0, 1, 3] + %expanded_0 = tensor.expand_shape %arg1 [[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32> + %1 = tensor.empty() : tensor<32x8x32x32xf32> + %expanded_1 = tensor.expand_shape %arg2 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xf32> into tensor<8x32x32x32xf32> + %transposed_2 = linalg.transpose ins(%expanded_1 : tensor<8x32x32x32xf32>) outs(%1 : tensor<32x8x32x32xf32>) permutation = [2, 0, 1, 3] + %expanded_3 = tensor.expand_shape %arg3 [[0, 1]] output_shape [32, 32] : tensor<1024xf32> into tensor<32x32xf32> + return %transposed, %expanded_0, %transposed_2, %expanded_3 : tensor<8x16x32x32xf32>, tensor<8x32xf32>, tensor<32x8x32x32xf32>, tensor<32x32xf32> + } +} + """ + module_out = ir.Module.parse(mlir_str_4D4D_out) + + # module_in entry(%arg0: tensor<64x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256x1024xf32>, %arg4: tensor<1024xf32>) -> tensor<64x1024xf32> + torch_arg0 = torch.rand((64, 512), dtype=torch.float32) + torch_arg1 = torch.rand((512, 256), dtype=torch.float32) + torch_arg2 = torch.rand((256), dtype=torch.float32) + torch_arg3 = torch.rand((256, 1024), dtype=torch.float32) + torch_arg4 = torch.rand((1024), dtype=torch.float32) + + ref_res = (torch_arg0 @ torch_arg1 + torch_arg2) @ torch_arg3 + torch_arg4 + + passes = "any(gc-cpu-pipeline)" + compiler = GraphCompiler(passes) + ctx.enable_multithreading(False) + + arg0 = torch_arg0.contiguous().numpy() + arg1 = torch_arg1.contiguous().numpy() + arg2 = torch_arg2.contiguous().numpy() + arg3 = torch_arg3.contiguous().numpy() + arg4 = torch_arg4.contiguous().numpy() + gc_res = np.zeros((64, 1024), dtype=np.float32) + + entry = "entry" + mlir_args = get_mlir_args(module_in, entry, [arg0, arg1, arg2, arg3, arg4, gc_res]) + engine_in = compiler.compile_and_jit(module_in, ir_printing=False) + engine_in.invoke(entry, *mlir_args) + + print("Reference vs GC input IR close: ", np.allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5)) + assert_allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) + + + # module_out entry(%arg0: tensor<64x512xf32>, %arg1: tensor<8x16x32x32xf32>, %arg2: tensor<8x32xf32>, %arg3: tensor<32x8x32x32xf32>, %arg4: tensor<32x32xf32>) -> tensor<64x1024xf32> + # module_out runtime_fold(%arg0: tensor<512x256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256x1024xf32>, %arg3: tensor<1024xf32>) -> (tensor<8x16x32x32xf32>, tensor<8x32xf32>, tensor<32x8x32x32xf32>, tensor<32x32xf32>) + fold_arg0 = arg1 + fold_arg1 = arg2 + fold_arg2 = arg3 + fold_arg3 = arg4 + fold_res0 = np.zeros((8, 16, 32, 32), dtype=np.float32) + fold_res1 = np.zeros((8, 32), dtype=np.float32) + fold_res2 = np.zeros((32, 8, 32, 32), dtype=np.float32) + fold_res3 = np.zeros((32, 32), dtype=np.float32) + + runtime_fold = "runtime_fold" + fold_mlir_args = get_mlir_args(module_out, runtime_fold, [fold_arg0, fold_arg1, fold_arg2, fold_arg3, fold_res0, fold_res1, fold_res2, fold_res3]) + + gc_res_out = np.zeros((64, 1024), dtype=np.float32) + entry = "entry" + entry_mlir_args = get_mlir_args(module_out, entry, [arg0, fold_res0, fold_res1, fold_res2, fold_res3, gc_res_out]) + + engine_out = compiler.compile_and_jit(module_out, ir_printing=False) + engine_out.invoke(runtime_fold, *fold_mlir_args) + engine_out.invoke(entry, *entry_mlir_args) + + print("GC input IR vs GC output IR close: ", np.allclose(gc_res, gc_res_out, rtol=1e-5, atol=1e-5)) + assert_allclose(gc_res, gc_res_out, rtol=1e-5, atol=1e-5) From f9c24256b1605dcf8b734f2f3976c3929e00f1cd Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Mon, 19 Aug 2024 18:23:38 -0700 Subject: [PATCH 52/64] Support MemRef args --- lib/gc/Transforms/ConstantTensorFolding.cpp | 31 +++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index d7174ec6e..9b1aa27cb 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -53,6 +54,8 @@ bool isInConstantSubgraph(Operation *op) { auto opNamespace = op->getDialect()->getNamespace(); if (opNamespace == linalg::LinalgDialect::getDialectNamespace() || opNamespace == tensor::TensorDialect::getDialectNamespace() || + opNamespace == + bufferization::BufferizationDialect::getDialectNamespace() || opNamespace == arith::ArithDialect::getDialectNamespace()) { if (op->getAttr("onednn_graph.in_const_subgraph")) { return true; @@ -61,7 +64,7 @@ bool isInConstantSubgraph(Operation *op) { return false; } -int64_t getTensorSize(TensorType t) { +template int64_t getDataSize(T t) { Type eleType = t.getElementType(); unsigned bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes ArrayRef shape = t.getShape(); @@ -72,6 +75,16 @@ int64_t getTensorSize(TensorType t) { return size; } +int64_t getValueSize(Value v) { + if (isa(v.getType())) { + auto t = dyn_cast(v.getType()); + return getDataSize(t); + } else { + auto t = dyn_cast(v.getType()); + return getDataSize(t); + } +} + /// @brief op has only one operand, or operands of op are one same value, or /// operands of op are one same value or from tensor.EmptyOp. /// @param op @@ -465,7 +478,7 @@ void getInputsAndOutputs(Block &block, // The constant ops are all single-input single-output. bool simpleTopo = true; auto arg = block.getArgument(id); - if (!isa(arg.getType())) { + if (!isa(arg.getType()) && !isa(arg.getType())) { continue; } inputTypes.push_back(arg.getType()); @@ -511,15 +524,12 @@ void getInputsAndOutputs(Block &block, // not fold it. Compare data size changes during traverse to find the last // op that satisfies this condition. if (simpleTopo) { - int64_t initSize = - getTensorSize(dyn_cast(valuesOnTheWay[0].getType())); - if (!isa(outputTypes.back()) || - initSize * DATA_SIZE_EXPANDING_THRESHOLD < - getTensorSize(dyn_cast(outputTypes.back()))) { + int64_t initSize = getValueSize(valuesOnTheWay[0]); + if (initSize * DATA_SIZE_EXPANDING_THRESHOLD < + getValueSize(valuesOnTheWay.back())) { size_t lastIdx = 0; for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { - int64_t size = getTensorSize( - dyn_cast(valuesOnTheWay[i].getType())); + int64_t size = getValueSize(valuesOnTheWay[i]); if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { lastIdx = i; } @@ -574,8 +584,7 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, for (Value &tensor : outputValuesInFold) { LLVM_DEBUG(llvm::dbgs() << "Allocate buffer for tensor: " << tensor << "\n"); - buffersSize.push_back( - getTensorSize(dyn_cast(tensor.getType()))); + buffersSize.push_back(getValueSize(tensor)); } auto manager = ConstGraphTensorCacheManager::get(); SmallVector globalIndexes; From d8d2d7998dcc71e29db3a414c953bd87cd847f92 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Mon, 19 Aug 2024 18:24:10 -0700 Subject: [PATCH 53/64] Add to pipeline --- lib/gc/Transforms/Pipeline.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 74da09bf4..c0ebfb175 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -51,6 +51,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass // todo: tensor constant propagation pass + pm.addPass(createConstantSubgraphAnalysisPass()); + pm.addPass(createConstantTensorFoldingPass()); // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass pm.addNestedPass(createDeepTileContractionNamedOp()); From 22c4474dac1302e0c2696f1d16fb54cb4e36d817 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Mon, 26 Aug 2024 00:03:07 -0700 Subject: [PATCH 54/64] Forbid buffer_to_tensor case --- lib/gc/Transforms/ConstantTensorFolding.cpp | 72 ++++++++++++------- .../unittests/ExecutionEngine/JitWrapper.cpp | 1 - 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 9b1aa27cb..f1d85e62e 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -9,8 +9,8 @@ // This transformation pass performs a constant subgraph transform in MLIR. // //===----------------------------------------------------------------------===// - #include +#include #include #include "mlir/Transforms/Passes.h" @@ -496,6 +496,14 @@ void getInputsAndOutputs(Block &block, [](Operation *child) { return !isInConstantSubgraph(child); })) { + // skip case: memref v -> bufferization.to_tensor -> tensor t. + if (valuesOnTheWay.size() == 2 && v.hasOneUse() && + isa(v.getDefiningOp())) { + inputTypes.pop_back(); + inputValues.pop_back(); + constArgsIndexes.erase(id); + continue; + } if (std::find(outputValues.begin(), outputValues.end(), v) == outputValues.end()) { outputTypes.push_back(v.getType()); @@ -777,13 +785,17 @@ void ConstantTensorFolding::runOnOperation() { getInputsAndOutputs(block, compiletimeConstArgsIndexes, compiletimeInputTypes, compiletimeInputValues, compiletimeOutputTypes, compiletimeOutputValues); + assert(compiletimeInputTypes.size() == compiletimeInputValues.size()); + assert(compiletimeOutputTypes.size() == compiletimeOutputValues.size()); - func::FuncOp compiletimeFoldFunc = - buildFoldFunc(context, builder, topOp, "compiletime_fold", constOps, - compiletimeInputTypes, compiletimeInputValues, - compiletimeOutputTypes, compiletimeOutputValues); - (void)compiletimeFoldFunc; - canonicalizeAndClean(context, compiletimeFoldFunc.getOperation()); + if (!compiletimeOutputTypes.empty()) { + func::FuncOp compiletimeFoldFunc = + buildFoldFunc(context, builder, topOp, "compiletime_fold", constOps, + compiletimeInputTypes, compiletimeInputValues, + compiletimeOutputTypes, compiletimeOutputValues); + (void)compiletimeFoldFunc; + canonicalizeAndClean(context, compiletimeFoldFunc.getOperation()); + } // ===== build runtime folding function ===== SmallVector runtimeInputTypes; // types of constant tensors @@ -795,12 +807,16 @@ void ConstantTensorFolding::runOnOperation() { getInputsAndOutputs(block, runtimeConstArgsIndexes, runtimeInputTypes, runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); - - func::FuncOp runtimeFoldFunc = buildFoldFunc( - context, builder, topOp, "runtime_fold", constOps, runtimeInputTypes, - runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); - (void)runtimeFoldFunc; - canonicalizeAndClean(context, runtimeFoldFunc.getOperation()); + assert(runtimeInputTypes.size() == runtimeInputValues.size()); + assert(runtimeOutputTypes.size() == runtimeOutputValues.size()); + + if (!runtimeOutputTypes.empty()) { + func::FuncOp runtimeFoldFunc = buildFoldFunc( + context, builder, topOp, "runtime_fold", constOps, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); + (void)runtimeFoldFunc; + canonicalizeAndClean(context, runtimeFoldFunc.getOperation()); + } // ===== build computing function ===== std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; @@ -811,8 +827,10 @@ void ConstantTensorFolding::runOnOperation() { SmallVector outputValues = compiletimeOutputValues; outputValues.insert(outputValues.end(), runtimeOutputValues.begin(), runtimeOutputValues.end()); - modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, - outputTypes, outputValues); + if (!outputTypes.empty()) { + modifyComputeFunc(context, builder, topOp, topFunc, block, + constArgsIndexes, outputTypes, outputValues); + } } else { std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; constArgsIndexes.merge(runtimeConstArgsIndexes); @@ -827,16 +845,20 @@ void ConstantTensorFolding::runOnOperation() { getArithConstantOutputs(block, outputTypes, outputValues); getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues, outputTypes, outputValues); - - func::FuncOp foldFunc = - buildFoldFunc(context, builder, topOp, "runtime_fold", constOps, - inputTypes, inputValues, outputTypes, outputValues); - (void)foldFunc; - canonicalizeAndClean(context, foldFunc.getOperation()); - - // ===== build computing function ===== - modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes, - outputTypes, outputValues); + assert(inputTypes.size() == inputValues.size()); + assert(outputTypes.size() == outputValues.size()); + + if (!outputTypes.empty()) { + func::FuncOp foldFunc = + buildFoldFunc(context, builder, topOp, "runtime_fold", constOps, + inputTypes, inputValues, outputTypes, outputValues); + (void)foldFunc; + canonicalizeAndClean(context, foldFunc.getOperation()); + + // ===== build computing function ===== + modifyComputeFunc(context, builder, topOp, topFunc, block, + constArgsIndexes, outputTypes, outputValues); + } } canonicalizeAndClean(context, topOp); diff --git a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp index f7b93eaa6..48b27975e 100644 --- a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp +++ b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp @@ -25,7 +25,6 @@ using namespace mlir; static const char code1[] = R"mlir( module { -llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32 func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> From e20d059ef539a8e256990fe0396a45da27e78a45 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Fri, 6 Sep 2024 11:14:31 +0800 Subject: [PATCH 55/64] Add shape info to global --- lib/gc/Transforms/ConstantTensorFolding.cpp | 38 ++++++++++++++++--- .../test_constant_tensor_folding-1.mlir | 2 +- .../test_constant_tensor_folding.mlir | 2 +- ...constant_tensor_folding_bf16_two_layers.py | 2 +- ..._constant_tensor_folding_f32_two_layers.py | 2 +- 5 files changed, 37 insertions(+), 9 deletions(-) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index f1d85e62e..17270d54f 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -602,7 +602,7 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); auto moduleOp = dyn_cast(topOp); addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, - "__" + name + "_buffer_ids_", globalIndexes); + "__" + name + "_buffer_ids", globalIndexes); auto returnOp = builder.create(topOp->getLoc(), outputValuesInFold); @@ -615,6 +615,24 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, }); } + // the ranks of folded results. + SmallVector foldRanks; + // the shapes of folded results. + SmallVector foldShapes; + for (Value &tensor : outputValuesInFold) { + auto t = dyn_cast(tensor.getType()); + Type eleType = t.getElementType(); + int64_t bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes + ArrayRef shape = t.getShape(); + foldRanks.push_back(shape.size()); + foldShapes.insert(foldShapes.end(), shape.begin(), shape.end()); + foldShapes.push_back(bitWidth); + } + addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__folded_ranks", + foldRanks); + addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__folded_shapes", + foldShapes); + foldFunc.setVisibility(SymbolTable::Visibility::Public); foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), UnitAttr::get(context)); @@ -631,11 +649,13 @@ void modifyComputeFunc(MLIRContext *context, OpBuilder &builder, std::unordered_set &constArgsIndexes, SmallVector &outputTypes, SmallVector &outputValues) { - // the indexes of args to the folding func. + // the indexes of args to the folding func, including to-fold tensors and + // folded results. SmallVector foldArgs; - // the indexes of folded args. + // the indexes of folded results. SmallVector foldIds; - // the indexes of args to the computing func. + // the indexes of args to the computing func, including non-fold tensors and + // folded results. SmallVector computeArgs; // modify the BlockArguments of block @@ -715,7 +735,7 @@ void modifyComputeFunc(MLIRContext *context, OpBuilder &builder, addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args", computeArgs); - addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args", + addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_args", oriNumArgs); } @@ -740,6 +760,14 @@ void canonicalizeAndClean(MLIRContext *context, Operation *topOp) { op->removeAttr("onednn_graph.in_const_subgraph"); } }); + topOp->walk([&](func::FuncOp op) { + if (op.getOperation()->getAttr("compiletime_const_args_index")) { + op.getOperation()->removeAttr("compiletime_const_args_index"); + } + if (op.getOperation()->getAttr("runtime_const_args_index")) { + op.getOperation()->removeAttr("runtime_const_args_index"); + } + }); } // Operate on tensors. Create fold() and compute() on module. The diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index 0664edafb..cdb5d1397 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -32,7 +32,7 @@ module { // COM: expected output: // COM: module { -// COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32 +// COM: llvm.mlir.global external constant @__num_orig_args(3 : i32) {addr_space = 0 : i32} : i32 // COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64> diff --git a/test/gc/Transforms/test_constant_tensor_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir index 71f475c00..59fe90236 100644 --- a/test/gc/Transforms/test_constant_tensor_folding.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding.mlir @@ -74,7 +74,7 @@ module { // COM: expected output: // COM: module { -// COM: llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32 +// COM: llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 // COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> // COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> diff --git a/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py b/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py index d444416e7..4e66b1ebf 100644 --- a/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py +++ b/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py @@ -141,7 +141,7 @@ #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> module { - llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32 + llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> llvm.mlir.global external constant @__runtime_fold_buffer_ids_(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> diff --git a/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py b/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py index 377e28a36..e05e2ac15 100644 --- a/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py +++ b/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py @@ -111,7 +111,7 @@ #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> module { - llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32 + llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> llvm.mlir.global external constant @__runtime_fold_buffer_ids_(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> From 36fc758d13532c2feb8079fefa78281e4034dd7d Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Fri, 13 Sep 2024 10:52:56 +0800 Subject: [PATCH 56/64] Make things work --- .../CPURuntime/ConstantCache.h | 48 ++++- include/gc/ExecutionEngine/Driver/Driver.h | 53 +++--- .../ExecutionEngine/CPURuntime/CMakeLists.txt | 1 - .../CPURuntime/ConstantCache.cpp | 46 ----- lib/gc/ExecutionEngine/Driver/Driver.cpp | 135 ++++++++------ lib/gc/Transforms/ConstantTensorFolding.cpp | 19 +- .../unittests/ExecutionEngine/JitWrapper.cpp | 93 +++++++++- unittests/ExecutionEngine/CMakeLists.txt | 7 - unittests/ExecutionEngine/JitWrapper.cpp | 175 ------------------ 9 files changed, 239 insertions(+), 338 deletions(-) delete mode 100644 lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp delete mode 100644 unittests/ExecutionEngine/CMakeLists.txt delete mode 100644 unittests/ExecutionEngine/JitWrapper.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h index f41cb09e8..a3756220d 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -9,9 +9,11 @@ #define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H #include "mlir/ExecutionEngine/CRunnerUtils.h" #include +#include +#include #include #include - +#include namespace mlir { namespace gc { /** @@ -79,7 +81,7 @@ struct ConstCacheProxy : RefCountManaged { size_t size, bool is_lazy) : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), buffer(buffer) {} - ~ConstCacheProxy(); + ~ConstCacheProxy() = default; // get the buffer and increment the refcount. If the buffer is evicted, // returns null @@ -117,20 +119,52 @@ struct ConstCacheProxy : RefCountManaged { }; struct CachedGraphTensor { + // Multiple tensors can reside in one common ConstCacheProxy `base`, with + // different offsets. std::shared_ptr base; size_t offset; - CachedGraphTensor(const std::shared_ptr &base, - size_t offset); + CachedGraphTensor(const std::shared_ptr &base, size_t offset) + : base{base}, offset{offset} { + // todo: fill in real values + ref.basePtr = (char *)base->getBufferUnsafe() + offset; + ref.data = ref.basePtr; + ref.offset = 0; + memset(ref.sizes, 0, sizeof(ref.sizes)); + memset(ref.strides, 0, sizeof(ref.strides)); + } friend class JitModule; private: StridedMemRefType ref; }; -std::shared_ptr queryCacheTensor(uint64_t key); -bool regCachedTensor(uint64_t key, const std::shared_ptr &base, - size_t offset); +static std::unordered_map> cache; + +inline std::shared_ptr createConstCacheProxy(size_t size) { + // simply allocate buffer and return + std::shared_ptr base = std::shared_ptr{ + std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; + return std::make_shared(base, base.get(), size, true); +} + +inline std::shared_ptr queryCacheTensor(int64_t key) { + auto itr = cache.find(key); + if (itr != cache.end()) { + return itr->second; + } + return nullptr; +} + +inline bool regCachedTensor(int64_t key, + const std::shared_ptr &base, + size_t offset) { + if (queryCacheTensor(key)) { + return false; + } + cache[key] = std::make_shared(base, offset); + return true; +} } // namespace gc } // namespace mlir diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index 2f6a50909..d80fb4e51 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -39,43 +39,42 @@ class JitModule { create(Operation *op, const DriverOptions &options = {}); // args should be an array of XXXMemrefType* - void call(GeneralMemrefPtr *args); - - /// args should be an array of XXXMemrefType* - void call(GeneralMemrefPtr *args, std::size_t numArgs) { - // Silly code, MLIR execution engine requires pointers of real args as - // inputs - llvm::SmallVector realargs; - realargs.reserve(numArgs); - for (size_t i = 0; i < numArgs; i++) { - realargs.push_back(&args[i]); - } - compute(realargs.data()); - } - - /// directly call compute(). args should be an array of void*. args[i] should + // numArgs: including input and output args. + void call(GeneralMemrefPtr *args, int32_t numArgs); + + /// directly call entry(). args should be an array of void*. args[i] should /// be a pointer to the real data. For passing memref, users need to 1) create /// a pointer to XXXMemrefType 2) store the pointer to pointer to /// XXXMemrefType in args[i] - void callRaw(void **args) { compute(args); } + void callRaw(void **args) { entry(args); } - JitModule(std::unique_ptr engine, JitModuleFuncT compute); + JitModule(std::unique_ptr engine, JitModuleFuncT entry); JitModule( - std::unique_ptr engine, JitModuleFuncT compute, - JitModuleFuncT fold, size_t numOrigArgs, + std::unique_ptr engine, JitModuleFuncT entry, + JitModuleFuncT fold, int32_t numOrigArgs, // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs, + llvm::ArrayRef entryArgs, // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs, + llvm::ArrayRef foldArgs, std::vector> &&cachekeepAlive = {}); ~JitModule(); private: std::unique_ptr engine; - JitModuleFuncT compute; + JitModuleFuncT entry; JitModuleFuncT fold; - size_t numOrigArgs; + int32_t numOrigArgs; // only input args + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef entryArgs; + + // The bases of CachedGraphTensors. For example, tensor1 (size 256) and + // tensor2 (size 256) are in ConstCacheProxy base1, and tensor3 (size 256) in + // base2. Then cacheBases is {base1, base2}, cacheInfo is {{baseIdx=0, + // offset=0}, {baseIdx=0, offset=256}, {baseIdx=1, offset=0}}. + // `keepAlive` has the ownership of the objects pointed by this vector llvm::SmallVector cacheBases; struct CacheBufferInfo { @@ -85,14 +84,10 @@ class JitModule { }; // the info for each folded cached buffer llvm::SmallVector cacheInfo; + // holding the pointers to StridedMemRefType of folded cache - // `keepAlive` holds the the ownership of the pointers llvm::SmallVector fastFoldBuffers; - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs; - // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs; - + // `keepAlive` holds the the ownership of the pointers std::vector> keepAlive; }; diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 95e8b5915..f678bb88d 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -18,7 +18,6 @@ gc_add_mlir_library(GcCpuRuntime SHARED Parallel.cpp MemoryPool.cpp - ConstantCache.cpp ${MICROKERNEL_RUNTIME_SOURCES} DEPENDS diff --git a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp b/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp deleted file mode 100644 index ff45cd180..000000000 --- a/lib/gc/ExecutionEngine/CPURuntime/ConstantCache.cpp +++ /dev/null @@ -1,46 +0,0 @@ -//===-- ConstantCache.cpp - Constant cache ----------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" -#include -#include - -namespace mlir::gc { - -ConstCacheProxy::~ConstCacheProxy() = default; - -CachedGraphTensor::CachedGraphTensor( - const std::shared_ptr &base, size_t offset) - : base{base}, offset{offset} { - // todo: fill in real values - ref.basePtr = (char *)base->getBufferUnsafe() + offset; - ref.data = ref.basePtr; - ref.offset = 0; - memset(ref.sizes, 0, sizeof(ref.sizes)); - memset(ref.strides, 0, sizeof(ref.strides)); -} - -static std::unordered_map> cache; - -std::shared_ptr queryCacheTensor(uint64_t key) { - auto itr = cache.find(key); - if (itr != cache.end()) { - return itr->second; - } - return nullptr; -} - -bool regCachedTensor(uint64_t key, const std::shared_ptr &base, - size_t offset) { - if (queryCacheTensor(key)) { - return false; - } - cache[key] = std::make_shared(base, offset); - return true; -} -} // namespace mlir::gc \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index a784c0b55..10afbeb20 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -20,6 +20,8 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" +#define DEBUG_TYPE "driver" + #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) @@ -49,8 +51,8 @@ const DialectRegistry &initCompilerAndGetDialects() { return reg; } -static const char defaultComputeName[] = "_mlir_ciface_compute"; -static const char defaultFoldName[] = "_mlir_ciface_fold"; +static const char defaultEntryName[] = "_mlir_ciface_entry"; +static const char defaultFoldName[] = "_mlir_ciface_runtime_fold"; llvm::Expected> JitModule::create(Operation *op, const DriverOptions &options) { if (options.runTransforms) { @@ -69,43 +71,49 @@ JitModule::create(Operation *op, const DriverOptions &options) { return exec.takeError(); } auto &engine = *exec; - uint32_t numOrigArgs; - { - auto expectArgs = engine->lookup("__num_orig_num_args"); - if (!expectArgs) { - return expectArgs.takeError(); - } - numOrigArgs = *reinterpret_cast(*expectArgs); - } - JitModuleFuncT compute; - { - auto expectCompute = engine->lookupPacked(defaultComputeName); - if (!expectCompute) { - return expectCompute.takeError(); - } - compute = *expectCompute; + + auto expectEntry = engine->lookupPacked(defaultEntryName); + if (!expectEntry) { + // entry function must exist + return expectEntry.takeError(); } - llvm::ArrayRef foldBufferIds; + JitModuleFuncT entry = *expectEntry; + + int32_t numOrigArgs; + llvm::ArrayRef foldBufferIds; JitModuleFuncT fold = nullptr; - llvm::ArrayRef computeArgs; - llvm::ArrayRef foldArgs; + llvm::ArrayRef entryArgs; + llvm::ArrayRef foldArgs; do { - auto expectBufferIds = engine->lookup("__fold_buffer_ids"); - if (!expectBufferIds) { - // nothing to fold, It is OK. - llvm::consumeError(expectBufferIds.takeError()); - // break out of the scope, don't need to lookup "fold" function - break; - } else { - auto raw = reinterpret_cast(*expectBufferIds); - foldBufferIds = llvm::ArrayRef{raw + 1, raw[0]}; + { + auto expectArgs = engine->lookup("__num_orig_num_args"); + if (!expectArgs) { // nothing to fold, It is OK. + llvm::consumeError(expectArgs.takeError()); + // break out of the scope, don't need to lookup other things + break; + } + numOrigArgs = *reinterpret_cast(*expectArgs); + } + + // If lookup("__num_orig_num_args") succeeds, then all the following should + // also succeed. + { + auto expectBufferIds = engine->lookup("__runtime_fold_buffer_ids_"); + if (!expectBufferIds) { + expectBufferIds.takeError(); + break; + } + auto raw = reinterpret_cast(*expectBufferIds); + foldBufferIds = + llvm::ArrayRef{raw + 1, static_cast(raw[0])}; } // find "fold" func { auto expectFold = engine->lookupPacked(defaultFoldName); if (!expectFold) { - return expectFold.takeError(); + expectFold.takeError(); + break; } fold = *expectFold; } @@ -114,20 +122,22 @@ JitModule::create(Operation *op, const DriverOptions &options) { { auto expectFold = engine->lookup("__fold_args"); if (!expectFold) { - return expectFold.takeError(); + expectFold.takeError(); + break; } - auto raw = reinterpret_cast(*expectFold); - foldArgs = llvm::ArrayRef{raw + 1, raw[0]}; + auto raw = reinterpret_cast(*expectFold); + foldArgs = llvm::ArrayRef{raw + 1, static_cast(raw[0])}; } - // find "computeArgs" + // find "entryArgs" { auto expect = engine->lookup("__compute_args"); if (!expect) { - return expect.takeError(); + expect.takeError(); + break; } - auto raw = reinterpret_cast(*expect); - computeArgs = llvm::ArrayRef{raw + 1, raw[0]}; + auto raw = reinterpret_cast(*expect); + entryArgs = llvm::ArrayRef{raw + 1, static_cast(raw[0])}; } } while (false); @@ -143,22 +153,22 @@ JitModule::create(Operation *op, const DriverOptions &options) { foldInfo.emplace_back(std::move(ret)); } - return std::make_shared(std::move(engine), compute, fold, - numOrigArgs, computeArgs, foldArgs, + return std::make_shared(std::move(engine), entry, fold, + numOrigArgs, entryArgs, foldArgs, std::move(foldInfo)); } JitModule::JitModule( - std::unique_ptr engine, JitModuleFuncT compute, - JitModuleFuncT fold, size_t numOrigArgs, + std::unique_ptr engine, JitModuleFuncT entry, + JitModuleFuncT fold, int32_t numOrigArgs, // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef computeArgs, + llvm::ArrayRef entryArgs, // The code inside `engine` has the ownership of the buffer - llvm::ArrayRef foldArgs, + llvm::ArrayRef foldArgs, std::vector> &&cachekeepAlive) - : engine{std::move(engine)}, compute{compute}, fold{fold}, - numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, - computeArgs{computeArgs}, keepAlive{std::move(cachekeepAlive)} { + : engine{std::move(engine)}, entry{entry}, fold{fold}, + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, entryArgs{entryArgs}, + keepAlive{std::move(cachekeepAlive)} { for (const auto &cache : keepAlive) { auto currentItr = std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); @@ -174,9 +184,10 @@ JitModule::JitModule( JitModule::~JitModule() = default; static void prepareCallArgs(llvm::SmallVector &realargs, - GeneralMemrefPtr *origargs, size_t numOrigArgs, - GeneralMemrefPtr *foldedCache, - llvm::ArrayRef realArgIdx) { + GeneralMemrefPtr *origargs, int32_t numArgs, + int32_t numOrigArgs, GeneralMemrefPtr *foldedCache, + llvm::ArrayRef realArgIdx) { + // inputs, including unfolded and folded realargs.reserve(realArgIdx.size()); for (auto argIdx : realArgIdx) { if (argIdx < numOrigArgs) { @@ -185,19 +196,23 @@ static void prepareCallArgs(llvm::SmallVector &realargs, realargs.push_back(&foldedCache[argIdx - numOrigArgs]); } } + // outputs + for (int i = numOrigArgs; i < numArgs; ++i) { + realargs.push_back(&origargs[i]); + } } -void JitModule::call(GeneralMemrefPtr *args) { +void JitModule::call(GeneralMemrefPtr *args, int32_t numArgs) { if (unlikely(cacheInfo.empty())) { // fast path, no folded cached buffers // Silly code, MLIR execution engine requires pointers of real args as // inputs llvm::SmallVector realargs; - realargs.reserve(numOrigArgs); - for (size_t i = 0; i < numOrigArgs; i++) { + realargs.reserve(numArgs); + for (int i = 0; i < numArgs; i++) { realargs.push_back(&args[i]); } - compute(realargs.data()); + entry(realargs.data()); return; } @@ -234,15 +249,21 @@ void JitModule::call(GeneralMemrefPtr *args) { } foldedCache = slowFold.data(); llvm::SmallVector realargs; - prepareCallArgs(realargs, args, numOrigArgs, foldedCache, foldArgs); + prepareCallArgs(realargs, args, numArgs, numOrigArgs, foldedCache, + foldArgs); + LLVM_DEBUG(llvm::dbgs() << "foldArgs size: " << foldArgs.size() << '\n'); fold(realargs.data()); } - // stage 3, call compute + // stage 3, call entry { llvm::SmallVector realargs; - prepareCallArgs(realargs, args, numOrigArgs, foldedCache, computeArgs); - compute(realargs.data()); + prepareCallArgs(realargs, args, numArgs, numOrigArgs, foldedCache, + entryArgs); + LLVM_DEBUG(llvm::dbgs() + << "entryArgs size: " << entryArgs.size() + << ", Entry real args size: " << realargs.size() << '\n'); + entry(realargs.data()); } // stage 4, cleanup diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index f1d85e62e..c442e95ec 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -31,7 +31,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" -// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #define DEBUG_TYPE "constant-tensor-folding" @@ -316,17 +316,6 @@ void postponeBroadcast(Block &block) { static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; -// get from dnnl_graph_compiler_context -// void *allocator(size_t size) { return std::aligned_alloc(64, size); } -// void deallocator(void *ptr) { std::free(ptr); } - -// std::shared_ptr createConstCacheProxy(size_t size) { -// // simply allocate buffer and return -// std::shared_ptr base = std::shared_ptr{ -// std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; -// return std::make_shared(base, base.get(), size, true); -// } - size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } // Manager @@ -349,12 +338,13 @@ struct ConstGraphTensorCacheManager { totalSize += divideAndCeil(size, 64) * 64; } LLVM_DEBUG(llvm::dbgs() << "Alloc total size: " << totalSize << '\n'); - // auto base = createConstCacheProxy(totalSize); + auto base = createConstCacheProxy(totalSize); std::vector globalIds(buffersSize.size()); size_t offset = 0; for (size_t i = 0; i < buffersSize.size(); i++) { LLVM_DEBUG(llvm::dbgs() << "Alloc offset: " << offset << '\n'); - // regCachedTensor(cachedTensorGlobalId, base, offset); + bool regRes = regCachedTensor(cachedTensorGlobalId, base, offset); + assert(regRes && "Register constant tensor failed"); globalIds[i] = cachedTensorGlobalId; ++cachedTensorGlobalId; offset += divideAndCeil(buffersSize[i], 64) * 64; @@ -862,6 +852,7 @@ void ConstantTensorFolding::runOnOperation() { } canonicalizeAndClean(context, topOp); + topOp->dump(); } std::unique_ptr createConstantTensorFoldingPass() { diff --git a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp index 3eb4717c0..6d7a489c9 100644 --- a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp +++ b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp @@ -25,8 +25,7 @@ using namespace mlir; static const char code1[] = R"mlir( module { -llvm.mlir.global external constant @__num_orig_args(3 : i32) {addr_space = 0 : i32} : i32 -func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { +func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> return %2 : tensor<128xf32> @@ -68,3 +67,93 @@ TEST(ExecutionEngine, JitWrapper) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } + +// compute d = (a+a) + (b+b) + c, where a,b is marked constant +// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 +static const char code2[] = R"mlir( +module { +func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface, runtime_const_args_index = [0 : i32, 1 : i32] } { + %out = tensor.empty() : tensor<128xf32> + %ax2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %out2 = tensor.empty() : tensor<128xf32> + %bx2 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> + %out3 = tensor.empty() : tensor<128xf32> + %ax2pbx2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out3 : tensor<128xf32>) -> tensor<128xf32> + %out4 = tensor.empty() : tensor<128xf32> + %d = linalg.add ins(%ax2pbx2, %c : tensor<128xf32>,tensor<128xf32>) outs(%out4 : tensor<128xf32>) -> tensor<128xf32> + return %d : tensor<128xf32> +} +} +)mlir"; + +TEST(ExecutionEngine, JitWrapperCached) { + MLIRContext ctx{gc::initCompilerAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code2); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get()); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + + auto ret = std::shared_ptr(new float[128]); + auto proxy = std::make_shared(ret, ret.get(), + 128 * sizeof(float), true); + // Can not register with already existing key. + ASSERT_FALSE(gc::regCachedTensor(0, proxy, 0)); + + proxy = gc::queryCacheTensor(0)->base; + auto data = (float *)proxy->getBufferUnsafe(); + + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{ + {128}, {128}, [](float &ptr, ArrayRef idx) { + ptr = -idx[0] * 3; + }}; + OwningMemRef bufD{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 100.0f; }}; + void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; + + { + // first call, should run fold() + jited.get()->call(args, 4); + + for (int i = 0; i < 128; i++) { + ASSERT_EQ(*(data + i), 2 * 1.0f + 2 * i); + } + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + } + + { + // second call, should not run fold() + jited.get()->call(args, 4); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + } + + // the cache is evicted + proxy->deref(); + { + // third call, should run fold() + jited.get()->call(args, 4); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + } +} diff --git a/unittests/ExecutionEngine/CMakeLists.txt b/unittests/ExecutionEngine/CMakeLists.txt deleted file mode 100644 index 0e7315a0f..000000000 --- a/unittests/ExecutionEngine/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_mlir_unittest(GCExecutionEngineTests - JitWrapper.cpp -) -target_link_libraries(GCExecutionEngineTests - PRIVATE - GCJitWrapper - GCCpuRuntime) diff --git a/unittests/ExecutionEngine/JitWrapper.cpp b/unittests/ExecutionEngine/JitWrapper.cpp deleted file mode 100644 index 9069eef79..000000000 --- a/unittests/ExecutionEngine/JitWrapper.cpp +++ /dev/null @@ -1,175 +0,0 @@ -//===-- JitWrapper.cpp - Wrapper for JIT ------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/ExecutionEngine/Driver/Driver.h" -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/ExecutionEngine/MemRefUtils.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "llvm/Support/ErrorOr.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/raw_ostream.h" -#include "gtest/gtest.h" -#include - -using namespace mlir; - -static const char code1[] = R"mlir( -module { -llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32 -func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - return %2 : tensor<128xf32> -} -} -)mlir"; - -extern "C" { -extern int gc_runtime_keep_alive; -} - -TEST(ExecutionEngine, JitWrapper) { - gc_runtime_keep_alive = 0; - MLIRContext ctx{gc::initCompilerAndGetDialects()}; - std::unique_ptr ir_buffer = - llvm::MemoryBuffer::getMemBuffer(code1); - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); - mlir::OwningOpRef module = - mlir::parseSourceFile(sourceMgr, &ctx); - ASSERT_TRUE(module); - auto jited = gc::JitModule::create(module.get()); - bool jit_success = static_cast(jited); - if (!jit_success) { - auto err = jited.takeError(); - llvm::errs() << err; - llvm::consumeError(std::move(err)); - } - ASSERT_TRUE(jit_success); - OwningMemRef bufA{ - {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; - OwningMemRef bufB{ - {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; - OwningMemRef bufC{{128}, {128}}; - void *args[] = {&*bufA, &*bufB, &*bufC}; - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufC[{i}], 1.0f + i); - } -} - -// compute d = (a+a) + (b+b) + c, where a,b is marked constant -// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 -static const char code2[] = R"mlir( -module { -llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32 -llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64> -// a,b, foldedA,foldedB -llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32> -// foldedA, foldedB, c, d -llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32> - -func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } { - %c0 = arith.constant 0 : index - cpuruntime.printf "HI%zu\n" %c0 : index - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - %out2 = tensor.empty() : tensor<128xf32> - %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> - return %2, %3 : tensor<128xf32>, tensor<128xf32> -} - -func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { - %out = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> - %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> - return %d : tensor<128xf32> -} -} -)mlir"; - -TEST(ExecutionEngine, JitWrapperCached) { - MLIRContext ctx{gc::initCompilerAndGetDialects()}; - std::unique_ptr ir_buffer = - llvm::MemoryBuffer::getMemBuffer(code2); - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); - mlir::OwningOpRef module = - mlir::parseSourceFile(sourceMgr, &ctx); - - // foldedA and foldedB uses this buffer - auto ret = std::shared_ptr(new float[128 * 2]); - auto proxy = std::make_shared( - ret, ret.get(), 128 * 2 * sizeof(float), true); - - ASSERT_TRUE(gc::regCachedTensor(114514, proxy, 0)); - ASSERT_TRUE(gc::regCachedTensor(1919810, proxy, 128 * sizeof(float))); - - ASSERT_TRUE(module); - auto jited = gc::JitModule::create(module.get()); - bool jit_success = static_cast(jited); - if (!jit_success) { - auto err = jited.takeError(); - llvm::errs() << err; - llvm::consumeError(std::move(err)); - } - ASSERT_TRUE(jit_success); - OwningMemRef bufA{ - {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; - OwningMemRef bufB{ - {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; - OwningMemRef bufC{ - {128}, {128}, [](float &ptr, ArrayRef idx) { - ptr = -idx[0] * 3; - }}; - OwningMemRef bufD{{128}, {128}}; - void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; - - // first call, should run fold() - { - testing::internal::CaptureStdout(); - // first call, should run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_EQ(output, "HI0\n"); - } - - { - testing::internal::CaptureStdout(); - // second call, should not run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_TRUE(output.empty()); - } - - // the cache is evicted - proxy->deref(); - { - testing::internal::CaptureStdout(); - // third call, should run fold() - jited.get()->call(args); - for (int i = 0; i < 128; i++) { - ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); - } - std::string output = testing::internal::GetCapturedStdout(); - ASSERT_EQ(output, "HI0\n"); - } -} From 8d08752f5e80559d107325f78428233f82194437 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Fri, 13 Sep 2024 11:24:29 +0800 Subject: [PATCH 57/64] Unify attr name --- lib/gc/ExecutionEngine/Driver/Driver.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 10afbeb20..8d4a0e5c4 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -86,7 +86,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { llvm::ArrayRef foldArgs; do { { - auto expectArgs = engine->lookup("__num_orig_num_args"); + auto expectArgs = engine->lookup("__num_orig_args"); if (!expectArgs) { // nothing to fold, It is OK. llvm::consumeError(expectArgs.takeError()); // break out of the scope, don't need to lookup other things @@ -98,7 +98,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { // If lookup("__num_orig_num_args") succeeds, then all the following should // also succeed. { - auto expectBufferIds = engine->lookup("__runtime_fold_buffer_ids_"); + auto expectBufferIds = engine->lookup("__runtime_fold_buffer_ids"); if (!expectBufferIds) { expectBufferIds.takeError(); break; From edbb708155112a9e3e42577789dc2ee9351434ad Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sat, 14 Sep 2024 10:42:33 +0800 Subject: [PATCH 58/64] Clean tests. --- ...ir => test_constant_tensor_folding-0.mlir} | 2 + .../test_constant_tensor_folding-1.mlir | 130 +++++---- .../test_constant_tensor_folding.mlir | 82 ------ .../test_constant_tensor_folding_bf16_4D5D.py | 101 ------- ...constant_tensor_folding_bf16_two_layers.py | 258 ------------------ .../test_constant_tensor_folding_f32_4D4D.py | 96 ------- ..._constant_tensor_folding_f32_two_layers.py | 225 --------------- 7 files changed, 83 insertions(+), 811 deletions(-) rename test/gc/Transforms/{test_constant_tensor_folding-2.mlir => test_constant_tensor_folding-0.mlir} (99%) delete mode 100644 test/gc/Transforms/test_constant_tensor_folding.mlir delete mode 100644 test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py delete mode 100644 test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py delete mode 100644 test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py delete mode 100644 test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py diff --git a/test/gc/Transforms/test_constant_tensor_folding-2.mlir b/test/gc/Transforms/test_constant_tensor_folding-0.mlir similarity index 99% rename from test/gc/Transforms/test_constant_tensor_folding-2.mlir rename to test/gc/Transforms/test_constant_tensor_folding-0.mlir index a5e123085..eabdacc93 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-2.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-0.mlir @@ -1,5 +1,7 @@ // RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s +// COM:A complete example of compile-time and runtime folding. + // CHECK-LABEL: func.func @entry #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index cdb5d1397..92231703d 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -1,59 +1,91 @@ // RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s +// COM: Test the 'postponeBroadcast' feature of constant tensor folding. + // CHECK-LABEL: func.func @entry +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module { - func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, runtime_const_args_index = [0 : i32, 1 : i32] } { - %c0 = arith.constant 0 : index - cpuruntime.printf "HI%zu\n" %c0 : index - %ax2 = tensor.empty() : tensor<128xf32> - %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%ax2 : tensor<128xf32>) -> tensor<128xf32> - %bx2 = tensor.empty() : tensor<128xf32> - %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%bx2 : tensor<128xf32>) -> tensor<128xf32> - %ax2pbx2 = tensor.empty() : tensor<128xf32> - %4 = linalg.add ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2 : tensor<128xf32>) -> tensor<128xf32> - %ax2mbx2 = tensor.empty() : tensor<128xf32> - %5 = linalg.mul ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2mbx2 : tensor<128xf32>) -> tensor<128xf32> - %ax2pbx2pc = tensor.empty() : tensor<128xf32> - %6 = linalg.add ins(%4, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2pc : tensor<128xf32>) -> tensor<128xf32> - %ax2mbx2mc = tensor.empty() : tensor<128xf32> - %7 = linalg.mul ins(%5, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2mbx2mc : tensor<128xf32>) -> tensor<128xf32> - return %6, %7 : tensor<128xf32>, tensor<128xf32> - } + // COM: A two-layer mlp. arg0: input feature. + // COM: arg1: weight of #1 linear. arg2: bias of #1 linear. + // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %1 = tensor.empty() : tensor<2x16x32x32xbf16> + %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> + %2 = tensor.empty() : tensor<8x16x32x32xbf16> + %packed_arg1 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<512x256xbf16> -> tensor<8x16x32x32xbf16> + %3 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %packed_packed_arg1 = tensor.pack %packed_arg1 inner_dims_pos = [2] inner_tiles = [2] into %3 : tensor<8x16x32x32xbf16> -> tensor<8x16x16x32x2xbf16> + %4 = tensor.empty() : tensor<2x8x32x32xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %5 = linalg.fill ins(%cst_0 : bf16) outs(%4 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%packed_arg0, %packed_packed_arg1 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%5 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x8x32x32xbf16> + + // COM: Operations on %arg2: {pack, broadcast, extf, mul, truncf, bias_add} in entry(). + %15 = tensor.empty() : tensor<8x32xbf16> + %packed_arg2 = tensor.pack %arg2 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %15 : tensor<256xbf16> -> tensor<8x32xbf16> + %bc_arg2_init = tensor.empty() : tensor<2x8x32x32xbf16> + %bc_arg2 = linalg.broadcast ins(%packed_arg2 : tensor<8x32xbf16>) outs(%bc_arg2_init : tensor<2x8x32x32xbf16>) dimensions = [0, 2] + %extf32 = arith.extf %bc_arg2 : tensor<2x8x32x32xbf16> to tensor<2x8x32x32xf32> + %cst_2 = arith.constant 2.000000e+00 : f32 + %extf32_mul2_init = tensor.empty() : tensor<2x8x32x32xf32> + %extf32_mul2 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extf32 : tensor<2x8x32x32xf32>) outs(%extf32_mul2_init : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.mulf %in, %cst_2 : f32 + linalg.yield %8 : f32 + } -> tensor<2x8x32x32xf32> + %truncbf16 = arith.truncf %extf32_mul2 : tensor<2x8x32x32xf32> to tensor<2x8x32x32xbf16> + + %7 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%truncbf16 : tensor<2x8x32x32xbf16>) outs(%6 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %45 = arith.addf %in, %out : bf16 + linalg.yield %45 : bf16 + } -> tensor<2x8x32x32xbf16> + + %8 = tensor.empty() : tensor<32x8x32x32xbf16> + %packed_arg3 = tensor.pack %arg3 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %8 : tensor<256x1024xbf16> -> tensor<32x8x32x32xbf16> + %9 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %packed_packed_arg3 = tensor.pack %packed_arg3 inner_dims_pos = [2] inner_tiles = [2] into %9 : tensor<32x8x32x32xbf16> -> tensor<32x8x16x32x2xbf16> + %10 = tensor.empty() : tensor<2x32x32x32xbf16> + %11 = linalg.fill ins(%cst_0 : bf16) outs(%10 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %12 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%7, %packed_packed_arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%11 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %46 = arith.mulf %in, %in_0 : bf16 + %56 = arith.addf %out, %46 : bf16 + linalg.yield %56 : bf16 + } -> tensor<2x32x32x32xbf16> + %16 = tensor.empty() : tensor<32x32xbf16> + %packed_arg4 = tensor.pack %arg4 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %16 : tensor<1024xbf16> -> tensor<32x32xbf16> + %13 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%packed_arg4 : tensor<32x32xbf16>) outs(%12 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %47 = arith.addf %in, %out : bf16 + linalg.yield %47 : bf16 + } -> tensor<2x32x32x32xbf16> + %14 = tensor.empty() : tensor<64x1024xbf16> + %unpack = tensor.unpack %13 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %14 : tensor<2x32x32x32xbf16> -> tensor<64x1024xbf16> + return %unpack : tensor<64x1024xbf16> + } } -// CHECK: cpuruntime.printf -// CHECK: linalg.add -// CHECK: linalg.mul +// COM: After transform, operations on %arg2: {pack, extf, mul, truncf} in fold(), {broadcast, bias_add} in entry(). +// CHECK: linalg.broadcast // CHECK: func.func @runtime_fold -// CHECK: linalg.add -// CHECK: linalg.add -// CHECK: linalg.add -// CHECK: linalg.mul +// CHECK: arith.extf +// CHECK: arith.truncf // COM: expected output: // COM: module { -// COM: llvm.mlir.global external constant @__num_orig_args(3 : i32) {addr_space = 0 : i32} : i32 -// COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32> -// COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32> -// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64> -// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, runtime_const_args_index = [0 : i32, 1 : i32]} { -// COM: %c0 = arith.constant 0 : index -// COM: cpuruntime.printf "HI%zu\0A" %c0 : index -// COM: %0 = tensor.empty() : tensor<128xf32> -// COM: %1 = linalg.add ins(%arg2, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> -// COM: %2 = tensor.empty() : tensor<128xf32> -// COM: %3 = linalg.mul ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> -// COM: return %1, %3 : tensor<128xf32>, tensor<128xf32> -// COM: } -// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface} { -// COM: %0 = tensor.empty() : tensor<128xf32> -// COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32> -// COM: %2 = tensor.empty() : tensor<128xf32> -// COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32> -// COM: %4 = tensor.empty() : tensor<128xf32> -// COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32> -// COM: %6 = tensor.empty() : tensor<128xf32> -// COM: %7 = linalg.mul ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%6 : tensor<128xf32>) -> tensor<128xf32> -// COM: return %7, %5 : tensor<128xf32>, tensor<128xf32> -// COM: } -// COM: } \ No newline at end of file +// COM: llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 +// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> +// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> +// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} +// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} diff --git a/test/gc/Transforms/test_constant_tensor_folding.mlir b/test/gc/Transforms/test_constant_tensor_folding.mlir deleted file mode 100644 index 59fe90236..000000000 --- a/test/gc/Transforms/test_constant_tensor_folding.mlir +++ /dev/null @@ -1,82 +0,0 @@ -// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s - -// CHECK-LABEL: func.func @entry -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> -#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -module { - // COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear. - // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { - %1 = tensor.empty() : tensor<2x16x32x32xbf16> - %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> - %2 = tensor.empty() : tensor<8x16x32x32xbf16> - %packed_arg1 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<512x256xbf16> -> tensor<8x16x32x32xbf16> - %3 = tensor.empty() : tensor<8x16x16x32x2xbf16> - %packed_packed_arg1 = tensor.pack %packed_arg1 inner_dims_pos = [2] inner_tiles = [2] into %3 : tensor<8x16x32x32xbf16> -> tensor<8x16x16x32x2xbf16> - %4 = tensor.empty() : tensor<2x8x32x32xbf16> - %cst_0 = arith.constant 0.000000e+00 : bf16 - %5 = linalg.fill ins(%cst_0 : bf16) outs(%4 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> - %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%packed_arg0, %packed_packed_arg1 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%5 : tensor<2x8x32x32xbf16>) { - ^bb0(%in: bf16, %in_0: bf16, %out: bf16): - %44 = arith.mulf %in, %in_0 : bf16 - %55 = arith.addf %out, %44 : bf16 - linalg.yield %55 : bf16 - } -> tensor<2x8x32x32xbf16> - %15 = tensor.empty() : tensor<8x32xbf16> - %packed_arg2 = tensor.pack %arg2 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %15 : tensor<256xbf16> -> tensor<8x32xbf16> - %bc_arg2_init = tensor.empty() : tensor<2x8x32x32xbf16> - %bc_arg2 = linalg.broadcast ins(%packed_arg2 : tensor<8x32xbf16>) outs(%bc_arg2_init : tensor<2x8x32x32xbf16>) dimensions = [0, 2] - %extf32 = arith.extf %bc_arg2 : tensor<2x8x32x32xbf16> to tensor<2x8x32x32xf32> - %cst_2 = arith.constant 2.000000e+00 : f32 - %extf32_mul2_init = tensor.empty() : tensor<2x8x32x32xf32> - %extf32_mul2 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extf32 : tensor<2x8x32x32xf32>) outs(%extf32_mul2_init : tensor<2x8x32x32xf32>) { - ^bb0(%in: f32, %out: f32): - %8 = arith.mulf %in, %cst_2 : f32 - linalg.yield %8 : f32 - } -> tensor<2x8x32x32xf32> - %truncbf16 = arith.truncf %extf32_mul2 : tensor<2x8x32x32xf32> to tensor<2x8x32x32xbf16> - %7 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%truncbf16 : tensor<2x8x32x32xbf16>) outs(%6 : tensor<2x8x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %45 = arith.addf %in, %out : bf16 - linalg.yield %45 : bf16 - } -> tensor<2x8x32x32xbf16> - %8 = tensor.empty() : tensor<32x8x32x32xbf16> - %packed_arg3 = tensor.pack %arg3 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %8 : tensor<256x1024xbf16> -> tensor<32x8x32x32xbf16> - %9 = tensor.empty() : tensor<32x8x16x32x2xbf16> - %packed_packed_arg3 = tensor.pack %packed_arg3 inner_dims_pos = [2] inner_tiles = [2] into %9 : tensor<32x8x32x32xbf16> -> tensor<32x8x16x32x2xbf16> - %10 = tensor.empty() : tensor<2x32x32x32xbf16> - %11 = linalg.fill ins(%cst_0 : bf16) outs(%10 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> - %12 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%7, %packed_packed_arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%11 : tensor<2x32x32x32xbf16>) { - ^bb0(%in: bf16, %in_0: bf16, %out: bf16): - %46 = arith.mulf %in, %in_0 : bf16 - %56 = arith.addf %out, %46 : bf16 - linalg.yield %56 : bf16 - } -> tensor<2x32x32x32xbf16> - %16 = tensor.empty() : tensor<32x32xbf16> - %packed_arg4 = tensor.pack %arg4 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %16 : tensor<1024xbf16> -> tensor<32x32xbf16> - %13 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%packed_arg4 : tensor<32x32xbf16>) outs(%12 : tensor<2x32x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %47 = arith.addf %in, %out : bf16 - linalg.yield %47 : bf16 - } -> tensor<2x32x32x32xbf16> - %14 = tensor.empty() : tensor<64x1024xbf16> - %unpack = tensor.unpack %13 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %14 : tensor<2x32x32x32xbf16> -> tensor<64x1024xbf16> - return %unpack : tensor<64x1024xbf16> - } -} -// CHECK: linalg.broadcast -// CHECK: func.func @runtime_fold -// CHECK: arith.extf -// CHECK: arith.truncf - -// COM: expected output: -// COM: module { -// COM: llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 -// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> -// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> -// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> -// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} -// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} diff --git a/test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py b/test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py deleted file mode 100644 index 0fafbd080..000000000 --- a/test/gc/Transforms/test_constant_tensor_folding_bf16_4D5D.py +++ /dev/null @@ -1,101 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -from enum import Flag -import os -import sys -import ml_dtypes -import numpy as np -from gc_mlir import ir -from gc_mlir.graph_compiler import GraphCompiler -from numpy.testing import assert_allclose - -project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_dir not in sys.path: - sys.path.insert(0, project_dir) - -import torch -# from bench import py_timeit_bench -from utils import get_mlir_args - -if __name__ == "__main__": - with ir.Context() as ctx: - ctx.allow_unregistered_dialects = True - - M = 64 - N = 256 - K = 512 - MBlock = 32 - NBlock = 32 - KBlock = 32 - vnni_size = 2 - shapeA = [M // MBlock, K // KBlock, MBlock, KBlock] - shapeB = [N // NBlock, K // KBlock, KBlock // vnni_size, NBlock, vnni_size] - shapeC = [M // MBlock, N // NBlock, MBlock, NBlock] - - block_start = "{" - block_end = "}" - mlir_str = f''' -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> -module {block_start} - func.func @entry(%arg0: tensor<{M // MBlock}x{K // KBlock}x{MBlock}x{KBlock}xbf16>, %cst: tensor<{N // NBlock}x{K // KBlock}x{KBlock // vnni_size}x{NBlock}x{vnni_size}xbf16>) -> tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> attributes {block_start}llvm.emit_c_interface{block_end} {block_start} - %cst_0 = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> - %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16>) -> tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> - %2 = linalg.generic {block_start}indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]{block_end} ins(%arg0, %cst : tensor<{M // MBlock}x{K // KBlock}x{MBlock}x{KBlock}xbf16>, tensor<{N // NBlock}x{K // KBlock}x{KBlock // vnni_size}x{NBlock}x{vnni_size}xbf16>) outs(%1 : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16>) {block_start} - ^bb0(%in: bf16, %in_1: bf16, %out: bf16): - %3 = arith.mulf %in, %in_1 : bf16 - %4 = arith.addf %out, %3 : bf16 - linalg.yield %4 : bf16 - {block_end} -> tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> - return %2 : tensor<{M // MBlock}x{N // NBlock}x{MBlock}x{NBlock}xbf16> - {block_end} -{block_end} - ''' - print(mlir_str) - - # 4D x 5D, inputs transposed - module_in = ir.Module.parse(mlir_str) - - # entry(%transposed: tensor<2x16x32x32xbf16>, %transposed_5: tensor<8x16x16x32x2xbf16>) -> tensor<2x8x32x32xbf16> - torch_arg0 = torch.rand((M, K), dtype=torch.bfloat16) - torch_arg1 = torch.rand((K, N), dtype=torch.bfloat16) - ref_res = torch_arg0 @ torch_arg1 - - passes = "any(gc-cpu-pipeline)" - shared_libs = [ - os.environ["MLIR_C_RUNNER_UTILS"], - os.environ["MLIR_RUNNER_UTILS"], - ] - compiler = GraphCompiler(passes) - ctx.enable_multithreading(False) - - arg0 = torch_arg0.view(shapeA).permute([0, 2, 1, 3]).contiguous() # MK -> MKmk - np_arg0 = arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - arg1 = torch_arg1.view(shapeB).permute([3, 0, 1, 4, 2]).contiguous() # KN -> NKkn2k - np_arg1 = arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - gc_res = np.ones(shapeC, dtype=ml_dtypes.bfloat16) - - entry = "entry" - mlir_args = get_mlir_args(module_in, entry, [np_arg0, np_arg1, gc_res]) - engine_in = compiler.compile_and_jit(module_in, ir_printing=False) - engine_in.invoke(entry, *mlir_args) - gc_res = np.reshape(np.transpose(gc_res, (0, 2, 1, 3)), (M, N)) # MNmn -> MN - - assert_allclose(gc_res.astype(np.float32), ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) diff --git a/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py b/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py deleted file mode 100644 index 4e66b1ebf..000000000 --- a/test/gc/Transforms/test_constant_tensor_folding_bf16_two_layers.py +++ /dev/null @@ -1,258 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import os -import sys - -import numpy as np -import ml_dtypes - -from gc_mlir import ir -from gc_mlir.graph_compiler import GraphCompiler -from numpy.testing import assert_allclose - -project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_dir not in sys.path: - sys.path.insert(0, project_dir) - -import torch -# from bench import py_timeit_bench -from utils import get_mlir_args - -if __name__ == "__main__": - with ir.Context() as ctx: - ctx.allow_unregistered_dialects = True - # ctx.enable_multithreading = False - module_in = ir.Module.parse( - """ -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> -#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -module { - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { - %0 = tensor.empty() : tensor<2x16x32x32xbf16> - %cst = arith.constant 0.000000e+00 : bf16 - %padded = tensor.pad %arg0 low[0, 0] high[0, 0] { - ^bb0(%arg5: index, %arg6: index): - tensor.yield %cst : bf16 - } : tensor<64x512xbf16> to tensor<64x512xbf16> - %expanded = tensor.expand_shape %padded [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xbf16> into tensor<2x32x16x32xbf16> - %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xbf16>) outs(%0 : tensor<2x16x32x32xbf16>) permutation = [0, 2, 1, 3] - %1 = tensor.empty() : tensor<8x16x32x32xbf16> - %padded_0 = tensor.pad %arg1 low[0, 0] high[0, 0] { - ^bb0(%arg5: index, %arg6: index): - tensor.yield %cst : bf16 - } : tensor<512x256xbf16> to tensor<512x256xbf16> - %expanded_1 = tensor.expand_shape %padded_0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xbf16> into tensor<16x32x8x32xbf16> - %transposed_2 = linalg.transpose ins(%expanded_1 : tensor<16x32x8x32xbf16>) outs(%1 : tensor<8x16x32x32xbf16>) permutation = [2, 0, 1, 3] - %2 = tensor.empty() : tensor<8x16x16x32x2xbf16> - %padded_3 = tensor.pad %transposed_2 low[0, 0, 0, 0] high[0, 0, 0, 0] { - ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index): - tensor.yield %cst : bf16 - } : tensor<8x16x32x32xbf16> to tensor<8x16x32x32xbf16> - %expanded_4 = tensor.expand_shape %padded_3 [[0], [1], [2, 3], [4]] output_shape [8, 16, 16, 2, 32] : tensor<8x16x32x32xbf16> into tensor<8x16x16x2x32xbf16> - %transposed_5 = linalg.transpose ins(%expanded_4 : tensor<8x16x16x2x32xbf16>) outs(%2 : tensor<8x16x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] - %3 = tensor.empty() : tensor<2x8x32x32xbf16> - %4 = linalg.fill ins(%cst : bf16) outs(%3 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> - %5 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %transposed_5 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%4 : tensor<2x8x32x32xbf16>) { - ^bb0(%in: bf16, %in_19: bf16, %out: bf16): - %17 = arith.mulf %in, %in_19 : bf16 - %18 = arith.addf %out, %17 : bf16 - linalg.yield %18 : bf16 - } -> tensor<2x8x32x32xbf16> - %6 = tensor.empty() : tensor<8x32xbf16> - %padded_6 = tensor.pad %arg2 low[0] high[0] { - ^bb0(%arg5: index): - tensor.yield %cst : bf16 - } : tensor<256xbf16> to tensor<256xbf16> - %expanded_7 = tensor.expand_shape %padded_6 [[0, 1]] output_shape [8, 32] : tensor<256xbf16> into tensor<8x32xbf16> - %transposed_8 = linalg.transpose ins(%expanded_7 : tensor<8x32xbf16>) outs(%6 : tensor<8x32xbf16>) permutation = [0, 1] - %broadcasted = linalg.broadcast ins(%transposed_8 : tensor<8x32xbf16>) outs(%3 : tensor<2x8x32x32xbf16>) dimensions = [0, 2] - %7 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xbf16>) outs(%5 : tensor<2x8x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %17 = arith.addf %in, %out : bf16 - linalg.yield %17 : bf16 - } -> tensor<2x8x32x32xbf16> - %8 = tensor.empty() : tensor<32x8x32x32xbf16> - %padded_9 = tensor.pad %arg3 low[0, 0] high[0, 0] { - ^bb0(%arg5: index, %arg6: index): - tensor.yield %cst : bf16 - } : tensor<256x1024xbf16> to tensor<256x1024xbf16> - %expanded_10 = tensor.expand_shape %padded_9 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xbf16> into tensor<8x32x32x32xbf16> - %transposed_11 = linalg.transpose ins(%expanded_10 : tensor<8x32x32x32xbf16>) outs(%8 : tensor<32x8x32x32xbf16>) permutation = [2, 0, 1, 3] - %9 = tensor.empty() : tensor<32x8x16x32x2xbf16> - %padded_12 = tensor.pad %transposed_11 low[0, 0, 0, 0] high[0, 0, 0, 0] { - ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index): - tensor.yield %cst : bf16 - } : tensor<32x8x32x32xbf16> to tensor<32x8x32x32xbf16> - %expanded_13 = tensor.expand_shape %padded_12 [[0], [1], [2, 3], [4]] output_shape [32, 8, 16, 2, 32] : tensor<32x8x32x32xbf16> into tensor<32x8x16x2x32xbf16> - %transposed_14 = linalg.transpose ins(%expanded_13 : tensor<32x8x16x2x32xbf16>) outs(%9 : tensor<32x8x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] - %10 = tensor.empty() : tensor<2x32x32x32xbf16> - %11 = linalg.fill ins(%cst : bf16) outs(%10 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> - %12 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%7, %transposed_14 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%11 : tensor<2x32x32x32xbf16>) { - ^bb0(%in: bf16, %in_19: bf16, %out: bf16): - %17 = arith.mulf %in, %in_19 : bf16 - %18 = arith.addf %out, %17 : bf16 - linalg.yield %18 : bf16 - } -> tensor<2x32x32x32xbf16> - %13 = tensor.empty() : tensor<32x32xbf16> - %padded_15 = tensor.pad %arg4 low[0] high[0] { - ^bb0(%arg5: index): - tensor.yield %cst : bf16 - } : tensor<1024xbf16> to tensor<1024xbf16> - %expanded_16 = tensor.expand_shape %padded_15 [[0, 1]] output_shape [32, 32] : tensor<1024xbf16> into tensor<32x32xbf16> - %transposed_17 = linalg.transpose ins(%expanded_16 : tensor<32x32xbf16>) outs(%13 : tensor<32x32xbf16>) permutation = [0, 1] - %14 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%transposed_17 : tensor<32x32xbf16>) outs(%12 : tensor<2x32x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %17 = arith.addf %in, %out : bf16 - linalg.yield %17 : bf16 - } -> tensor<2x32x32x32xbf16> - %15 = tensor.empty() : tensor<64x1024xbf16> - %transposed_18 = linalg.transpose ins(%14 : tensor<2x32x32x32xbf16>) outs(%10 : tensor<2x32x32x32xbf16>) permutation = [0, 2, 1, 3] - %collapsed = tensor.collapse_shape %transposed_18 [[0, 1], [2, 3]] : tensor<2x32x32x32xbf16> into tensor<64x1024xbf16> - %extracted_slice = tensor.extract_slice %collapsed[0, 0] [64, 1024] [1, 1] : tensor<64x1024xbf16> to tensor<64x1024xbf16> - %16 = linalg.copy ins(%extracted_slice : tensor<64x1024xbf16>) outs(%15 : tensor<64x1024xbf16>) -> tensor<64x1024xbf16> - return %16 : tensor<64x1024xbf16> - } -} - """ - ) - module_out = ir.Module.parse( - """ -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> -#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -module { - llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 - llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> - llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> - llvm.mlir.global external constant @__runtime_fold_buffer_ids_(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> - func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<2x16x32x32xbf16> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xbf16> into tensor<2x32x16x32xbf16> - %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xbf16>) outs(%0 : tensor<2x16x32x32xbf16>) permutation = [0, 2, 1, 3] - %1 = tensor.empty() : tensor<2x8x32x32xbf16> - %2 = linalg.fill ins(%cst : bf16) outs(%1 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> - %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %arg1 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%2 : tensor<2x8x32x32xbf16>) { - ^bb0(%in: bf16, %in_1: bf16, %out: bf16): - %11 = arith.mulf %in, %in_1 : bf16 - %12 = arith.addf %out, %11 : bf16 - linalg.yield %12 : bf16 - } -> tensor<2x8x32x32xbf16> - %broadcasted = linalg.broadcast ins(%arg2 : tensor<8x32xbf16>) outs(%1 : tensor<2x8x32x32xbf16>) dimensions = [0, 2] - %4 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xbf16>) outs(%3 : tensor<2x8x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %11 = arith.addf %in, %out : bf16 - linalg.yield %11 : bf16 - } -> tensor<2x8x32x32xbf16> - %5 = tensor.empty() : tensor<2x32x32x32xbf16> - %6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> - %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%4, %arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%6 : tensor<2x32x32x32xbf16>) { - ^bb0(%in: bf16, %in_1: bf16, %out: bf16): - %11 = arith.mulf %in, %in_1 : bf16 - %12 = arith.addf %out, %11 : bf16 - linalg.yield %12 : bf16 - } -> tensor<2x32x32x32xbf16> - %8 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<32x32xbf16>) outs(%7 : tensor<2x32x32x32xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %11 = arith.addf %in, %out : bf16 - linalg.yield %11 : bf16 - } -> tensor<2x32x32x32xbf16> - %9 = tensor.empty() : tensor<64x1024xbf16> - %transposed_0 = linalg.transpose ins(%8 : tensor<2x32x32x32xbf16>) outs(%5 : tensor<2x32x32x32xbf16>) permutation = [0, 2, 1, 3] - %collapsed = tensor.collapse_shape %transposed_0 [[0, 1], [2, 3]] : tensor<2x32x32x32xbf16> into tensor<64x1024xbf16> - %10 = linalg.copy ins(%collapsed : tensor<64x1024xbf16>) outs(%9 : tensor<64x1024xbf16>) -> tensor<64x1024xbf16> - return %10 : tensor<64x1024xbf16> - } - func.func @runtime_fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} { - %0 = tensor.empty() : tensor<8x16x32x32xbf16> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xbf16> into tensor<16x32x8x32xbf16> - %transposed = linalg.transpose ins(%expanded : tensor<16x32x8x32xbf16>) outs(%0 : tensor<8x16x32x32xbf16>) permutation = [2, 0, 1, 3] - %1 = tensor.empty() : tensor<8x16x16x32x2xbf16> - %expanded_0 = tensor.expand_shape %transposed [[0], [1], [2, 3], [4]] output_shape [8, 16, 16, 2, 32] : tensor<8x16x32x32xbf16> into tensor<8x16x16x2x32xbf16> - %transposed_1 = linalg.transpose ins(%expanded_0 : tensor<8x16x16x2x32xbf16>) outs(%1 : tensor<8x16x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] - %expanded_2 = tensor.expand_shape %arg1 [[0, 1]] output_shape [8, 32] : tensor<256xbf16> into tensor<8x32xbf16> - %2 = tensor.empty() : tensor<32x8x32x32xbf16> - %expanded_3 = tensor.expand_shape %arg2 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xbf16> into tensor<8x32x32x32xbf16> - %transposed_4 = linalg.transpose ins(%expanded_3 : tensor<8x32x32x32xbf16>) outs(%2 : tensor<32x8x32x32xbf16>) permutation = [2, 0, 1, 3] - %3 = tensor.empty() : tensor<32x8x16x32x2xbf16> - %expanded_5 = tensor.expand_shape %transposed_4 [[0], [1], [2, 3], [4]] output_shape [32, 8, 16, 2, 32] : tensor<32x8x32x32xbf16> into tensor<32x8x16x2x32xbf16> - %transposed_6 = linalg.transpose ins(%expanded_5 : tensor<32x8x16x2x32xbf16>) outs(%3 : tensor<32x8x16x32x2xbf16>) permutation = [0, 1, 2, 4, 3] - %expanded_7 = tensor.expand_shape %arg3 [[0, 1]] output_shape [32, 32] : tensor<1024xbf16> into tensor<32x32xbf16> - return %transposed_1, %expanded_2, %transposed_6, %expanded_7 : tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16> - } -} - """ - ) - - # module_in entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> - torch_arg0 = torch.rand((64, 512), dtype=torch.bfloat16) - torch_arg1 = torch.rand((512, 256), dtype=torch.bfloat16) - torch_arg2 = torch.rand((256), dtype=torch.bfloat16) - torch_arg3 = torch.rand((256, 1024), dtype=torch.bfloat16) - torch_arg4 = torch.rand((1024), dtype=torch.bfloat16) - - ref_res = (torch_arg0 @ torch_arg1 + torch_arg2) @ torch_arg3 + torch_arg4 - - passes = "any(gc-cpu-pipeline)" - compiler = GraphCompiler(passes) - ctx.enable_multithreading(False) - - arg0 = torch_arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - arg1 = torch_arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - arg2 = torch_arg2.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - arg3 = torch_arg3.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - arg4 = torch_arg4.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - gc_res = np.ones((64, 1024), dtype=ml_dtypes.bfloat16) - - entry = "entry" - mlir_args = get_mlir_args(module_in, entry, [arg0, arg1, arg2, arg3, arg4, gc_res]) - engine_in = compiler.compile_and_jit(module_in, ir_printing=True) - engine_in.invoke(entry, *mlir_args) - - assert_allclose(gc_res.astype(np.float32), ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) - - - # module_out entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> - # module_out runtime_fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) - fold_arg0 = arg1 - fold_arg1 = arg2 - fold_arg2 = arg3 - fold_arg3 = arg4 - fold_res0 = np.zeros((8, 16, 16, 32, 2), dtype=ml_dtypes.bfloat16) - fold_res1 = np.zeros((8, 32), dtype=ml_dtypes.bfloat16) - fold_res2 = np.zeros((32, 8, 16, 32, 2), dtype=ml_dtypes.bfloat16) - fold_res3 = np.zeros((32, 32), dtype=ml_dtypes.bfloat16) - - runtime_fold = "runtime_fold" - fold_mlir_args = get_mlir_args(module_out, runtime_fold, [fold_arg0, fold_arg1, fold_arg2, fold_arg3, fold_res0, fold_res1, fold_res2, fold_res3]) - - gc_res_out = np.zeros((64, 1024), dtype=ml_dtypes.bfloat16) - entry = "entry" - mlir_args = get_mlir_args(module_out, entry, [arg0, fold_res0, fold_res1, fold_res2, fold_res3, gc_res_out]) - - engine_out = compiler.compile_and_jit(module_out, ir_printing=True) - engine_out.invoke(runtime_fold, *fold_mlir_args) - engine_out.invoke(entry, *mlir_args) - - assert_allclose(gc_res.astype(np.float32), gc_res_out.astype(np.float32), rtol=1e-5, atol=1e-5) - diff --git a/test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py b/test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py deleted file mode 100644 index 465d390fd..000000000 --- a/test/gc/Transforms/test_constant_tensor_folding_f32_4D4D.py +++ /dev/null @@ -1,96 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -from enum import Flag -import os -import sys - -import numpy as np -from gc_mlir import ir -from gc_mlir.graph_compiler import GraphCompiler -from numpy.testing import assert_allclose - -project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_dir not in sys.path: - sys.path.insert(0, project_dir) - -import torch -# from bench import py_timeit_bench -from utils import get_mlir_args - -if __name__ == "__main__": - with ir.Context() as ctx: - ctx.allow_unregistered_dialects = True - - M = 64 - N = 256 - K = 512 - MBlock = 32 - NBlock = 32 - KBlock = 32 - vnni_size = 1 - shapeA = [M // MBlock, K // KBlock, MBlock, KBlock] - shapeB = [N // NBlock, K // KBlock, KBlock, NBlock] - shapeC = [M // MBlock, N // NBlock, MBlock, NBlock] - - # 4D x 4D, inputs transposed - mlir_str = """ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> -module { - func.func @main_entry(%arg0: tensor<2x16x32x32xf32>, %arg1: tensor<8x16x32x32xf32>) -> tensor<2x8x32x32xf32> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<2x8x32x32xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x8x32x32xf32>) -> tensor<2x8x32x32xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%1 : tensor<2x8x32x32xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x8x32x32xf32> - return %2 : tensor<2x8x32x32xf32> - } -} - """ - module = ir.Module.parse(mlir_str) - - torch_arg0 = torch.rand((M, K), dtype=torch.float32) - torch_arg1 = torch.rand((K, N), dtype=torch.float32) - ref_res = torch.matmul(torch_arg0, torch_arg1) - - arg0_0 = torch_arg0.view([M // MBlock, MBlock, K // KBlock, KBlock]).permute([0, 2, 1, 3]).contiguous().numpy().view(np.dtype("float32")) - arg0_1 = np.transpose(np.reshape(torch_arg0.contiguous().numpy().view(np.dtype("float32")), (M // MBlock, MBlock, K // KBlock, KBlock)), (0, 2, 1, 3)) # MK -> MKmk - print("arg0_0 arg0_1 close: ", np.allclose(arg0_0, arg0_1, rtol=1e-5, atol=1e-5)) - - arg1 = torch_arg1.view([K // KBlock, KBlock, N // NBlock, NBlock]).permute([2, 0, 1, 3]).contiguous().numpy().view(np.dtype("float32")) - # arg1 = np.transpose(np.reshape(torch_arg1.contiguous().numpy(), (16, 32, 8, 32)), (2, 0, 1, 3)).view(np.dtype("float32")) # KN -> NKkn, 8x16x32x32 - - gc_res = np.ones(shapeC, dtype=np.dtype("float32")) - - entry = "main_entry" - mlir_args = get_mlir_args(module, entry, [arg0_1, arg1, gc_res]) - - passes = "any(gc-cpu-pipeline)" - compiler = GraphCompiler(passes) - engine_in = compiler.compile_and_jit(module) - engine_in.invoke(entry, *mlir_args) - gc_res = np.reshape(np.transpose(gc_res, (0, 2, 1, 3)), (64, 256)) # MNmn -> MN - - print("gc_res ref_res close: ", np.allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5)) - assert_allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) - diff --git a/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py b/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py deleted file mode 100644 index e05e2ac15..000000000 --- a/test/gc/Transforms/test_constant_tensor_folding_f32_two_layers.py +++ /dev/null @@ -1,225 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import os -import sys - -import numpy as np -from gc_mlir import ir -from gc_mlir.graph_compiler import GraphCompiler -from numpy.testing import assert_allclose - -project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_dir not in sys.path: - sys.path.insert(0, project_dir) - -import torch -# from bench import py_timeit_bench -from utils import get_mlir_args - -if __name__ == "__main__": - with ir.Context() as ctx: - ctx.allow_unregistered_dialects = True - - # 4D x 4D, inputs plain, two layers - mlir_str_4D4D = """ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -module { - func.func @entry(%arg0: tensor<64x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256x1024xf32>, %arg4: tensor<1024xf32>) -> tensor<64x1024xf32> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { - %0 = tensor.empty() : tensor<2x16x32x32xf32> - %cst = arith.constant 0.000000e+00 : f32 - %padded = tensor.pad %arg0 low[0, 0] high[0, 0] { - ^bb0(%arg5: index, %arg6: index): - tensor.yield %cst : f32 - } : tensor<64x512xf32> to tensor<64x512xf32> - %expanded = tensor.expand_shape %padded [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xf32> into tensor<2x32x16x32xf32> - %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xf32>) outs(%0 : tensor<2x16x32x32xf32>) permutation = [0, 2, 1, 3] - %1 = tensor.empty() : tensor<8x16x32x32xf32> - %padded_0 = tensor.pad %arg1 low[0, 0] high[0, 0] { - ^bb0(%arg5: index, %arg6: index): - tensor.yield %cst : f32 - } : tensor<512x256xf32> to tensor<512x256xf32> - %expanded_1 = tensor.expand_shape %padded_0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xf32> into tensor<16x32x8x32xf32> - %transposed_2 = linalg.transpose ins(%expanded_1 : tensor<16x32x8x32xf32>) outs(%1 : tensor<8x16x32x32xf32>) permutation = [2, 0, 1, 3] - %2 = tensor.empty() : tensor<2x8x32x32xf32> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x8x32x32xf32>) -> tensor<2x8x32x32xf32> - %4 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %transposed_2 : tensor<2x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%3 : tensor<2x8x32x32xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %14 = arith.mulf %in, %in_8 : f32 - %15 = arith.addf %out, %14 : f32 - linalg.yield %15 : f32 - } -> tensor<2x8x32x32xf32> - %expanded_3 = tensor.expand_shape %arg2 [[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32> - %broadcasted = linalg.broadcast ins(%expanded_3 : tensor<8x32xf32>) outs(%2 : tensor<2x8x32x32xf32>) dimensions = [0, 2] - %5 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xf32>) outs(%4 : tensor<2x8x32x32xf32>) { - ^bb0(%in: f32, %out: f32): - %14 = arith.addf %in, %out : f32 - linalg.yield %14 : f32 - } -> tensor<2x8x32x32xf32> - %6 = tensor.empty() : tensor<32x8x32x32xf32> - %expanded_4 = tensor.expand_shape %arg3 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xf32> into tensor<8x32x32x32xf32> - %transposed_5 = linalg.transpose ins(%expanded_4 : tensor<8x32x32x32xf32>) outs(%6 : tensor<32x8x32x32xf32>) permutation = [2, 0, 1, 3] - %7 = tensor.empty() : tensor<2x32x32x32xf32> - %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<2x32x32x32xf32>) -> tensor<2x32x32x32xf32> - %9 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%5, %transposed_5 : tensor<2x8x32x32xf32>, tensor<32x8x32x32xf32>) outs(%8 : tensor<2x32x32x32xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %14 = arith.mulf %in, %in_8 : f32 - %15 = arith.addf %out, %14 : f32 - linalg.yield %15 : f32 - } -> tensor<2x32x32x32xf32> - %expanded_6 = tensor.expand_shape %arg4 [[0, 1]] output_shape [32, 32] : tensor<1024xf32> into tensor<32x32xf32> - %10 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_6 : tensor<32x32xf32>) outs(%9 : tensor<2x32x32x32xf32>) { - ^bb0(%in: f32, %out: f32): - %14 = arith.addf %in, %out : f32 - linalg.yield %14 : f32 - } -> tensor<2x32x32x32xf32> - %11 = tensor.empty() : tensor<2x32x32x32xf32> - %transposed_7 = linalg.transpose ins(%10 : tensor<2x32x32x32xf32>) outs(%11 : tensor<2x32x32x32xf32>) permutation = [0, 2, 1, 3] - %collapsed = tensor.collapse_shape %transposed_7 [[0, 1], [2, 3]] : tensor<2x32x32x32xf32> into tensor<64x1024xf32> - %12 = tensor.empty() : tensor<64x1024xf32> - %13 = linalg.copy ins(%collapsed : tensor<64x1024xf32>) outs(%12 : tensor<64x1024xf32>) -> tensor<64x1024xf32> - return %13 : tensor<64x1024xf32> - } -} - """ - - module_in = ir.Module.parse(mlir_str_4D4D) - - - mlir_str_4D4D_out = """ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -module { - llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 - llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> - llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> - llvm.mlir.global external constant @__runtime_fold_buffer_ids_(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> - func.func @entry(%arg0: tensor<64x512xf32>, %arg1: tensor<8x16x32x32xf32>, %arg2: tensor<8x32xf32>, %arg3: tensor<32x8x32x32xf32>, %arg4: tensor<32x32xf32>) -> tensor<64x1024xf32> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<2x16x32x32xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [2, 32, 16, 32] : tensor<64x512xf32> into tensor<2x32x16x32xf32> - %transposed = linalg.transpose ins(%expanded : tensor<2x32x16x32xf32>) outs(%0 : tensor<2x16x32x32xf32>) permutation = [0, 2, 1, 3] - %1 = tensor.empty() : tensor<2x8x32x32xf32> - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x8x32x32xf32>) -> tensor<2x8x32x32xf32> - %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%transposed, %arg1 : tensor<2x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%2 : tensor<2x8x32x32xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %12 = arith.mulf %in, %in_1 : f32 - %13 = arith.addf %out, %12 : f32 - linalg.yield %13 : f32 - } -> tensor<2x8x32x32xf32> - %broadcasted = linalg.broadcast ins(%arg2 : tensor<8x32xf32>) outs(%1 : tensor<2x8x32x32xf32>) dimensions = [0, 2] - %4 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%broadcasted : tensor<2x8x32x32xf32>) outs(%3 : tensor<2x8x32x32xf32>) { - ^bb0(%in: f32, %out: f32): - %12 = arith.addf %in, %out : f32 - linalg.yield %12 : f32 - } -> tensor<2x8x32x32xf32> - %5 = tensor.empty() : tensor<2x32x32x32xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x32x32x32xf32>) -> tensor<2x32x32x32xf32> - %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%4, %arg3 : tensor<2x8x32x32xf32>, tensor<32x8x32x32xf32>) outs(%6 : tensor<2x32x32x32xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %12 = arith.mulf %in, %in_1 : f32 - %13 = arith.addf %out, %12 : f32 - linalg.yield %13 : f32 - } -> tensor<2x32x32x32xf32> - %8 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<32x32xf32>) outs(%7 : tensor<2x32x32x32xf32>) { - ^bb0(%in: f32, %out: f32): - %12 = arith.addf %in, %out : f32 - linalg.yield %12 : f32 - } -> tensor<2x32x32x32xf32> - %9 = tensor.empty() : tensor<2x32x32x32xf32> - %transposed_0 = linalg.transpose ins(%8 : tensor<2x32x32x32xf32>) outs(%9 : tensor<2x32x32x32xf32>) permutation = [0, 2, 1, 3] - %collapsed = tensor.collapse_shape %transposed_0 [[0, 1], [2, 3]] : tensor<2x32x32x32xf32> into tensor<64x1024xf32> - %10 = tensor.empty() : tensor<64x1024xf32> - %11 = linalg.copy ins(%collapsed : tensor<64x1024xf32>) outs(%10 : tensor<64x1024xf32>) -> tensor<64x1024xf32> - return %11 : tensor<64x1024xf32> - } - - func.func @runtime_fold(%arg0: tensor<512x256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256x1024xf32>, %arg3: tensor<1024xf32>) -> (tensor<8x16x32x32xf32>, tensor<8x32xf32>, tensor<32x8x32x32xf32>, tensor<32x32xf32>) attributes {llvm.emit_c_interface} { - %0 = tensor.empty() : tensor<8x16x32x32xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [16, 32, 8, 32] : tensor<512x256xf32> into tensor<16x32x8x32xf32> - %transposed = linalg.transpose ins(%expanded : tensor<16x32x8x32xf32>) outs(%0 : tensor<8x16x32x32xf32>) permutation = [2, 0, 1, 3] - %expanded_0 = tensor.expand_shape %arg1 [[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32> - %1 = tensor.empty() : tensor<32x8x32x32xf32> - %expanded_1 = tensor.expand_shape %arg2 [[0, 1], [2, 3]] output_shape [8, 32, 32, 32] : tensor<256x1024xf32> into tensor<8x32x32x32xf32> - %transposed_2 = linalg.transpose ins(%expanded_1 : tensor<8x32x32x32xf32>) outs(%1 : tensor<32x8x32x32xf32>) permutation = [2, 0, 1, 3] - %expanded_3 = tensor.expand_shape %arg3 [[0, 1]] output_shape [32, 32] : tensor<1024xf32> into tensor<32x32xf32> - return %transposed, %expanded_0, %transposed_2, %expanded_3 : tensor<8x16x32x32xf32>, tensor<8x32xf32>, tensor<32x8x32x32xf32>, tensor<32x32xf32> - } -} - """ - module_out = ir.Module.parse(mlir_str_4D4D_out) - - # module_in entry(%arg0: tensor<64x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256x1024xf32>, %arg4: tensor<1024xf32>) -> tensor<64x1024xf32> - torch_arg0 = torch.rand((64, 512), dtype=torch.float32) - torch_arg1 = torch.rand((512, 256), dtype=torch.float32) - torch_arg2 = torch.rand((256), dtype=torch.float32) - torch_arg3 = torch.rand((256, 1024), dtype=torch.float32) - torch_arg4 = torch.rand((1024), dtype=torch.float32) - - ref_res = (torch_arg0 @ torch_arg1 + torch_arg2) @ torch_arg3 + torch_arg4 - - passes = "any(gc-cpu-pipeline)" - compiler = GraphCompiler(passes) - ctx.enable_multithreading(False) - - arg0 = torch_arg0.contiguous().numpy() - arg1 = torch_arg1.contiguous().numpy() - arg2 = torch_arg2.contiguous().numpy() - arg3 = torch_arg3.contiguous().numpy() - arg4 = torch_arg4.contiguous().numpy() - gc_res = np.zeros((64, 1024), dtype=np.float32) - - entry = "entry" - mlir_args = get_mlir_args(module_in, entry, [arg0, arg1, arg2, arg3, arg4, gc_res]) - engine_in = compiler.compile_and_jit(module_in, ir_printing=False) - engine_in.invoke(entry, *mlir_args) - - print("Reference vs GC input IR close: ", np.allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5)) - assert_allclose(gc_res, ref_res.to(torch.float32).numpy(), rtol=1e-5, atol=1e-5) - - - # module_out entry(%arg0: tensor<64x512xf32>, %arg1: tensor<8x16x32x32xf32>, %arg2: tensor<8x32xf32>, %arg3: tensor<32x8x32x32xf32>, %arg4: tensor<32x32xf32>) -> tensor<64x1024xf32> - # module_out runtime_fold(%arg0: tensor<512x256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256x1024xf32>, %arg3: tensor<1024xf32>) -> (tensor<8x16x32x32xf32>, tensor<8x32xf32>, tensor<32x8x32x32xf32>, tensor<32x32xf32>) - fold_arg0 = arg1 - fold_arg1 = arg2 - fold_arg2 = arg3 - fold_arg3 = arg4 - fold_res0 = np.zeros((8, 16, 32, 32), dtype=np.float32) - fold_res1 = np.zeros((8, 32), dtype=np.float32) - fold_res2 = np.zeros((32, 8, 32, 32), dtype=np.float32) - fold_res3 = np.zeros((32, 32), dtype=np.float32) - - runtime_fold = "runtime_fold" - fold_mlir_args = get_mlir_args(module_out, runtime_fold, [fold_arg0, fold_arg1, fold_arg2, fold_arg3, fold_res0, fold_res1, fold_res2, fold_res3]) - - gc_res_out = np.zeros((64, 1024), dtype=np.float32) - entry = "entry" - entry_mlir_args = get_mlir_args(module_out, entry, [arg0, fold_res0, fold_res1, fold_res2, fold_res3, gc_res_out]) - - engine_out = compiler.compile_and_jit(module_out, ir_printing=False) - engine_out.invoke(runtime_fold, *fold_mlir_args) - engine_out.invoke(entry, *entry_mlir_args) - - print("GC input IR vs GC output IR close: ", np.allclose(gc_res, gc_res_out, rtol=1e-5, atol=1e-5)) - assert_allclose(gc_res, gc_res_out, rtol=1e-5, atol=1e-5) From fa30e4a5f5d9546a2fe2c29fc1e409e28c3ac782 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sat, 14 Sep 2024 11:33:00 +0800 Subject: [PATCH 59/64] Updates --- .../DataFlow/ConstantSubgraphAnalyser.cpp | 11 ++- lib/gc/Transforms/ConstantTensorFolding.cpp | 73 ++++++++----------- 2 files changed, 36 insertions(+), 48 deletions(-) diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp index e4a2130f3..b3c6b51ba 100644 --- a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -143,9 +143,9 @@ void RunConstantSubgraphAnalyser::getConstantSubgraph(DataFlowSolver &solver, for (Operation &op : llvm::make_early_inc_range(block)) { // If all the result values of a op are const, we mark this op as const. bool resultsAllConstant = true; - if (op.getNumResults() == 0) { + if (op.getNumResults() == 0) continue; - } + for (Value res : op.getResults()) { auto *lattice = solver.lookupState>(res); if (!lattice || lattice->getValue().isUninitialized()) { @@ -164,9 +164,8 @@ void RunConstantSubgraphAnalyser::getConstantSubgraph(DataFlowSolver &solver, } } - if (constantOperations.empty()) { + if (constantOperations.empty()) return; - } } RunConstantSubgraphAnalyser::RunConstantSubgraphAnalyser() { @@ -175,9 +174,9 @@ RunConstantSubgraphAnalyser::RunConstantSubgraphAnalyser() { } void RunConstantSubgraphAnalyser::run(Operation *op) { - if (failed(solver.initializeAndRun(op))) { + if (failed(solver.initializeAndRun(op))) return; - } + getConstantSubgraph(solver, op); } diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 17270d54f..b093ffd2c 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -69,9 +69,9 @@ template int64_t getDataSize(T t) { unsigned bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes ArrayRef shape = t.getShape(); int64_t size = bitWidth; - for (auto s : shape) { + for (auto s : shape) size *= s; - } + return size; } @@ -94,13 +94,12 @@ bool singleOperand(Operation *op) { Value firstOperand = op->getOperand(0); for (int64_t i = 1; i < op->getNumOperands(); ++i) { Value operand = op->getOperand(i); - if (firstOperand == operand) { + if (firstOperand == operand) continue; - } + auto parentOp = operand.getDefiningOp(); - if (parentOp && !isa(parentOp)) { + if (parentOp && !isa(parentOp)) return false; - } } } return true; @@ -121,16 +120,14 @@ bool canMoveBefore(Operation *op) { SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); for (auto &affineMap : indexingMaps) { - if (!affineMap.isIdentity()) { + if (!affineMap.isIdentity()) return false; - } } SmallVector iterTypes = linalgOp.getIteratorTypesArray(); for (auto &iterType : iterTypes) { - if (iterType != utils::IteratorType::parallel) { + if (iterType != utils::IteratorType::parallel) return false; - } } if (op->getNumOperands() > 1) { @@ -140,9 +137,8 @@ bool canMoveBefore(Operation *op) { for (int64_t i = 0; i < numInits; ++i) { OpOperand *outOperand = linalgOp.getDpsInitOperand(i); auto parentOp = outOperand->get().getDefiningOp(); - if (!isa(parentOp)) { + if (!isa(parentOp)) return false; - } } } @@ -156,9 +152,8 @@ void postponeBroadcast(Block &block) { for (Operation &op : block.getOperations()) { if (isa(&op)) { Operation *bcOp = &op; - if (isInConstantSubgraph(bcOp)) { + if (isInConstantSubgraph(bcOp)) constBcOps.push_back(bcOp); - } } } @@ -172,9 +167,9 @@ void postponeBroadcast(Block &block) { SmallVector prevOps; Operation *currOp = bcOp; while (true) { - if (currOp->getNumOperands() != 1) { + if (currOp->getNumOperands() != 1) break; - } + Value operand = currOp->getOperand(0); if (isa(operand)) { break; @@ -188,9 +183,9 @@ void postponeBroadcast(Block &block) { SmallVector postOps; currOp = bcOp; while (true) { - if (currOp->getNumResults() != 1 || !currOp->hasOneUse()) { + if (currOp->getNumResults() != 1 || !currOp->hasOneUse()) break; - } + Value input = currOp->getResult(0); currOp = *(input.getUsers().begin()); Value output = currOp->getResult(0); @@ -212,9 +207,8 @@ void postponeBroadcast(Block &block) { postOps.push_back(currOp); } } - if (postOps.empty()) { + if (postOps.empty()) continue; - } // move bcOp after the last constant op SmallVector newPostOps; @@ -308,17 +302,12 @@ void postponeBroadcast(Block &block) { return op == bcOp; }); - for (auto it = postOps.rbegin(); it != postOps.rend(); ++it) { + for (auto it = postOps.rbegin(); it != postOps.rend(); ++it) (*it)->erase(); - } } } -static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; - -// get from dnnl_graph_compiler_context -// void *allocator(size_t size) { return std::aligned_alloc(64, size); } -// void deallocator(void *ptr) { std::free(ptr); } +// TODO: The following manager will be moved to appropriate place later. // std::shared_ptr createConstCacheProxy(size_t size) { // // simply allocate buffer and return @@ -331,9 +320,7 @@ size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } // Manager struct ConstGraphTensorCacheManager { - // dnnl_graph_compiler_context *ctx; - - uint64_t cachedTensorGlobalId = 0; + int64_t cachedTensorGlobalId = 0; // singleton static std::shared_ptr get() { @@ -343,14 +330,14 @@ struct ConstGraphTensorCacheManager { } // alloc and set the buf_base_ and offset_ attributes of cache - std::vector alloc(std::vector buffersSize) { + std::vector alloc(std::vector buffersSize) { size_t totalSize = 0; - for (size_t size : buffersSize) { + for (size_t size : buffersSize) totalSize += divideAndCeil(size, 64) * 64; - } + LLVM_DEBUG(llvm::dbgs() << "Alloc total size: " << totalSize << '\n'); // auto base = createConstCacheProxy(totalSize); - std::vector globalIds(buffersSize.size()); + std::vector globalIds(buffersSize.size()); size_t offset = 0; for (size_t i = 0; i < buffersSize.size(); i++) { LLVM_DEBUG(llvm::dbgs() << "Alloc offset: " << offset << '\n'); @@ -427,9 +414,9 @@ void getArithConstantOutputs(Block &block, SmallVector &outputTypes, if (isa(&op)) { Operation *constOp = &op; auto constTensor = constOp->getResults().front(); - if (!isa(constTensor.getType())) { + if (!isa(constTensor.getType())) continue; - } + auto v = dyn_cast(constTensor); SmallVector valuesOnTheWay = {v}; // the constant tensors std::deque dq; @@ -465,6 +452,8 @@ void getArithConstantOutputs(Block &block, SmallVector &outputTypes, } } +static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; + void getInputsAndOutputs(Block &block, std::unordered_set &constArgsIndexes, SmallVector &inputTypes, @@ -511,15 +500,15 @@ void getInputsAndOutputs(Block &block, } continue; } - if (!v.hasOneUse()) { + if (!v.hasOneUse()) simpleTopo = false; - } + // the children ops of v are all constant, we push their results to // queue for (Operation *child : v.getUsers()) { - if (!singleOperand(child) || child->getResults().size() > 1) { + if (!singleOperand(child) || child->getResults().size() > 1) simpleTopo = false; - } + for (OpResult result : child->getResults()) { auto r = dyn_cast(result); dq.push_back(r); @@ -596,9 +585,9 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, } auto manager = ConstGraphTensorCacheManager::get(); SmallVector globalIndexes; - for (auto id : manager->alloc(buffersSize)) { + for (auto id : manager->alloc(buffersSize)) globalIndexes.push_back(id); - } + globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); auto moduleOp = dyn_cast(topOp); addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, From b8b0dd288137e78c026c010c7865f745578035e1 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sat, 14 Sep 2024 15:33:55 +0800 Subject: [PATCH 60/64] Move manager --- .../CPURuntime/ConstantCache.h | 68 ++++++++++++++----- lib/gc/ExecutionEngine/Driver/Driver.cpp | 12 ++-- lib/gc/Transforms/ConstantTensorFolding.cpp | 45 ++---------- .../unittests/ExecutionEngine/JitWrapper.cpp | 5 +- 4 files changed, 66 insertions(+), 64 deletions(-) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h index a3756220d..9e4ea4f6f 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -138,8 +138,6 @@ struct CachedGraphTensor { StridedMemRefType ref; }; -static std::unordered_map> cache; - inline std::shared_ptr createConstCacheProxy(size_t size) { // simply allocate buffer and return std::shared_ptr base = std::shared_ptr{ @@ -147,24 +145,62 @@ inline std::shared_ptr createConstCacheProxy(size_t size) { return std::make_shared(base, base.get(), size, true); } -inline std::shared_ptr queryCacheTensor(int64_t key) { - auto itr = cache.find(key); - if (itr != cache.end()) { - return itr->second; - } - return nullptr; +inline static size_t divideAndCeil(size_t x, size_t y) { + return (x + y - 1) / y; } -inline bool regCachedTensor(int64_t key, - const std::shared_ptr &base, - size_t offset) { - if (queryCacheTensor(key)) { - return false; +// Manager +struct ConstGraphTensorCacheManager { + int64_t cachedTensorGlobalId = 0; + + std::unordered_map> cache; + + // singleton + static std::shared_ptr get() { + static std::shared_ptr c = + std::make_shared(); + return c; } - cache[key] = std::make_shared(base, offset); - return true; -} + std::shared_ptr queryCacheTensor(int64_t key) { + auto itr = cache.find(key); + if (itr != cache.end()) { + return itr->second; + } + return nullptr; + } + + bool regCachedTensor(int64_t key, + const std::shared_ptr &base, + size_t offset) { + if (queryCacheTensor(key)) { + return false; + } + + cache[key] = std::make_shared(base, offset); + return true; + } + + // alloc and set the buf_base_ and offset_ attributes of cache + std::vector alloc(std::vector buffersSize) { + size_t totalSize = 0; + for (size_t size : buffersSize) { + totalSize += divideAndCeil(size, 64) * 64; + } + auto base = createConstCacheProxy(totalSize); + std::vector globalIds(buffersSize.size()); + size_t offset = 0; + for (size_t i = 0; i < buffersSize.size(); i++) { + bool regRes = regCachedTensor(cachedTensorGlobalId, base, offset); + assert(regRes && "Register constant tensor failed"); + globalIds[i] = cachedTensorGlobalId; + ++cachedTensorGlobalId; + offset += divideAndCeil(buffersSize[i], 64) * 64; + } + return globalIds; + } +}; + } // namespace gc } // namespace mlir diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 8d4a0e5c4..e7d1f1a66 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -143,11 +143,13 @@ JitModule::create(Operation *op, const DriverOptions &options) { std::vector> foldInfo; foldInfo.reserve(foldBufferIds.size()); + auto cacheManager = ConstGraphTensorCacheManager::get(); for (auto bufId : foldBufferIds) { - auto ret = queryCacheTensor(bufId); + auto ret = cacheManager->queryCacheTensor(bufId); if (!ret) { return llvm::make_error( - "Failed to query the folded cached tensor", + "Failed to query the folded cached tensor of id: " + + std::to_string(bufId), llvm::inconvertibleErrorCode()); } foldInfo.emplace_back(std::move(ret)); @@ -251,7 +253,8 @@ void JitModule::call(GeneralMemrefPtr *args, int32_t numArgs) { llvm::SmallVector realargs; prepareCallArgs(realargs, args, numArgs, numOrigArgs, foldedCache, foldArgs); - LLVM_DEBUG(llvm::dbgs() << "foldArgs size: " << foldArgs.size() << '\n'); + LLVM_DEBUG(llvm::dbgs() + << "fold func args size: " << foldArgs.size() << '\n'); fold(realargs.data()); } @@ -261,8 +264,7 @@ void JitModule::call(GeneralMemrefPtr *args, int32_t numArgs) { prepareCallArgs(realargs, args, numArgs, numOrigArgs, foldedCache, entryArgs); LLVM_DEBUG(llvm::dbgs() - << "entryArgs size: " << entryArgs.size() - << ", Entry real args size: " << realargs.size() << '\n'); + << "entry func args size: " << realargs.size() << '\n'); entry(realargs.data()); } diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 9fca113b3..c22de5fc3 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -314,45 +314,6 @@ void postponeBroadcast(Block &block) { } } -static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; - -size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } - -// Manager -struct ConstGraphTensorCacheManager { - // dnnl_graph_compiler_context *ctx; - - uint64_t cachedTensorGlobalId = 0; - - // singleton - static std::shared_ptr get() { - static std::shared_ptr c = - std::make_shared(); - return c; - } - - // alloc and set the buf_base_ and offset_ attributes of cache - std::vector alloc(std::vector buffersSize) { - size_t totalSize = 0; - for (size_t size : buffersSize) { - totalSize += divideAndCeil(size, 64) * 64; - } - LLVM_DEBUG(llvm::dbgs() << "Alloc total size: " << totalSize << '\n'); - auto base = createConstCacheProxy(totalSize); - std::vector globalIds(buffersSize.size()); - size_t offset = 0; - for (size_t i = 0; i < buffersSize.size(); i++) { - LLVM_DEBUG(llvm::dbgs() << "Alloc offset: " << offset << '\n'); - bool regRes = regCachedTensor(cachedTensorGlobalId, base, offset); - assert(regRes && "Register constant tensor failed"); - globalIds[i] = cachedTensorGlobalId; - ++cachedTensorGlobalId; - offset += divideAndCeil(buffersSize[i], 64) * 64; - } - return globalIds; - } -}; - static void addGlobalI32(ModuleOp &module, Location loc, OpBuilder &builder, StringRef name, int32_t value) { OpBuilder::InsertionGuard insertGuard(builder); @@ -455,6 +416,8 @@ void getArithConstantOutputs(Block &block, SmallVector &outputTypes, } } +static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; + void getInputsAndOutputs(Block &block, std::unordered_set &constArgsIndexes, SmallVector &inputTypes, @@ -584,9 +547,9 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, << "Allocate buffer for tensor: " << tensor << "\n"); buffersSize.push_back(getValueSize(tensor)); } - auto manager = ConstGraphTensorCacheManager::get(); + auto cacheManager = ConstGraphTensorCacheManager::get(); SmallVector globalIndexes; - for (auto id : manager->alloc(buffersSize)) { + for (auto id : cacheManager->alloc(buffersSize)) { globalIndexes.push_back(id); } globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); diff --git a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp index 6d7a489c9..032c9e0d7 100644 --- a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp +++ b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp @@ -106,13 +106,14 @@ TEST(ExecutionEngine, JitWrapperCached) { } ASSERT_TRUE(jit_success); + auto cacheManager = gc::ConstGraphTensorCacheManager::get(); auto ret = std::shared_ptr(new float[128]); auto proxy = std::make_shared(ret, ret.get(), 128 * sizeof(float), true); // Can not register with already existing key. - ASSERT_FALSE(gc::regCachedTensor(0, proxy, 0)); + ASSERT_FALSE(cacheManager->regCachedTensor(0, proxy, 0)); - proxy = gc::queryCacheTensor(0)->base; + proxy = cacheManager->queryCacheTensor(0)->base; auto data = (float *)proxy->getBufferUnsafe(); OwningMemRef bufA{ From 6a041dd1e48a1a220b54a798107ccec02727383a Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sat, 14 Sep 2024 16:02:07 +0800 Subject: [PATCH 61/64] Use atomic --- include/gc/ExecutionEngine/CPURuntime/ConstantCache.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h index 9e4ea4f6f..8a7330eaa 100644 --- a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -151,7 +151,7 @@ inline static size_t divideAndCeil(size_t x, size_t y) { // Manager struct ConstGraphTensorCacheManager { - int64_t cachedTensorGlobalId = 0; + std::atomic_int64_t cachedTensorGlobalId = 0; std::unordered_map> cache; From c876358488a7d77fa8120a9a13d87ee0bb5d584a Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Sat, 14 Sep 2024 16:31:53 +0800 Subject: [PATCH 62/64] Fix --- lib/gc/ExecutionEngine/Driver/Driver.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index e7d1f1a66..42fa83b68 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -79,7 +79,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { } JitModuleFuncT entry = *expectEntry; - int32_t numOrigArgs; + int32_t numOrigArgs = 0; llvm::ArrayRef foldBufferIds; JitModuleFuncT fold = nullptr; llvm::ArrayRef entryArgs; @@ -100,7 +100,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { { auto expectBufferIds = engine->lookup("__runtime_fold_buffer_ids"); if (!expectBufferIds) { - expectBufferIds.takeError(); + llvm_unreachable("Symbol: __runtime_fold_buffer_ids not found"); break; } auto raw = reinterpret_cast(*expectBufferIds); @@ -112,7 +112,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { { auto expectFold = engine->lookupPacked(defaultFoldName); if (!expectFold) { - expectFold.takeError(); + llvm_unreachable("Symbol: runtime_fold not found"); break; } fold = *expectFold; @@ -122,7 +122,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { { auto expectFold = engine->lookup("__fold_args"); if (!expectFold) { - expectFold.takeError(); + llvm_unreachable("Symbol: __fold_args not found"); break; } auto raw = reinterpret_cast(*expectFold); @@ -133,7 +133,7 @@ JitModule::create(Operation *op, const DriverOptions &options) { { auto expect = engine->lookup("__compute_args"); if (!expect) { - expect.takeError(); + llvm_unreachable("Symbol: __compute_args not found"); break; } auto raw = reinterpret_cast(*expect); From 77e0f0258f4ccd0d87b8824132dca5718474e6fc Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Wed, 18 Sep 2024 15:12:18 +0800 Subject: [PATCH 63/64] Merge into one pass --- include/gc/Transforms/Passes.h | 1 - include/gc/Transforms/Passes.td | 8 --- lib/gc/Transforms/CMakeLists.txt | 1 - .../Transforms/ConstantSubgraphAnalysis.cpp | 54 ------------------- lib/gc/Transforms/ConstantTensorFolding.cpp | 5 ++ lib/gc/Transforms/Pipeline.cpp | 2 - .../test_constant_tensor_folding-0.mlir | 2 +- .../test_constant_tensor_folding-1.mlir | 2 +- 8 files changed, 7 insertions(+), 68 deletions(-) delete mode 100644 lib/gc/Transforms/ConstantSubgraphAnalysis.cpp diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index a42dba87b..06a3ee83d 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -124,7 +124,6 @@ void populateGPUPipeline(mlir::OpPassManager &); #define GEN_PASS_DECL_CONSTANTTENSORFOLDING #include "gc/Transforms/Passes.h.inc" -std::unique_ptr createConstantSubgraphAnalysisPass(); std::unique_ptr createConstantTensorFoldingPass(); #define GEN_PASS_REGISTRATION diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 034380323..9a968c3bd 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -169,14 +169,6 @@ def MergeNestedForall : Pass<"merge-nested-forall"> { let dependentDialects = ["scf::SCFDialect"]; } -def ConstantSubgraphAnalysis : Pass<"constant-subgraph-analysis"> { - let summary = "Constant Subgraph Analysis"; - let description = [{ - This pass implements a constant subgraph analysis. - }]; - let constructor = "mlir::gc::createConstantSubgraphAnalysisPass()"; -} - def ConstantTensorFolding : Pass<"constant-tensor-folding"> { let summary = "Constant Tensor Folding Transform"; let description = [{ diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 44415fece..08d60e513 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,7 +16,6 @@ gc_add_mlir_library(GcPasses IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp - ConstantSubgraphAnalysis.cpp ConstantTensorFolding.cpp DecomposeAggregatedOps.cpp DeepTileContractionOp.cpp diff --git a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp b/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp deleted file mode 100644 index 511d76f21..000000000 --- a/lib/gc/Transforms/ConstantSubgraphAnalysis.cpp +++ /dev/null @@ -1,54 +0,0 @@ -//===-- ConstantSubgraphAnalysis.cpp - Constant Subgraph --------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This transformation pass performs a constant subgraph analysis -// in MLIR. -// -//===----------------------------------------------------------------------===// -#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Passes.h" - -namespace mlir { -namespace gc { -#define GEN_PASS_DEF_CONSTANTSUBGRAPHANALYSIS -#include "gc/Transforms/Passes.h.inc" -} // namespace gc - -using namespace mlir; -using namespace mlir::dataflow; - -namespace gc { - -struct ConstantSubgraphAnalysis - : public impl::ConstantSubgraphAnalysisBase { - void runOnOperation() override; -}; - -void ConstantSubgraphAnalysis::runOnOperation() { - Operation *op = getOperation(); - auto &func = - op->getRegions().front().getBlocks().front().getOperations().front(); - - // Hard-code example: set some arguments to be constant. - // OpBuilder builder(op->getContext()); - // func.setAttr("runtime_const_args_index", - // builder.getI32ArrayAttr({1,2,3,4})); - - RunConstantSubgraphAnalyser runAnalyser; - (void)runAnalyser.run(&func); -} - -std::unique_ptr createConstantSubgraphAnalysisPass() { - return std::make_unique(); -} - -} // namespace gc -} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index b093ffd2c..6000ec844 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -13,6 +13,7 @@ #include #include +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" #include "mlir/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -766,6 +767,10 @@ void ConstantTensorFolding::runOnOperation() { MLIRContext *context = topOp->getContext(); auto &topFunc = topOp->getRegions().front().getBlocks().front().getOperations().front(); + + dataflow::RunConstantSubgraphAnalyser runAnalyser; + (void)runAnalyser.run(&topFunc); + OpBuilder builder(context); Region ®ion = topFunc.getRegions().front(); Block &block = region.getBlocks().front(); diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 4cd1e9272..40527f644 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -52,8 +52,6 @@ void populateFrontendPasses(mlir::OpPassManager &pm) { void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass - // todo: tensor constant propagation pass - pm.addPass(createConstantSubgraphAnalysisPass()); pm.addPass(createConstantTensorFoldingPass()); // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass pm.addNestedPass(createDeepTileContractionOp()); diff --git a/test/gc/Transforms/test_constant_tensor_folding-0.mlir b/test/gc/Transforms/test_constant_tensor_folding-0.mlir index eabdacc93..155e0875e 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-0.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-0.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-tensor-folding)" %s | FileCheck %s // COM:A complete example of compile-time and runtime folding. diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir index 92231703d..ca70f8d6a 100644 --- a/test/gc/Transforms/test_constant_tensor_folding-1.mlir +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-subgraph-analysis,constant-tensor-folding)" %s | FileCheck %s +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-tensor-folding)" %s | FileCheck %s // COM: Test the 'postponeBroadcast' feature of constant tensor folding. From 2df16c29b8dbadb4b31d812475599d166872b51e Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Wed, 18 Sep 2024 15:46:18 +0800 Subject: [PATCH 64/64] Skip case --- lib/gc/Transforms/ConstantTensorFolding.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp index 6000ec844..3fa85a496 100644 --- a/lib/gc/Transforms/ConstantTensorFolding.cpp +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -431,6 +431,9 @@ void getArithConstantOutputs(Block &block, SmallVector &outputTypes, [](Operation *child) { return !isInConstantSubgraph(child); })) { + if (valuesOnTheWay.size() == 1) { + continue; + } if (std::find(outputValues.begin(), outputValues.end(), v) == outputValues.end()) { outputTypes.push_back(v.getType());