Skip to content

Commit f5ff4a9

Browse files
committed
[mlir][LLVM] Add support for constant struct with multiple fields
Currently `mlir.llvm.constant` of structure types restricts that the structure type effectively represents a complex type -- it must have exactly two fields of the same type and the field type must be either an integer type or a float type. This patch relaxes this restriction and it allows the structure type to have an arbitrary number of fields.
1 parent ac47edd commit f5ff4a9

File tree

5 files changed

+45
-40
lines changed

5 files changed

+45
-40
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2678,31 +2678,30 @@ LogicalResult LLVM::ConstantOp::verify() {
26782678
return success();
26792679
}
26802680
if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
2681-
if (structType.getBody().size() != 2 ||
2682-
structType.getBody()[0] != structType.getBody()[1]) {
2683-
return emitError() << "expected struct type with two elements of the "
2684-
"same type, the type of a complex constant";
2685-
}
2686-
26872681
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
2688-
if (!arrayAttr || arrayAttr.size() != 2) {
2689-
return emitOpError() << "expected array attribute with two elements, "
2690-
"representing a complex constant";
2682+
if (!arrayAttr) {
2683+
return emitOpError() << "expected array attribute for a struct constant";
26912684
}
2692-
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
2693-
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
2694-
if (!re || !im || re.getType() != im.getType()) {
2695-
return emitOpError()
2696-
<< "expected array attribute with two elements of the same type";
2685+
2686+
ArrayRef<Type> elementTypes = structType.getBody();
2687+
if (arrayAttr.size() != elementTypes.size()) {
2688+
return emitOpError() << "expected array attribute of size "
2689+
<< elementTypes.size();
26972690
}
26982691

2699-
Type elementType = structType.getBody()[0];
2700-
if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
2701-
elementType)) {
2702-
return emitError()
2703-
<< "expected struct element types to be floating point type or "
2704-
"integer type";
2692+
for (size_t i = 0; i < elementTypes.size(); ++i) {
2693+
auto element = arrayAttr[i];
2694+
if (!mlir::isa<IntegerAttr, FloatAttr>(element)) {
2695+
return emitOpError() << "expected struct element types to be floating "
2696+
"point type or integer type";
2697+
}
2698+
auto elementType = mlir::cast<TypedAttr>(element).getType();
2699+
if (elementType != elementTypes[i]) {
2700+
return emitOpError()
2701+
<< "struct element at index " << i << " is of wrong type";
2702+
}
27052703
}
2704+
27062705
return success();
27072706
}
27082707
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -557,20 +557,20 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
557557
return llvm::UndefValue::get(llvmType);
558558
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
559559
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
560-
if (!arrayAttr || arrayAttr.size() != 2) {
561-
emitError(loc, "expected struct type to be a complex number");
560+
if (!arrayAttr) {
561+
emitError(loc, "expected an array attribute for a struct constant");
562562
return nullptr;
563563
}
564-
llvm::Type *elementType = structType->getElementType(0);
565-
llvm::Constant *real =
566-
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
567-
if (!real)
568-
return nullptr;
569-
llvm::Constant *imag =
570-
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
571-
if (!imag)
572-
return nullptr;
573-
return llvm::ConstantStruct::get(structType, {real, imag});
564+
llvm::SmallVector<llvm::Constant *, 8> structElements;
565+
structElements.reserve(structType->getNumElements());
566+
for (size_t i = 0; i < arrayAttr.size(); ++i) {
567+
llvm::Constant *element = getLLVMConstant(
568+
structType->getElementType(i), arrayAttr[i], loc, moduleTranslation);
569+
if (!element)
570+
return nullptr;
571+
structElements.push_back(element);
572+
}
573+
return llvm::ConstantStruct::get(structType, structElements);
574574
}
575575
// For integer types, we allow a mismatch in sizes as the index type in
576576
// MLIR might have a different size than the index type in the LLVM module.

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,39 +367,39 @@ func.func @constant_wrong_type_string() {
367367
// -----
368368

369369
llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
370-
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
370+
// expected-error @+1 {{expected array attribute of size 2}}
371371
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
372372
llvm.return %0 : !llvm.struct<(f64, f64)>
373373
}
374374

375375
// -----
376376

377377
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
378-
// expected-error @+1 {{expected array attribute with two elements of the same type}}
378+
// expected-error @+1 {{struct element at index 1 is of wrong type}}
379379
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
380380
llvm.return %0 : !llvm.struct<(f64, f64)>
381381
}
382382

383383
// -----
384384

385385
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
386-
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
386+
// expected-error @+1 {{expected array attribute}}
387387
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
388388
llvm.return %0 : !llvm.struct<(f64, f64)>
389389
}
390390

391391
// -----
392392

393393
llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
394-
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
394+
// expected-error @+1 {{expected array attribute of size 1}}
395395
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
396396
llvm.return %0 : !llvm.struct<(f64)>
397397
}
398398

399399
// -----
400400

401401
llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
402-
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
402+
// expected-error @+1 {{struct element at index 1 is of wrong type}}
403403
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
404404
llvm.return %0 : !llvm.struct<(f64, f32)>
405405
}

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@ llvm.func @vector_with_non_vector_type() -> f32 {
1616
// -----
1717

1818
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
19-
// expected-error @below{{expected struct type to be a complex number}}
19+
// expected-error @below{{expected an array attribute for a struct constant}}
2020
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
2121
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
2222
}
2323

2424
// -----
2525

2626
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
27-
// expected-error @below{{expected struct type to be a complex number}}
27+
// expected-error @below{{expected an array attribute for a struct constant}}
2828
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
2929
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
3030
}
3131

3232
// -----
3333

3434
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
35-
// expected-error @below{{FloatAttr does not match expected type of the constant}}
35+
// expected-error @below{{struct element at index 0 is of wrong type}}
3636
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
3737
llvm.return %0 : !llvm.struct<(f64, f64)>
3838
}

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
13061306
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
13071307
}
13081308

1309+
llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
1310+
%1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
1311+
// CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
1312+
llvm.return %1 : !llvm.struct<(i32, f32)>
1313+
}
1314+
13091315
// CHECK-LABEL: @indexconstantsplat
13101316
llvm.func @indexconstantsplat() -> vector<3xi32> {
13111317
%1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>

0 commit comments

Comments
 (0)