Skip to content

Commit 092cee5

Browse files
clementvaljeanPerierschweitzpgi
committed
[fir] Add fir.convert op conversion from FIR to LLVM IR
Add conversion pattern for the `fir.convert` operation. This patch is part of the upstreaming effort from fir-dev branch. This patch was previously landed with a truncated version that was failing the windows buildbot. Reviewed By: rovka, awarzynski Differential Revision: https://reviews.llvm.org/D113469 Co-authored-by: Jean Perier <[email protected]> Co-authored-by: Eric Schweitz <[email protected]>
1 parent a40929d commit 092cee5

File tree

2 files changed

+241
-13
lines changed

2 files changed

+241
-13
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 123 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,121 @@ struct CallOpConversion : public FIROpConversion<fir::CallOp> {
9696
}
9797
};
9898

99+
static mlir::Type getComplexEleTy(mlir::Type complex) {
100+
if (auto cc = complex.dyn_cast<mlir::ComplexType>())
101+
return cc.getElementType();
102+
return complex.cast<fir::ComplexType>().getElementType();
103+
}
104+
105+
/// convert value of from-type to value of to-type
106+
struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
107+
using FIROpConversion::FIROpConversion;
108+
109+
static bool isFloatingPointTy(mlir::Type ty) {
110+
return ty.isa<mlir::FloatType>();
111+
}
112+
113+
mlir::LogicalResult
114+
matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
115+
mlir::ConversionPatternRewriter &rewriter) const override {
116+
auto fromTy = convertType(convert.value().getType());
117+
auto toTy = convertType(convert.res().getType());
118+
mlir::Value op0 = adaptor.getOperands()[0];
119+
if (fromTy == toTy) {
120+
rewriter.replaceOp(convert, op0);
121+
return success();
122+
}
123+
auto loc = convert.getLoc();
124+
auto convertFpToFp = [&](mlir::Value val, unsigned fromBits,
125+
unsigned toBits, mlir::Type toTy) -> mlir::Value {
126+
if (fromBits == toBits) {
127+
// TODO: Converting between two floating-point representations with the
128+
// same bitwidth is not allowed for now.
129+
mlir::emitError(loc,
130+
"cannot implicitly convert between two floating-point "
131+
"representations of the same bitwidth");
132+
return {};
133+
}
134+
if (fromBits > toBits)
135+
return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val);
136+
return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val);
137+
};
138+
// Complex to complex conversion.
139+
if (fir::isa_complex(convert.value().getType()) &&
140+
fir::isa_complex(convert.res().getType())) {
141+
// Special case: handle the conversion of a complex such that both the
142+
// real and imaginary parts are converted together.
143+
auto zero = mlir::ArrayAttr::get(convert.getContext(),
144+
rewriter.getI32IntegerAttr(0));
145+
auto one = mlir::ArrayAttr::get(convert.getContext(),
146+
rewriter.getI32IntegerAttr(1));
147+
auto ty = convertType(getComplexEleTy(convert.value().getType()));
148+
auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, zero);
149+
auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, one);
150+
auto nt = convertType(getComplexEleTy(convert.res().getType()));
151+
auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
152+
auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt);
153+
auto rc = convertFpToFp(rp, fromBits, toBits, nt);
154+
auto ic = convertFpToFp(ip, fromBits, toBits, nt);
155+
auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy);
156+
auto i1 =
157+
rewriter.create<mlir::LLVM::InsertValueOp>(loc, toTy, un, rc, zero);
158+
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, toTy, i1,
159+
ic, one);
160+
return mlir::success();
161+
}
162+
// Floating point to floating point conversion.
163+
if (isFloatingPointTy(fromTy)) {
164+
if (isFloatingPointTy(toTy)) {
165+
auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
166+
auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
167+
auto v = convertFpToFp(op0, fromBits, toBits, toTy);
168+
rewriter.replaceOp(convert, v);
169+
return mlir::success();
170+
}
171+
if (toTy.isa<mlir::IntegerType>()) {
172+
rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(convert, toTy, op0);
173+
return mlir::success();
174+
}
175+
} else if (fromTy.isa<mlir::IntegerType>()) {
176+
// Integer to integer conversion.
177+
if (toTy.isa<mlir::IntegerType>()) {
178+
auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
179+
auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
180+
assert(fromBits != toBits);
181+
if (fromBits > toBits) {
182+
rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0);
183+
return mlir::success();
184+
}
185+
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0);
186+
return mlir::success();
187+
}
188+
// Integer to floating point conversion.
189+
if (isFloatingPointTy(toTy)) {
190+
rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0);
191+
return mlir::success();
192+
}
193+
// Integer to pointer conversion.
194+
if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
195+
rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0);
196+
return mlir::success();
197+
}
198+
} else if (fromTy.isa<mlir::LLVM::LLVMPointerType>()) {
199+
// Pointer to integer conversion.
200+
if (toTy.isa<mlir::IntegerType>()) {
201+
rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0);
202+
return mlir::success();
203+
}
204+
// Pointer to pointer conversion.
205+
if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
206+
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0);
207+
return mlir::success();
208+
}
209+
}
210+
return emitError(loc) << "cannot convert " << fromTy << " to " << toTy;
211+
}
212+
};
213+
99214
/// Lower `fir.has_value` operation to `llvm.return` operation.
100215
struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
101216
using FIROpConversion::FIROpConversion;
@@ -489,12 +604,6 @@ struct InsertOnRangeOpConversion
489604
}
490605
};
491606

