Skip to content

Commit 8c4c3ca

Browse files
committed
[CIR][ThroughMLIR] Lower cir.bool to i1
1 parent a04cf10 commit 8c4c3ca

File tree

11 files changed

+119
-116
lines changed

11 files changed

+119
-116
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,8 @@ class CIRConditionOpLowering
337337
auto *parentOp = op->getParentOp();
338338
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
339339
.Case<mlir::scf::WhileOp>([&](auto) {
340-
auto condition = adaptor.getCondition();
341-
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
342-
op.getLoc(), rewriter.getI1Type(), condition);
343340
rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
344-
op, i1Condition, parentOp->getOperands());
341+
op, adaptor.getCondition(), parentOp->getOperands());
345342
return mlir::success();
346343
})
347344
.Default([](auto) { return mlir::failure(); });

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/IR/Operation.h"
3636
#include "mlir/IR/Region.h"
3737
#include "mlir/IR/TypeRange.h"
38+
#include "mlir/IR/Value.h"
3839
#include "mlir/IR/ValueRange.h"
3940
#include "mlir/Pass/Pass.h"
4041
#include "mlir/Pass/PassManager.h"
@@ -105,15 +106,55 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
105106
}
106107
};
107108

109+
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
110+
mlir::Type type) {
111+
// TODO(cir): Handle other types similarly to clang's codegen
112+
// convertTypeForMemory
113+
if (isa<cir::BoolType>(type)) {
114+
// TODO: Use datalayout to get the size of bool
115+
return mlir::IntegerType::get(type.getContext(), 8);
116+
}
117+
118+
return converter.convertType(type);
119+
}
120+
121+
static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter,
122+
cir::LoadOp op, mlir::Value value) {
123+
124+
// TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
125+
if (isa<cir::BoolType>(op.getResult().getType())) {
126+
// Create trunc of value from i8 to i1
127+
// TODO: Use datalayout to get the size of bool
128+
assert(value.getType().isInteger(8));
129+
return createIntCast(rewriter, value, rewriter.getI1Type());
130+
}
131+
132+
return value;
133+
}
134+
135+
static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter,
136+
cir::StoreOp op, mlir::Value value) {
137+
138+
// TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
139+
if (isa<cir::BoolType>(op.getValue().getType())) {
140+
// Create zext of value from i1 to i8
141+
// TODO: Use datalayout to get the size of bool
142+
return createIntCast(rewriter, value, rewriter.getI8Type());
143+
}
144+
145+
return value;
146+
}
147+
108148
class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
109149
public:
110150
using OpConversionPattern<cir::AllocaOp>::OpConversionPattern;
111151

112152
mlir::LogicalResult
113153
matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor,
114154
mlir::ConversionPatternRewriter &rewriter) const override {
115-
auto type = adaptor.getAllocaType();
116-
auto mlirType = getTypeConverter()->convertType(type);
155+
156+
mlir::Type mlirType =
157+
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());
117158

