Skip to content

Conversation

DavidTruby
Copy link
Member

This patch adds support for BIND(C) derived types as return values
matching the AArch64 Procedure Call Standard for C.

Support for BIND(C) derived types as value parameters will be in a
separate patch.

This patch adds support for BIND(C) derived types as return values
matching the AArch64 Procedure Call Standard for C.

Support for BIND(C) derived types as value parameters will be in a
separate patch.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Oct 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Truby (DavidTruby)

Changes

This patch adds support for BIND(C) derived types as return values
matching the AArch64 Procedure Call Standard for C.

Support for BIND(C) derived types as value parameters will be in a
separate patch.


Full diff: https://github.com/llvm/llvm-project/pull/114051.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+42)
  • (added) flang/test/Fir/struct-return-aarch64.fir (+156)
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 6c148dffb0e55a..15ffdb74ef51d6 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     }
     return marshal;
   }
+
+  static bool isHFA(fir::RecordType ty) {
+    auto types = ty.getTypeList();
+    if (types.empty() || types.size() > 4) {
+      return false;
+    }
+
+    if (!isa_real(types.front().second)) {
+      types.front().second.dump();
+      return false;
+    }
+
+    return llvm::all_equal(llvm::make_second_range(types));
+  }
+
+  CodeGenSpecifics::Marshalling
+  structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+    CodeGenSpecifics::Marshalling marshal;
+
+    if (isHFA(ty)) {
+      auto newTy = fir::SequenceType::get({ty.getNumFields()}, ty.getType(0));
+      marshal.emplace_back(newTy, AT{});
+      return marshal;
+    }
+
+    auto [size, align] =
+        fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+
+    // return in registers if size <= 16 bytes
+    if (size <= 16) {
+      auto dwordSize = (size + 7) / 8;
+      auto newTy = fir::SequenceType::get(
+          dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
+      marshal.emplace_back(newTy, AT{});
+      return marshal;
+    }
+
+    unsigned short stackAlign = std::max<unsigned short>(align, 8u);
+    marshal.emplace_back(fir::ReferenceType::get(ty),
+                         AT{stackAlign, false, true});
+    return marshal;
+  }
 };
 } // namespace
 
diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir
new file mode 100644
index 00000000000000..96f2f9999b3435
--- /dev/null
+++ b/flang/test/Fir/struct-return-aarch64.fir
@@ -0,0 +1,156 @@
+// Test AArch64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
+// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
+
+!composite = !fir.type<t1{i:f32,j:i32,k:f32}>
+// CHECK-LABEL: func.func private @test_composite() -> !fir.array<2xi64>
+func.func private @test_composite() -> !composite
+// CHECK-LABEL: func.func @test_call_composite(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>)
+func.func @test_call_composite(%arg0 : !fir.ref<!composite>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_composite() : () -> !fir.array<2xi64>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xi64>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xi64>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xi64>>) -> !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_composite() : () -> !composite
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+  fir.store %out to %arg0 : !fir.ref<!composite>
+  // CHECK: return
+  return
+}
+
+!hfa_f16 = !fir.type<t2{x:f16, y:f16}>
+// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.array<2xf16>
+func.func private @test_hfa_f16() -> !hfa_f16
+// CHECK-LABEL: func.func @test_call_hfa_f16(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t2{x:f16,y:f16}>>) {
+func.func @test_call_hfa_f16(%arg0 : !fir.ref<!hfa_f16>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.array<2xf16>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xf16>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xf16>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xf16>>) -> !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f16() : () -> !hfa_f16
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f16>
+  return
+}
+
+!hfa_f32 = !fir.type<t3{w:f32, x:f32, y:f32, z:f32}>
+// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.array<4xf32>
+func.func private @test_hfa_f32() -> !hfa_f32
+// CHECK-LABEL: func.func @test_call_hfa_f32(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>) {
+func.func @test_call_hfa_f32(%arg0 : !fir.ref<!hfa_f32>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.array<4xf32>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf32>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf32>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f32() : () -> !hfa_f32
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f32>
+  return
+}
+
+!hfa_f64 = !fir.type<t4{x:f64, y:f64, z:f64}>
+// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.array<3xf64>
+func.func private @test_hfa_f64() -> !hfa_f64
+// CHECK-LABEL: func.func @test_call_hfa_f64(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>)
+func.func @test_call_hfa_f64(%arg0 : !fir.ref<!hfa_f64>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.array<3xf64>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<3xf64>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<3xf64>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<3xf64>>) -> !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f64() : () -> !hfa_f64
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f64>
+  return
+}
+
+!hfa_f128 = !fir.type<t5{w:f128, x:f128, y:f128, z:f128}>
+// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.array<4xf128>
+func.func private @test_hfa_f128() -> !hfa_f128
+// CHECK-LABEL: func.func @test_call_hfa_f128(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>) {
+func.func @test_call_hfa_f128(%arg0 : !fir.ref<!hfa_f128>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.array<4xf128>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf128>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf128>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf128>>) -> !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f128() : () -> !hfa_f128
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f128>
+  return
+}
+
+!hfa_bf16 = !fir.type<t6{w:bf16, x:bf16, y:bf16, z:bf16}>
+// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.array<4xbf16>
+func.func private @test_hfa_bf16() -> !hfa_bf16
+// CHECK-LABEL: func.func @test_call_hfa_bf16(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>) {
+func.func @test_call_hfa_bf16(%arg0 : !fir.ref<!hfa_bf16>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.array<4xbf16>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xbf16>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xbf16>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xbf16>>) -> !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_bf16>
+  return
+}
+
+!too_big = !fir.type<t7{x:i64, y:i64, z:i64}>
+// CHECK-LABEL: func.func private @test_too_big(!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+// CHECK-SAME:    {llvm.align = 8 : i32, llvm.sret = !fir.type<t7{x:i64,y:i64,z:i64}>})
+func.func private @test_too_big() -> !too_big
+// CHECK-LABEL: func.func @test_call_too_big(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) {
+func.func @test_call_too_big(%arg0 : !fir.ref<!too_big>) {
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARG:.*]] = fir.alloca !fir.type<t7{x:i64,y:i64,z:i64}>
+  // CHECK: fir.call @test_too_big(%[[ARG]]) : (!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) -> ()
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) -> !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_too_big() : () -> !too_big
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+  fir.store %out to %arg0 : !fir.ref<!too_big>
+  return
+}
+
+
+!too_big_hfa = !fir.type<t8{i:!fir.array<5xf32>}>
+// CHECK-LABEL: func.func private @test_too_big_hfa(!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+// CHECK-SAME:    {llvm.align = 8 : i32, llvm.sret = !fir.type<t8{i:!fir.array<5xf32>}>})
+func.func private @test_too_big_hfa() -> !too_big_hfa
+// CHECK-LABEL: func.func @test_call_too_big_hfa(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) {
+func.func @test_call_too_big_hfa(%arg0 : !fir.ref<!too_big_hfa>) {
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARG:.*]] = fir.alloca !fir.type<t8{i:!fir.array<5xf32>}>
+  // CHECK: fir.call @test_too_big_hfa(%[[ARG]]) : (!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) -> ()
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) -> !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_too_big_hfa() : () -> !too_big_hfa
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+  fir.store %out to %arg0 : !fir.ref<!too_big_hfa>
+  return
+}

