Skip to content

[mlir][LLVM] Add support for constant struct with multiple fields #102752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1620,19 +1620,30 @@ def LLVM_ConstantOp
let description = [{
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
all constants must be created as SSA values before being used in other
operations. `llvm.mlir.constant` creates such values for scalars and
vectors. It has a mandatory `value` attribute, which may be an integer,
floating point attribute; dense or sparse attribute containing integers or
floats. The type of the attribute is one of the corresponding MLIR builtin
types. It may be omitted for `i64` and `f64` types that are implied.

The operation produces a new SSA value of the specified LLVM IR dialect
type. Certain builtin types such as integer, float and vector types are
also allowed. The result type _must_ correspond to the attribute type
converted to LLVM IR. In particular, the number of elements of a container
type must match the number of elements in the attribute. If the type is or
contains a scalable vector type, the attribute must be a splat elements
attribute.
operations. `llvm.mlir.constant` creates such values for scalars, vectors,
strings, and structs. It has a mandatory `value` attribute whose type
depends on the type of the constant value. The type of the constant value
must correspond to the attribute type converted to LLVM IR type.

When creating constant scalars, the `value` attribute must be either an
integer attribute or a floating point attribute. The type of the attribute
may be omitted for `i64` and `f64` types that are implied.

When creating constant vectors, the `value` attribute must be either an
array attribute, a dense attribute, or a sparse attribute that contains
integers or floats. The number of elements in the result vector must match
the number of elements in the attribute.

When creating constant strings, the `value` attribute must be a string
attribute. The type of the constant must be an LLVM array of `i8`s, and the
length of the array must match the length of the attribute.

When creating constant structs, the `value` attribute must be an array
attribute that contains integers or floats. The type of the constant must be
an LLVM struct type. The number of fields in the struct must match the
number of elements in the attribute, and the type of each LLVM struct field
must correspond to the type of the corresponding attribute element converted
to LLVM IR.

Examples:

Expand Down
47 changes: 27 additions & 20 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() {
}
return success();
}
if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
return emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
if (auto structType = dyn_cast<LLVMStructType>(getType())) {
auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr) {
return emitOpError() << "expected array attribute for a struct constant";
}

auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr || arrayAttr.size() != 2) {
return emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
ArrayRef<Type> elementTypes = structType.getBody();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this check now be recursive, given that one can define nested struct constants?

if (arrayAttr.size() != elementTypes.size()) {
return emitOpError() << "expected array attribute of size "
<< elementTypes.size();
}
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (!re || !im || re.getType() != im.getType()) {
return emitOpError()
<< "expected array attribute with two elements of the same type";
for (auto elementTy : elementTypes) {
if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
return emitOpError() << "expected struct element types to be floating "
"point type or integer type";
}
}

Type elementType = structType.getBody()[0];
if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
elementType)) {
return emitError()
<< "expected struct element types to be floating point type or "
"integer type";
for (size_t i = 0; i < elementTypes.size(); ++i) {
Attribute element = arrayAttr[i];
if (!isa<IntegerAttr, FloatAttr>(element)) {
return emitOpError()
<< "expected struct element attribute types to be floating "
"point type or integer type";
}
auto elementType = cast<TypedAttr>(element).getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto elementType = cast<TypedAttr>(element).getType();
auto elementType = dyn_cast<TypedAttr>(element).getType();

Didn't see this before. I think we should dyn_cast here (as it was before). In case the attribute does not implement the TypedAttr interface we want to return an error and not assert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cast fallible? We have already checked that the attribute is either an IntegerAttr or a FloatAttr a few lines above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True you are right!

Then this PR is ready to go. Do you have commit rights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have commit rights?

No.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I will land it then.

if (elementType != elementTypes[i]) {
return emitOpError()
<< "struct element at index " << i << " is of wrong type";
}
}

return success();
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
Expand Down
25 changes: 13 additions & 12 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,20 +557,21 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
if (!arrayAttr || arrayAttr.size() != 2) {
emitError(loc, "expected struct type to be a complex number");
if (!arrayAttr) {
emitError(loc, "expected an array attribute for a struct constant");
return nullptr;
}
llvm::Type *elementType = structType->getElementType(0);
llvm::Constant *real =
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
if (!real)
return nullptr;
llvm::Constant *imag =
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
if (!imag)
return nullptr;
return llvm::ConstantStruct::get(structType, {real, imag});
SmallVector<llvm::Constant *> structElements;
structElements.reserve(structType->getNumElements());
for (auto [elemType, elemAttr] :
zip_equal(structType->elements(), arrayAttr)) {
llvm::Constant *element =
getLLVMConstant(elemType, elemAttr, loc, moduleTranslation);
if (!element)
return nullptr;
structElements.push_back(element);
}
return llvm::ConstantStruct::get(structType, structElements);
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -367,39 +367,39 @@ func.func @constant_wrong_type_string() {
// -----

llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
// expected-error @+1 {{expected array attribute of size 2}}
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements of the same type}}
// expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
// expected-error @+1 {{expected array attribute}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
// expected-error @+1 {{expected array attribute of size 1}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
llvm.return %0 : !llvm.struct<(f64)>
}

// -----

llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
// expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
llvm.return %0 : !llvm.struct<(f64, f32)>
}
Expand Down
26 changes: 21 additions & 5 deletions mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,40 @@ llvm.func @vector_with_non_vector_type() -> f32 {

// -----

llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @below{{expected struct type to be a complex number}}
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}

// -----

llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @below{{expected struct type to be a complex number}}
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}

// -----

llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
// expected-error @below{{expected struct element types to be floating point type or integer type}}
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
}

// -----

llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{FloatAttr does not match expected type of the constant}}
// expected-error @below{{struct element at index 0 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
}

llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
%1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
// CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
llvm.return %1 : !llvm.struct<(i32, f32)>
}

// CHECK-LABEL: @indexconstantsplat
llvm.func @indexconstantsplat() -> vector<3xi32> {
%1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>
Expand Down
Loading