diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h index d4fe91387..ec571e6ca 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/Traits.h" #define GET_OP_CLASSES #include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h.inc" diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td index 031a84fbc..9d0c2fab3 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td @@ -25,10 +25,11 @@ class OneDNNGraph_Op traits = []> : Op; class OneDNNGraph_ElemwiseBinaryOp traits = []> : - OneDNNGraph_Op { - let arguments = (ins OneDNNGraph_LogicalTensor:$input_0, - OneDNNGraph_LogicalTensor:$input_1); - let results = (outs OneDNNGraph_LogicalTensor:$result); + OneDNNGraph_Op { + let arguments = (ins OneDNNGraph_FloatTensor:$input_0, + OneDNNGraph_FloatTensor:$input_1); + let results = (outs OneDNNGraph_FloatTensor:$result); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -36,8 +37,8 @@ class OneDNNGraph_ElemwiseBinaryOp traits = []> : class OneDNNGraph_ElemwiseUnaryOp traits = []> : OneDNNGraph_Op { - let arguments = (ins OneDNNGraph_LogicalTensor:$operand); - let results = (outs OneDNNGraph_LogicalTensor:$result); + let arguments = (ins OneDNNGraph_FloatTensor:$operand); + let results = (outs OneDNNGraph_FloatTensor:$result); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -51,15 +52,15 @@ def OneDNNGraph_MatMulOp : OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> { let summary = "Generalized matrix multiplication"; let description = [{ - `https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html` + `https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html` }]; - let arguments = (ins OneDNNGraph_LogicalTensor:$input_a, - OneDNNGraph_LogicalTensor:$input_b, + let arguments = (ins OneDNNGraph_FloatTensor:$input_a, + OneDNNGraph_FloatTensor:$input_b, Optional:$bias, DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b); - let results = (outs OneDNNGraph_LogicalTensor:$result); + let results = (outs OneDNNGraph_FloatTensor:$result); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -68,14 +69,14 @@ def OneDNNGraph_MatMulOp : def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> { let summary = "element-wise relu"; let description = [{ - `https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html` + `https://oneapi-src.github.io/oneDNN/dev_guide_op_relu.html` }]; } def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> { let summary = "element-wise addition with multi-directional broadcast"; let description = [{ - `https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html` + `https://oneapi-src.github.io/oneDNN/dev_guide_op_add.html` }]; } diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td index 3c9f0e41d..216ba85bc 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td @@ -17,14 +17,26 @@ include "OneDNNGraphDialect.td" // OneDNNGraph type definitions //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Floating-point types. +//===----------------------------------------------------------------------===// +def OneDNNGraph_Float : AnyTypeOf<[F32, + F16, + BF16]>; + +//===----------------------------------------------------------------------===// +// Integer types. +//===----------------------------------------------------------------------===// + +def OneDNNGraph_Int : AnyTypeOf<[SI<8>, + UI<8>]>; + def OneDNNGraph_DataType : AnyTypeOf<[ - F16, - BF16, - F32, - SI<32>, - SI<8>, - UI<8>]>; + OneDNNGraph_Float, + OneDNNGraph_Int + ]>; def OneDNNGraph_LogicalTensor : TensorOf<[OneDNNGraph_DataType]>; +def OneDNNGraph_FloatTensor : TensorOf<[OneDNNGraph_Float]>; #endif // ONEDNNGRAPH_TYPES diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp index 0519ca8f3..d81061c18 100644 --- a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp @@ -17,59 +17,22 @@ namespace mlir { namespace onednn_graph { -// https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md -template -static LogicalResult inferBroadcastShape( - ShapeRange operands, SmallVector &outShape, - const std::function &getShapeIdx) { - int64_t outRank = 0; - for (size_t i = 0; i < operands.size(); i++) { - auto shape = getShapeIdx(operands, i); - if (!shape.hasRank()) { - return failure(); - } - outRank = std::max(outRank, shape.getRank()); - } - // Start with all 1 dim - outShape.clear(); - outShape.resize(outRank, 1); - // Scan each shape for match dims - for (size_t i = 0; i < operands.size(); i++) { - auto shape = getShapeIdx(operands, i); - auto diff = outShape.size() - shape.getRank(); - for (int64_t j = 0; j < shape.getRank(); j++) { - auto dim1 = outShape[diff + j]; - auto dim2 = shape.getDimSize(j); - auto resolvedDim = dim1; - - if (dim1 == 1) { - resolvedDim = dim2; - } else if (dim2 == 1) { - resolvedDim = dim1; - } else if (dim1 != dim2) { - return failure(); - } - outShape[diff + j] = resolvedDim; - } - } - return success(); -} - LogicalResult onednn_graph::AddOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outShape; - auto resultTy = dyn_cast(operands.front().getType()); - auto getShapeIdx = [](ValueShapeRange operands, size_t i) { - return operands.getShape(i); + auto resultTy = dyn_cast(operands.front().getType()); + auto getShapeIdx = [&operands](size_t i) { + return operands.getTypes()[i].dyn_cast().getShape(); }; - auto ret = - inferBroadcastShape(operands, outShape, getShapeIdx); + + auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1), + outShape); inferredReturnShapes.push_back( ShapedTypeComponents(outShape, resultTy.getElementType())); - return ret; + return LogicalResult::success(ret); } LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( @@ -158,22 +121,21 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents( // Not supported return failure(); } - auto getShapeIdx = [](ArrayRef operands, size_t i) { - return operands[i]; - }; // final shape auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType()); inferredReturnShapes.push_back(retShape); // check for bias broadcasting if (adaptor.getBias()) { - ShapeAdaptor biasShape(adaptor.getBias().getType()); - ShapeAdaptor matShape(retShape); + auto biasType = adaptor.getBias().getType(); + ShapeAdaptor biasShape(biasType); + bool biasRankMatch = biasShape.getRank() == 1 || biasShape.getRank() == (int64_t)outShape.size(); - SmallVector bcastShape; + SmallVector resultShape; if (!biasRankMatch || - failed(inferBroadcastShape>( - {matShape, biasShape}, bcastShape, getShapeIdx))) { + !OpTrait::util::getBroadcastedShape( + retShape.getDims(), biasType.dyn_cast().getShape(), + resultShape)) { return failure(); } }