Skip to content

Commit b8bf14e

Browse files
[mlir][LLVMIR] Check number of elements in mlir.constant verifier (#102906)
Check that the number of elements in the result type and the attribute of an `llvm.mlir.constant` op matches. Also fix a broken test where that was not the case.
1 parent 85b113c commit b8bf14e

File tree

4 files changed

+94
-15
lines changed

4 files changed

+94
-15
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,10 +1623,15 @@ def LLVM_ConstantOp
16231623
vectors. It has a mandatory `value` attribute, which may be an integer,
16241624
floating point attribute; dense or sparse attribute containing integers or
16251625
floats. The type of the attribute is one of the corresponding MLIR builtin
1626-
types. It may be omitted for `i64` and `f64` types that are implied. The
1627-
operation produces a new SSA value of the specified LLVM IR dialect type.
1628-
The type of that value _must_ correspond to the attribute type converted to
1629-
LLVM IR.
1626+
types. It may be omitted for `i64` and `f64` types that are implied.
1627+
1628+
The operation produces a new SSA value of the specified LLVM IR dialect
1629+
type. Certain builtin types such as integer, float and vector types are
1630+
also allowed. The result type _must_ correspond to the attribute type
1631+
converted to LLVM IR. In particular, the number of elements of a container
1632+
type must match the number of elements in the attribute. If the type is or
1633+
contains a scalable vector type, the attribute must be a splat elements
1634+
attribute.
16301635

16311636
Examples:
16321637

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2666,6 +2666,39 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
26662666
// ConstantOp.
26672667
//===----------------------------------------------------------------------===//
26682668

2669+
/// Compute the total number of elements in the given type, also taking into
2670+
/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
2671+
/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
2672+
static int64_t getNumElements(Type t) {
2673+
if (auto vecType = dyn_cast<VectorType>(t))
2674+
return vecType.getNumElements() * getNumElements(vecType.getElementType());
2675+
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2676+
return arrayType.getNumElements() *
2677+
getNumElements(arrayType.getElementType());
2678+
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2679+
return vecType.getNumElements() * getNumElements(vecType.getElementType());
2680+
assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
2681+
"number of elements of a scalable vector type is unknown");
2682+
return 1;
2683+
}
2684+
2685+
/// Check if the given type is a scalable vector type or a vector/array type
2686+
/// that contains a nested scalable vector type.
2687+
static bool hasScalableVectorType(Type t) {
2688+
if (isa<LLVM::LLVMScalableVectorType>(t))
2689+
return true;
2690+
if (auto vecType = dyn_cast<VectorType>(t)) {
2691+
if (vecType.isScalable())
2692+
return true;
2693+
return hasScalableVectorType(vecType.getElementType());
2694+
}
2695+
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2696+
return hasScalableVectorType(arrayType.getElementType());
2697+
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2698+
return hasScalableVectorType(vecType.getElementType());
2699+
return false;
2700+
}
2701+
26692702
LogicalResult LLVM::ConstantOp::verify() {
26702703
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
26712704
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -2708,14 +2741,12 @@ LogicalResult LLVM::ConstantOp::verify() {
27082741
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
27092742
return emitOpError() << "does not support target extension type.";
27102743
}
2711-
if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
2712-
return emitOpError()
2713-
<< "only supports integer, float, string or elements attributes";
2744+
2745+
// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
27142746
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
27152747
if (!llvm::isa<IntegerType>(getType()))
27162748
return emitOpError() << "expected integer type";
2717-
}
2718-
if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
2749+
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
27192750
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
27202751
unsigned floatWidth = APFloat::getSizeInBits(sem);
27212752
if (auto floatTy = dyn_cast<FloatType>(getType())) {
@@ -2728,13 +2759,34 @@ LogicalResult LLVM::ConstantOp::verify() {
27282759
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
27292760
return emitOpError() << "expected integer type of width " << floatWidth;
27302761
}
2731-
}
2732-
if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
2733-
if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
2734-
!isa<LLVM::LLVMFixedVectorType>(getType()) &&
2735-
!isa<LLVM::LLVMScalableVectorType>(getType()))
2762+
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
2763+
if (hasScalableVectorType(getType())) {
2764+
// The exact number of elements of a scalable vector is unknown, so we
2765+
// allow only splat attributes.
2766+
auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
2767+
if (!splatElementsAttr)
2768+
return emitOpError()
2769+
<< "scalable vector type requires a splat attribute";
2770+
return success();
2771+
}
2772+
if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
2773+
getType()))
27362774
return emitOpError() << "expected vector or array type";
2775+
// The number of elements of the attribute and the type must match.
2776+
int64_t attrNumElements;
2777+
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
2778+
attrNumElements = elementsAttr.getNumElements();
2779+
else
2780+
attrNumElements = cast<ArrayAttr>(getValue()).size();
2781+
if (getNumElements(getType()) != attrNumElements)
2782+
return emitOpError()
2783+
<< "type and attribute have a different number of elements: "
2784+
<< getNumElements(getType()) << " vs. " << attrNumElements;
2785+
} else {
2786+
return emitOpError()
2787+
<< "only supports integer, float, string or elements attributes";
27372788
}
2789+
27382790
return success();
27392791
}
27402792

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,22 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !
414414

415415
// -----
416416

417+
llvm.func @const_wrong_number_of_elements() -> vector<5xf64> {
418+
// expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}}
419+
%0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64>
420+
llvm.return %0 : vector<5xf64>
421+
}
422+
423+
// -----
424+
425+
llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
426+
// expected-error @+1{{scalable vector type requires a splat attribute}}
427+
%0 = llvm.mlir.constant(dense<[1.0, 1.0, 2.0, 2.0]> : tensor<4xf64>) : vector<[4]xf64>
428+
llvm.return %0 : vector<[4]xf64>
429+
}
430+
431+
// -----
432+
417433
func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
418434
// expected-error@+2 {{expected LLVM IR Dialect type}}
419435
llvm.insertvalue %a, %b[0] : tensor<*xi32>

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
12951295
}
12961296

12971297
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
1298-
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
1298+
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
12991299
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
13001300
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
13011301
}
13021302

1303+
llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> {
1304+
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<1 x !llvm.struct<(i32, i32)>>
1305+
// CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }]
1306+
llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>>
1307+
}
1308+
13031309
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
13041310
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
13051311
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]

0 commit comments

Comments
 (0)