@@ -93,6 +93,39 @@ void convertToDenseElementsAttrImpl(
9393 }
9494}
9595
96+ template <typename AttrTy, typename StorageTy>
97+ void convertToDenseElementsAttrImpl (
98+ cir::ConstVectorAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
99+ const llvm::SmallVectorImpl<int64_t > ¤tDims, int64_t dimIndex,
100+ int64_t currentIndex) {
101+ dimIndex++;
102+ std::size_t elementsSizeInCurrentDim = 1 ;
103+ for (std::size_t i = dimIndex; i < currentDims.size (); i++)
104+ elementsSizeInCurrentDim *= currentDims[i];
105+
106+ auto arrayAttr = mlir::cast<mlir::ArrayAttr>(attr.getElts ());
107+ for (auto eltAttr : arrayAttr) {
108+ if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
109+ values[currentIndex++] = valueAttr.getValue ();
110+ continue ;
111+ }
112+
113+ if (auto subArrayAttr = mlir::dyn_cast<cir::ConstArrayAttr>(eltAttr)) {
114+ convertToDenseElementsAttrImpl<AttrTy>(subArrayAttr, values, currentDims,
115+ dimIndex, currentIndex);
116+ currentIndex += elementsSizeInCurrentDim;
117+ continue ;
118+ }
119+
120+ if (mlir::isa<cir::ZeroAttr, cir::UndefAttr>(eltAttr)) {
121+ currentIndex += elementsSizeInCurrentDim;
122+ continue ;
123+ }
124+
125+ llvm_unreachable (" unknown element in ConstArrayAttr" );
126+ }
127+ }
128+
96129template <typename AttrTy, typename StorageTy>
97130mlir::DenseElementsAttr convertToDenseElementsAttr (
98131 cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t > &dims,
@@ -109,6 +142,22 @@ mlir::DenseElementsAttr convertToDenseElementsAttr(
109142 llvm::ArrayRef (values));
110143}
111144
145+ template <typename AttrTy, typename StorageTy>
146+ mlir::DenseElementsAttr convertToDenseElementsAttr (
147+ cir::ConstVectorAttr attr, const llvm::SmallVectorImpl<int64_t > &dims,
148+ mlir::Type elementType, mlir::Type convertedElementType) {
149+ unsigned vector_size = 1 ;
150+ for (auto dim : dims)
151+ vector_size *= dim;
152+ auto values = llvm::SmallVector<StorageTy, 8 >(
153+ vector_size, getZeroInitFromType<StorageTy>(elementType));
154+ convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /* currentDim=*/ 0 ,
155+ /* initialIndex=*/ 0 );
156+ return mlir::DenseElementsAttr::get (
157+ mlir::RankedTensorType::get (dims, convertedElementType),
158+ llvm::ArrayRef (values));
159+ }
160+
112161std::optional<mlir::Attribute>
113162lowerConstArrayAttr (cir::ConstArrayAttr constArr,
114163 const mlir::TypeConverter *converter) {
@@ -141,3 +190,33 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
141190
142191 return std::nullopt ;
143192}
193+
194+ std::optional<mlir::Attribute>
195+ lowerConstVectorAttr (cir::ConstVectorAttr constArr,
196+ const mlir::TypeConverter *converter) {
197+
198+ // Ensure ConstArrayAttr has a type.
199+ auto typedConstArr = mlir::dyn_cast<mlir::TypedAttr>(constArr);
200+ assert (typedConstArr && " cir::ConstArrayAttr is not a mlir::TypedAttr" );
201+
202+ // Ensure ConstArrayAttr type is a ArrayType.
203+ auto cirArrayType = mlir::dyn_cast<cir::VectorType>(typedConstArr.getType ());
204+ assert (cirArrayType && " cir::ConstArrayAttr is not a cir::ArrayType" );
205+
206+ // Is a ConstArrayAttr with an cir::ArrayType: fetch element type.
207+ mlir::Type type = cirArrayType;
208+ auto dims = llvm::SmallVector<int64_t , 2 >{};
209+ while (auto arrayType = mlir::dyn_cast<cir::ArrayType>(type)) {
210+ dims.push_back (arrayType.getSize ());
211+ type = arrayType.getEltType ();
212+ }
213+
214+ if (mlir::isa<cir::IntType>(type))
215+ return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
216+ constArr, dims, type, converter->convertType (type));
217+ if (mlir::isa<cir::CIRFPTypeInterface>(type))
218+ return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
219+ constArr, dims, type, converter->convertType (type));
220+
221+ return std::nullopt ;
222+ }
0 commit comments