From ddafed4824c4cde006d6c9e8afd8eb9fee7abe57 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Wed, 15 May 2024 04:20:50 +0800 Subject: [PATCH] [Transform] Add basic onednn_graph dialect lowering (#61) --- include/gc/Transforms/Passes.td | 14 + lib/gc/Transforms/CMakeLists.txt | 2 + lib/gc/Transforms/OneDNNGraphToLinalg.cpp | 298 ++++++++++++++++++ .../OneDNNGraph/onednn-graph-to-linalg.mlir | 113 +++++++ 4 files changed, 427 insertions(+) create mode 100644 lib/gc/Transforms/OneDNNGraphToLinalg.cpp create mode 100644 test/gc/Dialect/OneDNNGraph/onednn-graph-to-linalg.mlir diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 7274534b7..d31baa5a7 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -17,4 +17,18 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> { ["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"]; } +def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { + let summary = "Lower the operations from the oneDNN Graph dialect into Linalg"; + let description = [{ + Lowers the `onednn_graph` ops to `linalg` ops. + }]; + let dependentDialects = [ + "func::FuncDialect", + "math::MathDialect", + "arith::ArithDialect", + "tensor::TensorDialect", + "linalg::LinalgDialect" + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index df8a14d01..1bcdf115c 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(GCPasses + OneDNNGraphToLinalg.cpp TileNamed.cpp ADDITIONAL_HEADER_DIRS @@ -9,6 +10,7 @@ add_mlir_library(GCPasses LINK_LIBS PUBLIC ${mlir_dialect_libs} + MLIROneDNNGraph MLIRIR MLIRSupport MLIRBufferizationToMemRef diff --git a/lib/gc/Transforms/OneDNNGraphToLinalg.cpp b/lib/gc/Transforms/OneDNNGraphToLinalg.cpp new file mode 100644 index 000000000..c115b42c2 --- /dev/null +++ b/lib/gc/Transforms/OneDNNGraphToLinalg.cpp @@ -0,0 +1,298 @@ +//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=// +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" +#include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir::onednn_graph; + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG +#include "gc/Transforms/Passes.h.inc" + +namespace { +//===----------------------------------------------------------------------===// +// Util funcs +//===----------------------------------------------------------------------===// + +Value createBroadcastOperand(Location loc, PatternRewriter &rewriter, + TensorType ty, Value op) { + auto opTy = dyn_cast(op.getType()); + llvm::ArrayRef bcastShape = ty.getShape(); + llvm::ArrayRef opShape = opTy.getShape(); + int64_t diff = bcastShape.size() - opShape.size(); + + if (bcastShape.equals(opShape)) { + return op; + } else { + // get broadcast dimensions + llvm::SmallVector bcastDims; + for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) { + int64_t idxOp = i - diff; + if (idxOp < 0) { + bcastDims.push_back(i); + } else if (bcastShape[i] != opShape[idxOp]) { + bcastDims.push_back(i); + } + } + // create a new output tensor + Value initTensor = + rewriter.create(loc, bcastShape, ty.getElementType()); + return rewriter + .create( + /*location=*/loc, + /*inputs=*/op, + /*inits=*/initTensor, + /*dimensions=*/bcastDims) + .getResults() + .front(); + } +} + +// Typedef for function to get operands for transformed op +typedef mlir::Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType); + +// Functions to get operands for from original op +struct OriginalOperand { + template + static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) { + if (I >= op->getNumOperands()) { + op->emitError("Index exceeds operand num.\n"); + return nullptr; + } + return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I)); + } +}; + +// Functions to get constant operands +struct ConstantOperand { + template + static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) { + const auto loc = op->getLoc(); + const auto elemTy = ty.getElementType(); + if (llvm::isa(elemTy)) { + return b.create( + loc, DenseElementsAttr::get(ty, b.getIntegerAttr(elemTy, I))); + } else if (llvm::isa(elemTy)) { + return b.create( + loc, DenseElementsAttr::get(ty, b.getFloatAttr(elemTy, I))); + } else { + op->emitError("Not a supported element type for constant.\n"); + return nullptr; + } + } +}; + +//===----------------------------------------------------------------------===// +// Elemwise lowering +//===----------------------------------------------------------------------===// + +// Generate elementwise op using linalg named ops +template +Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty, + llvm::ArrayRef inputs) { + // create a new output tensor + Value outTensor = + rewriter.create(loc, ty.getShape(), ty.getElementType()); + + auto elemwiseOp = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/outTensor.getType(), + /*inputs=*/inputs, + /*outputs=*/outTensor); + + return elemwiseOp.getResult(0); +} + +template +struct UnaryElemwiseLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(UnaryOp op, + PatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto resultTy = dyn_cast(op->getResultTypes().front()); + auto inOp = GetOperand(op, rewriter, resultTy); + if (!inOp) { + return rewriter.notifyMatchFailure(op, "Fail to get operand."); + } + auto unaryOp = createElemwiseOp(loc, rewriter, resultTy, {inOp}); + rewriter.replaceOp(op, unaryOp); + return success(); + } +}; + +template +struct BinaryElemwiseLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(BinaryOp op, + PatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto resultTy = dyn_cast(op->getResultTypes().front()); + auto lhsOp = GetOperandLHS(op, rewriter, resultTy); + auto rhsOp = GetOperandRHS(op, rewriter, resultTy); + if (!lhsOp || !rhsOp) { + return rewriter.notifyMatchFailure(op, "Fail to get operand."); + } + auto binaryOp = createElemwiseOp(loc, rewriter, resultTy, // + {lhsOp, rhsOp}); + rewriter.replaceOp(op, binaryOp); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Op lowering +//===----------------------------------------------------------------------===// + +using ReLUOpLowering = + BinaryElemwiseLowering, + ConstantOperand::getConst<0>>; + +using AddOpLowering = + BinaryElemwiseLowering, + OriginalOperand::getIdx<1>>; + +//===----------------------------------------------------------------------===// +// MatMulOp lowering +//===----------------------------------------------------------------------===// + +struct MatMulOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(MatMulOp op, + PatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto resultTy = dyn_cast(op->getResultTypes().front()); + auto typeA = dyn_cast(op.getInputA().getType()); + auto typeB = dyn_cast(op.getInputB().getType()); + // + auto getEmptyTensor = [&](TensorType tensorTy) -> Value { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(tensorTy.getElementType())); + Value newTensor = rewriter.create( + loc, tensorTy.getShape(), tensorTy.getElementType()); + return rewriter.create(loc, zero, newTensor).getResult(0); + }; + + if (typeA.getRank() != 2 || typeB.getRank() != 2) { + return rewriter.notifyMatchFailure( + op, "Currently not support multi batch matmul."); + } + bool transposeA = op.getTransposeA(); + bool transposeB = op.getTransposeB(); + Operation *newOp = nullptr; + if (!transposeA && !transposeB) { + // (A * B) + auto outTensor = getEmptyTensor(resultTy); + newOp = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/resultTy, + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, + /*outputs=*/outTensor); + } else if (transposeA && !transposeB) { + // T(A) * B + auto outTensor = getEmptyTensor(resultTy); + newOp = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/resultTy, + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, + /*outputs=*/outTensor); + } else if (!transposeA && transposeB) { + // A * T(B) + auto outTensor = getEmptyTensor(resultTy); + newOp = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/resultTy, + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, + /*outputs=*/outTensor); + } else { + // T(B * A) + const auto &resultShape = resultTy.getShape(); + SmallVector transShape{resultShape[1], resultShape[0]}; + SmallVector permutation{1, 0}; + auto transTy = resultTy.clone(transShape); + auto transTensor = getEmptyTensor(transTy); + auto matmulOp = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/transTy, + /*inputs=*/ValueRange{op.getInputB(), op.getInputA()}, + /*outputs=*/transTensor); + auto outTensor = getEmptyTensor(resultTy); + newOp = rewriter.create( + /*location=*/loc, + /*inputs=*/matmulOp.getResult(0), + /*outputs=*/outTensor, + /*permutation=*/permutation); + } + + if (op.getBias()) { + Value bias = + createBroadcastOperand(loc, rewriter, resultTy, op.getBias()); + Value outBias = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType()); + newOp = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/outBias.getType(), + /*inputs=*/ValueRange{newOp->getResult(0), bias}, + /*outputs=*/outBias); + } + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass define +//===----------------------------------------------------------------------===// + +struct ConvertOneDNNGraphToLinalg + : public impl::ConvertOneDNNGraphToLinalgBase { + + void runOnOperation() final { + auto *ctx = &getContext(); + // add lowering target + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + // set pattern + RewritePatternSet patterns(ctx); + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + // perform conversion + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace +} // namespace gc +} // namespace mlir diff --git a/test/gc/Dialect/OneDNNGraph/onednn-graph-to-linalg.mlir b/test/gc/Dialect/OneDNNGraph/onednn-graph-to-linalg.mlir new file mode 100644 index 000000000..7ea195b34 --- /dev/null +++ b/test/gc/Dialect/OneDNNGraph/onednn-graph-to-linalg.mlir @@ -0,0 +1,113 @@ +// RUN: gc-opt --split-input-file --convert-onednn-graph-to-linalg %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: @matmul +func.func @matmul(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[C0:%.+]] = arith.constant 0 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16> + return %0 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @matmul_bias +func.func @matmul_bias(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>, %arg3: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[C0:%.+]] = arith.constant 0 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: tensor.empty() + // CHECK: linalg.broadcast + // CHECK: tensor.empty() + // CHECK: linalg.add + %0 = onednn_graph.matmul %arg0, %arg1, %arg3 : (tensor<128x512xbf16>, tensor<512x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + return %0 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @add +func.func @add(%arg0: tensor<128x256xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + // CHECK: tensor.empty() + // CHECK: linalg.add + %0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<128x256xf32>) -> tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} + +// CHECK-LABEL: @add_bcast +func.func @add_bcast(%arg0: tensor<128x256xf32>, %arg1: tensor<256xf32>) -> tensor<128x256xf32> { + // CHECK: tensor.empty() + // CHECK: linalg.broadcast + // CHECK: tensor.empty() + // CHECK: linalg.add + %0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<256xf32>) -> tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} + +// CHECK-LABEL: @relu +func.func @relu(%arg0: tensor<128x256xbf16>) -> tensor<128x256xbf16> { + // CHECK: arith.constant dense<0.0{{.*}}> + // CHECK: tensor.empty() + // CHECK: linalg.max + %0 = onednn_graph.relu %arg0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %0 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a +func.func @mlp_transpose_a(%arg0: tensor<512x128xbf16>, %arg1: tensor<512x256xbf16>, %arg3: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[C0:%.+]] = arith.constant 0 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<512x128xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: tensor.empty() + // CHECK: linalg.broadcast + // CHECK: tensor.empty() + // CHECK: linalg.add + // CHECK: arith.constant dense<0.0{{.*}}> + // CHECK: tensor.empty() + // CHECK: linalg.max + %0 = onednn_graph.matmul %arg0, %arg1, %arg3 {transpose_a = true} + : (tensor<512x128xbf16>, tensor<512x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_b +func.func @mlp_transpose_b(%arg0: tensor<128x512xbf16>, %arg1: tensor<256x512xbf16>, %arg3: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[C0:%.+]] = arith.constant 0 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<256x512xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: tensor.empty() + // CHECK: linalg.broadcast + // CHECK: tensor.empty() + // CHECK: linalg.add + // CHECK: arith.constant dense<0.0{{.*}}> + // CHECK: tensor.empty() + // CHECK: linalg.max + %0 = onednn_graph.matmul %arg0, %arg1, %arg3 {transpose_b = true} + : (tensor<128x512xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a_b +func.func @mlp_transpose_a_b(%arg0: tensor<512x128xbf16>, %arg1: tensor<256x512xbf16>, %arg3: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[C0:%.+]] = arith.constant 0 + // CHECK: [[INIT0:%.+]] = tensor.empty() + // CHECK: [[FILLED0:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT0]] : tensor<256x128xbf16>) -> tensor<256x128xbf16> + // CHECK: [[MMT:%.+]] = linalg.matmul ins(%arg1, %arg0 : tensor<256x512xbf16>, tensor<512x128xbf16>) outs([[FILLED0]] : tensor<256x128xbf16>) -> tensor<256x128xbf16> + // CHECK: [[C1:%.+]] = arith.constant 0 + // CHECK: [[INIT1:%.+]] = tensor.empty() + // CHECK: [[FILLED1:%.+]] = linalg.fill ins([[C1]] : bf16) outs([[INIT1]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> + // CHECK: linalg.transpose ins([[MMT]] : tensor<256x128xbf16>) outs([[FILLED1]] : tensor<128x256xbf16>) + // CHECK: tensor.empty() + // CHECK: linalg.broadcast + // CHECK: tensor.empty() + // CHECK: linalg.add + // CHECK: arith.constant dense<0.0{{.*}}> + // CHECK: tensor.empty() + // CHECK: linalg.max + %0 = onednn_graph.matmul %arg0, %arg1, %arg3 {transpose_a = true, transpose_b = true} + : (tensor<512x128xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +}