Skip to content

Commit 2afdbf3

Browse files
Lancerncjdb
authored andcommitted
[mlir][LLVM] Add support for constant struct with multiple fields (#102752)
Currently `mlir.llvm.constant` of structure types restricts that the structure type effectively represents a complex type -- it must have exactly two fields of the same type and the field type must be either an integer type or a float type. This PR relaxes this restriction and it allows the structure type to have an arbitrary number of fields.
1 parent f194c42 commit 2afdbf3

File tree

6 files changed

+96
-55
lines changed

6 files changed

+96
-55
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

+24-13
Original file line numberDiff line numberDiff line change
@@ -1620,19 +1620,30 @@ def LLVM_ConstantOp
16201620
let description = [{
16211621
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
16221622
all constants must be created as SSA values before being used in other
1623-
operations. `llvm.mlir.constant` creates such values for scalars and
1624-
vectors. It has a mandatory `value` attribute, which may be an integer,
1625-
floating point attribute; dense or sparse attribute containing integers or
1626-
floats. The type of the attribute is one of the corresponding MLIR builtin
1627-
types. It may be omitted for `i64` and `f64` types that are implied.
1628-
1629-
The operation produces a new SSA value of the specified LLVM IR dialect
1630-
type. Certain builtin types such as integer, float and vector types are
1631-
also allowed. The result type _must_ correspond to the attribute type
1632-
converted to LLVM IR. In particular, the number of elements of a container
1633-
type must match the number of elements in the attribute. If the type is or
1634-
contains a scalable vector type, the attribute must be a splat elements
1635-
attribute.
1623+
operations. `llvm.mlir.constant` creates such values for scalars, vectors,
1624+
strings, and structs. It has a mandatory `value` attribute whose type
1625+
depends on the type of the constant value. The type of the constant value
1626+
must correspond to the attribute type converted to LLVM IR type.
1627+
1628+
When creating constant scalars, the `value` attribute must be either an
1629+
integer attribute or a floating point attribute. The type of the attribute
1630+
may be omitted for `i64` and `f64` types that are implied.
1631+
1632+
When creating constant vectors, the `value` attribute must be either an
1633+
array attribute, a dense attribute, or a sparse attribute that contains
1634+
integers or floats. The number of elements in the result vector must match
1635+
the number of elements in the attribute.
1636+
1637+
When creating constant strings, the `value` attribute must be a string
1638+
attribute. The type of the constant must be an LLVM array of `i8`s, and the
1639+
length of the array must match the length of the attribute.
1640+
1641+
When creating constant structs, the `value` attribute must be an array
1642+
attribute that contains integers or floats. The type of the constant must be
1643+
an LLVM struct type. The number of fields in the struct must match the
1644+
number of elements in the attribute, and the type of each LLVM struct field
1645+
must correspond to the type of the corresponding attribute element converted
1646+
to LLVM IR.
16361647

16371648
Examples:
16381649

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

+27-20
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1717
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
1818
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19+
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinOps.h"
2122
#include "mlir/IR/BuiltinTypes.h"
@@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() {
27102711
}
27112712
return success();
27122713
}
2713-
if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
2714-
if (structType.getBody().size() != 2 ||
2715-
structType.getBody()[0] != structType.getBody()[1]) {
2716-
return emitError() << "expected struct type with two elements of the "
2717-
"same type, the type of a complex constant";
2714+
if (auto structType = dyn_cast<LLVMStructType>(getType())) {
2715+
auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
2716+
if (!arrayAttr) {
2717+
return emitOpError() << "expected array attribute for a struct constant";
27182718
}
27192719

2720-
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
2721-
if (!arrayAttr || arrayAttr.size() != 2) {
2722-
return emitOpError() << "expected array attribute with two elements, "
2723-
"representing a complex constant";
2720+
ArrayRef<Type> elementTypes = structType.getBody();
2721+
if (arrayAttr.size() != elementTypes.size()) {
2722+
return emitOpError() << "expected array attribute of size "
2723+
<< elementTypes.size();
27242724
}
2725-
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
2726-
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
2727-
if (!re || !im || re.getType() != im.getType()) {
2728-
return emitOpError()
2729-
<< "expected array attribute with two elements of the same type";
2725+
for (auto elementTy : elementTypes) {
2726+
if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
2727+
return emitOpError() << "expected struct element types to be floating "
2728+
"point type or integer type";
2729+
}
27302730
}
27312731

2732-
Type elementType = structType.getBody()[0];
2733-
if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
2734-
elementType)) {
2735-
return emitError()
2736-
<< "expected struct element types to be floating point type or "
2737-
"integer type";
2732+
for (size_t i = 0; i < elementTypes.size(); ++i) {
2733+
Attribute element = arrayAttr[i];
2734+
if (!isa<IntegerAttr, FloatAttr>(element)) {
2735+
return emitOpError()
2736+
<< "expected struct element attribute types to be floating "
2737+
"point type or integer type";
2738+
}
2739+
auto elementType = cast<TypedAttr>(element).getType();
2740+
if (elementType != elementTypes[i]) {
2741+
return emitOpError()
2742+
<< "struct element at index " << i << " is of wrong type";
2743+
}
27382744
}
2745+
27392746
return success();
27402747
}
27412748
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

+13-12
Original file line numberDiff line numberDiff line change
@@ -557,20 +557,21 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
557557
return llvm::UndefValue::get(llvmType);
558558
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
559559
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
560-
if (!arrayAttr || arrayAttr.size() != 2) {
561-
emitError(loc, "expected struct type to be a complex number");
560+
if (!arrayAttr) {
561+
emitError(loc, "expected an array attribute for a struct constant");
562562
return nullptr;
563563
}
564-
llvm::Type *elementType = structType->getElementType(0);
565-
llvm::Constant *real =
566-
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
567-
if (!real)
568-
return nullptr;
569-
llvm::Constant *imag =
570-
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
571-
if (!imag)
572-
return nullptr;
573-
return llvm::ConstantStruct::get(structType, {real, imag});
564+
SmallVector<llvm::Constant *> structElements;
565+
structElements.reserve(structType->getNumElements());
566+
for (auto [elemType, elemAttr] :
567+
zip_equal(structType->elements(), arrayAttr)) {
568+
llvm::Constant *element =
569+
getLLVMConstant(elemType, elemAttr, loc, moduleTranslation);
570+
if (!element)
571+
return nullptr;
572+
structElements.push_back(element);
573+
}
574+
return llvm::ConstantStruct::get(structType, structElements);
574575
}
575576
// For integer types, we allow a mismatch in sizes as the index type in
576577
// MLIR might have a different size than the index type in the LLVM module.

mlir/test/Dialect/LLVMIR/invalid.mlir

+5-5
Original file line numberDiff line numberDiff line change
@@ -367,39 +367,39 @@ func.func @constant_wrong_type_string() {
367367
// -----
368368

369369
llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
370-
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
370+
// expected-error @+1 {{expected array attribute of size 2}}
371371
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
372372
llvm.return %0 : !llvm.struct<(f64, f64)>
373373
}
374374

375375
// -----
376376

377377
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
378-
// expected-error @+1 {{expected array attribute with two elements of the same type}}
378+
// expected-error @+1 {{struct element at index 1 is of wrong type}}
379379
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
380380
llvm.return %0 : !llvm.struct<(f64, f64)>
381381
}
382382

383383
// -----
384384

385385
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
386-
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
386+
// expected-error @+1 {{expected array attribute}}
387387
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
388388
llvm.return %0 : !llvm.struct<(f64, f64)>
389389
}
390390

391391
// -----
392392

393393
llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
394-
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
394+
// expected-error @+1 {{expected array attribute of size 1}}
395395
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
396396
llvm.return %0 : !llvm.struct<(f64)>
397397
}
398398

