Skip to content

Commit 26ca8ef

Browse files
authored
[libc] GPU RPC interface: add return value to rpc_host_call (#111288)
1 parent 56757e5 commit 26ca8ef

File tree

6 files changed

+32
-12
lines changed

6 files changed

+32
-12
lines changed

libc/newhdrgen/yaml/gpu/rpc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ functions:
1616
- name: rpc_host_call
1717
standards:
1818
- GPUExtensions
19-
return_type: void
19+
return_type: unsigned long long
2020
arguments:
2121
- type: void *
2222
- type: void *

libc/spec/gpu_ext.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def GPUExtensions : StandardSpec<"GPUExtensions"> {
77
[
88
FunctionSpec<
99
"rpc_host_call",
10-
RetValSpec<VoidType>,
10+
RetValSpec<UnsignedLongLongType>,
1111
[ArgSpec<VoidPtr>, ArgSpec<VoidPtr>, ArgSpec<SizeTType>]
1212
>,
1313
]

libc/src/gpu/rpc_host_call.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,19 @@ namespace LIBC_NAMESPACE_DECL {
1717

1818
// This calls the associated function pointer on the RPC server with the given
1919
// arguments. We expect that the pointer here is a valid pointer on the server.
20-
LLVM_LIBC_FUNCTION(void, rpc_host_call, (void *fn, void *data, size_t size)) {
20+
LLVM_LIBC_FUNCTION(unsigned long long, rpc_host_call,
21+
(void *fn, void *data, size_t size)) {
2122
rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
2223
port.send_n(data, size);
2324
port.send([=](rpc::Buffer *buffer) {
2425
buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
2526
});
26-
port.recv([](rpc::Buffer *) {});
27+
unsigned long long ret;
28+
port.recv([&](rpc::Buffer *buffer) {
29+
ret = static_cast<unsigned long long>(buffer->data[0]);
30+
});
2731
port.close();
32+
return ret;
2833
}
2934

3035
} // namespace LIBC_NAMESPACE_DECL

libc/src/gpu/rpc_host_call.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
namespace LIBC_NAMESPACE_DECL {
1616

17-
void rpc_host_call(void *fn, void *buffer, size_t size);
17+
unsigned long long rpc_host_call(void *fn, void *buffer, size_t size);
1818

1919
} // namespace LIBC_NAMESPACE_DECL
2020

libc/utils/gpu/server/rpc_server.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,18 @@ rpc_status_t handle_server_impl(
319319
}
320320
case RPC_HOST_CALL: {
321321
uint64_t sizes[lane_size] = {0};
322+
unsigned long long results[lane_size] = {0};
322323
void *args[lane_size] = {nullptr};
323324
port->recv_n(args, sizes,
324325
[&](uint64_t size) { return temp_storage.alloc(size); });
325326
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
326-
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
327+
using func_ptr_t = unsigned long long (*)(void *);
328+
auto func = reinterpret_cast<func_ptr_t>(buffer->data[0]);
329+
results[id] = func(args[id]);
330+
});
331+
port->send([&](rpc::Buffer *buffer, uint32_t id) {
332+
buffer->data[0] = static_cast<uint64_t>(results[id]);
327333
});
328-
port->send([&](rpc::Buffer *, uint32_t id) {});
329334
break;
330335
}
331336
case RPC_FEOF: {

offload/test/libc/host_call.c

+15-5
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88

99
#pragma omp begin declare variant match(device = {kind(gpu)})
1010
// Extension provided by the 'libc' project.
11-
void rpc_host_call(void *fn, void *args, size_t size);
11+
unsigned long long rpc_host_call(void *fn, void *args, size_t size);
1212
#pragma omp declare target to(rpc_host_call) device_type(nohost)
1313
#pragma omp end declare variant
1414

1515
#pragma omp begin declare variant match(device = {kind(cpu)})
1616
// Dummy host implementation to make this work for all targets.
17-
void rpc_host_call(void *fn, void *args, size_t size) {
18-
((void (*)(void *))fn)(args);
17+
unsigned long long rpc_host_call(void *fn, void *args, size_t size) {
18+
return ((unsigned long long (*)(void *))fn)(args);
1919
}
2020
#pragma omp end declare variant
2121

@@ -25,17 +25,26 @@ typedef struct args_s {
2525
} args_t;
2626

2727
// CHECK-DAG: Thread: 0, Block: 0
28+
// CHECK-DAG: Result: 42
2829
// CHECK-DAG: Thread: 1, Block: 0
30+
// CHECK-DAG: Result: 42
2931
// CHECK-DAG: Thread: 0, Block: 1
32+
// CHECK-DAG: Result: 42
3033
// CHECK-DAG: Thread: 1, Block: 1
34+
// CHECK-DAG: Result: 42
3135
// CHECK-DAG: Thread: 0, Block: 2
36+
// CHECK-DAG: Result: 42
3237
// CHECK-DAG: Thread: 1, Block: 2
38+
// CHECK-DAG: Result: 42
3339
// CHECK-DAG: Thread: 0, Block: 3
40+
// CHECK-DAG: Result: 42
3441
// CHECK-DAG: Thread: 1, Block: 3
35-
void foo(void *data) {
42+
// CHECK-DAG: Result: 42
43+
long long foo(void *data) {
3644
assert(omp_is_initial_device() && "Not executing on host?");
3745
args_t *args = (args_t *)data;
3846
printf("Thread: %d, Block: %d\n", args->thread_id, args->block_id);
47+
return 42;
3948
}
4049

4150
void *fn_ptr = NULL;
@@ -49,6 +58,7 @@ int main() {
4958
#pragma omp parallel num_threads(2)
5059
{
5160
args_t args = {omp_get_thread_num(), omp_get_team_num()};
52-
rpc_host_call(fn_ptr, &args, sizeof(args_t));
61+
unsigned long long res = rpc_host_call(fn_ptr, &args, sizeof(args_t));
62+
printf("Result: %d\n", (int)res);
5363
}
5464
}

0 commit comments

Comments
 (0)