diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 71f249fa538ca..46bf1c9640c17 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1620,19 +1620,30 @@ def LLVM_ConstantOp let description = [{ Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all constants must be created as SSA values before being used in other - operations. `llvm.mlir.constant` creates such values for scalars and - vectors. It has a mandatory `value` attribute, which may be an integer, - floating point attribute; dense or sparse attribute containing integers or - floats. The type of the attribute is one of the corresponding MLIR builtin - types. It may be omitted for `i64` and `f64` types that are implied. - - The operation produces a new SSA value of the specified LLVM IR dialect - type. Certain builtin types such as integer, float and vector types are - also allowed. The result type _must_ correspond to the attribute type - converted to LLVM IR. In particular, the number of elements of a container - type must match the number of elements in the attribute. If the type is or - contains a scalable vector type, the attribute must be a splat elements - attribute. + operations. `llvm.mlir.constant` creates such values for scalars, vectors, + strings, and structs. It has a mandatory `value` attribute whose type + depends on the type of the constant value. The type of the constant value + must correspond to the attribute type converted to LLVM IR type. + + When creating constant scalars, the `value` attribute must be either an + integer attribute or a floating point attribute. The type of the attribute + may be omitted for `i64` and `f64` types that are implied. + + When creating constant vectors, the `value` attribute must be either an + array attribute, a dense attribute, or a sparse attribute that contains + integers or floats. The number of elements in the result vector must match + the number of elements in the attribute. + + When creating constant strings, the `value` attribute must be a string + attribute. The type of the constant must be an LLVM array of `i8`s, and the + length of the array must match the length of the attribute. + + When creating constant structs, the `value` attribute must be an array + attribute that contains integers or floats. The type of the constant must be + an LLVM struct type. The number of fields in the struct must match the + number of elements in the attribute, and the type of each LLVM struct field + must correspond to the type of the corresponding attribute element converted + to LLVM IR. Examples: diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 92f3984e5e6db..3870aab52f199 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() { } return success(); } - if (auto structType = llvm::dyn_cast(getType())) { - if (structType.getBody().size() != 2 || - structType.getBody()[0] != structType.getBody()[1]) { - return emitError() << "expected struct type with two elements of the " - "same type, the type of a complex constant"; + if (auto structType = dyn_cast(getType())) { + auto arrayAttr = dyn_cast(getValue()); + if (!arrayAttr) { + return emitOpError() << "expected array attribute for a struct constant"; } - auto arrayAttr = llvm::dyn_cast(getValue()); - if (!arrayAttr || arrayAttr.size() != 2) { - return emitOpError() << "expected array attribute with two elements, " - "representing a complex constant"; + ArrayRef elementTypes = structType.getBody(); + if (arrayAttr.size() != elementTypes.size()) { + return emitOpError() << "expected array attribute of size " + << elementTypes.size(); } - auto re = llvm::dyn_cast(arrayAttr[0]); - auto im = llvm::dyn_cast(arrayAttr[1]); - if (!re || !im || re.getType() != im.getType()) { - return emitOpError() - << "expected array attribute with two elements of the same type"; + for (auto elementTy : elementTypes) { + if (!isa(elementTy)) { + return emitOpError() << "expected struct element types to be floating " + "point type or integer type"; + } } - Type elementType = structType.getBody()[0]; - if (!llvm::isa( - elementType)) { - return emitError() - << "expected struct element types to be floating point type or " - "integer type"; + for (size_t i = 0; i < elementTypes.size(); ++i) { + Attribute element = arrayAttr[i]; + if (!isa(element)) { + return emitOpError() + << "expected struct element attribute types to be floating " + "point type or integer type"; + } + auto elementType = cast(element).getType(); + if (elementType != elementTypes[i]) { + return emitOpError() + << "struct element at index " << i << " is of wrong type"; + } } + return success(); } if (auto targetExtType = dyn_cast(getType())) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 930300d26c447..adf70e6aab5d1 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -557,20 +557,21 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( return llvm::UndefValue::get(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { auto arrayAttr = dyn_cast(attr); - if (!arrayAttr || arrayAttr.size() != 2) { - emitError(loc, "expected struct type to be a complex number"); + if (!arrayAttr) { + emitError(loc, "expected an array attribute for a struct constant"); return nullptr; } - llvm::Type *elementType = structType->getElementType(0); - llvm::Constant *real = - getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); - if (!real) - return nullptr; - llvm::Constant *imag = - getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); - if (!imag) - return nullptr; - return llvm::ConstantStruct::get(structType, {real, imag}); + SmallVector structElements; + structElements.reserve(structType->getNumElements()); + for (auto [elemType, elemAttr] : + zip_equal(structType->elements(), arrayAttr)) { + llvm::Constant *element = + getLLVMConstant(elemType, elemAttr, loc, moduleTranslation); + if (!element) + return nullptr; + structElements.push_back(element); + } + return llvm::ConstantStruct::get(structType, structElements); } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 62346ce0d2c4b..6670e4b186c39 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -367,7 +367,7 @@ func.func @constant_wrong_type_string() { // ----- llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> { - // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}} + // expected-error @+1 {{expected array attribute of size 2}} %0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } @@ -375,7 +375,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> { // ----- llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> { - // expected-error @+1 {{expected array attribute with two elements of the same type}} + // expected-error @+1 {{struct element at index 1 is of wrong type}} %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } @@ -383,7 +383,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> { // ----- llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> { - // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}} + // expected-error @+1 {{expected array attribute}} %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } @@ -391,7 +391,7 @@ llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> { // ----- llvm.func @struct_one_element() -> !llvm.struct<(f64)> { - // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}} + // expected-error @+1 {{expected array attribute of size 1}} %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)> llvm.return %0 : !llvm.struct<(f64)> } @@ -399,7 +399,7 @@ llvm.func @struct_one_element() -> !llvm.struct<(f64)> { // ----- llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> { - // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}} + // expected-error @+1 {{struct element at index 1 is of wrong type}} %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)> llvm.return %0 : !llvm.struct<(f64, f32)> } diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index 9cf922ad490a9..0e2afe6fb004d 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -15,24 +15,40 @@ llvm.func @vector_with_non_vector_type() -> f32 { // ----- -llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @below{{expected struct type to be a complex number}} +llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { + // expected-error @below{{expected an array attribute for a struct constant}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> } // ----- -llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { - // expected-error @below{{expected struct type to be a complex number}} +llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { + // expected-error @below{{expected an array attribute for a struct constant}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> } // ----- +llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> { + // expected-error @below{{expected struct element types to be floating point type or integer type}} + %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)> + llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)> +} + +// ----- + +llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> { + // expected-error @below{{expected struct element attribute types to be floating point type or integer type}} + %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + +// ----- + llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { - // expected-error @below{{FloatAttr does not match expected type of the constant}} + // expected-error @below{{struct element at index 0 is of wrong type}} %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 8453983aa07c3..df61fef605fde 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> } +llvm.func @structconstant() -> !llvm.struct<(i32, f32)> { + %1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)> + // CHECK: ret { i32, float } { i32 1, float 2.000000e+00 } + llvm.return %1 : !llvm.struct<(i32, f32)> +} + // CHECK-LABEL: @indexconstantsplat llvm.func @indexconstantsplat() -> vector<3xi32> { %1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>