From dd69e6143642f810c8b091554616624b9c644641 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Mon, 13 May 2024 11:12:14 +0800 Subject: [PATCH 1/5] tmp remove --- include/gc/Dialect/OnednnGraph/CMakeLists.txt | 3 -- .../Dialect/OnednnGraph/OnednnGraphDialect.h | 16 ---------- .../Dialect/OnednnGraph/OnednnGraphDialect.td | 29 ------------------- .../gc/Dialect/OnednnGraph/OnednnGraphOps.h | 17 ----------- .../gc/Dialect/OnednnGraph/OnednnGraphOps.td | 14 --------- lib/gc/Dialect/OnednnGraph/CMakeLists.txt | 13 --------- .../OnednnGraph/OnednnGraphDialect.cpp | 20 ------------- lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp | 14 --------- 8 files changed, 126 deletions(-) delete mode 100644 include/gc/Dialect/OnednnGraph/CMakeLists.txt delete mode 100644 include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h delete mode 100644 include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td delete mode 100644 include/gc/Dialect/OnednnGraph/OnednnGraphOps.h delete mode 100644 include/gc/Dialect/OnednnGraph/OnednnGraphOps.td delete mode 100644 lib/gc/Dialect/OnednnGraph/CMakeLists.txt delete mode 100644 lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp delete mode 100644 lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp diff --git a/include/gc/Dialect/OnednnGraph/CMakeLists.txt b/include/gc/Dialect/OnednnGraph/CMakeLists.txt deleted file mode 100644 index 7e7c7eb68..000000000 --- a/include/gc/Dialect/OnednnGraph/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_mlir_dialect(OnednnGraphOps onednn_graph) -add_mlir_doc(OnednnGraphOps OnednnGraphOps gc/Dialect/OnednnGraph/ -gen-op-doc) -add_mlir_doc(OnednnGraphDialect OnednnGraphDialect gc/Dialect/OnednnGraph/ -gen-dialect-doc) diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h b/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h deleted file mode 100644 index 7f128cb64..000000000 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h +++ /dev/null @@ -1,16 +0,0 @@ -//===- OnednnGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef GC_DIALECTS_ONEDNNGRAPHDIALECT_H -#define GC_DIALECTS_ONEDNNGRAPHDIALECT_H - -#include "mlir/IR/Dialect.h" - -#include "gc/Dialect/OnednnGraph/OnednnGraphOpsDialect.h.inc" - -#endif // GC_DIALECTS_ONEDNNGRAPHDIALECT_H diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td b/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td deleted file mode 100644 index 16615a4d3..000000000 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td +++ /dev/null @@ -1,29 +0,0 @@ -//===- OnednnGraphDialect.td - OneDNN input dialect --------*- tablegen -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ONEDNNGRAPH_DIALECT -#define ONEDNNGRAPH_DIALECT - -include "mlir/IR/OpBase.td" - -//===----------------------------------------------------------------------===// -// OneDNNGraph dialect definition. -//===----------------------------------------------------------------------===// - -def OnednnGraphDialect : Dialect { - let name = "onednn_graph"; - let summary = "A dialect for oneDNN Graph."; - let description = [{ - This dialect follows oneDNN Graph Specification. - }]; - let cppNamespace = "::mlir::onednn_graph"; - - let useDefaultTypePrinterParser = 1; -} - -#endif // ONEDNNGRAPH_DIALECT diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.h b/include/gc/Dialect/OnednnGraph/OnednnGraphOps.h deleted file mode 100644 index ad86d908c..000000000 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.h +++ /dev/null @@ -1,17 +0,0 @@ -//===- OnednnGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef GC_DIALECTS_ONEDNNGRAPHOPS_H -#define GC_DIALECTS_ONEDNNGRAPHOPS_H - -#include "mlir/IR/OpDefinition.h" - -#define GET_OP_CLASSES -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h.inc" - -#endif // GC_DIALECTS_ONEDNNGRAPHOPS_H diff --git a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.td b/include/gc/Dialect/OnednnGraph/OnednnGraphOps.td deleted file mode 100644 index e460011ee..000000000 --- a/include/gc/Dialect/OnednnGraph/OnednnGraphOps.td +++ /dev/null @@ -1,14 +0,0 @@ -//===- OnednnGraphOps.td - OneDNN input dialect ops --------*- tablegen -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ONEDNNGRAPH_OPS -#define ONEDNNGRAPH_OPS - -include "OnednnGraphDialect.td" - -#endif // ONEDNNGRAPH_OPS \ No newline at end of file diff --git a/lib/gc/Dialect/OnednnGraph/CMakeLists.txt b/lib/gc/Dialect/OnednnGraph/CMakeLists.txt deleted file mode 100644 index 4571697b3..000000000 --- a/lib/gc/Dialect/OnednnGraph/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_mlir_dialect_library(MLIROnednnGraph - OnednnGraphDialect.cpp - OnednnGraphOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/gc/Dialect/OnednnGraph - - DEPENDS - MLIROnednnGraphOpsIncGen - - LINK_LIBS PUBLIC - MLIRIR -) \ No newline at end of file diff --git a/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp b/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp deleted file mode 100644 index 434fa8a57..000000000 --- a/lib/gc/Dialect/OnednnGraph/OnednnGraphDialect.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===- OnednnGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/Dialect/OnednnGraph/OnednnGraphDialect.h" -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h" - -using namespace mlir; -using namespace mlir::onednn_graph; - -void OnednnGraphDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.cpp.inc" - >(); -} diff --git a/lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp b/lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp deleted file mode 100644 index b5f1dadca..000000000 --- a/lib/gc/Dialect/OnednnGraph/OnednnGraphOps.cpp +++ /dev/null @@ -1,14 +0,0 @@ -//===- OnednnGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h" -#include "gc/Dialect/OnednnGraph/OnednnGraphDialect.h" -#include "mlir/IR/OpImplementation.h" - -#define GET_OP_CLASSES -#include "gc/Dialect/OnednnGraph/OnednnGraphOps.cpp.inc" From cd43162aaa58c096d987853c68db351bc87a4d0f Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Mon, 13 May 2024 11:15:33 +0800 Subject: [PATCH 2/5] rebase --- CMakeLists.txt | 2 +- include/gc/Dialect/CMakeLists.txt | 2 +- include/gc/Dialect/OneDNNGraph/CMakeLists.txt | 3 + .../Dialect/OneDNNGraph/OneDNNGraphDialect.h | 19 ++ .../Dialect/OneDNNGraph/OneDNNGraphDialect.td | 27 +++ .../gc/Dialect/OneDNNGraph/OneDNNGraphOps.h | 23 +++ .../gc/Dialect/OneDNNGraph/OneDNNGraphOps.td | 82 ++++++++ .../gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h | 17 ++ .../Dialect/OneDNNGraph/OneDNNGraphTypes.td | 30 +++ lib/gc/Dialect/CMakeLists.txt | 2 +- lib/gc/Dialect/OneDNNGraph/CMakeLists.txt | 13 ++ .../OneDNNGraph/OneDNNGraphDialect.cpp | 44 +++++ lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp | 184 ++++++++++++++++++ src/gc-opt.cpp | 2 + .../Dialect/OneDNNGraph/onednn-graph-mlp.mlir | 42 ++++ 15 files changed, 489 insertions(+), 3 deletions(-) create mode 100644 include/gc/Dialect/OneDNNGraph/CMakeLists.txt create mode 100644 include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h create mode 100644 include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td create mode 100644 include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h create mode 100644 include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td create mode 100644 include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h create mode 100644 include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td create mode 100644 lib/gc/Dialect/OneDNNGraph/CMakeLists.txt create mode 100644 lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp create mode 100644 lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp create mode 100644 test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir diff --git a/CMakeLists.txt b/CMakeLists.txt index eb12ee346..c1692f486 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,7 +52,7 @@ target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES}) set(GC_LIB_LINKED_LIBS MLIRLinalgx MLIRMicrokernel - MLIROnednnGraph + MLIROneDNNGraph ) target_link_libraries(graph_compiler PRIVATE ${GC_LIB_LINKED_LIBS}) diff --git a/include/gc/Dialect/CMakeLists.txt b/include/gc/Dialect/CMakeLists.txt index ffeda0aa7..bedadc2af 100644 --- a/include/gc/Dialect/CMakeLists.txt +++ b/include/gc/Dialect/CMakeLists.txt @@ -1,3 +1,3 @@ -add_subdirectory(OnednnGraph) +add_subdirectory(OneDNNGraph) add_subdirectory(Microkernel) add_subdirectory(Linalgx) \ No newline at end of file diff --git a/include/gc/Dialect/OneDNNGraph/CMakeLists.txt b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt new file mode 100644 index 000000000..7b446f467 --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_dialect(OneDNNGraphOps onednn_graph) +add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc-dialects/OneDNNGraph/ -gen-op-doc) +add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc-dialects/OneDNNGraph/ -gen-dialect-doc) diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h new file mode 100644 index 000000000..ade444b9b --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h @@ -0,0 +1,19 @@ +//===- OneDNNGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_DIALECTS_ONEDNNGRAPHDIALECT_H +#define GC_DIALECTS_ONEDNNGRAPHDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "gc-dialects/OneDNNGraph/OneDNNGraphOpsDialect.h.inc" + +#endif // GC_DIALECTS_ONEDNNGRAPHDIALECT_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td new file mode 100644 index 000000000..6e1eaceca --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td @@ -0,0 +1,27 @@ +//===- OneDNNGraphDialect.td - OneDNN input dialect --------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_DIALECT +#define ONEDNNGRAPH_DIALECT + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// OneDNNGraph dialect definition. +//===----------------------------------------------------------------------===// + +def OneDNNGraphDialect : Dialect { + let name = "onednn_graph"; + let summary = "A dialect for oneDNN Graph."; + let description = [{ + This dialect follows oneDNN Graph Specification. + }]; + let cppNamespace = "::mlir::onednn_graph"; +} + +#endif // ONEDNNGRAPH_DIALECT diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h new file mode 100644 index 000000000..4e0c34396 --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h @@ -0,0 +1,23 @@ +//===- OneDNNGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_DIALECTS_ONEDNNGRAPHOPS_H +#define GC_DIALECTS_ONEDNNGRAPHOPS_H + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h.inc" + +#endif // GC_DIALECTS_ONEDNNGRAPHOPS_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td new file mode 100644 index 000000000..031a84fbc --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td @@ -0,0 +1,82 @@ +//===- OneDNNGraphOps.td - OneDNN input dialect ops --------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_OPS +#define ONEDNNGRAPH_OPS + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "OneDNNGraphDialect.td" +include "OneDNNGraphTypes.td" + +//===----------------------------------------------------------------------===// +// Base OneDNNGraph operation definition. +//===----------------------------------------------------------------------===// + +class OneDNNGraph_Op traits = []> : + Op; + +class OneDNNGraph_ElemwiseBinaryOp traits = []> : + OneDNNGraph_Op { + let arguments = (ins OneDNNGraph_LogicalTensor:$input_0, + OneDNNGraph_LogicalTensor:$input_1); + let results = (outs OneDNNGraph_LogicalTensor:$result); + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +class OneDNNGraph_ElemwiseUnaryOp traits = []> : + OneDNNGraph_Op { + let arguments = (ins OneDNNGraph_LogicalTensor:$operand); + let results = (outs OneDNNGraph_LogicalTensor:$result); + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// OneDNNGraph op definitions +//===----------------------------------------------------------------------===// + +def OneDNNGraph_MatMulOp : + OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> { + let summary = "Generalized matrix multiplication"; + let description = [{ + `https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html` + }]; + + let arguments = (ins OneDNNGraph_LogicalTensor:$input_a, + OneDNNGraph_LogicalTensor:$input_b, + Optional:$bias, + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b); + let results = (outs OneDNNGraph_LogicalTensor:$result); + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; +} + +def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> { + let summary = "element-wise relu"; + let description = [{ + `https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html` + }]; +} + +def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> { + let summary = "element-wise addition with multi-directional broadcast"; + let description = [{ + `https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html` + }]; +} + +#endif // ONEDNNGRAPH_OPS \ No newline at end of file diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h new file mode 100644 index 000000000..de8ae3062 --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h @@ -0,0 +1,17 @@ +//===- OneDNNGraphTypes.h - OneDNN input dialect types ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_ONEDNNGRAPHTYPES_H +#define ONEDNNGRAPH_ONEDNNGRAPHTYPES_H + +#include "mlir/IR/BuiltinTypes.h" + +#define GET_TYPEDEF_CLASSES +#include "gc-dialects/OneDNNGraph/OneDNNGraphOpsTypes.h.inc" + +#endif // ONEDNNGRAPH_ONEDNNGRAPHTYPES_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td new file mode 100644 index 000000000..e75df7515 --- /dev/null +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td @@ -0,0 +1,30 @@ +//===- OneDNNGraphTypes.h - OneDNN input dialect types -----*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ONEDNNGRAPH_TYPES +#define ONEDNNGRAPH_TYPES + +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/AttrTypeBase.td" +include "OneDNNGraphDialect.td" + +//===----------------------------------------------------------------------===// +// OneDNNGraph type definitions +//===----------------------------------------------------------------------===// + +def OneDNNGraph_DataType : AnyTypeOf<[ + F16, + BF16, + F32, + SI<32>, + SI<8>, + UI<8>]>; + +def OneDNNGraph_LogicalTensor : TensorOf<[OneDNNGraph_DataType]>; + +#endif // ONEDNNGRAPH_TYPES diff --git a/lib/gc/Dialect/CMakeLists.txt b/lib/gc/Dialect/CMakeLists.txt index a880ff2ed..802d78764 100644 --- a/lib/gc/Dialect/CMakeLists.txt +++ b/lib/gc/Dialect/CMakeLists.txt @@ -1,3 +1,3 @@ add_subdirectory(Linalgx) add_subdirectory(Microkernel) -add_subdirectory(OnednnGraph) +add_subdirectory(OneDNNGraph) diff --git a/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt new file mode 100644 index 000000000..88c36c5fc --- /dev/null +++ b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROneDNNGraph + OneDNNGraphDialect.cpp + OneDNNGraphOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/gc-dialects/OneDNNGraph + + DEPENDS + MLIROneDNNGraphOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR +) \ No newline at end of file diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp new file mode 100644 index 000000000..4a81fc647 --- /dev/null +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp @@ -0,0 +1,44 @@ +//===- OneDNNGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h" +#include "gc-dialects/OneDNNGraph/OneDNNGraphTypes.h" + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::onednn_graph; + +#include "gc-dialects/OneDNNGraph/OneDNNGraphOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// OneDNNGraph dialect. +//===----------------------------------------------------------------------===// + +void OneDNNGraphDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.cpp.inc" + >(); +} diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp new file mode 100644 index 000000000..b96d023d2 --- /dev/null +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp @@ -0,0 +1,184 @@ +//===- OneDNNGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h" +#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/Support/Debug.h" + +#define GET_OP_CLASSES +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.cpp.inc" + +namespace mlir { +namespace onednn_graph { + +// https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md +template +static LogicalResult inferBroadcastShape( + ShapeRange operands, SmallVector &outShape, + const std::function &getShapeIdx) { + int64_t outRank = 0; + for (size_t i = 0; i < operands.size(); i++) { + auto shape = getShapeIdx(operands, i); + if (!shape.hasRank()) { + return failure(); + } + outRank = std::max(outRank, shape.getRank()); + } + // Start with all 1 dim + outShape.clear(); + outShape.resize(outRank, 1); + // Scan each shape for match dims + for (size_t i = 0; i < operands.size(); i++) { + auto shape = getShapeIdx(operands, i); + auto diff = outShape.size() - shape.getRank(); + for (int64_t j = 0; j < shape.getRank(); j++) { + auto dim1 = outShape[diff + j]; + auto dim2 = shape.getDimSize(j); + auto resolvedDim = dim1; + + if (dim1 == 1) { + resolvedDim = dim2; + } else if (dim2 == 1) { + resolvedDim = dim1; + } else if (dim1 != dim2) { + return failure(); + } + outShape[diff + j] = resolvedDim; + } + } + return success(); +} + +LogicalResult onednn_graph::AddOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outShape; + auto resultTy = dyn_cast(operands.front().getType()); + auto getShapeIdx = [](ValueShapeRange operands, size_t i) { + return operands.getShape(i); + }; + auto ret = + inferBroadcastShape(operands, outShape, getShapeIdx); + inferredReturnShapes.push_back( + ShapedTypeComponents(outShape, resultTy.getElementType())); + return ret; +} + +LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + MatMulOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + // get batch dims from shape + auto extractBatch = [](const ShapeAdaptor &lhsShape, + const ShapeAdaptor &rhsShape, int64_t range, + int64_t diff, SmallVector &outDims) { + for (int64_t i = 0; i < range; i++) { + int64_t idx = i - diff; + if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) { + return failure(); + } + outDims.push_back(lhsShape.getDimSize(i)); + } + return success(); + }; + // get row col of 2d matrix according to transpose info + auto getMatRowCol = [](const ShapeAdaptor &shape, bool transpose) { + using pairRowCol = std::pair; + auto rank = shape.getRank(); + assert(rank > 1); + return transpose ? pairRowCol(shape.getDimSize(rank - 1), + shape.getDimSize(rank - 2)) + : pairRowCol(shape.getDimSize(rank - 2), + shape.getDimSize(rank - 1)); + }; + ShapeAdaptor lhsShape(adaptor.getInputA().getType()); + ShapeAdaptor rhsShape(adaptor.getInputB().getType()); + bool transposeA = adaptor.getTransposeA(); + bool transposeB = adaptor.getTransposeB(); + int64_t lRank = lhsShape.getRank(); + int64_t rRank = rhsShape.getRank(); + // + SmallVector outShape; + LogicalResult status = failure(); + if (lRank == 1 && rRank == 1) { + // 1D x 1D + if (lhsShape.getDimSize(0) != rhsShape.getDimSize(0)) { + return failure(); + } + outShape.push_back(1); + } else if (lRank == 1 && rRank > 1) { + // 1D x ND + auto rMatRowCol = getMatRowCol(rhsShape, transposeB); + status = extractBatch(rhsShape, rhsShape, rRank - 2, 0, outShape); + if (lhsShape.getDimSize(0) != rMatRowCol.first) { + return failure(); + } + outShape.push_back(rhsShape.getDimSize(rMatRowCol.second)); + } else if (lRank > 1 && rRank == 1) { + // ND x 1D + auto lMatRowCol = getMatRowCol(lhsShape, transposeA); + status = extractBatch(lhsShape, lhsShape, lRank - 2, 0, outShape); + if (lMatRowCol.second != rhsShape.getDimSize(0)) { + return failure(); + } + outShape.push_back(lhsShape.getDimSize(lMatRowCol.first)); + } else if (lRank > 1 && rRank > 1) { + if (lRank == rRank) { + // ND x ND + auto range = lRank - 2; + status = extractBatch(lhsShape, rhsShape, range, 0, outShape); + } else if (lRank > rRank) { + // MD x ND (M > N) + auto range = lRank - 2; + auto diff = lRank - rRank; + status = extractBatch(lhsShape, rhsShape, range, diff, outShape); + } else { + // ND x MD (M > N) + auto range = rRank - 2; + auto diff = rRank - lRank; + status = extractBatch(rhsShape, lhsShape, range, diff, outShape); + } + // + auto lMatRowCol = getMatRowCol(lhsShape, transposeA); + auto rMatRowCol = getMatRowCol(rhsShape, transposeB); + if (failed(status) || (lMatRowCol.second != rMatRowCol.first)) { + return failure(); + } + outShape.push_back(lMatRowCol.first); + outShape.push_back(rMatRowCol.second); + } else { + // Not supported + return failure(); + } + auto getShapeIdx = [](ArrayRef operands, size_t i) { + return operands[i]; + }; + // final shape + auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType()); + inferredReturnShapes.push_back(retShape); + // check for bias broadcasting + if (adaptor.getBias()) { + ShapeAdaptor biasShape(adaptor.getBias().getType()); + ShapeAdaptor matShape(retShape); + bool biasRankMatch = biasShape.getRank() == 1 || + biasShape.getRank() == (int64_t)outShape.size(); + SmallVector bcastShape; + if (!biasRankMatch || + failed(inferBroadcastShape>( + {matShape, biasShape}, bcastShape, getShapeIdx))) { + return failure(); + } + } + return success(); +} + +} // namespace onednn_graph +} // namespace mlir diff --git a/src/gc-opt.cpp b/src/gc-opt.cpp index 72a25abf5..818f777e7 100644 --- a/src/gc-opt.cpp +++ b/src/gc-opt.cpp @@ -17,6 +17,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -28,6 +29,7 @@ int main(int argc, char *argv[]) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); + registry.insert(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir new file mode 100644 index 000000000..3b9d8e863 --- /dev/null +++ b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir @@ -0,0 +1,42 @@ +// RUN: gc-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @mlp +func.func @mlp(%in: tensor<128x512xbf16>, + %weight0: tensor<512x64xbf16>, %bias0: tensor<64xbf16>, + %weight1: tensor<64x256xbf16>, %bias1: tensor<256xbf16>) -> tensor<128x256xbf16> { + %0 = onednn_graph.matmul %in, %weight0, %bias0 + : (tensor<128x512xbf16>, tensor<512x64xbf16>, tensor<64xbf16>) -> tensor<128x64xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = onednn_graph.matmul %1, %weight1 + : (tensor<128x64xbf16>, tensor<64x256xbf16>) -> tensor<128x256xbf16> + %3 = onednn_graph.add %2, %bias1 : (tensor<128x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %4 = onednn_graph.relu %3 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %4 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a +func.func @mlp_transpose_a(%in: tensor<512x128xbf16>, + %weight0: tensor<512x256xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_a = true} + : (tensor<512x128xbf16>, tensor<512x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a +func.func @mlp_transpose_b(%in: tensor<128x512xbf16>, + %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_b = true} + : (tensor<128x512xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} + +// CHECK-LABEL: @mlp_transpose_a_b +func.func @mlp_transpose_a_b(%in: tensor<512x128xbf16>, + %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_a = true, transpose_b = true} + : (tensor<512x128xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> + %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %1 : tensor<128x256xbf16> +} From 861e546e28a14a0d3faf18e24f5b72970344bcfc Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Mon, 13 May 2024 11:24:55 +0800 Subject: [PATCH 3/5] fix --- include/gc/Dialect/OneDNNGraph/CMakeLists.txt | 4 ++-- include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h | 2 +- include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h | 2 +- include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h | 2 +- lib/gc/Dialect/OneDNNGraph/CMakeLists.txt | 2 +- lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp | 10 +++++----- lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp | 6 +++--- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/gc/Dialect/OneDNNGraph/CMakeLists.txt b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt index 7b446f467..63dfde793 100644 --- a/include/gc/Dialect/OneDNNGraph/CMakeLists.txt +++ b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -1,3 +1,3 @@ add_mlir_dialect(OneDNNGraphOps onednn_graph) -add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc-dialects/OneDNNGraph/ -gen-op-doc) -add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc-dialects/OneDNNGraph/ -gen-dialect-doc) +add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc) +add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc) diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h index ade444b9b..bae992507 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h @@ -14,6 +14,6 @@ #include "mlir/IR/OpImplementation.h" #define GET_OP_CLASSES -#include "gc-dialects/OneDNNGraph/OneDNNGraphOpsDialect.h.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsDialect.h.inc" #endif // GC_DIALECTS_ONEDNNGRAPHDIALECT_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h index 4e0c34396..d4fe91387 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h @@ -18,6 +18,6 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #define GET_OP_CLASSES -#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h.inc" #endif // GC_DIALECTS_ONEDNNGRAPHOPS_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h index de8ae3062..c897fb8be 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h @@ -12,6 +12,6 @@ #include "mlir/IR/BuiltinTypes.h" #define GET_TYPEDEF_CLASSES -#include "gc-dialects/OneDNNGraph/OneDNNGraphOpsTypes.h.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsTypes.h.inc" #endif // ONEDNNGRAPH_ONEDNNGRAPHTYPES_H diff --git a/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt index 88c36c5fc..2dc50f1b7 100644 --- a/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt +++ b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIROneDNNGraph OneDNNGraphOps.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/gc-dialects/OneDNNGraph + ${PROJECT_SOURCE_DIR}/include/gc/Dialect/OneDNNGraph DEPENDS MLIROneDNNGraphOpsIncGen diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp index 4a81fc647..1caa71f44 100644 --- a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h" -#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h" -#include "gc-dialects/OneDNNGraph/OneDNNGraphTypes.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -30,7 +30,7 @@ using namespace mlir; using namespace mlir::onednn_graph; -#include "gc-dialects/OneDNNGraph/OneDNNGraphOpsDialect.cpp.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsDialect.cpp.inc" //===----------------------------------------------------------------------===// // OneDNNGraph dialect. @@ -39,6 +39,6 @@ using namespace mlir::onednn_graph; void OneDNNGraphDialect::initialize() { addOperations< #define GET_OP_LIST -#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.cpp.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc" >(); } diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp index b96d023d2..0519ca8f3 100644 --- a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp @@ -6,13 +6,13 @@ // //===----------------------------------------------------------------------===// -#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h" -#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "mlir/IR/OpImplementation.h" #include "llvm/Support/Debug.h" #define GET_OP_CLASSES -#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.cpp.inc" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc" namespace mlir { namespace onednn_graph { From 6a7716706ada835ed6555b249244480494287986 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Mon, 13 May 2024 12:11:16 +0800 Subject: [PATCH 4/5] fix test --- test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir index 3b9d8e863..6d3653a89 100644 --- a/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir +++ b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir @@ -23,7 +23,7 @@ func.func @mlp_transpose_a(%in: tensor<512x128xbf16>, return %1 : tensor<128x256xbf16> } -// CHECK-LABEL: @mlp_transpose_a +// CHECK-LABEL: @mlp_transpose_b func.func @mlp_transpose_b(%in: tensor<128x512xbf16>, %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_b = true} From 08cf9cdc98674b6c27b6261dd4ab47b6d41c9ad8 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 14 May 2024 12:33:19 +0800 Subject: [PATCH 5/5] fix --- .../gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td | 2 +- .../Dialect/OneDNNGraph/OneDNNGraphDialect.cpp | 17 ----------------- .../Dialect/OneDNNGraph/onednn-graph-mlp.mlir | 18 ++++++++++++++++++ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td index e75df7515..3c9f0e41d 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td @@ -18,7 +18,7 @@ include "OneDNNGraphDialect.td" //===----------------------------------------------------------------------===// def OneDNNGraph_DataType : AnyTypeOf<[ - F16, + F16, BF16, F32, SI<32>, diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp index 1caa71f44..7529e04fc 100644 --- a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp @@ -10,23 +10,6 @@ #include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" -#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/TypeSwitch.h" - using namespace mlir; using namespace mlir::onednn_graph; diff --git a/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir index 6d3653a89..d10146d28 100644 --- a/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir +++ b/test/gc/Dialect/OneDNNGraph/onednn-graph-mlp.mlir @@ -4,6 +4,12 @@ func.func @mlp(%in: tensor<128x512xbf16>, %weight0: tensor<512x64xbf16>, %bias0: tensor<64xbf16>, %weight1: tensor<64x256xbf16>, %bias1: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK: [[MM2:%.+]] = onednn_graph.matmul + // CHECK: [[AD2:%.+]] = onednn_graph.add [[MM2]] + // CHECK: [[RL2:%.+]] = onednn_graph.relu [[AD2]] + // CHECK: return [[RL2]] %0 = onednn_graph.matmul %in, %weight0, %bias0 : (tensor<128x512xbf16>, tensor<512x64xbf16>, tensor<64xbf16>) -> tensor<128x64xbf16> %1 = onednn_graph.relu %0 : (tensor<128x64xbf16>) -> tensor<128x64xbf16> @@ -17,6 +23,10 @@ func.func @mlp(%in: tensor<128x512xbf16>, // CHECK-LABEL: @mlp_transpose_a func.func @mlp_transpose_a(%in: tensor<512x128xbf16>, %weight0: tensor<512x256xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: {transpose_a = true} + // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK-NEXT: return [[RL1]] %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_a = true} : (tensor<512x128xbf16>, tensor<512x256xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> @@ -26,6 +36,10 @@ func.func @mlp_transpose_a(%in: tensor<512x128xbf16>, // CHECK-LABEL: @mlp_transpose_b func.func @mlp_transpose_b(%in: tensor<128x512xbf16>, %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: {transpose_b = true} + // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK-NEXT: return [[RL1]] %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_b = true} : (tensor<128x512xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16> @@ -35,6 +49,10 @@ func.func @mlp_transpose_b(%in: tensor<128x512xbf16>, // CHECK-LABEL: @mlp_transpose_a_b func.func @mlp_transpose_a_b(%in: tensor<512x128xbf16>, %weight0: tensor<256x512xbf16>, %bias0: tensor<256xbf16>) -> tensor<128x256xbf16> { + // CHECK: [[MM1:%.+]] = onednn_graph.matmul + // CHECK: {transpose_a = true, transpose_b = true} + // CHECK-NEXT: [[RL1:%.+]] = onednn_graph.relu [[MM1]] + // CHECK-NEXT: return [[RL1]] %0 = onednn_graph.matmul %in, %weight0, %bias0 {transpose_a = true, transpose_b = true} : (tensor<512x128xbf16>, tensor<256x512xbf16>, tensor<256xbf16>) -> tensor<128x256xbf16> %1 = onednn_graph.relu %0 : (tensor<128x256xbf16>) -> tensor<128x256xbf16>