diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index fbf3f19cde0e9..e80dbb2afb9ef 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -684,6 +684,16 @@ def LinalgStructuredInterface return; }] >, + InterfaceMethod< + /*desc=*/[{ + Return true if the user has supplied an explicit indexing maps for this op. + }], + /*retTy=*/"bool", + /*methodName=*/"hasUserDefinedMaps", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return false; }] + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 8cb698096ef5b..97b90333e2b20 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: matmul - cpp_class_name: MatmulOp - doc: |- - Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - implements: - - LinalgContractionOpInterface -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: A - kind: input_tensor - type_var: T1 - shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> - - !LinalgOperandDefConfig - name: B - kind: input_tensor - type_var: T2 - shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - - !LinalgOperandDefConfig - name: C - kind: output_tensor - type_var: U - shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> - - !LinalgOperandDefConfig - name: cast - kind: type_fn_attr - default_fn: cast_signed - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> - - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> - - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> - iterator_types: - - parallel - - parallel - - reduction - assignments: - - !ScalarAssign - arg: C - value: !ScalarExpression - scalar_fn: - kind: binary - fn_name: add - operands: - - !ScalarExpression - scalar_arg: C - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: mul - operands: - - !ScalarExpression - scalar_fn: - kind: type - attr_name: cast - type_var: U - operands: - - !ScalarExpression - scalar_arg: A - - !ScalarExpression - scalar_fn: - kind: type - attr_name: cast - type_var: U - operands: - - !ScalarExpression - scalar_arg: B ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul cpp_class_name: QuantizedMatmulOp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 31f2913924726..61d4fc9734c6d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -535,6 +535,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// Op definition for MatmulOp +//===----------------------------------------------------------------------===// + +def MatmulOp : LinalgStructuredBase_Op<"matmul", [ + AttrSizedOperandSegments, + LinalgContractionOpInterface]> { + + let summary = [{ + Performs a matrix multiplication of two 2D inputs without broadcast or transpose. + }]; + let description = [{ + Numeric casting is performed on the operands to the inner multiply, + promoting them to the same data type as the accumulator/output. + + Broadcast and Transpose semantics can be appiled by specifying the explicit attribute + 'indexing_maps' as shown below.This is a list attribute, so the list must include all + the maps if specified. + + Example Transpose: + ``` + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) + outs(%arg2: memref<3x7xf32>) + ``` + + Example Broadcast: + ``` + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, // broadcast + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) + outs(%arg2: memref<3x7xf32>) + ``` + + Example Broadcast and transpose: + ``` + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose + affine_map<(d0, d1, d2) -> (d2)>, // broadcast + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + DefaultValuedOptionalAttr:$indexing_maps, + DefaultValuedOptionalAttr:$cast + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, MatmulOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildStructuredOp($_builder, $_state, resultTensorTypes, + inputs, outputs, attributes, MatmulOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(operands); + $_state.addAttributes(attributes); + $_state.addTypes(resultTensorTypes); + (void)$_state.addRegion(); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + "Attribute":$cast, CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addAttribute("cast", cast); + buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, + attributes, MatmulOp::getRegionBuilder()); + }]> + + ]; + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + SmallVector getIteratorTypesArray(); + + /// Implements the block region builder. + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + + /// Returns a list of AffineMap with the typical matmul indexing charactristic. + SmallVector getDefaultIndexingMaps(); + + /// Returns true if the given broadcast map \p bcastMap is valid for this op. + bool isValidLhsRhsBroadcastMap(AffineMap bcastMap); + + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + + ::mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); + } + + // Generic methods. + static unsigned getNumRegionArgs(); + std::string getLibraryCallName(); + bool hasDynamicIndexingMaps(); + /// Check if the op has broadcast and/or transpose semantic. Returns true if the + /// user defined indexing maps are not equal to default map. + bool hasUserDefinedMaps(); + }]; +} + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 40795879c3026..3b9194098fa78 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -15,13 +15,20 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" #include +#include using namespace mlir; using namespace mlir::linalg; @@ -1142,7 +1149,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); - // Mixed tensor/buffer operands are not allowed. if (!linalgOp.hasPureTensorSemantics() && !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) @@ -1162,6 +1168,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { << ") to be equal to the number of input/output operands (" << linalgOp->getNumOperands() << ")"; + // Set this flag if this op has user defined maps. This is required to guard + // the below error condition which assume default indexing maps. for (OpOperand &opOperand : linalgOp->getOpOperands()) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); @@ -1178,13 +1186,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { << " dim(s) to match the number of loops"; int64_t rank = linalgOp.getRank(&opOperand); + if (indexingMap.getNumResults() != rank) return op->emitOpError("expected operand rank (") << rank << ") to match the result rank of indexing_map #" << opOperand.getOperandNumber() << " (" << indexingMap.getNumResults() << ")"; } - SmallVector redDims; linalgOp.getReductionDims(redDims); @@ -1194,9 +1202,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { // Check if given shapes match to inferred shapes. SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); - - // Verify only static cases since we can't get exact dimension sizes and loop - // ranges for dynamic cases in this stage. + // Verify only static cases since we can't get exact dimension sizes and + // loop ranges for dynamic cases in this stage. if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 730c478c2883e..4f350ea236da8 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" @@ -37,12 +38,17 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include #include using namespace mlir; @@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, // iterator_types is an auto-generated method. } +/// Helper to create a typical indexing map for MatmulOp. Returns a list of +/// AffineMap. +static SmallVector +getDefaultIndexingMapsForMatmul(MLIRContext *context) { + AffineExpr d0, d1, d2; + SmallVector indexingMaps; + bindDims(context, d0, d1, d2); + indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context)); + indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context)); + indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context)); + return indexingMaps; +} + +/// Wrapper to return the typical indexing map array attribute for MatmulOp. +static SmallVector getDefaultIndexingMapAttr(MLIRContext *context) { + return llvm::map_to_vector( + getDefaultIndexingMapsForMatmul(context), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); +} + /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. /// The result types are derived automatically if `resultTensorTypes` is none. /// The body of the operation is filled using `regionBuilder`. All ods-gen /// created structured operations use the method to implement their builders. -static void buildStructuredOp(OpBuilder &b, OperationState &state, - std::optional resultTensorTypes, - ValueRange inputs, ValueRange outputs, - ArrayRef attributes, - RegionBuilderFn regionBuilder) { +static void buildStructuredOp( + OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, ValueRange inputs, + ValueRange outputs, ArrayRef attributes, + RegionBuilderFn regionBuilder, + std::optional> indexingMaps = std::nullopt) { // Derive the result types if needed. SmallVector derivedResultTypes = resultTensorTypes.value_or(TypeRange()); @@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state, state.addOperands(inputs); state.addOperands(outputs); state.addTypes(derivedResultTypes); + + // Initialize indexingMaps, for MatmulOp. + SmallVector indexingMapsAttrVal; + if (indexingMaps.has_value()) { + for (mlir::AffineMap map : *indexingMaps) { + // Convert each AffineMap to an AffineMapAttr + indexingMapsAttrVal.push_back(AffineMapAttr::get(map)); + } + state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); + } else { + indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext()); + state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); + } + state.addAttributes(attributes); state.addAttribute( "operandSegmentSizes", @@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder) { + + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + + if (parser.parseLSquare()) + return failure(); + + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + if (!isa(mapAttr)) { + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + } + indexingMapsAttr.push_back(mapAttr); + + if (parser.parseOptionalComma()) + break; + } while (true); + + if (parser.parseRSquare()) + return failure(); + } + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext()); + } + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); + // Parse optional attributes. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; @@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p, } static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, - ValueRange inputs, ValueRange outputs) { - p.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes", - // See generated code in - // LinalgNamedStructuredOps.yamlgen.cpp.inc - "linalg.memoized_indexing_maps"}); + ValueRange inputs, ValueRange outputs, + ArrayRef elidedAttrs = {}) { + p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); // Printing is shared with generic ops, except for the region and // attributes. @@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder, Location loc) { return arith::ConstantOp::materialize(builder, value, type, loc); } + +/// Returns true if the result AffineExpr of the \p explicitMap is same as \p +/// defaultMap. +static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) { + auto explicitRange = explictMap.getResults(); + auto defaultRange = defaultMap.getResults(); + DenseSet explicitSet(explicitRange.begin(), explicitRange.end()); + DenseSet defaultSet(defaultRange.begin(), defaultRange.end()); + llvm::set_union(explicitSet, defaultSet); + return explicitSet == defaultSet; +} + +/// Returns true if the \p explictMap is broadcasted with respect to the +/// \p defaultMap. +static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { + return explictMap.getNumResults() < defaultMap.getNumResults(); +} + +/// Verifies the broadcast and transpose semantic sepecified by the explicit +/// indexing map for the MatmulOp \p op for each operand specified by \p +/// opIndex. +static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, + unsigned opIndex) { + SmallVector opIndexingMaps = matmulOp.getIndexingMapsArray(); + SmallVector defaultIndexingMaps = + matmulOp.getDefaultIndexingMaps(); + + auto opIndexingMap = opIndexingMaps[opIndex]; + auto defaultIndexingMap = defaultIndexingMaps[opIndex]; + // Check general validity of indexing map results. + if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) + return matmulOp->emitOpError() + << "Unexpected dim expression in map result."; + + // Check if the requested broadcast is valid. + if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { + if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { + return matmulOp->emitOpError() + << "Invalid broadcast requested, should be (d2)."; + } + return success(); + } + return success(); +} + +namespace mlir { +namespace linalg { +//===----------------------------------------------------------------------===// +// MatMulOp +//===----------------------------------------------------------------------===// +SmallVector MatmulOp::getIteratorTypesArray() { + return SmallVector{utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::reduction}; +} + +unsigned MatmulOp::getNumRegionArgs() { return 3; } + +std::string MatmulOp::getLibraryCallName() { + return generateLibraryCallName(getOperation()); +} + +bool MatmulOp::hasDynamicIndexingMaps() { return true; } + +/// Check if the op has broadcast and/or transpose semantic. Returns true if the +/// user defined indexing maps are not equal to default map. +bool MatmulOp::hasUserDefinedMaps() { + SmallVector defaultMaps = getDefaultIndexingMaps(); + SmallVector explicitMaps = getIndexingMapsArray(); + return defaultMaps != explicitMaps; +} + +/// Implements the block region builder for the MatmulOp. This is called by +/// 'fillStructuredOpRegion'. +void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + assert(3 > 0 && block.getNumArguments() == 3 && + "MatmulOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(b, block); + SmallVector yields; + + TypeFn castVal = TypeFn::cast_signed; + auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { + return attr.getName() == "cast"; + }); + if (castIter != attrs.end()) { + if (auto attr = llvm::dyn_cast(castIter->getValue())) + castVal = attr.getValue(); + } + + Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), + block.getArgument(0)); + Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), + block.getArgument(1)); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); + Value value4 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +/// Returns a list of AffineMap with the typical matmul indexing charactristic. +SmallVector MatmulOp::getDefaultIndexingMaps() { + MLIRContext *context = this->getContext(); + return getDefaultIndexingMapsForMatmul(context); +} + +/// Returns true if the given broadcast map \p bcastMap is valid for this op. +bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { + assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr."); + AffineExpr exp = bcastMap.getResult(0); + // Invalid map if the common dimension of matmul not found. + return exp.isFunctionOfDim(bcastMap.getNumDims() - 1); +} + +ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(), + MatmulOp::getRegionBuilder()); +} +void MatmulOp::print(OpAsmPrinter &p) { + SmallVector elidedAttrs = { + "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; + printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); + + SmallVector indexingMaps = + getDefaultIndexingMapAttr(getContext()); + if (!llvm::equal(getIndexingMaps(), indexingMaps)) { + p << " indexing_maps = ["; + llvm::interleaveComma(getIndexingMaps(), p, + [&](Attribute attr) { p.printAttribute(attr); }); + p << "]"; + } +} + +/// Verify the user defined indexing maps. +LogicalResult MatmulOp::verify() { + // Verification of pure matmul is handled by verifyStructuredOpInterface(). + if (!hasUserDefinedMaps()) + return success(); + + for (unsigned opIndex = 0; opIndex < 2; opIndex++) { + if (failed(verifyExtendedMatmulSemantic(*this, opIndex))) + return failure(); + } + return success(); +} + +LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} +void MatmulOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability MatmulOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + +} // namespace linalg +} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index aa0052ce47fa7..6b934f7e8157d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -31,6 +31,13 @@ using namespace mlir::linalg; FailureOr mlir::linalg::transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp matmulOp, bool transposeLHS) { + // Check to not let go the matmul with extended semantic, through this + // transform. + if (matmulOp.hasUserDefinedMaps()) { + return rewriter.notifyMatchFailure( + matmulOp, "only matmul ops with non-extended semantics are supported"); + } + if (!bufferization::hasTensorSemantics(matmulOp)) return rewriter.notifyMatchFailure( matmulOp, "only matmul ops with tensors are supported"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 09c6b2683b438..e3f010d9cfb20 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2071,6 +2071,11 @@ vectorizeScalableVectorPrecondition(Operation *op, return failure(); } + // Check to not let go the matmul with extended semantic, through this + // transform. + if (linalgOp.hasUserDefinedMaps()) + return failure(); + // Cond 4: Only the following ops are supported in the // presence of scalable vectors return success(isElementwise(linalgOp) || isa(op) || diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 0c2275bbc4b22..3c508ed6e324b 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -821,6 +821,12 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( bool fail = true; // TODO: more robust detection of matmulOp, with transposes etc. if (isa_and_nonnull(linalgOp.getOperation())) { + // Check to not let go the matmul with extended semantic, through this + // transform. + if (linalgOp.hasUserDefinedMaps()) { + return emitSilenceableError() + << "only matmul ops with non-extended semantics are supported"; + } Location loc = linalgOp.getLoc(); // TODO: more robust computation of laneId, for now assume a single warp. Value laneId = rewriter.create( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index e4a6ec7487bb2..d5e79b4d3cb6d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -383,23 +383,6 @@ def select( O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) -@linalg_structured_op -def matmul( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) - - @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index 1e8f1435ca0fa..aba26c35931fd 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -29,6 +29,34 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, // ----- +func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_bcast_a( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { +// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) { +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): +// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32 +// CHECK: linalg.yield %[[VAL_7]] : f32 +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> @@ -891,3 +919,86 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor, tensor, tensor> } + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_transpose_a_explicit( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { + +// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK: arith.mulf +// CHECK: arith.addf + +func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) + outs(%arg2: memref<3x7xf32>) + + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_transpose_b_explicit( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { + +// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK: arith.mulf +// CHECK: arith.addf + +func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) + outs(%arg2: memref<3x7xf32>) + + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { + +// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK: arith.mulf +// CHECK: arith.addf + +func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) + outs(%arg2: memref<3x7xf32>) + + return +} + +// ----- + diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index c481a723c5623..b2869893b8042 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -361,6 +361,165 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, // ----- +func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) { + // expected-error @+1 {{expected attribute value}} + linalg.matmul indexing_maps = [ + , + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>) + outs(%arg2 :memref<2x4xf32>) + return +} + +// ----- + +func.func @invalid_matmul_dim_a(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) { + // expected-error @+1 {{Unexpected dim expression in map result}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>) + return +} + +// ----- + +func.func @invalid_matmul_dim_b(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) { + // expected-error @+1 {{Unexpected dim expression in map result}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>) + return +} + +// ----- + +func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> { + // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}} + %0 = linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) + outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32> + return %0: tensor<4x64xf32> +} + +// ----- + +func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> { + // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}} + %0 = linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) + outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32> + return %0: tensor<4x64xf32> +} + +// ----- + +func.func @invalid_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) { + // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #0 (1)}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #1 (1)}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) { + // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { + // expected-error @+1 {{'linalg.matmul' op Unexpected dim expression in map result.}} + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) { + // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}} + linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>) + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + return +} + +// ----- + func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) { // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}} linalg.conv_2d_nhwc_hwcf diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 02ecbed232c8b..65c18de842477 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1201,6 +1201,249 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a // ----- +// CHECK-LABEL: func @matmul_transpose_a_explicit +// CHECK: linalg.matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) +func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) + outs(%arg2: memref<3x7xf32>) + + return +} + +// ----- + +func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) + outs(%arg2: memref<3x7xf32>) + + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_transpose_b_explicit( +// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { +// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +// ----- + +func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) + outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { +// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +// ----- + +func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @matmul_bcast_a +// CHECK: linalg.matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) + +// ----- + +func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @matmul_bcast_a_dim1 +// CHECK: linalg.matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) + +// ----- + +func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @matmul_bcast_b +// CHECK: linalg.matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) + +// ----- + +func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_bcast_a_b( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { +// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] +// CHECK: return +// CHECK: } + +// ----- + +func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @matmul_bcast_b_dim1 +// CHECK: linalg.matmul +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) + +// ----- + +func.func @dynamic_matmul_bcast_a(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @dynamic_matmul_bcast_a( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref, +// CHECK-SAME: %[[VAL_2:.*]]: memref) { +// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref, memref) outs(%[[VAL_2]] : memref) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +// ----- + +func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_bcast_a_transpose_b( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { +// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +// ----- + +func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @matmul_bcast_b_transpose_a( +// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { +// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: return +// CHECK: } + +// ----- + // CHECK-LABEL: func @matmul_transpose_b // CHECK: linalg.matmul_transpose_b // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 3bfbcf7d7f7c8..72045a07b2da8 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -84,81 +84,6 @@ def named_form(lhs, rhs): print(module) - -# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm -@run -def testNamedStructuredOpGenericForm(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) - ) - def named_form(lhs, rhs): - init_result = tensor.empty([4, 8], f32) - # CHECK: "linalg.matmul"(%{{.*}}) - # CHECK-SAME: cast = #linalg.type_fn - # CHECK-SAME: operandSegmentSizes = array - # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): - # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 - # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 - # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> - return linalg.matmul(lhs, rhs, outs=[init_result]) - - module.operation.print(print_generic_op_form=True) - - -# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp -@run -def testNamedStructuredAsGenericOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) - ) - def generic_form(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) - # CHECK: linalg.generic - return linalg.matmul( - lhs, rhs, outs=[init_result.result], emit_generic=True - ) - - print(module) - - -# CHECK-LABEL: TEST: testOpResultFromOtherOp -@run -def testOpResultFromOtherOp(): - with Context(), Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) - ) - def pass_an_op_directly(arg0, arg1): - one = arith.ConstantOp(F32Type.get(), 1.0) - # CHECK: %[[LHS:.*]] = linalg.fill - lhs = linalg.fill(one, outs=[arg0]) - # CHECK: %[[RHS:.*]] = linalg.fill - rhs = linalg.fill(one, outs=[arg1]) - # CHECK: %[[INIT:.*]] = tensor.empty - init = tensor.EmptyOp([4, 8], f32) - # CHECK: linalg.matmul - # CHECK: ins(%[[LHS]], %[[RHS]] - # CHECK: outs(%[[INIT]] - return linalg.matmul(lhs, rhs, outs=init) - - print(module) - - # CHECK-LABEL: TEST: testIdentityRegionOps @run def testIdentityRegionOps(): diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index aa5a52a21f125..f820cb7ee8c3c 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -681,7 +681,11 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ {0}::getNumRegionArgs(), {0}::getRegionBuilder()); } void {0}::print(OpAsmPrinter &p) {{ - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); + SmallVector elidedAttrs = {{"operandSegmentSizes", + "linalg.memoized_indexing_maps", + "indexing_maps"}; + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); } )FMT";