Skip to content

Commit c399a1c

Browse files
[mlir][IR] Auto-generate element type verification for VectorType
1 parent 27f3ffa commit c399a1c

File tree

4 files changed

+10
-15
lines changed

4 files changed

+10
-15
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/IR/BuiltinDialect.td"
1919
include "mlir/IR/BuiltinTypeInterfaces.td"
20+
include "mlir/IR/CommonTypeConstraints.td"
2021

2122
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
2223
// This is to differentiate the types here with the ones in OpBase.td. We should
@@ -1146,7 +1147,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11461147
}];
11471148
let parameters = (ins
11481149
ArrayRefParameter<"int64_t">:$shape,
1149-
"Type":$elementType,
1150+
AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
11501151
ArrayRefParameter<"bool">:$scalableDims
11511152
);
11521153
let builders = [
@@ -1173,6 +1174,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11731174
/// type. In particular, vectors can consist of integer, index, or float
11741175
/// primitives.
11751176
static bool isValidElementType(Type t) {
1177+
// TODO: Auto-generate this function from $elementType.
11761178
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
11771179
}
11781180

mlir/lib/AsmParser/TypeParser.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -458,31 +458,24 @@ Type Parser::parseTupleType() {
458458
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
459459
///
460460
VectorType Parser::parseVectorType() {
461+
SMLoc loc = getToken().getLoc();
461462
consumeToken(Token::kw_vector);
462463

463464
if (parseToken(Token::less, "expected '<' in vector type"))
464465
return nullptr;
465466

467+
// Parse the dimensions.
466468
SmallVector<int64_t, 4> dimensions;
467469
SmallVector<bool, 4> scalableDims;
468470
if (parseVectorDimensionList(dimensions, scalableDims))
469471
return nullptr;
470-
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
471-
return emitError(getToken().getLoc(),
472-
"vector types must have positive constant sizes"),
473-
nullptr;
474472

475473
// Parse the element type.
476-
auto typeLoc = getToken().getLoc();
477474
auto elementType = parseType();
478475
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
479476
return nullptr;
480477

481-
if (!VectorType::isValidElementType(elementType))
482-
return emitError(typeLoc, "vector elements must be int/index/float type"),
483-
nullptr;
484-
485-
return VectorType::get(dimensions, elementType, scalableDims);
478+
return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
486479
}
487480

488481
/// Parse a dimension list in a vector type. This populates the dimension list.

mlir/test/IR/invalid-builtin-types.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,17 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
120120
// -----
121121

122122
// Test no nested vector.
123-
// expected-error@+1 {{vector elements must be int/index/float type}}
123+
// expected-error@+1 {{failed to verify 'elementType': integer or index or floating-point}}
124124
func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
125125

126126
// -----
127127

128-
// expected-error @+1 {{vector types must have positive constant sizes}}
128+
// expected-error @+1 {{vector types must have positive constant sizes but got 0}}
129129
func.func @zero_vector_type() -> vector<0xi32>
130130

131131
// -----
132132

133-
// expected-error @+1 {{vector types must have positive constant sizes}}
133+
// expected-error @+1 {{vector types must have positive constant sizes but got 1, 0}}
134134
func.func @zero_in_vector_type() -> vector<1x0xi32>
135135

136136
// -----

mlir/test/python/ir/builtin_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def testVectorType():
345345
VectorType.get(shape, none)
346346
except MLIRError as e:
347347
# CHECK: Invalid type:
348-
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
348+
# CHECK: error: unknown: failed to verify 'elementType': integer or index or floating-point
349349
print(e)
350350
else:
351351
print("Exception not produced")

0 commit comments

Comments
 (0)