118159
// FIXME: Some types can not be converted yet (e.g. struct)
119160
if (!mlirType)
@@ -174,12 +215,20 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
174215
mlir::Value base;
175216
SmallVector<mlir::Value> indices;
176217
SmallVector<mlir::Operation *> eraseList;
218+
mlir::memref::LoadOp newLoad;
177219
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
178220
rewriter)) {
179-
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
221+
newLoad =
222+
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), base, indices);
223+
// rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
180224
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
181225
} else
182-
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
226+
newLoad =
227+
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), adaptor.getAddr());
228+
229+
// Convert adapted result to its original type if needed.
230+
mlir::Value result = emitFromMemory(rewriter, op, newLoad.getResult());
231+
rewriter.replaceOp(op, result);
183232
return mlir::LogicalResult::success();
184233
}
185234
};
@@ -194,13 +243,16 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
194243
mlir::Value base;
195244
SmallVector<mlir::Value> indices;
196245
SmallVector<mlir::Operation *> eraseList;
246+
247+
// Convert adapted value to its memory type if needed.
248+
mlir::Value value = emitToMemory(rewriter, op, adaptor.getValue());
197249
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
198250
rewriter)) {
199-
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
200-
base, indices);
251+
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value, base,
252+
indices);
201253
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
202254
} else
203-
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
255+
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value,
204256
adaptor.getAddr());
205257
return mlir::LogicalResult::success();
206258
}
@@ -741,29 +793,20 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
741793
mlir::ConversionPatternRewriter &rewriter) const override {
742794
auto type = op.getLhs().getType();
743795

744-
mlir::Value mlirResult;
745-
746796
if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
747797
auto kind = convertCmpKindToCmpIPredicate(op.getKind(), ty.isSigned());
748-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
749-
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
798+
rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(
799+
op, kind, adaptor.getLhs(), adaptor.getRhs());
750800
} else if (auto ty = mlir::dyn_cast<cir::CIRFPTypeInterface>(type)) {
751801
auto kind = convertCmpKindToCmpFPredicate(op.getKind());
752-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
753-
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
802+
rewriter.replaceOpWithNewOp<mlir::arith::CmpFOp>(
803+
op, kind, adaptor.getLhs(), adaptor.getRhs());
754804
} else if (auto ty = mlir::dyn_cast<cir::PointerType>(type)) {
755805
llvm_unreachable("pointer comparison not supported yet");
756806
} else {
757807
return op.emitError() << "unsupported type for CmpOp: " << type;
758808
}
759809

760-
// MLIR comparison ops return i1, but cir::CmpOp returns the same type as
761-
// the LHS value. Since this return value can be used later, we need to
762-
// restore the type with the extension below.
763-
auto mlirResultTy = getTypeConverter()->convertType(op.getType());
764-
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, mlirResultTy,
765-
mlirResult);
766-
767810
return mlir::LogicalResult::success();
768811
}
769812
};
@@ -823,12 +866,8 @@ struct CIRBrCondOpLowering : public mlir::OpConversionPattern<cir::BrCondOp> {
823866
mlir::LogicalResult
824867
matchAndRewrite(cir::BrCondOp brOp, OpAdaptor adaptor,
825868
mlir::ConversionPatternRewriter &rewriter) const override {
826-
827-
auto condition = adaptor.getCond();
828-
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
829-
brOp.getLoc(), rewriter.getI1Type(), condition);
830869
rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
831-
brOp, i1Condition.getResult(), brOp.getDestTrue(),
870+
brOp, adaptor.getCond(), brOp.getDestTrue(),
832871
adaptor.getDestOperandsTrue(), brOp.getDestFalse(),
833872
adaptor.getDestOperandsFalse());
834873

@@ -844,16 +883,13 @@ class CIRTernaryOpLowering : public mlir::OpConversionPattern<cir::TernaryOp> {
844883
matchAndRewrite(cir::TernaryOp op, OpAdaptor adaptor,
845884
mlir::ConversionPatternRewriter &rewriter) const override {
846885
rewriter.setInsertionPoint(op);
847-
auto condition = adaptor.getCond();
848-
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
849-
op.getLoc(), rewriter.getI1Type(), condition);
850886
SmallVector<mlir::Type> resultTypes;
851887
if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(),
852888
resultTypes)))
853889
return mlir::failure();
854890

855891
auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), resultTypes,
856-
i1Condition.getResult(), true);
892+
adaptor.getCond(), true);
857893
auto *thenBlock = &ifOp.getThenRegion().front();
858894
auto *elseBlock = &ifOp.getElseRegion().front();
859895
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), thenBlock,
@@ -890,11 +926,8 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<cir::IfOp> {
890926
mlir::LogicalResult
891927
matchAndRewrite(cir::IfOp ifop, OpAdaptor adaptor,
892928
mlir::ConversionPatternRewriter &rewriter) const override {
893-
auto condition = adaptor.getCondition();
894-
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
895-
ifop->getLoc(), rewriter.getI1Type(), condition);
896929
auto newIfOp = rewriter.create<mlir::scf::IfOp>(
897-
ifop->getLoc(), ifop->getResultTypes(), i1Condition);
930+
ifop->getLoc(), ifop->getResultTypes(), adaptor.getCondition());
898931
auto *thenBlock = rewriter.createBlock(&newIfOp.getThenRegion());
899932
rewriter.inlineBlockBefore(&ifop.getThenRegion().front(), thenBlock,
900933
thenBlock->end());
@@ -921,7 +954,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
921954
mlir::OpBuilder b(moduleOp.getContext());
922955

