-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[flang] AArch64 support for BIND(C) derived return types #114051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds support for BIND(C) derived types as return values matching the AArch64 Procedure Call Standard for C. Support for BIND(C) derived types as value parameters will be in a separate patch.
@llvm/pr-subscribers-flang-codegen @llvm/pr-subscribers-flang-fir-hlfir Author: David Truby (DavidTruby) ChangesThis patch adds support for BIND(C) derived types as return values Support for BIND(C) derived types as value parameters will be in a Full diff: https://github.com/llvm/llvm-project/pull/114051.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 6c148dffb0e55a..15ffdb74ef51d6 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
}
return marshal;
}
+
+ static bool isHFA(fir::RecordType ty) {
+ auto types = ty.getTypeList();
+ if (types.empty() || types.size() > 4) {
+ return false;
+ }
+
+ if (!isa_real(types.front().second)) {
+ types.front().second.dump();
+ return false;
+ }
+
+ return llvm::all_equal(llvm::make_second_range(types));
+ }
+
+ CodeGenSpecifics::Marshalling
+ structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+ CodeGenSpecifics::Marshalling marshal;
+
+ if (isHFA(ty)) {
+ auto newTy = fir::SequenceType::get({ty.getNumFields()}, ty.getType(0));
+ marshal.emplace_back(newTy, AT{});
+ return marshal;
+ }
+
+ auto [size, align] =
+ fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+
+ // return in registers if size <= 16 bytes
+ if (size <= 16) {
+ auto dwordSize = (size + 7) / 8;
+ auto newTy = fir::SequenceType::get(
+ dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
+ marshal.emplace_back(newTy, AT{});
+ return marshal;
+ }
+
+ unsigned short stackAlign = std::max<unsigned short>(align, 8u);
+ marshal.emplace_back(fir::ReferenceType::get(ty),
+ AT{stackAlign, false, true});
+ return marshal;
+ }
};
} // namespace
diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir
new file mode 100644
index 00000000000000..96f2f9999b3435
--- /dev/null
+++ b/flang/test/Fir/struct-return-aarch64.fir
@@ -0,0 +1,156 @@
+// Test AArch64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
+// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
+
+!composite = !fir.type<t1{i:f32,j:i32,k:f32}>
+// CHECK-LABEL: func.func private @test_composite() -> !fir.array<2xi64>
+func.func private @test_composite() -> !composite
+// CHECK-LABEL: func.func @test_call_composite(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>)
+func.func @test_call_composite(%arg0 : !fir.ref<!composite>) {
+ // CHECK: %[[OUT:.*]] = fir.call @test_composite() : () -> !fir.array<2xi64>
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xi64>
+ // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xi64>>
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xi64>>) -> !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_composite() : () -> !composite
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+ fir.store %out to %arg0 : !fir.ref<!composite>
+ // CHECK: return
+ return
+}
+
+!hfa_f16 = !fir.type<t2{x:f16, y:f16}>
+// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.array<2xf16>
+func.func private @test_hfa_f16() -> !hfa_f16
+// CHECK-LABEL: func.func @test_call_hfa_f16(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t2{x:f16,y:f16}>>) {
+func.func @test_call_hfa_f16(%arg0 : !fir.ref<!hfa_f16>) {
+ // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.array<2xf16>
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xf16>
+ // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xf16>>
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xf16>>) -> !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_hfa_f16() : () -> !hfa_f16
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+ fir.store %out to %arg0 : !fir.ref<!hfa_f16>
+ return
+}
+
+!hfa_f32 = !fir.type<t3{w:f32, x:f32, y:f32, z:f32}>
+// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.array<4xf32>
+func.func private @test_hfa_f32() -> !hfa_f32
+// CHECK-LABEL: func.func @test_call_hfa_f32(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>) {
+func.func @test_call_hfa_f32(%arg0 : !fir.ref<!hfa_f32>) {
+ // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.array<4xf32>
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf32>
+ // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf32>>
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_hfa_f32() : () -> !hfa_f32
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+ fir.store %out to %arg0 : !fir.ref<!hfa_f32>
+ return
+}
+
+!hfa_f64 = !fir.type<t4{x:f64, y:f64, z:f64}>
+// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.array<3xf64>
+func.func private @test_hfa_f64() -> !hfa_f64
+// CHECK-LABEL: func.func @test_call_hfa_f64(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>)
+func.func @test_call_hfa_f64(%arg0 : !fir.ref<!hfa_f64>) {
+ // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.array<3xf64>
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<3xf64>
+ // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<3xf64>>
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<3xf64>>) -> !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_hfa_f64() : () -> !hfa_f64
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+ fir.store %out to %arg0 : !fir.ref<!hfa_f64>
+ return
+}
+
+!hfa_f128 = !fir.type<t5{w:f128, x:f128, y:f128, z:f128}>
+// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.array<4xf128>
+func.func private @test_hfa_f128() -> !hfa_f128
+// CHECK-LABEL: func.func @test_call_hfa_f128(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>) {
+func.func @test_call_hfa_f128(%arg0 : !fir.ref<!hfa_f128>) {
+ // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.array<4xf128>
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf128>
+ // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf128>>
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf128>>) -> !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_hfa_f128() : () -> !hfa_f128
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+ fir.store %out to %arg0 : !fir.ref<!hfa_f128>
+ return
+}
+
+!hfa_bf16 = !fir.type<t6{w:bf16, x:bf16, y:bf16, z:bf16}>
+// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.array<4xbf16>
+func.func private @test_hfa_bf16() -> !hfa_bf16
+// CHECK-LABEL: func.func @test_call_hfa_bf16(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>) {
+func.func @test_call_hfa_bf16(%arg0 : !fir.ref<!hfa_bf16>) {
+ // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.array<4xbf16>
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xbf16>
+ // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xbf16>>
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xbf16>>) -> !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+ fir.store %out to %arg0 : !fir.ref<!hfa_bf16>
+ return
+}
+
+!too_big = !fir.type<t7{x:i64, y:i64, z:i64}>
+// CHECK-LABEL: func.func private @test_too_big(!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+// CHECK-SAME: {llvm.align = 8 : i32, llvm.sret = !fir.type<t7{x:i64,y:i64,z:i64}>})
+func.func private @test_too_big() -> !too_big
+// CHECK-LABEL: func.func @test_call_too_big(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) {
+func.func @test_call_too_big(%arg0 : !fir.ref<!too_big>) {
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARG:.*]] = fir.alloca !fir.type<t7{x:i64,y:i64,z:i64}>
+ // CHECK: fir.call @test_too_big(%[[ARG]]) : (!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) -> ()
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) -> !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_too_big() : () -> !too_big
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+ fir.store %out to %arg0 : !fir.ref<!too_big>
+ return
+}
+
+
+!too_big_hfa = !fir.type<t8{i:!fir.array<5xf32>}>
+// CHECK-LABEL: func.func private @test_too_big_hfa(!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+// CHECK-SAME: {llvm.align = 8 : i32, llvm.sret = !fir.type<t8{i:!fir.array<5xf32>}>})
+func.func private @test_too_big_hfa() -> !too_big_hfa
+// CHECK-LABEL: func.func @test_call_too_big_hfa(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) {
+func.func @test_call_too_big_hfa(%arg0 : !fir.ref<!too_big_hfa>) {
+ // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+ // CHECK: %[[ARG:.*]] = fir.alloca !fir.type<t8{i:!fir.array<5xf32>}>
+ // CHECK: fir.call @test_too_big_hfa(%[[ARG]]) : (!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) -> ()
+ // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) -> !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+ // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+ // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+ %out = fir.call @test_too_big_hfa() : () -> !too_big_hfa
+ // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+ fir.store %out to %arg0 : !fir.ref<!too_big_hfa>
+ return
+}
|
The ABI for return values here is also the same on Windows so I have not separated out Windows from non-Windows as is done on x86. |
Please could you link to some public documentation for the ABI? |
@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { | |||
} | |||
return marshal; | |||
} | |||
|
|||
static bool isHFA(fir::RecordType ty) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is HFA? Please document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DavidTruby I think you missed this when responding to feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I commented on the wrong one here. I meant the "expand auto" nit below.
@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { | |||
} | |||
return marshal; | |||
} | |||
|
|||
static bool isHFA(fir::RecordType ty) { | |||
auto types = ty.getTypeList(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:expand auto.
} | ||
|
||
if (!isa_real(types.front().second)) { | ||
types.front().second.dump(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this leftover debug?
|
||
// return in registers if size <= 16 bytes | ||
if (size <= 16) { | ||
auto dwordSize = (size + 7) / 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:Spell the type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One nitpick otherwise great stuff
|
||
unsigned short stackAlign = std::max<unsigned short>(align, 8u); | ||
marshal.emplace_back(fir::ReferenceType::get(ty), | ||
AT{stackAlign, false, true}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AT{stackAlign, false, true}); | |
AT{/*alignment=*/stackAlign, /*byval=*/false, /*sret=*/true}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, not familiar with the ABI, just a small question. Code looks good otherwise.
|
||
if (!isa_real(types.front().second)) { | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Derived types containing small fp arrays will be rejected here, is the following C struct an HFA? :
typedef struct {
float x[2];
float y;
} S;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah. It is, yes. As is:
typedef struct {
float x;
float y;
} S1;
typedef struct {
S1 s;
float z;
} S2;
Good spot.
I'm not all that satisfied with how I did this, I feel like I have over complicated it, although it does work. If anyone has any suggestions I'd happily take them! |
@jeanPerier does this look ok to you now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that for addressing the comment.
Building the flatten type list just for the HFA check is indeed a bit heavy, if this was expected to be exercised for types with many components, I would advise some kind of type visitor without copy and with early exit, but given this is not common and you only call it for types with less than 4 components, it is just fine as it is.
This patch adds support for BIND(C) derived types as return values
matching the AArch64 Procedure Call Standard for C.
Support for BIND(C) derived types as value parameters will be in a
separate patch.