|
16 | 16 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
|
17 | 17 | #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
|
18 | 18 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
| 19 | +#include "mlir/IR/Attributes.h" |
19 | 20 | #include "mlir/IR/Builders.h"
|
20 | 21 | #include "mlir/IR/BuiltinOps.h"
|
21 | 22 | #include "mlir/IR/BuiltinTypes.h"
|
@@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() {
|
2710 | 2711 | }
|
2711 | 2712 | return success();
|
2712 | 2713 | }
|
2713 |
| - if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) { |
2714 |
| - if (structType.getBody().size() != 2 || |
2715 |
| - structType.getBody()[0] != structType.getBody()[1]) { |
2716 |
| - return emitError() << "expected struct type with two elements of the " |
2717 |
| - "same type, the type of a complex constant"; |
| 2714 | + if (auto structType = dyn_cast<LLVMStructType>(getType())) { |
| 2715 | + auto arrayAttr = dyn_cast<ArrayAttr>(getValue()); |
| 2716 | + if (!arrayAttr) { |
| 2717 | + return emitOpError() << "expected array attribute for a struct constant"; |
2718 | 2718 | }
|
2719 | 2719 |
|
2720 |
| - auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue()); |
2721 |
| - if (!arrayAttr || arrayAttr.size() != 2) { |
2722 |
| - return emitOpError() << "expected array attribute with two elements, " |
2723 |
| - "representing a complex constant"; |
| 2720 | + ArrayRef<Type> elementTypes = structType.getBody(); |
| 2721 | + if (arrayAttr.size() != elementTypes.size()) { |
| 2722 | + return emitOpError() << "expected array attribute of size " |
| 2723 | + << elementTypes.size(); |
2724 | 2724 | }
|
2725 |
| - auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]); |
2726 |
| - auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]); |
2727 |
| - if (!re || !im || re.getType() != im.getType()) { |
2728 |
| - return emitOpError() |
2729 |
| - << "expected array attribute with two elements of the same type"; |
| 2725 | + for (auto elementTy : elementTypes) { |
| 2726 | + if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) { |
| 2727 | + return emitOpError() << "expected struct element types to be floating " |
| 2728 | + "point type or integer type"; |
| 2729 | + } |
2730 | 2730 | }
|
2731 | 2731 |
|
2732 |
| - Type elementType = structType.getBody()[0]; |
2733 |
| - if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>( |
2734 |
| - elementType)) { |
2735 |
| - return emitError() |
2736 |
| - << "expected struct element types to be floating point type or " |
2737 |
| - "integer type"; |
| 2732 | + for (size_t i = 0; i < elementTypes.size(); ++i) { |
| 2733 | + Attribute element = arrayAttr[i]; |
| 2734 | + if (!isa<IntegerAttr, FloatAttr>(element)) { |
| 2735 | + return emitOpError() |
| 2736 | + << "expected struct element attribute types to be floating " |
| 2737 | + "point type or integer type"; |
| 2738 | + } |
| 2739 | + auto elementType = cast<TypedAttr>(element).getType(); |
| 2740 | + if (elementType != elementTypes[i]) { |
| 2741 | + return emitOpError() |
| 2742 | + << "struct element at index " << i << " is of wrong type"; |
| 2743 | + } |
2738 | 2744 | }
|
| 2745 | + |
2739 | 2746 | return success();
|
2740 | 2747 | }
|
2741 | 2748 | if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
|
|
0 commit comments