923956
const auto CIRSymType = op.getSymType();
924-
auto convertedType = getTypeConverter()->convertType(CIRSymType);
957+
auto convertedType = convertTypeForMemory(*getTypeConverter(), CIRSymType);
925958
if (!convertedType)
926959
return mlir::failure();
927960
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
@@ -1167,19 +1200,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
11671200
return mlir::success();
11681201
}
11691202
case CIR::float_to_bool: {
1170-
auto dstTy = mlir::cast<cir::BoolType>(op.getType());
1171-
auto newDstType = convertTy(dstTy);
11721203
auto kind = mlir::arith::CmpFPredicate::UNE;
11731204

11741205
// Check if float is not equal to zero.
11751206
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
11761207
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));
11771208

1178-
// Extend comparison result to either bool (C++) or int (C).
1179-
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
1180-
op.getLoc(), kind, src, zeroFloat);
1181-
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
1182-
cmpResult);
1209+
rewriter.replaceOpWithNewOp<mlir::arith::CmpFOp>(op, kind, src,
1210+
zeroFloat);
11831211
return mlir::success();
11841212
}
11851213
case CIR::bool_to_int: {
@@ -1327,7 +1355,7 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13271355
static mlir::TypeConverter prepareTypeConverter() {
13281356
mlir::TypeConverter converter;
13291357
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
1330-
auto ty = converter.convertType(type.getPointee());
1358+
auto ty = convertTypeForMemory(converter, type.getPointee());
13311359
// FIXME: The pointee type might not be converted (e.g. struct)
13321360
if (!ty)
13331361
return nullptr;
@@ -1347,7 +1375,7 @@ static mlir::TypeConverter prepareTypeConverter() {
13471375
mlir::IntegerType::SignednessSemantics::Signless);
13481376
});
13491377
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
1350-
return mlir::IntegerType::get(type.getContext(), 8);
1378+
return mlir::IntegerType::get(type.getContext(), 1);
13511379
});
13521380
converter.addConversion([&](cir::SingleType type) -> mlir::Type {
13531381
return mlir::FloatType::getF32(type.getContext());

clang/test/CIR/CodeGen/globals.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ void use_global() {
2020
int li = a;
2121
}
2222

23+
bool bool_global() {
24+
return e;
25+
}
26+
2327
void use_global_string() {
2428
unsigned char c = s2[0];
2529
}

clang/test/CIR/Lowering/ThroughMLIR/bool.cir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ module {
1414

1515
// MLIR: func @foo() {
1616
// MLIR: [[Value:%[a-z0-9]+]] = memref.alloca() {alignment = 1 : i64} : memref<i8>
17-
// MLIR: = arith.constant 1 : i8
18-
// MLIR: memref.store {{.*}}, [[Value]][] : memref<i8>
17+
// MLIR: %[[CONST:.*]] = arith.constant true
18+
// MLIR: %[[BOOL_TO_MEM:.*]] = arith.extui %[[CONST]] : i1 to i8
19+
// MLIR-NEXT: memref.store %[[BOOL_TO_MEM]], [[Value]][] : memref<i8>
1920
// return
2021

2122
// LLVM: = alloca i8, i64

clang/test/CIR/Lowering/ThroughMLIR/branch.cir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i {
1313
}
1414

1515
// MLIR: module {
16-
// MLIR-NEXT: func.func @foo(%arg0: i8) -> i32
17-
// MLIR-NEXT: %0 = arith.trunci %arg0 : i8 to i1
18-
// MLIR-NEXT: cf.cond_br %0, ^bb1, ^bb2
16+
// MLIR-NEXT: func.func @foo(%arg0: i1) -> i32
17+
// MLIR-NEXT: cf.cond_br %arg0, ^bb1, ^bb2
1918
// MLIR-NEXT: ^bb1: // pred: ^bb0
2019
// MLIR-NEXT: %c1_i32 = arith.constant 1 : i32
2120
// MLIR-NEXT: return %c1_i32 : i32
@@ -25,13 +24,12 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i {
2524
// MLIR-NEXT: }
2625
// MLIR-NEXT: }
2726

28-
// LLVM: define i32 @foo(i8 %0)
29-
// LLVM-NEXT: %2 = trunc i8 %0 to i1
30-
// LLVM-NEXT: br i1 %2, label %3, label %4
27+
// LLVM: define i32 @foo(i1 %0)
28+
// LLVM-NEXT: br i1 %0, label %[[TRUE:.*]], label %[[FALSE:.*]]
3129
// LLVM-EMPTY:
32-
// LLVM-NEXT: 3: ; preds = %1
30+
// LLVM-NEXT: [[TRUE]]:
3331
// LLVM-NEXT: ret i32 1
3432
// LLVM-EMPTY:
35-
// LLVM-NEXT: 4: ; preds = %1
33+
// LLVM-NEXT: [[FALSE]]:
3634
// LLVM-NEXT: ret i32 0
3735
// LLVM-NEXT: }

clang/test/CIR/Lowering/ThroughMLIR/cast.cir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
!u16i = !cir.int<u, 16>
88
!u8i = !cir.int<u, 8>
99
module {
10-
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i8
11-
// LLVM-LABEL: define i8 @cast_int_to_bool(i32 %0)
10+
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i1
11+
// LLVM-LABEL: define i1 @cast_int_to_bool(i32 %0)
1212
cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool {
1313
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32
1414
// MLIR-NEXT: arith.cmpi ne, %arg0, %[[ZERO]]
@@ -71,8 +71,8 @@ module {
7171
%1 = cir.cast(floating, %f : !cir.float), !cir.double
7272
cir.return %1 : !cir.double
7373
}
74-
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i8
75-
// LLVM-LABEL: define i8 @cast_float_to_bool(float %0)
74+
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i1
75+
// LLVM-LABEL: define i1 @cast_float_to_bool(float %0)
7676
cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool {
7777
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
7878
// MLIR-NEXT: arith.cmpf une, %arg0, %[[ZERO]] : f32
@@ -81,29 +81,29 @@ module {
8181
%1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool
8282
cir.return %1 : !cir.bool
8383
}
84-
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i8) -> i8
85-
// LLVM-LABEL: define i8 @cast_bool_to_int8(i8 %0)
84+
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i1) -> i8
85+
// LLVM-LABEL: define i8 @cast_bool_to_int8(i1 %0)
8686
cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i {
87-
// MLIR-NEXT: arith.bitcast %arg0 : i8 to i8
88-
// LLVM-NEXT: ret i8 %0
87+
// MLIR-NEXT: arith.extui %arg0 : i1 to i8
88+
// LLVM-NEXT: zext i1 %0 to i8
8989

9090
%1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i
9191
cir.return %1 : !u8i
9292
}
93-
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i8) -> i32
94-
// LLVM-LABEL: define i32 @cast_bool_to_int(i8 %0)
93+
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i1) -> i32
94+
// LLVM-LABEL: define i32 @cast_bool_to_int(i1 %0)
9595
cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i {
96-
// MLIR-NEXT: arith.extui %arg0 : i8 to i32
97-
// LLVM-NEXT: zext i8 %0 to i32
96+
// MLIR-NEXT: arith.extui %arg0 : i1 to i32
97+
// LLVM-NEXT: zext i1 %0 to i32
9898

9999
%1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i
100100
cir.return %1 : !u32i
101101
}
102-
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i8) -> f32
103-
// LLVM-LABEL: define float @cast_bool_to_float(i8 %0)
102+
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i1) -> f32
103+
// LLVM-LABEL: define float @cast_bool_to_float(i1 %0)
104104
cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float {
105-
// MLIR-NEXT: arith.uitofp %arg0 : i8 to f32
106-
// LLVM-NEXT: uitofp i8 %0 to float
105+
// MLIR-NEXT: arith.uitofp %arg0 : i1 to f32
106+
// LLVM-NEXT: uitofp i1 %0 to float
107107

108108
%1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float
109109
cir.return %1 : !cir.float

0 commit comments

Comments
 (0)