3535#include " mlir/IR/Operation.h"
3636#include " mlir/IR/Region.h"
3737#include " mlir/IR/TypeRange.h"
38+ #include " mlir/IR/Value.h"
3839#include " mlir/IR/ValueRange.h"
3940#include " mlir/Pass/Pass.h"
4041#include " mlir/Pass/PassManager.h"
@@ -105,15 +106,55 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
105106 }
106107};
107108
109+ static mlir::Type convertTypeForMemory (const mlir::TypeConverter &converter,
110+ mlir::Type type) {
111+ // TODO(cir): Handle other types similarly to clang's codegen
112+ // convertTypeForMemory
113+ if (isa<cir::BoolType>(type)) {
114+ // TODO: Use datalayout to get the size of bool
115+ return mlir::IntegerType::get (type.getContext (), 8 );
116+ }
117+
118+ return converter.convertType (type);
119+ }
120+
121+ static mlir::Value emitFromMemory (mlir::ConversionPatternRewriter &rewriter,
122+ cir::LoadOp op, mlir::Value value) {
123+
124+ // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
125+ if (isa<cir::BoolType>(op.getResult ().getType ())) {
126+ // Create trunc of value from i8 to i1
127+ // TODO: Use datalayout to get the size of bool
128+ assert (value.getType ().isInteger (8 ));
129+ return createIntCast (rewriter, value, rewriter.getI1Type ());
130+ }
131+
132+ return value;
133+ }
134+
135+ static mlir::Value emitToMemory (mlir::ConversionPatternRewriter &rewriter,
136+ cir::StoreOp op, mlir::Value value) {
137+
138+ // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
139+ if (isa<cir::BoolType>(op.getValue ().getType ())) {
140+ // Create zext of value from i1 to i8
141+ // TODO: Use datalayout to get the size of bool
142+ return createIntCast (rewriter, value, rewriter.getI8Type ());
143+ }
144+
145+ return value;
146+ }
147+
108148class CIRAllocaOpLowering : public mlir ::OpConversionPattern<cir::AllocaOp> {
109149public:
110150 using OpConversionPattern<cir::AllocaOp>::OpConversionPattern;
111151
112152 mlir::LogicalResult
113153 matchAndRewrite (cir::AllocaOp op, OpAdaptor adaptor,
114154 mlir::ConversionPatternRewriter &rewriter) const override {
115- auto type = adaptor.getAllocaType ();
116- auto mlirType = getTypeConverter ()->convertType (type);
155+
156+ mlir::Type mlirType =
157+ convertTypeForMemory (*getTypeConverter (), adaptor.getAllocaType ());
117158
118159 // FIXME: Some types can not be converted yet (e.g. struct)
119160 if (!mlirType)
@@ -174,12 +215,20 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
174215 mlir::Value base;
175216 SmallVector<mlir::Value> indices;
176217 SmallVector<mlir::Operation *> eraseList;
218+ mlir::memref::LoadOp newLoad;
177219 if (findBaseAndIndices (adaptor.getAddr (), base, indices, eraseList,
178220 rewriter)) {
179- rewriter.replaceOpWithNewOp <mlir::memref::LoadOp>(op, base, indices);
221+ newLoad =
222+ rewriter.create <mlir::memref::LoadOp>(op.getLoc (), base, indices);
223+ // rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
180224 eraseIfSafe (op.getAddr (), base, eraseList, rewriter);
181225 } else
182- rewriter.replaceOpWithNewOp <mlir::memref::LoadOp>(op, adaptor.getAddr ());
226+ newLoad =
227+ rewriter.create <mlir::memref::LoadOp>(op.getLoc (), adaptor.getAddr ());
228+
229+ // Convert adapted result to its original type if needed.
230+ mlir::Value result = emitFromMemory (rewriter, op, newLoad.getResult ());
231+ rewriter.replaceOp (op, result);
183232 return mlir::LogicalResult::success ();
184233 }
185234};
@@ -194,13 +243,16 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
194243 mlir::Value base;
195244 SmallVector<mlir::Value> indices;
196245 SmallVector<mlir::Operation *> eraseList;
246+
247+ // Convert adapted value to its memory type if needed.
248+ mlir::Value value = emitToMemory (rewriter, op, adaptor.getValue ());
197249 if (findBaseAndIndices (adaptor.getAddr (), base, indices, eraseList,
198250 rewriter)) {
199- rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, adaptor. getValue () ,
200- base, indices);
251+ rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, value, base ,
252+ indices);
201253 eraseIfSafe (op.getAddr (), base, eraseList, rewriter);
202254 } else
203- rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, adaptor. getValue () ,
255+ rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, value ,
204256 adaptor.getAddr ());
205257 return mlir::LogicalResult::success ();
206258 }
@@ -741,29 +793,20 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
741793 mlir::ConversionPatternRewriter &rewriter) const override {
742794 auto type = op.getLhs ().getType ();
743795
744- mlir::Value mlirResult;
745-
746796 if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
747797 auto kind = convertCmpKindToCmpIPredicate (op.getKind (), ty.isSigned ());
748- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
749- op. getLoc () , kind, adaptor.getLhs (), adaptor.getRhs ());
798+ rewriter.replaceOpWithNewOp <mlir::arith::CmpIOp>(
799+ op, kind, adaptor.getLhs (), adaptor.getRhs ());
750800 } else if (auto ty = mlir::dyn_cast<cir::CIRFPTypeInterface>(type)) {
751801 auto kind = convertCmpKindToCmpFPredicate (op.getKind ());
752- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
753- op. getLoc () , kind, adaptor.getLhs (), adaptor.getRhs ());
802+ rewriter.replaceOpWithNewOp <mlir::arith::CmpFOp>(
803+ op, kind, adaptor.getLhs (), adaptor.getRhs ());
754804 } else if (auto ty = mlir::dyn_cast<cir::PointerType>(type)) {
755805 llvm_unreachable (" pointer comparison not supported yet" );
756806 } else {
757807 return op.emitError () << " unsupported type for CmpOp: " << type;
758808 }
759809
760- // MLIR comparison ops return i1, but cir::CmpOp returns the same type as
761- // the LHS value. Since this return value can be used later, we need to
762- // restore the type with the extension below.
763- auto mlirResultTy = getTypeConverter ()->convertType (op.getType ());
764- rewriter.replaceOpWithNewOp <mlir::arith::ExtUIOp>(op, mlirResultTy,
765- mlirResult);
766-
767810 return mlir::LogicalResult::success ();
768811 }
769812};
@@ -823,12 +866,8 @@ struct CIRBrCondOpLowering : public mlir::OpConversionPattern<cir::BrCondOp> {
823866 mlir::LogicalResult
824867 matchAndRewrite (cir::BrCondOp brOp, OpAdaptor adaptor,
825868 mlir::ConversionPatternRewriter &rewriter) const override {
826-
827- auto condition = adaptor.getCond ();
828- auto i1Condition = rewriter.create <mlir::arith::TruncIOp>(
829- brOp.getLoc (), rewriter.getI1Type (), condition);
830869 rewriter.replaceOpWithNewOp <mlir::cf::CondBranchOp>(
831- brOp, i1Condition. getResult (), brOp.getDestTrue (),
870+ brOp, adaptor. getCond (), brOp.getDestTrue (),
832871 adaptor.getDestOperandsTrue (), brOp.getDestFalse (),
833872 adaptor.getDestOperandsFalse ());
834873
@@ -844,16 +883,13 @@ class CIRTernaryOpLowering : public mlir::OpConversionPattern<cir::TernaryOp> {
844883 matchAndRewrite (cir::TernaryOp op, OpAdaptor adaptor,
845884 mlir::ConversionPatternRewriter &rewriter) const override {
846885 rewriter.setInsertionPoint (op);
847- auto condition = adaptor.getCond ();
848- auto i1Condition = rewriter.create <mlir::arith::TruncIOp>(
849- op.getLoc (), rewriter.getI1Type (), condition);
850886 SmallVector<mlir::Type> resultTypes;
851887 if (mlir::failed (getTypeConverter ()->convertTypes (op->getResultTypes (),
852888 resultTypes)))
853889 return mlir::failure ();
854890
855891 auto ifOp = rewriter.create <mlir::scf::IfOp>(op.getLoc (), resultTypes,
856- i1Condition. getResult (), true );
892+ adaptor. getCond (), true );
857893 auto *thenBlock = &ifOp.getThenRegion ().front ();
858894 auto *elseBlock = &ifOp.getElseRegion ().front ();
859895 rewriter.inlineBlockBefore (&op.getTrueRegion ().front (), thenBlock,
@@ -890,11 +926,8 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<cir::IfOp> {
890926 mlir::LogicalResult
891927 matchAndRewrite (cir::IfOp ifop, OpAdaptor adaptor,
892928 mlir::ConversionPatternRewriter &rewriter) const override {
893- auto condition = adaptor.getCondition ();
894- auto i1Condition = rewriter.create <mlir::arith::TruncIOp>(
895- ifop->getLoc (), rewriter.getI1Type (), condition);
896929 auto newIfOp = rewriter.create <mlir::scf::IfOp>(
897- ifop->getLoc (), ifop->getResultTypes (), i1Condition );
930+ ifop->getLoc (), ifop->getResultTypes (), adaptor. getCondition () );
898931 auto *thenBlock = rewriter.createBlock (&newIfOp.getThenRegion ());
899932 rewriter.inlineBlockBefore (&ifop.getThenRegion ().front (), thenBlock,
900933 thenBlock->end ());
@@ -921,7 +954,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
921954 mlir::OpBuilder b (moduleOp.getContext ());
922955
923956 const auto CIRSymType = op.getSymType ();
924- auto convertedType = getTypeConverter ()-> convertType ( CIRSymType);
957+ auto convertedType = convertTypeForMemory (* getTypeConverter (), CIRSymType);
925958 if (!convertedType)
926959 return mlir::failure ();
927960 auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
@@ -1167,19 +1200,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
11671200 return mlir::success ();
11681201 }
11691202 case CIR::float_to_bool: {
1170- auto dstTy = mlir::cast<cir::BoolType>(op.getType ());
1171- auto newDstType = convertTy (dstTy);
11721203 auto kind = mlir::arith::CmpFPredicate::UNE;
11731204
11741205 // Check if float is not equal to zero.
11751206 auto zeroFloat = rewriter.create <mlir::arith::ConstantOp>(
11761207 op.getLoc (), src.getType (), mlir::FloatAttr::get (src.getType (), 0.0 ));
11771208
1178- // Extend comparison result to either bool (C++) or int (C).
1179- mlir::Value cmpResult = rewriter.create <mlir::arith::CmpFOp>(
1180- op.getLoc (), kind, src, zeroFloat);
1181- rewriter.replaceOpWithNewOp <mlir::arith::ExtUIOp>(op, newDstType,
1182- cmpResult);
1209+ rewriter.replaceOpWithNewOp <mlir::arith::CmpFOp>(op, kind, src,
1210+ zeroFloat);
11831211 return mlir::success ();
11841212 }
11851213 case CIR::bool_to_int: {
@@ -1327,7 +1355,7 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13271355static mlir::TypeConverter prepareTypeConverter () {
13281356 mlir::TypeConverter converter;
13291357 converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
1330- auto ty = converter. convertType ( type.getPointee ());
1358+ auto ty = convertTypeForMemory (converter, type.getPointee ());
13311359 // FIXME: The pointee type might not be converted (e.g. struct)
13321360 if (!ty)
13331361 return nullptr ;
@@ -1347,7 +1375,7 @@ static mlir::TypeConverter prepareTypeConverter() {
13471375 mlir::IntegerType::SignednessSemantics::Signless);
13481376 });
13491377 converter.addConversion ([&](cir::BoolType type) -> mlir::Type {
1350- return mlir::IntegerType::get (type.getContext (), 8 );
1378+ return mlir::IntegerType::get (type.getContext (), 1 );
13511379 });
13521380 converter.addConversion ([&](cir::SingleType type) -> mlir::Type {
13531381 return mlir::FloatType::getF32 (type.getContext ());
0 commit comments