diff --git a/CMakeLists.txt b/CMakeLists.txt index 04da65c01..0b6195d9a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,7 +64,7 @@ add_subdirectory(src) set(GC_LIB_LINKED_LIBS MLIRLinalgx MLIRMicrokernel - MLIROnednnGraph + MLIROneDNNGraph ) add_library(graph_compiler SHARED ${GC_LIB_SOURCES}) target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES}) diff --git a/include/gc/Dialect/CMakeLists.txt b/include/gc/Dialect/CMakeLists.txt index ffeda0aa7..bedadc2af 100644 --- a/include/gc/Dialect/CMakeLists.txt +++ b/include/gc/Dialect/CMakeLists.txt @@ -1,3 +1,3 @@ -add_subdirectory(OnednnGraph) +add_subdirectory(OneDNNGraph) add_subdirectory(Microkernel) add_subdirectory(Linalgx) \ No newline at end of file diff --git a/include/gc/Dialect/OneDNNGraph/CMakeLists.txt b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt new file mode 100644 index 000000000..63dfde793 --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_dialect(OneDNNGraphOps onednn_graph) +add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc) +add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc) diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h similarity index 67% rename from include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h rename to include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h index 7f128cb64..bae992507 100644 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h @@ -1,4 +1,4 @@ -//===- OnednnGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// +//===- OneDNNGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,7 +10,10 @@ #define GC_DIALECTS_ONEDNNGRAPHDIALECT_H #include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" -#include "gc/Dialect/OnednnGraph/OnednnGraphOpsDialect.h.inc" +#define GET_OP_CLASSES +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsDialect.h.inc" #endif // GC_DIALECTS_ONEDNNGRAPHDIALECT_H diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td similarity index 85% rename from include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td rename to include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td index 16615a4d3..6e1eaceca 100644 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td @@ -1,4 +1,4 @@ -//===- OnednnGraphDialect.td - OneDNN input dialect --------*- tablegen -*-===// +//===- OneDNNGraphDialect.td - OneDNN input dialect --------*- tablegen -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,15 +15,13 @@ include "mlir/IR/OpBase.td" // OneDNNGraph dialect definition. //===----------------------------------------------------------------------===// -def OnednnGraphDialect : Dialect { +def OneDNNGraphDialect : Dialect { let name = "onednn_graph"; let summary = "A dialect for oneDNN Graph."; let description = [{ This dialect follows oneDNN Graph Specification. }]; let cppNamespace = "::mlir::onednn_graph"; - - let useDefaultTypePrinterParser = 1; } #endif // ONEDNNGRAPH_DIALECT diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h similarity index 54% rename from include/gc/Dialect/OnednnGraph/OnednnGraphOps.h rename to include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h index ad86d908c..d4fe91387 100644 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.h +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h @@ -1,4 +1,4 @@ -//===- OnednnGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// +//===- OneDNNGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,9 +9,15 @@ #ifndef GC_DIALECTS_ONEDNNGRAPHOPS_H #define GC_DIALECTS_ONEDNNGRAPHOPS_H +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #define GET_OP_CLASSES -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h.inc" #endif // GC_DIALECTS_ONEDNNGRAPHOPS_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td new file mode 100644 index 000000000..031a84fbc --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td @@ -0,0 +1,82 @@ +//===- OneDNNGraphOps.td - OneDNN input dialect ops --------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_OPS +#define ONEDNNGRAPH_OPS + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "OneDNNGraphDialect.td" +include "OneDNNGraphTypes.td" + +//===----------------------------------------------------------------------===// +// Base OneDNNGraph operation definition. +//===----------------------------------------------------------------------===// + +class OneDNNGraph_Op traits = []> : + Op; + +class OneDNNGraph_ElemwiseBinaryOp traits = []> : + OneDNNGraph_Op { + let arguments = (ins OneDNNGraph_LogicalTensor:$input_0, + OneDNNGraph_LogicalTensor:$input_1); + let results = (outs OneDNNGraph_LogicalTensor:$result); + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +class OneDNNGraph_ElemwiseUnaryOp traits = []> : + OneDNNGraph_Op { + let arguments = (ins OneDNNGraph_LogicalTensor:$operand); + let results = (outs OneDNNGraph_LogicalTensor:$result); + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// OneDNNGraph op definitions +//===----------------------------------------------------------------------===// + +def OneDNNGraph_MatMulOp : + OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> { + let summary = "Generalized matrix multiplication"; + let description = [{ + `https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html` + }]; + + let arguments = (ins OneDNNGraph_LogicalTensor:$input_a, + OneDNNGraph_LogicalTensor:$input_b, + Optional:$bias, + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b); + let results = (outs OneDNNGraph_LogicalTensor:$result); + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> { + let summary = "element-wise relu"; + let description = [{ + `https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html` + }]; +} + +def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> { + let summary = "element-wise addition with multi-directional broadcast"; + let description = [{ + `https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html` + }]; +} + +#endif // ONEDNNGRAPH_OPS \ No newline at end of file diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h new file mode 100644 index 000000000..c897fb8be --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h @@ -0,0 +1,17 @@ +//===- OneDNNGraphTypes.h - OneDNN input dialect types ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_ONEDNNGRAPHTYPES_H +#define ONEDNNGRAPH_ONEDNNGRAPHTYPES_H + +#include "mlir/IR/BuiltinTypes.h" + +#define GET_TYPEDEF_CLASSES +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsTypes.h.inc" + +#endif // ONEDNNGRAPH_ONEDNNGRAPHTYPES_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td new file mode 100644 index 000000000..3c9f0e41d --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td @@ -0,0 +1,30 @@ +//===- OneDNNGraphTypes.h - OneDNN input dialect types -----*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_TYPES +#define ONEDNNGRAPH_TYPES + +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/AttrTypeBase.td" +include "OneDNNGraphDialect.td" + +//===----------------------------------------------------------------------===// +// OneDNNGraph type definitions +//===----------------------------------------------------------------------===// + +def OneDNNGraph_DataType : AnyTypeOf<[ + F16, + BF16, + F32, + SI<32>, + SI<8>, + UI<8>]>; + +def OneDNNGraph_LogicalTensor : TensorOf<[OneDNNGraph_DataType]>; + +#endif // ONEDNNGRAPH_TYPES diff --git a/include/gc/Dialect/OnednnGraph/CMakeLists.txt b/include/gc/Dialect/OnednnGraph/CMakeLists.txt deleted file mode 100644 index 7e7c7eb68..000000000 --- a/include/gc/Dialect/OnednnGraph/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_mlir_dialect(OnednnGraphOps onednn_graph) -add_mlir_doc(OnednnGraphOps OnednnGraphOps gc/Dialect/OnednnGraph/ -gen-op-doc) -add_mlir_doc(OnednnGraphDialect OnednnGraphDialect gc/Dialect/OnednnGraph/ -gen-dialect-doc) diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.td b/include/gc/Dialect/OnednnGraph/OnednnGraphOps.td deleted file mode 100644 index e460011ee..000000000 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.td +++ /dev/null @@ -1,14 +0,0 @@ -//===- OnednnGraphOps.td - OneDNN input dialect ops --------*- tablegen -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ONEDNNGRAPH_OPS -#define ONEDNNGRAPH_OPS - -include "OnednnGraphDialect.td" - -#endif // ONEDNNGRAPH_OPS \ No newline at end of file diff --git a/lib/gc/Dialect/CMakeLists.txt b/lib/gc/Dialect/CMakeLists.txt index a880ff2ed..802d78764 100644 --- a/lib/gc/Dialect/CMakeLists.txt +++ b/lib/gc/Dialect/CMakeLists.txt @@ -1,3 +1,3 @@ add_subdirectory(Linalgx) add_subdirectory(Microkernel) -add_subdirectory(OnednnGraph) +add_subdirectory(OneDNNGraph) diff --git a/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt new file mode 100644 index 000000000..2dc50f1b7 --- /dev/null +++ b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROneDNNGraph + OneDNNGraphDialect.cpp + OneDNNGraphOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/gc/Dialect/OneDNNGraph + + DEPENDS + MLIROneDNNGraphOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR +) \ No newline at end of file diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp new file mode 100644 index 000000000..7529e04fc --- /dev/null +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp @@ -0,0 +1,27 @@ +//===- OneDNNGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" + +using namespace mlir; +using namespace mlir::onednn_graph; + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// OneDNNGraph dialect. +//===----------------------------------------------------------------------===// + +void OneDNNGraphDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc" + >(); +} diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp new file mode 100644 index 000000000..0519ca8f3 --- /dev/null +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp @@ -0,0 +1,184 @@ +//===- OneDNNGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/Support/Debug.h" + +#define GET_OP_CLASSES +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc" + +namespace mlir { +namespace onednn_graph { + +// https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md +template +static LogicalResult inferBroadcastShape( + ShapeRange operands, SmallVector &outShape, + const std::function &getShapeIdx) { + int64_t outRank = 0; + for (size_t i = 0; i < operands.size(); i++) { + auto shape = getShapeIdx(operands, i); + if (!shape.hasRank()) { + return failure(); + } + outRank = std::max(outRank, shape.getRank()); + } + // Start with all 1 dim + outShape.clear(); + outShape.resize(outRank, 1); + // Scan each shape for match dims + for (size_t i = 0; i < operands.size(); i++) { + auto shape = getShapeIdx(operands, i); + auto diff = outShape.size() - shape.getRank(); + for (int64_t j = 0; j < shape.getRank(); j++) { + auto dim1 = outShape[diff + j]; + auto dim2 = shape.getDimSize(j); + auto resolvedDim = dim1; + + if (dim1 == 1) { + resolvedDim = dim2; + } else if (dim2 == 1) { + resolvedDim = dim1; + } else if (dim1 != dim2) { + return failure(); + } + outShape[diff + j] = resolvedDim; + } + } + return success(); +} + +LogicalResult onednn_graph::AddOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outShape; + auto resultTy = dyn_cast(operands.front().getType()); + auto getShapeIdx = [](ValueShapeRange operands, size_t i) { + return operands.getShape(i); + }; + auto ret = + inferBroadcastShape(operands, outShape, getShapeIdx); + inferredReturnShapes.push_back( + ShapedTypeComponents(outShape, resultTy.getElementType())); + return ret; +} + +LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + MatMulOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + // get batch dims from shape + auto extractBatch = [](const ShapeAdaptor &lhsShape, + const ShapeAdaptor &rhsShape, int64_t range, + int64_t diff, SmallVector &outDims) { + for (int64_t i = 0; i < range; i++) { + int64_t idx = i - diff; + if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) { + return failure(); + } + outDims.push_back(lhsShape.getDimSize(i)); + } + return success(); + }; + // get row col of 2d matrix according to transpose info + auto getMatRowCol = [](const ShapeAdaptor &shape, bool transpose) { + using pairRowCol = std::pair; + auto rank = shape.getRank(); + assert(rank > 1); + return transpose ? pairRowCol(shape.getDimSize(rank - 1), + shape.getDimSize(rank - 2)) + : pairRowCol(shape.getDimSize(rank - 2), + shape.getDimSize(rank - 1)); + }; + ShapeAdaptor lhsShape(adaptor.getInputA().getType()); + ShapeAdaptor rhsShape(adaptor.getInputB().getType()); + bool transposeA = adaptor.getTransposeA(); + bool transposeB = adaptor.getTransposeB(); + int64_t lRank = lhsShape.getRank(); + int64_t rRank = rhsShape.getRank(); + // + SmallVector outShape; + LogicalResult status = failure(); + if (lRank == 1 && rRank == 1) { + // 1D x 1D + if (lhsShape.getDimSize(0) != rhsShape.getDimSize(0)) { + return failure(); + } + outShape.push_back(1); + } else if (lRank == 1 && rRank > 1) { + // 1D x ND + auto rMatRowCol = getMatRowCol(rhsShape, transposeB); + status = extractBatch(rhsShape, rhsShape, rRank - 2, 0, outShape); + if (lhsShape.getDimSize(0) != rMatRowCol.first) { + return failure(); + } + outShape.push_back(rhsShape.getDimSize(rMatRowCol.second)); + } else if (lRank > 1 && rRank == 1) { + // ND x 1D + auto lMatRowCol = getMatRowCol(lhsShape, transposeA); + status = extractBatch(lhsShape, lhsShape, lRank - 2, 0, outShape); + if (lMatRowCol.second != rhsShape.getDimSize(0)) { + return failure(); + } + outShape.push_back(lhsShape.getDimSize(lMatRowCol.first)); + } else if (lRank > 1 && rRank > 1) { + if (lRank == rRank) { + // ND x ND + auto range = lRank - 2; + status = extractBatch(lhsShape, rhsShape, range, 0, outShape); + } else if (lRank > rRank) { + // MD x ND (M > N) + auto range = lRank - 2; + auto diff = lRank - rRank; + status = extractBatch(lhsShape, rhsShape, range, diff, outShape); + } else { + // ND x MD (M > N) + auto range = rRank - 2; + auto diff = rRank - lRank; + status = extractBatch(rhsShape, lhsShape, range, diff, outShape); + } + // + auto lMatRowCol = getMatRowCol(lhsShape, transposeA); + auto rMatRowCol = getMatRowCol(rhsShape, transposeB); + if (failed(status) || (lMatRowCol.second != rMatRowCol.first)) { + return failure(); + } + outShape.push_back(lMatRowCol.first); + outShape.push_back(rMatRowCol.second); + } else { + // Not supported + return failure(); + } + auto getShapeIdx = [](ArrayRef operands, size_t i) { + return operands[i]; + }; + // final shape + auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType()); + inferredReturnShapes.push_back(retShape); + // check for bias broadcasting + if (adaptor.getBias()) { + ShapeAdaptor biasShape(adaptor.getBias().getType()); + ShapeAdaptor matShape(retShape); + bool biasRankMatch = biasShape.getRank() == 1 || + biasShape.getRank() == (int64_t)outShape.size(); + SmallVector bcastShape; + if (!biasRankMatch || + failed(inferBroadcastShape>( + {matShape, biasShape}, bcastShape, getShapeIdx))) { + return failure(); + } + } + return success(); +} + +} // namespace onednn_graph +} // namespace mlir diff --git a/lib/gc/Dialect/OnednnGraph/CMakeLists.txt b/lib/gc/Dialect/OnednnGraph/CMakeLists.txt deleted file mode 100644 index 4571697b3..000000000 --- a/lib/gc/Dialect/OnednnGraph/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_mlir_dialect_library(MLIROnednnGraph - OnednnGraphDialect.cpp - OnednnGraphOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/gc/Dialect/OnednnGraph - - DEPENDS - MLIROnednnGraphOpsIncGen - - LINK_LIBS PUBLIC - MLIRIR -) \ No newline at end of file diff --git a/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp b/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp deleted file mode 100644 index 434fa8a57..000000000 --- a/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===- OnednnGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/Dialect/OnednnGraph/OnednnGraphDialect.h" -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h" - -using namespace mlir; -using namespace mlir::onednn_graph; - -void OnednnGraphDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.cpp.inc" - >(); -} diff --git a/lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp b/lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp deleted file mode 100644 index b5f1dadca..000000000 --- a/lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp +++ /dev/null @@ -1,14 +0,0 @@ -//===- OnednnGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h" -#include "gc/Dialect/OnednnGraph/OnednnGraphDialect.h" -#include "mlir/IR/OpImplementation.h" - -#define GET_OP_CLASSES -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.cpp.inc" diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 72a25abf5..818f777e7 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -17,6 +17,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -28,6 +29,7 @@ int main(int argc, char *argv[]) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); + registry.insert(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir new file mode 100644 index 000000000..d10146d28 --- /dev/null +++ b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir @@ -0,0 +1,60 @@ +// RUN: gc-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @mlp +func.func @mlp(%in: tensor<128x512xbf16>, + %weight0: tensor<512x64xbf16>, %bias0: tensor<64xbf16>, + %weight1: tensor<64x256xbf16>, %bias1: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK: [[MM2:%.+]] = onednn_graph.matmul + // CHECK: [[AD2:%.+]] = onednn_graph.add [[MM2]] + // CHECK: [[RL2:%.+]] = onednn_graph.relu [[AD2]] + // CHECK: return [[RL2]] + %0 = onednn_graph.matmul %in, %weight0, %bias0 + : (tensor<128x512xbf16>, tensor<512x64xbf16>, tensor<64xbf16>) -> tensor<128x64xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = onednn_graph.matmul %1, %weight1 + : (tensor<128x64xbf16>, tensor<64x256xbf16>) -> tensor<128x256xbf16> + %3 = onednn_graph.add %2, %bias1 : (tensor<128x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %4 = onednn_graph.relu %3 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %4 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a +func.func @mlp_transpose_a(%in: tensor<512x128xbf16>, + %weight0: tensor<512x256xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: {transpose_a = true} + // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK-NEXT: return [[RL1]] + %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_a = true} + : (tensor<512x128xbf16>, tensor<512x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_b +func.func @mlp_transpose_b(%in: tensor<128x512xbf16>, + %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: {transpose_b = true} + // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK-NEXT: return [[RL1]] + %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_b = true} + : (tensor<128x512xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a_b +func.func @mlp_transpose_a_b(%in: tensor<512x128xbf16>, + %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: {transpose_a = true, transpose_b = true} + // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK-NEXT: return [[RL1]] + %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_a = true, transpose_b = true} + : (tensor<512x128xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +}