|
16 | 16 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
17 | 17 | #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" |
18 | 18 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 19 | +#include "mlir/IR/Attributes.h" |
19 | 20 | #include "mlir/IR/Builders.h" |
20 | 21 | #include "mlir/IR/BuiltinOps.h" |
21 | 22 | #include "mlir/IR/BuiltinTypes.h" |
@@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() { |
2710 | 2711 | } |
2711 | 2712 | return success(); |
2712 | 2713 | } |
2713 | | - if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) { |
2714 | | - if (structType.getBody().size() != 2 || |
2715 | | - structType.getBody()[0] != structType.getBody()[1]) { |
2716 | | - return emitError() << "expected struct type with two elements of the " |
2717 | | - "same type, the type of a complex constant"; |
| 2714 | + if (auto structType = dyn_cast<LLVMStructType>(getType())) { |
| 2715 | + auto arrayAttr = dyn_cast<ArrayAttr>(getValue()); |
| 2716 | + if (!arrayAttr) { |
| 2717 | + return emitOpError() << "expected array attribute for a struct constant"; |
2718 | 2718 | } |
2719 | 2719 |
|
2720 | | - auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue()); |
2721 | | - if (!arrayAttr || arrayAttr.size() != 2) { |
2722 | | - return emitOpError() << "expected array attribute with two elements, " |
2723 | | - "representing a complex constant"; |
| 2720 | + ArrayRef<Type> elementTypes = structType.getBody(); |
| 2721 | + if (arrayAttr.size() != elementTypes.size()) { |
| 2722 | + return emitOpError() << "expected array attribute of size " |
| 2723 | + << elementTypes.size(); |
2724 | 2724 | } |
2725 | | - auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]); |
2726 | | - auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]); |
2727 | | - if (!re || !im || re.getType() != im.getType()) { |
2728 | | - return emitOpError() |
2729 | | - << "expected array attribute with two elements of the same type"; |
| 2725 | + for (auto elementTy : elementTypes) { |
| 2726 | + if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) { |
| 2727 | + return emitOpError() << "expected struct element types to be floating " |
| 2728 | + "point type or integer type"; |
| 2729 | + } |
2730 | 2730 | } |
2731 | 2731 |
|
2732 | | - Type elementType = structType.getBody()[0]; |
2733 | | - if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>( |
2734 | | - elementType)) { |
2735 | | - return emitError() |
2736 | | - << "expected struct element types to be floating point type or " |
2737 | | - "integer type"; |
| 2732 | + for (size_t i = 0; i < elementTypes.size(); ++i) { |
| 2733 | + Attribute element = arrayAttr[i]; |
| 2734 | + if (!isa<IntegerAttr, FloatAttr>(element)) { |
| 2735 | + return emitOpError() |
| 2736 | + << "expected struct element attribute types to be floating " |
| 2737 | + "point type or integer type"; |
| 2738 | + } |
| 2739 | + auto elementType = cast<TypedAttr>(element).getType(); |
| 2740 | + if (elementType != elementTypes[i]) { |
| 2741 | + return emitOpError() |
| 2742 | + << "struct element at index " << i << " is of wrong type"; |
| 2743 | + } |
2738 | 2744 | } |
| 2745 | + |
2739 | 2746 | return success(); |
2740 | 2747 | } |
2741 | 2748 | if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { |
|
0 commit comments