@DavidTruby
Copy link
Member Author

The ABI for return values here is also the same on Windows so I have not separated out Windows from non-Windows as is done on x86.

@tblah
Copy link
Contributor

tblah commented Oct 29, 2024

Please could you link to some public documentation for the ABI?

@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
}
return marshal;
}

static bool isHFA(fir::RecordType ty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is HFA? Please document.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DavidTruby I think you missed this when responding to feedback

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I commented on the wrong one here. I meant the "expand auto" nit below.

@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
}
return marshal;
}

static bool isHFA(fir::RecordType ty) {
auto types = ty.getTypeList();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:expand auto.

}

if (!isa_real(types.front().second)) {
types.front().second.dump();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this leftover debug?


// return in registers if size <= 16 bytes
if (size <= 16) {
auto dwordSize = (size + 7) / 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:Spell the type.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nitpick otherwise great stuff


unsigned short stackAlign = std::max<unsigned short>(align, 8u);
marshal.emplace_back(fir::ReferenceType::get(ty),
AT{stackAlign, false, true});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AT{stackAlign, false, true});
AT{/*alignment=*/stackAlign, /*byval=*/false, /*sret=*/true});

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, not familiar with the ABI, just a small question. Code looks good otherwise.


if (!isa_real(types.front().second)) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derived types containing small fp arrays will be rejected here, is the following C struct an HFA? :

typedef struct {
    float x[2];
    float y;
} S;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. It is, yes. As is:

typedef struct {
  float x;
  float y;
} S1;

typedef struct {
  S1 s;
  float z;
} S2;

Good spot.

@DavidTruby
Copy link
Member Author

I'm not all that satisfied with how I did this, I feel like I have over complicated it, although it does work. If anyone has any suggestions I'd happily take them!

@DavidTruby
Copy link
Member Author

@jeanPerier does this look ok to you now?

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that for addressing the comment.

Building the flatten type list just for the HFA check is indeed a bit heavy, if this was expected to be exercised for types with many components, I would advise some kind of type visitor without copy and with early exit, but given this is not common and you only call it for types with less than 4 components, it is just fine as it is.

@DavidTruby DavidTruby merged commit 2d62daa into llvm:main Nov 25, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants