Skip to content

Commit c71b42d

Browse files
committed
Add ShapedTypeMatchesElementCountAndTypes
1 parent 1aa54ae commit c71b42d

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -792,10 +792,7 @@ def Vector_FMAOp :
792792

793793
def Vector_ToElementsOp : Vector_Op<"to_elements", [
794794
Pure,
795-
TypesMatchWith<"operand element type matches result types",
796-
"source", "elements", "SmallVector<Type>("
797-
"::llvm::cast<VectorType>($_self).getNumElements(), "
798-
"::llvm::cast<VectorType>($_self).getElementType())">]> {
795+
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
799796
let summary = "operation that decomposes a vector into all its scalar elements";
800797
let description = [{
801798
This operation decomposes all the scalar elements from a vector. The
@@ -843,10 +840,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
843840

844841
def Vector_FromElementsOp : Vector_Op<"from_elements", [
845842
Pure,
846-
TypesMatchWith<"operand types match result element type",
847-
"dest", "elements", "SmallVector<Type>("
848-
"::llvm::cast<VectorType>($_self).getNumElements(), "
849-
"::llvm::cast<VectorType>($_self).getElementType())">]> {
843+
ShapedTypeMatchesElementCountAndTypes<"dest", "elements">]> {
850844
let summary = "operation that defines a vector from scalar elements";
851845
let description = [{
852846
This operation defines a vector from one or multiple scalar elements. The

mlir/include/mlir/IR/OpBase.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,25 @@ class AllShapesMatch<list<string> names> :
556556
class AllTypesMatch<list<string> names> :
557557
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
558558

559+
// A type constraint that verifies that a shaped type matches the size and
560+
// element type of a container with element types. More specifically, it denotes
561+
// shapedArg.getType().getNumElements() == elementsArg.size() &&
562+
// shapedArg.getType().getElementType() == elementsArg[i].getType(), for i in
563+
// [0, elementsArg.size()).
564+
class ShapedTypeMatchesElementCountAndTypes<string shapedArg,
565+
string elementsArg> :
566+
PredOpTrait<"shaped type '" # shapedArg # "' matches '" # elementsArg # "' "
567+
"element count and types",
568+
And<[CPred<ElementCount<shapedArg>.result # " == "
569+
"$" # elementsArg # ".getTypes().size()">,
570+
CPred<"::llvm::all_of($" # elementsArg # ".getTypes(), "
571+
"[&](::mlir::Type t) { return t == "
572+
# ElementType<shapedArg>.result # "; })">]>> {
573+
574+
string shaped = shapedArg;
575+
string elements = elementsArg;
576+
}
577+
559578
// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
560579
// An optional comparator function may be provided that changes the above form
561580
// into: `comparator(transform(lhs.getType()), rhs.getType())`.

mlir/lib/TableGen/Operator.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,37 @@ void Operator::populateTypeInferenceInfo(
468468
continue;
469469
}
470470

471+
// The `ShapedTypeMatchesElementCountAndTypes` trait represents a 1 -> 1
472+
// type inference edge where a shaped type matches element count and types
473+
// of variadic elements.
474+
if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
475+
StringRef shapedArg = def.getValueAsString("shaped");
476+
StringRef elementsArg = def.getValueAsString("elements");
477+
478+
int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg);
479+
int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg);
480+
481+
// Handle result type inference from shaped type to variadic elements.
482+
if (InferredResultType::isResultIndex(elementsIndex) &&
483+
InferredResultType::isArgIndex(shapedIndex)) {
484+
int resultIndex = InferredResultType::unmapResultIndex(elementsIndex);
485+
ResultTypeInference &infer = inference[resultIndex];
486+
if (!infer.inferred) {
487+
infer.sources.emplace_back(
488+
shapedIndex,
489+
"::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
490+
"ShapedType>($_self).getNumElements(), "
491+
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
492+
infer.inferred = true;
493+
}
494+
}
495+
496+
// Type inference in the opposite direction is not possible as the actual
497+
// shaped type can't be inferred from the variadic elements.
498+
499+
continue;
500+
}
501+
471502
if (!def.isSubClassOf("AllTypesMatch"))
472503
continue;
473504

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,11 @@ class OpFormatParser : public FormatParser {
27872787
void handleTypesMatchConstraint(
27882788
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
27892789

2790+
/// Check for inferable type resolution based on
2791+
/// `ShapedTypeMatchesElementCountAndTypes` constraint.
2792+
void handleShapedTypeMatchesElementCountAndTypesConstraint(
2793+
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
2794+
27902795
/// Returns an argument or attribute with the given name that has been seen
27912796
/// within the format.
27922797
ConstArgument findSeenArg(StringRef name);
@@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
28502855
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
28512856
} else if (def.isSubClassOf("TypesMatchWith")) {
28522857
handleTypesMatchConstraint(variableTyResolver, def);
2858+
} else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
2859+
handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
2860+
def);
28532861
} else if (!op.allResultTypesKnown()) {
28542862
// This doesn't check the name directly to handle
28552863
// DeclareOpInterfaceMethods<InferTypeOpInterface>
@@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
32893297
variableTyResolver[rhsName] = {arg, transformer};
32903298
}
32913299

3300+
void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
3301+
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
3302+
StringRef shapedArg = def.getValueAsString("shaped");
3303+
StringRef elementsArg = def.getValueAsString("elements");
3304+
3305+
// Check if the 'shaped' argument is seen, then we can infer the 'elements'
3306+
// types.
3307+
if (ConstArgument arg = findSeenArg(shapedArg)) {
3308+
variableTyResolver[elementsArg] = {
3309+
arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
3310+
"ShapedType>($_self).getNumElements(), "
3311+
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
3312+
}
3313+
3314+
// Type inference in the opposite direction is not possible as the actual
3315+
// shaped type can't be inferred from the variadic elements.
3316+
}
3317+
32923318
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
32933319
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
32943320
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;

0 commit comments

Comments
 (0)