diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index c38a2584c8eec..643522d5903fd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1623,10 +1623,15 @@ def LLVM_ConstantOp vectors. It has a mandatory `value` attribute, which may be an integer, floating point attribute; dense or sparse attribute containing integers or floats. The type of the attribute is one of the corresponding MLIR builtin - types. It may be omitted for `i64` and `f64` types that are implied. The - operation produces a new SSA value of the specified LLVM IR dialect type. - The type of that value _must_ correspond to the attribute type converted to - LLVM IR. + types. It may be omitted for `i64` and `f64` types that are implied. + + The operation produces a new SSA value of the specified LLVM IR dialect + type. Certain builtin types such as integer, float and vector types are + also allowed. The result type _must_ correspond to the attribute type + converted to LLVM IR. In particular, the number of elements of a container + type must match the number of elements in the attribute. If the type is or + contains a scalable vector type, the attribute must be a splat elements + attribute. Examples: diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 90610118a45cd..07262bb8e1bac 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2666,6 +2666,39 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) { // ConstantOp. //===----------------------------------------------------------------------===// +/// Compute the total number of elements in the given type, also taking into +/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and +/// `LLVMFixedVectorType`. Everything else is treated as a scalar. +static int64_t getNumElements(Type t) { + if (auto vecType = dyn_cast(t)) + return vecType.getNumElements() * getNumElements(vecType.getElementType()); + if (auto arrayType = dyn_cast(t)) + return arrayType.getNumElements() * + getNumElements(arrayType.getElementType()); + if (auto vecType = dyn_cast(t)) + return vecType.getNumElements() * getNumElements(vecType.getElementType()); + assert(!isa(t) && + "number of elements of a scalable vector type is unknown"); + return 1; +} + +/// Check if the given type is a scalable vector type or a vector/array type +/// that contains a nested scalable vector type. +static bool hasScalableVectorType(Type t) { + if (isa(t)) + return true; + if (auto vecType = dyn_cast(t)) { + if (vecType.isScalable()) + return true; + return hasScalableVectorType(vecType.getElementType()); + } + if (auto arrayType = dyn_cast(t)) + return hasScalableVectorType(arrayType.getElementType()); + if (auto vecType = dyn_cast(t)) + return hasScalableVectorType(vecType.getElementType()); + return false; +} + LogicalResult LLVM::ConstantOp::verify() { if (StringAttr sAttr = llvm::dyn_cast(getValue())) { auto arrayType = llvm::dyn_cast(getType()); @@ -2708,14 +2741,12 @@ LogicalResult LLVM::ConstantOp::verify() { if (auto targetExtType = dyn_cast(getType())) { return emitOpError() << "does not support target extension type."; } - if (!llvm::isa(getValue())) - return emitOpError() - << "only supports integer, float, string or elements attributes"; + + // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr. if (auto intAttr = dyn_cast(getValue())) { if (!llvm::isa(getType())) return emitOpError() << "expected integer type"; - } - if (auto floatAttr = dyn_cast(getValue())) { + } else if (auto floatAttr = dyn_cast(getValue())) { const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); unsigned floatWidth = APFloat::getSizeInBits(sem); if (auto floatTy = dyn_cast(getType())) { @@ -2728,13 +2759,34 @@ LogicalResult LLVM::ConstantOp::verify() { if (isa(getType()) && !getType().isInteger(floatWidth)) { return emitOpError() << "expected integer type of width " << floatWidth; } - } - if (auto splatAttr = dyn_cast(getValue())) { - if (!isa(getType()) && !isa(getType()) && - !isa(getType()) && - !isa(getType())) + } else if (isa(getValue())) { + if (hasScalableVectorType(getType())) { + // The exact number of elements of a scalable vector is unknown, so we + // allow only splat attributes. + auto splatElementsAttr = dyn_cast(getValue()); + if (!splatElementsAttr) + return emitOpError() + << "scalable vector type requires a splat attribute"; + return success(); + } + if (!isa( + getType())) return emitOpError() << "expected vector or array type"; + // The number of elements of the attribute and the type must match. + int64_t attrNumElements; + if (auto elementsAttr = dyn_cast(getValue())) + attrNumElements = elementsAttr.getNumElements(); + else + attrNumElements = cast(getValue()).size(); + if (getNumElements(getType()) != attrNumElements) + return emitOpError() + << "type and attribute have a different number of elements: " + << getNumElements(getType()) << " vs. " << attrNumElements; + } else { + return emitOpError() + << "only supports integer, float, string or elements attributes"; } + return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index fe288dab973f5..62346ce0d2c4b 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -414,6 +414,22 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, ! // ----- +llvm.func @const_wrong_number_of_elements() -> vector<5xf64> { + // expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}} + %0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64> + llvm.return %0 : vector<5xf64> +} + +// ----- + +llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> { + // expected-error @+1{{scalable vector type requires a splat attribute}} + %0 = llvm.mlir.constant(dense<[1.0, 1.0, 2.0, 2.0]> : tensor<4xf64>) : vector<[4]xf64> + llvm.return %0 : vector<[4]xf64> +} + +// ----- + func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) { // expected-error@+2 {{expected LLVM IR Dialect type}} llvm.insertvalue %a, %b[0] : tensor<*xi32> diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index fbdf725f3ec17..8453983aa07c3 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> { } llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> { - %1 = llvm.mlir.constant(dense<(0, 1)> : tensor>) : !llvm.array<2 x !llvm.struct<(i32, i32)>> + %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex>) : !llvm.array<2 x !llvm.struct<(i32, i32)>> // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }] llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>> } +llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> { + %1 = llvm.mlir.constant(dense<(0, 1)> : tensor>) : !llvm.array<1 x !llvm.struct<(i32, i32)>> + // CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }] + llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>> +} + llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> { %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>> // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]