Skip to content

Commit ee57a68

Browse files
authored
[libc] Make a dedicated thread for the RPC server (#111210)
Summary: Make a separate thread to run the server when we launch. This is required by CUDA, which you can force with `export CUDA_LAUNCH_BLOCKING=1`. I figured I might as well be consistent and do it for the AMD implementation as well even though I believe it's not necessary.
1 parent 1062007 commit ee57a68

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
#include "hsa/hsa_ext_amd.h"
2929
#endif
3030

31+
#include <atomic>
3132
#include <cstdio>
3233
#include <cstdlib>
3334
#include <cstring>
35+
#include <thread>
3436
#include <tuple>
3537
#include <utility>
3638

@@ -289,18 +291,26 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
289291
__atomic_store_n((uint32_t *)&packet->header, header_word, __ATOMIC_RELEASE);
290292
hsa_signal_store_relaxed(queue->doorbell_signal, packet_id);
291293

294+
std::atomic<bool> finished = false;
295+
std::thread server(
296+
[](std::atomic<bool> *finished, rpc_device_t device) {
297+
while (!*finished) {
298+
if (rpc_status_t err = rpc_handle_server(device))
299+
handle_error(err);
300+
}
301+
},
302+
&finished, device);
303+
292304
// Wait until the kernel has completed execution on the device. Periodically
293305
// check the RPC client for work to be performed on the server.
294-
while (hsa_signal_wait_scacquire(
295-
packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
296-
/*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
297-
if (rpc_status_t err = rpc_handle_server(device))
298-
handle_error(err);
306+
while (hsa_signal_wait_scacquire(packet->completion_signal,
307+
HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX,
308+
HSA_WAIT_STATE_BLOCKED) != 0)
309+
;
299310

300-
// Handle the server one more time in case the kernel exited with a pending
301-
// send still in flight.
302-
if (rpc_status_t err = rpc_handle_server(device))
303-
handle_error(err);
311+
finished = true;
312+
if (server.joinable())
313+
server.join();
304314

305315
// Destroy the resources acquired to launch the kernel and return.
306316
if (hsa_status_t err = hsa_amd_memory_pool_free(args))

libc/utils/gpu/loader/nvptx/nvptx-loader.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
#include "llvm/Object/ELF.h"
2121
#include "llvm/Object/ELFObjectFile.h"
2222

23+
#include <atomic>
2324
#include <cstddef>
2425
#include <cstdio>
2526
#include <cstdlib>
2627
#include <cstring>
28+
#include <thread>
2729
#include <vector>
2830

2931
using namespace llvm;
@@ -224,24 +226,30 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
224226
if (print_resource_usage)
225227
print_kernel_resources(binary, kernel_name);
226228

229+
std::atomic<bool> finished = false;
230+
std::thread server(
231+
[](std::atomic<bool> *finished, rpc_device_t device) {
232+
while (!*finished) {
233+
if (rpc_status_t err = rpc_handle_server(device))
234+
handle_error(err);
235+
}
236+
},
237+
&finished, rpc_device);
238+
227239
// Call the kernel with the given arguments.
228240
if (CUresult err = cuLaunchKernel(
229241
function, params.num_blocks_x, params.num_blocks_y,
230242
params.num_blocks_z, params.num_threads_x, params.num_threads_y,
231243
params.num_threads_z, 0, stream, nullptr, args_config))
232244
handle_error(err);
233245

234-
// Wait until the kernel has completed execution on the device. Periodically
235-
// check the RPC client for work to be performed on the server.
236-
while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
237-
if (rpc_status_t err = rpc_handle_server(rpc_device))
238-
handle_error(err);
239-
240-
// Handle the server one more time in case the kernel exited with a pending
241-
// send still in flight.
242-
if (rpc_status_t err = rpc_handle_server(rpc_device))
246+
if (CUresult err = cuStreamSynchronize(stream))
243247
handle_error(err);
244248

249+
finished = true;
250+
if (server.joinable())
251+
server.join();
252+
245253
return CUDA_SUCCESS;
246254
}
247255

0 commit comments

Comments
 (0)