diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8353314ed958b..125cd4645ccc2 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -790,40 +790,89 @@ def Vector_FMAOp : }]; } +def Vector_ToElementsOp : Vector_Op<"to_elements", [ + Pure, + ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> { + let summary = "operation that decomposes a vector into all its scalar elements"; + let description = [{ + This operation decomposes all the scalar elements from a vector. The + decomposed scalar elements are returned in row-major order. The number of + scalar results must match the number of elements in the input vector type. + All the result elements have the same result type, which must match the + element type of the input vector. Scalable vectors are not supported. + + Examples: + + ```mlir + // Decompose a 0-D vector. + %0 = vector.to_elements %v0 : vector + // %0 = %v0[0] + + // Decompose a 1-D vector. + %0:2 = vector.to_elements %v1 : vector<2xf32> + // %0#0 = %v1[0] + // %0#1 = %v1[1] + + // Decompose a 2-D. + %0:6 = vector.to_elements %v2 : vector<2x3xf32> + // %0#0 = %v2[0, 0] + // %0#1 = %v2[0, 1] + // %0#2 = %v2[0, 2] + // %0#3 = %v2[1, 0] + // %0#4 = %v2[1, 1] + // %0#5 = %v2[1, 2] + + // Decompose a 3-D vector. + %0:6 = vector.to_elements %v3 : vector<3x1x2xf32> + // %0#0 = %v3[0, 0, 0] + // %0#1 = %v3[0, 0, 1] + // %0#2 = %v3[1, 0, 0] + // %0#3 = %v3[1, 0, 1] + // %0#4 = %v3[2, 0, 0] + // %0#5 = %v3[2, 0, 1] + ``` + }]; + + let arguments = (ins AnyVectorOfAnyRank:$source); + let results = (outs Variadic:$elements); + let assemblyFormat = "$source attr-dict `:` type($source)"; +} + def Vector_FromElementsOp : Vector_Op<"from_elements", [ Pure, - TypesMatchWith<"operand types match result element type", - "result", "elements", "SmallVector(" - "::llvm::cast($_self).getNumElements(), " - "::llvm::cast($_self).getElementType())">]> { + ShapedTypeMatchesElementCountAndTypes<"dest", "elements">]> { let summary = "operation that defines a vector from scalar elements"; let description = [{ This operation defines a vector from one or multiple scalar elements. The - number of elements must match the number of elements in the result type. - All elements must have the same type, which must match the element type of - the result vector type. - - `elements` are a flattened version of the result vector in row-major order. + scalar elements are arranged in row-major within the vector. The number of + elements must match the number of elements in the result type. All elements + must have the same type, which must match the element type of the result + vector type. Scalable vectors are not supported. - Example: + Examples: ```mlir - // %f1 + // Define a 0-D vector. %0 = vector.from_elements %f1 : vector - // [%f1, %f2] + // [%f1] + + // Define a 1-D vector. %1 = vector.from_elements %f1, %f2 : vector<2xf32> - // [[%f1, %f2, %f3], [%f4, %f5, %f6]] + // [%f1, %f2] + + // Define a 2-D vector. %2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32> - // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]] + // [[%f1, %f2, %f3], [%f4, %f5, %f6]] + + // Define a 3-D vector. %3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32> + // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]] ``` - - Note, scalable vectors are not supported. }]; let arguments = (ins Variadic:$elements); - let results = (outs AnyFixedVectorOfAnyRank:$result); - let assemblyFormat = "$elements attr-dict `:` type($result)"; + let results = (outs AnyFixedVectorOfAnyRank:$dest); + let assemblyFormat = "$elements attr-dict `:` type($dest)"; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 51b60972203e7..b3fabe409806f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -556,6 +556,25 @@ class AllShapesMatch names> : class AllTypesMatch names> : AllMatchSameOperatorTrait; +// A type constraint that verifies that a shaped type matches the size and +// element type of a container with element types. More specifically, it denotes +// shapedArg.getType().getNumElements() == elementsArg.size() && +// shapedArg.getType().getElementType() == elementsArg[i].getType(), for i in +// [0, elementsArg.size()). +class ShapedTypeMatchesElementCountAndTypes : + PredOpTrait<"shaped type '" # shapedArg # "' matches '" # elementsArg # "' " + "element count and types", + And<[CPred.result # " == " + "$" # elementsArg # ".getTypes().size()">, + CPred<"::llvm::all_of($" # elementsArg # ".getTypes(), " + "[&](::mlir::Type t) { return t == " + # ElementType.result # "; })">]>> { + + string shaped = shapedArg; + string elements = elementsArg; +} + // A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`. // An optional comparator function may be provided that changes the above form // into: `comparator(transform(lhs.getType()), rhs.getType())`. diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 2544f0a1b91b6..07520a2f94d77 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -468,6 +468,37 @@ void Operator::populateTypeInferenceInfo( continue; } + // The `ShapedTypeMatchesElementCountAndTypes` trait represents a 1 -> 1 + // type inference edge where a shaped type matches element count and types + // of variadic elements. + if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) { + StringRef shapedArg = def.getValueAsString("shaped"); + StringRef elementsArg = def.getValueAsString("elements"); + + int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg); + int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg); + + // Handle result type inference from shaped type to variadic elements. + if (InferredResultType::isResultIndex(elementsIndex) && + InferredResultType::isArgIndex(shapedIndex)) { + int resultIndex = InferredResultType::unmapResultIndex(elementsIndex); + ResultTypeInference &infer = inference[resultIndex]; + if (!infer.inferred) { + infer.sources.emplace_back( + shapedIndex, + "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::" + "ShapedType>($_self).getNumElements(), " + "::llvm::cast<::mlir::ShapedType>($_self).getElementType())"); + infer.inferred = true; + } + } + + // Type inference in the opposite direction is not possible as the actual + // shaped type can't be inferred from the variadic elements. + + continue; + } + if (!def.isSubClassOf("AllTypesMatch")) continue; diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 04810ed52584f..ec7cee7b2c641 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) { // ----- -func.func @invalid_from_elements(%a: f32) { +func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) { + // expected-error @+1 {{operation defines 2 results but was provided 4 to bind}} + %0:4 = vector.to_elements %a : vector<1x1x2xf32> + return +} + +// ----- + +func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 { + // expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}} + // expected-note @+1 {{prior use here}} + %0:2 = vector.to_elements %a : vector<2xf32> + return %0#0 : i32 +} + +// ----- + +func.func @from_elements_wrong_num_operands(%a: f32) { // expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}} vector.from_elements %a : vector<2xf32> return @@ -1905,16 +1922,15 @@ func.func @invalid_from_elements(%a: f32) { // ----- // expected-note @+1 {{prior use here}} -func.func @invalid_from_elements(%a: f32, %b: i32) { +func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) { // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}} vector.from_elements %a, %b : vector<2xf32> return } - // ----- func.func @invalid_from_elements_scalable(%a: f32, %b: i32) { - // expected-error @+1 {{'result' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}} + // expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}} vector.from_elements %a, %b : vector<[2]xf32> return } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index f3220aed4360c..c59f7bd001905 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1175,6 +1175,24 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4 return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32> } +// CHECK-LABEL: func @to_elements( +// CHECK-SAME: %[[A_VEC:.*]]: vector, %[[B_VEC:.*]]: vector<1xf32>, +// CHECK-SAME: %[[C_VEC:.*]]: vector<1x2xf32>, %[[D_VEC:.*]]: vector<2x2xf32>) +func.func @to_elements(%a_vec : vector, %b_vec : vector<1xf32>, + %c_vec : vector<1x2xf32>, %d_vec : vector<2x2xf32>) + -> (f32, f32, f32, f32, f32, f32, f32, f32) { + // CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector + %0 = vector.to_elements %a_vec : vector + // CHECK: %[[B_ELEMS:.*]] = vector.to_elements %[[B_VEC]] : vector<1xf32> + %1 = vector.to_elements %b_vec : vector<1xf32> + // CHECK: %[[C_ELEMS:.*]]:2 = vector.to_elements %[[C_VEC]] : vector<1x2xf32> + %2:2 = vector.to_elements %c_vec : vector<1x2xf32> + // CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32> + %3:4 = vector.to_elements %d_vec : vector<2x2xf32> + // CHECK: return %[[A_ELEMS]], %[[B_ELEMS]], %[[C_ELEMS]]#0, %[[C_ELEMS]]#1, %[[D_ELEMS]]#0, %[[D_ELEMS]]#1, %[[D_ELEMS]]#2, %[[D_ELEMS]]#3 + return %0, %1, %2#0, %2#1, %3#0, %3#1, %3#2, %3#3: f32, f32, f32, f32, f32, f32, f32, f32 +} + // CHECK-LABEL: func @from_elements( // CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) func.func @from_elements(%a: f32, %b: f32) -> (vector, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) { diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 0a9d14d6603a8..ef3a18ba7df22 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2787,6 +2787,11 @@ class OpFormatParser : public FormatParser { void handleTypesMatchConstraint( StringMap &variableTyResolver, const Record &def); + /// Check for inferable type resolution based on + /// `ShapedTypeMatchesElementCountAndTypes` constraint. + void handleShapedTypeMatchesElementCountAndTypesConstraint( + StringMap &variableTyResolver, const Record &def); + /// Returns an argument or attribute with the given name that has been seen /// within the format. ConstArgument findSeenArg(StringRef name); @@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc, handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); + } else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) { + handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver, + def); } else if (!op.allResultTypesKnown()) { // This doesn't check the name directly to handle // DeclareOpInterfaceMethods @@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint( variableTyResolver[rhsName] = {arg, transformer}; } +void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint( + StringMap &variableTyResolver, const Record &def) { + StringRef shapedArg = def.getValueAsString("shaped"); + StringRef elementsArg = def.getValueAsString("elements"); + + // Check if the 'shaped' argument is seen, then we can infer the 'elements' + // types. + if (ConstArgument arg = findSeenArg(shapedArg)) { + variableTyResolver[elementsArg] = { + arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::" + "ShapedType>($_self).getNumElements(), " + "::llvm::cast<::mlir::ShapedType>($_self).getElementType())"}; + } + + // Type inference in the opposite direction is not possible as the actual + // shaped type can't be inferred from the variadic elements. +} + ConstArgument OpFormatParser::findSeenArg(StringRef name) { if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;