Skip to content

Commit a9eac21

Browse files
authored
[CIR] Backport global initialization for VectorType (#1592)
In this PR I extended the incubator implementation to support global init for VectorType
1 parent 62a7cbd commit a9eac21

File tree

6 files changed

+135
-3
lines changed

6 files changed

+135
-3
lines changed

clang/include/clang/CIR/LoweringHelpers.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212
#ifndef LLVM_CLANG_CIR_LOWERINGHELPERS_H
1313
#define LLVM_CLANG_CIR_LOWERINGHELPERS_H
14+
1415
#include "mlir/Dialect/Arith/IR/Arith.h"
1516
#include "mlir/IR/BuiltinAttributes.h"
1617
#include "mlir/IR/BuiltinTypes.h"
@@ -36,7 +37,17 @@ convertToDenseElementsAttr(cir::ConstArrayAttr attr,
3637
const llvm::SmallVectorImpl<int64_t> &dims,
3738
mlir::Type type);
3839

40+
template <typename AttrTy, typename StorageTy>
41+
mlir::DenseElementsAttr
42+
convertToDenseElementsAttr(cir::ConstVectorAttr attr,
43+
const llvm::SmallVectorImpl<int64_t> &dims,
44+
mlir::Type type);
45+
3946
std::optional<mlir::Attribute>
4047
lowerConstArrayAttr(cir::ConstArrayAttr constArr,
4148
const mlir::TypeConverter *converter);
49+
50+
std::optional<mlir::Attribute>
51+
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
52+
const mlir::TypeConverter *converter);
4253
#endif

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,8 +1860,10 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
18601860
rewriter.eraseOp(op);
18611861
return mlir::success();
18621862
}
1863-
} else if (const auto recordAttr =
1864-
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
1863+
}
1864+
1865+
else if (const auto recordAttr =
1866+
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
18651867
// TODO(cir): this diverges from traditional lowering. Normally the
18661868
// initializer would be a global constant that is memcopied. Here we just
18671869
// define a local constant with llvm.undef that will be stored into the
@@ -2421,6 +2423,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
24212423
} else if (mlir::isa<cir::ConstArrayAttr>(init)) {
24222424
return lowerInitializerForConstArray(rewriter, op, init,
24232425
useInitializerRegion);
2426+
} else if (mlir::isa<cir::ConstVectorAttr>(init)) {
2427+
return lowerInitializerForConstVector(rewriter, op, init,
2428+
useInitializerRegion);
24242429
} else if (auto dataMemberAttr = mlir::dyn_cast<cir::DataMemberAttr>(init)) {
24252430
assert(lowerMod && "lower module is not available");
24262431
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
@@ -2437,6 +2442,26 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
24372442
}
24382443
llvm_unreachable("unreachable");
24392444
}
2445+
2446+
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstVector(
2447+
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
2448+
mlir::Attribute &init, bool &useInitializerRegion) const {
2449+
auto constVec = mlir::cast<cir::ConstVectorAttr>(init);
2450+
if (const auto attr = mlir::dyn_cast<mlir::ArrayAttr>(constVec.getElts())) {
2451+
if (auto val = lowerConstVectorAttr(constVec, getTypeConverter());
2452+
val.has_value()) {
2453+
init = val.value();
2454+
useInitializerRegion = false;
2455+
} else
2456+
useInitializerRegion = true;
2457+
return mlir::success();
2458+
}
2459+
2460+
op.emitError() << "unsupported lowering for #cir.const_vector with value "
2461+
<< constVec.getElts();
2462+
return mlir::failure();
2463+
}
2464+
24402465
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray(
24412466
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
24422467
mlir::Attribute &init, bool &useInitializerRegion) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,11 @@ class CIRToLLVMGlobalOpLowering
623623
cir::GlobalOp op, mlir::Attribute &init,
624624
bool &useInitializerRegion) const;
625625

626+
mlir::LogicalResult
627+
lowerInitializerForConstVector(mlir::ConversionPatternRewriter &rewriter,
628+
cir::GlobalOp op, mlir::Attribute &init,
629+
bool &useInitializerRegion) const;
630+
626631
mlir::LogicalResult
627632
lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter,
628633
cir::GlobalOp op, mlir::Type llvmType,

clang/lib/CIR/Lowering/LoweringHelpers.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,39 @@ void convertToDenseElementsAttrImpl(
9393
}
9494
}
9595

