@@ -2666,6 +2666,39 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
2666
2666
// ConstantOp.
2667
2667
// ===----------------------------------------------------------------------===//
2668
2668
2669
+ // / Compute the total number of elements in the given type, also taking into
2670
+ // / account nested types. Supported types are `VectorType`, `LLVMArrayType` and
2671
+ // / `LLVMFixedVectorType`. Everything else is treated as a scalar.
2672
+ static int64_t getNumElements (Type t) {
2673
+ if (auto vecType = dyn_cast<VectorType>(t))
2674
+ return vecType.getNumElements () * getNumElements (vecType.getElementType ());
2675
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2676
+ return arrayType.getNumElements () *
2677
+ getNumElements (arrayType.getElementType ());
2678
+ if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2679
+ return vecType.getNumElements () * getNumElements (vecType.getElementType ());
2680
+ assert (!isa<LLVM::LLVMScalableVectorType>(t) &&
2681
+ " number of elements of a scalable vector type is unknown" );
2682
+ return 1 ;
2683
+ }
2684
+
2685
+ // / Check if the given type is a scalable vector type or a vector/array type
2686
+ // / that contains a nested scalable vector type.
2687
+ static bool hasScalableVectorType (Type t) {
2688
+ if (isa<LLVM::LLVMScalableVectorType>(t))
2689
+ return true ;
2690
+ if (auto vecType = dyn_cast<VectorType>(t)) {
2691
+ if (vecType.isScalable ())
2692
+ return true ;
2693
+ return hasScalableVectorType (vecType.getElementType ());
2694
+ }
2695
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2696
+ return hasScalableVectorType (arrayType.getElementType ());
2697
+ if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2698
+ return hasScalableVectorType (vecType.getElementType ());
2699
+ return false ;
2700
+ }
2701
+
2669
2702
LogicalResult LLVM::ConstantOp::verify () {
2670
2703
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue ())) {
2671
2704
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType ());
@@ -2708,14 +2741,12 @@ LogicalResult LLVM::ConstantOp::verify() {
2708
2741
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType ())) {
2709
2742
return emitOpError () << " does not support target extension type." ;
2710
2743
}
2711
- if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue ()))
2712
- return emitOpError ()
2713
- << " only supports integer, float, string or elements attributes" ;
2744
+
2745
+ // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
2714
2746
if (auto intAttr = dyn_cast<IntegerAttr>(getValue ())) {
2715
2747
if (!llvm::isa<IntegerType>(getType ()))
2716
2748
return emitOpError () << " expected integer type" ;
2717
- }
2718
- if (auto floatAttr = dyn_cast<FloatAttr>(getValue ())) {
2749
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue ())) {
2719
2750
const llvm::fltSemantics &sem = floatAttr.getValue ().getSemantics ();
2720
2751
unsigned floatWidth = APFloat::getSizeInBits (sem);
2721
2752
if (auto floatTy = dyn_cast<FloatType>(getType ())) {
@@ -2728,13 +2759,34 @@ LogicalResult LLVM::ConstantOp::verify() {
2728
2759
if (isa<IntegerType>(getType ()) && !getType ().isInteger (floatWidth)) {
2729
2760
return emitOpError () << " expected integer type of width " << floatWidth;
2730
2761
}
2731
- }
2732
- if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue ())) {
2733
- if (!isa<VectorType>(getType ()) && !isa<LLVM::LLVMArrayType>(getType ()) &&
2734
- !isa<LLVM::LLVMFixedVectorType>(getType ()) &&
2735
- !isa<LLVM::LLVMScalableVectorType>(getType ()))
2762
+ } else if (isa<ElementsAttr, ArrayAttr>(getValue ())) {
2763
+ if (hasScalableVectorType (getType ())) {
2764
+ // The exact number of elements of a scalable vector is unknown, so we
2765
+ // allow only splat attributes.
2766
+ auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue ());
2767
+ if (!splatElementsAttr)
2768
+ return emitOpError ()
2769
+ << " scalable vector type requires a splat attribute" ;
2770
+ return success ();
2771
+ }
2772
+ if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
2773
+ getType ()))
2736
2774
return emitOpError () << " expected vector or array type" ;
2775
+ // The number of elements of the attribute and the type must match.
2776
+ int64_t attrNumElements;
2777
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ()))
2778
+ attrNumElements = elementsAttr.getNumElements ();
2779
+ else
2780
+ attrNumElements = cast<ArrayAttr>(getValue ()).size ();
2781
+ if (getNumElements (getType ()) != attrNumElements)
2782
+ return emitOpError ()
2783
+ << " type and attribute have a different number of elements: "
2784
+ << getNumElements (getType ()) << " vs. " << attrNumElements;
2785
+ } else {
2786
+ return emitOpError ()
2787
+ << " only supports integer, float, string or elements attributes" ;
2737
2788
}
2789
+
2738
2790
return success ();
2739
2791
}
2740
2792
0 commit comments