399399
// -----
400400

401401
llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
402-
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
402+
// expected-error @+1 {{struct element at index 1 is of wrong type}}
403403
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
404404
llvm.return %0 : !llvm.struct<(f64, f32)>
405405
}

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

+21-5
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,40 @@ llvm.func @vector_with_non_vector_type() -> f32 {
1515

1616
// -----
1717

18-
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
19-
// expected-error @below{{expected struct type to be a complex number}}
18+
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
19+
// expected-error @below{{expected an array attribute for a struct constant}}
2020
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
2121
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
2222
}
2323

2424
// -----
2525

26-
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
27-
// expected-error @below{{expected struct type to be a complex number}}
26+
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
27+
// expected-error @below{{expected an array attribute for a struct constant}}
2828
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
2929
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
3030
}
3131

3232
// -----
3333

34+
llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
35+
// expected-error @below{{expected struct element types to be floating point type or integer type}}
36+
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
37+
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
38+
}
39+
40+
// -----
41+
42+
llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
43+
// expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
44+
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
45+
llvm.return %0 : !llvm.struct<(f64, f64)>
46+
}
47+
48+
// -----
49+
3450
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
35-
// expected-error @below{{FloatAttr does not match expected type of the constant}}
51+
// expected-error @below{{struct element at index 0 is of wrong type}}
3652
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
3753
llvm.return %0 : !llvm.struct<(f64, f64)>
3854
}

mlir/test/Target/LLVMIR/llvmir.mlir

+6
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
13121312
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
13131313
}
13141314

1315+
llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
1316+
%1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
1317+
// CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
1318+
llvm.return %1 : !llvm.struct<(i32, f32)>
1319+
}
1320+
13151321
// CHECK-LABEL: @indexconstantsplat
13161322
llvm.func @indexconstantsplat() -> vector<3xi32> {
13171323
%1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>

0 commit comments

Comments
 (0)