96+
template <typename AttrTy, typename StorageTy>
97+
void convertToDenseElementsAttrImpl(
98+
cir::ConstVectorAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
99+
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
100+
int64_t currentIndex) {
101+
dimIndex++;
102+
std::size_t elementsSizeInCurrentDim = 1;
103+
for (std::size_t i = dimIndex; i < currentDims.size(); i++)
104+
elementsSizeInCurrentDim *= currentDims[i];
105+
106+
auto arrayAttr = mlir::cast<mlir::ArrayAttr>(attr.getElts());
107+
for (auto eltAttr : arrayAttr) {
108+
if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
109+
values[currentIndex++] = valueAttr.getValue();
110+
continue;
111+
}
112+
113+
if (auto subArrayAttr = mlir::dyn_cast<cir::ConstArrayAttr>(eltAttr)) {
114+
convertToDenseElementsAttrImpl<AttrTy>(subArrayAttr, values, currentDims,
115+
dimIndex, currentIndex);
116+
currentIndex += elementsSizeInCurrentDim;
117+
continue;
118+
}
119+
120+
if (mlir::isa<cir::ZeroAttr, cir::UndefAttr>(eltAttr)) {
121+
currentIndex += elementsSizeInCurrentDim;
122+
continue;
123+
}
124+
125+
llvm_unreachable("unknown element in ConstArrayAttr");
126+
}
127+
}
128+
96129
template <typename AttrTy, typename StorageTy>
97130
mlir::DenseElementsAttr convertToDenseElementsAttr(
98131
cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
@@ -109,6 +142,22 @@ mlir::DenseElementsAttr convertToDenseElementsAttr(
109142
llvm::ArrayRef(values));
110143
}
111144

145+
template <typename AttrTy, typename StorageTy>
146+
mlir::DenseElementsAttr convertToDenseElementsAttr(
147+
cir::ConstVectorAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
148+
mlir::Type elementType, mlir::Type convertedElementType) {
149+
unsigned vector_size = 1;
150+
for (auto dim : dims)
151+
vector_size *= dim;
152+
auto values = llvm::SmallVector<StorageTy, 8>(
153+
vector_size, getZeroInitFromType<StorageTy>(elementType));
154+
convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0,
155+
/*initialIndex=*/0);
156+
return mlir::DenseElementsAttr::get(
157+
mlir::RankedTensorType::get(dims, convertedElementType),
158+
llvm::ArrayRef(values));
159+
}
160+
112161
std::optional<mlir::Attribute>
113162
lowerConstArrayAttr(cir::ConstArrayAttr constArr,
114163
const mlir::TypeConverter *converter) {
@@ -141,3 +190,33 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
141190

142191
return std::nullopt;
143192
}
193+
194+
std::optional<mlir::Attribute>
195+
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
196+
const mlir::TypeConverter *converter) {
197+
198+
// Ensure ConstArrayAttr has a type.
199+
auto typedConstArr = mlir::dyn_cast<mlir::TypedAttr>(constArr);
200+
assert(typedConstArr && "cir::ConstArrayAttr is not a mlir::TypedAttr");
201+
202+
// Ensure ConstArrayAttr type is a ArrayType.
203+
auto cirArrayType = mlir::dyn_cast<cir::VectorType>(typedConstArr.getType());
204+
assert(cirArrayType && "cir::ConstArrayAttr is not a cir::ArrayType");
205+
206+
// Is a ConstArrayAttr with an cir::ArrayType: fetch element type.
207+
mlir::Type type = cirArrayType;
208+
auto dims = llvm::SmallVector<int64_t, 2>{};
209+
while (auto arrayType = mlir::dyn_cast<cir::ArrayType>(type)) {
210+
dims.push_back(arrayType.getSize());
211+
type = arrayType.getEltType();
212+
}
213+
214+
if (mlir::isa<cir::IntType>(type))
215+
return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
216+
constArr, dims, type, converter->convertType(type));
217+
if (mlir::isa<cir::CIRFPTypeInterface>(type))
218+
return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
219+
constArr, dims, type, converter->convertType(type));
220+
221+
return std::nullopt;
222+
}

clang/test/CIR/CodeGen/vectype-ext.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,19 @@ vi2 vec_c;
2525

2626
// LLVM: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer
2727

28-
vd2 d;
28+
vd2 vec_d;
2929

3030
// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<!cir.double x 2>
3131

3232
// LLVM: @[[VEC_D:.*]] = global <2 x double> zeroinitializer
3333

34+
vi4 vec_e = { 1, 2, 3, 4 };
35+
36+
// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
37+
// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>
38+
39+
// LLVM: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
40+
3441
// CIR: cir.func {{@.*vector_int_test.*}}
3542
// LLVM: define dso_local void {{@.*vector_int_test.*}}
3643
void vector_int_test(int x) {

clang/test/CIR/CodeGen/vectype.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ vd2 b;
1414
vll2 c;
1515
// CHECK: cir.global external @[[VEC_C:.*]] = #cir.zero : !cir.vector<!s64i x 2>
1616

17+
vi4 d = { 1, 2, 3, 4 };
18+
19+
// CHECK: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
20+
// CHECK-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>
21+
1722
void vector_int_test(int x, unsigned short usx) {
1823

1924
// Vector constant.

0 commit comments

Comments
 (0)