492-
static mlir::Type getComplexEleTy(mlir::Type complex) {
493-
if (auto cc = complex.dyn_cast<mlir::ComplexType>())
494-
return cc.getElementType();
495-
return complex.cast<fir::ComplexType>().getElementType();
496-
}
497-
498607
//
499608
// Primitive operations on Complex types
500609
//
@@ -679,13 +788,14 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
679788
auto *context = getModule().getContext();
680789
fir::LLVMTypeConverter typeConverter{getModule()};
681790
mlir::OwningRewritePatternList pattern(context);
682-
pattern.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
683-
DivcOpConversion, ExtractValueOpConversion,
684-
HasValueOpConversion, GlobalOpConversion,
685-
InsertOnRangeOpConversion, InsertValueOpConversion,
686-
NegcOpConversion, MulcOpConversion, SelectOpConversion,
687-
SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
688-
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
791+
pattern
792+
.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
793+
ConvertOpConversion, DivcOpConversion, ExtractValueOpConversion,
794+
HasValueOpConversion, GlobalOpConversion,
795+
InsertOnRangeOpConversion, InsertValueOpConversion,
796+
NegcOpConversion, MulcOpConversion, SelectOpConversion,
797+
SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
798+
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
689799
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
690800
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
691801
pattern);

flang/test/Fir/convert-to-llvm.fir

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,121 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
514514
// CHECK: %{{.*}} = llvm.insertvalue %[[NEGX]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)>
515515
// CHECK: %{{.*}} = llvm.insertvalue %[[NEGY]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)>
516516
// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)>
517+
518+
// -----
519+
520+
// Test `fir.convert` operation conversion from Float type.
521+
522+
func @convert_from_float(%arg0 : f32) {
523+
%0 = fir.convert %arg0 : (f32) -> f16
524+
%1 = fir.convert %arg0 : (f32) -> f32
525+
%2 = fir.convert %arg0 : (f32) -> f64
526+
%3 = fir.convert %arg0 : (f32) -> f80
527+
%4 = fir.convert %arg0 : (f32) -> f128
528+
%5 = fir.convert %arg0 : (f32) -> i1
529+
%6 = fir.convert %arg0 : (f32) -> i8
530+
%7 = fir.convert %arg0 : (f32) -> i16
531+
%8 = fir.convert %arg0 : (f32) -> i32
532+
%9 = fir.convert %arg0 : (f32) -> i64
533+
return
534+
}
535+
536+
// CHECK-LABEL: convert_from_float(
537+
// CHECK-SAME: %[[ARG0:.*]]: f32
538+
// CHECK: %{{.*}} = llvm.fptrunc %[[ARG0]] : f32 to f16
539+
// CHECK-NOT: f32 to f32
540+
// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f64
541+
// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f80
542+
// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f128
543+
// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i1
544+
// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i8
545+
// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i16
546+
// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i32
547+
// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i64
548+
549+
// -----
550+
551+
// Test `fir.convert` operation conversion from Integer type.
552+
553+
func @convert_from_int(%arg0 : i32) {
554+
%0 = fir.convert %arg0 : (i32) -> f16
555+
%1 = fir.convert %arg0 : (i32) -> f32
556+
%2 = fir.convert %arg0 : (i32) -> f64
557+
%3 = fir.convert %arg0 : (i32) -> f80
558+
%4 = fir.convert %arg0 : (i32) -> f128
559+
%5 = fir.convert %arg0 : (i32) -> i1
560+
%6 = fir.convert %arg0 : (i32) -> i8
561+
%7 = fir.convert %arg0 : (i32) -> i16
562+
%8 = fir.convert %arg0 : (i32) -> i32
563+
%9 = fir.convert %arg0 : (i32) -> i64
564+
%10 = fir.convert %arg0 : (i32) -> i64
565+
%ptr = fir.convert %10 : (i64) -> !fir.ref<i64>
566+
return
567+
}
568+
569+
// CHECK-LABEL: convert_from_int(
570+
// CHECK-SAME: %[[ARG0:.*]]: i32
571+
// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f16
572+
// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f32
573+
// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f64
574+
// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f80
575+
// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f128
576+
// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i1
577+
// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i8
578+
// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i16
579+
// CHECK-NOT: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i32
580+
// CHECK: %{{.*}} = llvm.sext %[[ARG0]] : i32 to i64
581+
// CHECK: %{{.*}} = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<i64>
582+
583+
// -----
584+
585+
// Test `fir.convert` operation conversion from !fir.ref<> type.
586+
587+
func @convert_from_ref(%arg0 : !fir.ref<i32>) {
588+
%0 = fir.convert %arg0 : (!fir.ref<i32>) -> !fir.ref<i8>
589+
%1 = fir.convert %arg0 : (!fir.ref<i32>) -> i32
590+
return
591+
}
592+
593+
// CHECK-LABEL: convert_from_ref(
594+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<i32>
595+
// CHECK: %{{.*}} = llvm.bitcast %[[ARG0]] : !llvm.ptr<i32> to !llvm.ptr<i8>
596+
// CHECK: %{{.*}} = llvm.ptrtoint %[[ARG0]] : !llvm.ptr<i32> to i32
597+
598+
// -----
599+
600+
// Test `fir.convert` operation conversion between fir.complex types.
601+
602+
func @convert_complex4(%arg0 : !fir.complex<4>) -> !fir.complex<8> {
603+
%0 = fir.convert %arg0 : (!fir.complex<4>) -> !fir.complex<8>
604+
return %0 : !fir.complex<8>
605+
}
606+
607+
// CHECK-LABEL: func @convert_complex4(
608+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f32, f32)>) -> !llvm.struct<(f64, f64)>
609+
// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f32, f32)>
610+
// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f32, f32)>
611+
// CHECK: %[[CONVERTX:.*]] = llvm.fpext %[[X]] : f32 to f64
612+
// CHECK: %[[CONVERTY:.*]] = llvm.fpext %[[Y]] : f32 to f64
613+
// CHECK: %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
614+
// CHECK: %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f64, f64)>
615+
// CHECK: %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f64, f64)>
616+
// CHECK: llvm.return %[[STRUCT2]] : !llvm.struct<(f64, f64)>
617+
618+
// Test `fir.convert` operation conversion between fir.complex types.
619+
620+
func @convert_complex16(%arg0 : !fir.complex<16>) -> !fir.complex<2> {
621+
%0 = fir.convert %arg0 : (!fir.complex<16>) -> !fir.complex<2>
622+
return %0 : !fir.complex<2>
623+
}
624+
625+
// CHECK-LABEL: func @convert_complex16(
626+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f16, f16)>
627+
// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)>
628+
// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)>
629+
// CHECK: %[[CONVERTX:.*]] = llvm.fptrunc %[[X]] : f128 to f16
630+
// CHECK: %[[CONVERTY:.*]] = llvm.fptrunc %[[Y]] : f128 to f16
631+
// CHECK: %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f16, f16)>
632+
// CHECK: %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f16, f16)>
633+
// CHECK: %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f16, f16)>
634+
// CHECK: llvm.return %[[STRUCT2]] : !llvm.struct<(f16, f16)>

0 commit comments

Comments
 (0)