Skip to content

Commit 2b737ca

Browse files
authored
rpc : resource management rework (#7562)
* rpc : resource management rework * address review comments
1 parent ee3dff6 commit 2b737ca

File tree

1 file changed

+75
-58
lines changed

1 file changed

+75
-58
lines changed

ggml-rpc.cpp

Lines changed: 75 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <string>
77
#include <vector>
88
#include <memory>
9+
#include <mutex>
910
#include <unordered_map>
1011
#include <unordered_set>
1112
#ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
4748
sockfd_t fd;
4849
socket_t(sockfd_t fd) : fd(fd) {}
4950
~socket_t() {
51+
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
5052
#ifdef _WIN32
5153
closesocket(this->fd);
5254
#else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
9799
}
98100

99101
struct ggml_backend_rpc_buffer_type_context {
100-
std::shared_ptr<socket_t> sock;
102+
std::string endpoint;
101103
std::string name;
102104
size_t alignment;
103105
size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
106108
struct ggml_backend_rpc_context {
107109
std::string endpoint;
108110
std::string name;
109-
std::shared_ptr<socket_t> sock;
110-
ggml_backend_buffer_type_t buft;
111111
};
112112

113113
struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231231
return true;
232232
}
233233

234-
static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
235-
std::string str(endpoint);
236-
size_t pos = str.find(':');
234+
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
235+
size_t pos = endpoint.find(':');
237236
if (pos == std::string::npos) {
238237
return false;
239238
}
240-
host = str.substr(0, pos);
241-
port = std::stoi(str.substr(pos + 1));
239+
host = endpoint.substr(0, pos);
240+
port = std::stoi(endpoint.substr(pos + 1));
242241
return true;
243242
}
244243

@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273272

274273
// RPC client-side implementation
275274

275+
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
276+
static std::mutex mutex;
277+
std::lock_guard<std::mutex> lock(mutex);
278+
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
279+
static bool initialized = false;
280+
281+
auto it = sockets.find(endpoint);
282+
if (it != sockets.end()) {
283+
if (auto sock = it->second.lock()) {
284+
return sock;
285+
}
286+
}
287+
std::string host;
288+
int port;
289+
if (!parse_endpoint(endpoint, host, port)) {
290+
return nullptr;
291+
}
292+
#ifdef _WIN32
293+
if (!initialized) {
294+
WSADATA wsaData;
295+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
296+
if (res != 0) {
297+
return nullptr;
298+
}
299+
initialized = true;
300+
}
301+
#else
302+
UNUSED(initialized);
303+
#endif
304+
auto sock = socket_connect(host.c_str(), port);
305+
if (sock == nullptr) {
306+
return nullptr;
307+
}
308+
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
309+
sockets[endpoint] = sock;
310+
return sock;
311+
}
312+
276313
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
277314
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
278315
return ctx->name.c_str();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442479
std::vector<uint8_t> input(input_size, 0);
443480
memcpy(input.data(), &size, sizeof(size));
444481
std::vector<uint8_t> output;
445-
bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
482+
auto sock = get_socket(buft_ctx->endpoint);
483+
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
446484
GGML_ASSERT(status);
447485
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
448486
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453491
if (remote_ptr != 0) {
454492
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455493
ggml_backend_rpc_buffer_interface,
456-
new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
494+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
457495
remote_size);
458496
return buffer;
459497
} else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508546
}
509547
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
510548
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
511-
return buft_ctx->sock == rpc_ctx->sock;
549+
return buft_ctx->endpoint == rpc_ctx->endpoint;
512550
}
513551

514552
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521559
/* .is_host = */ NULL,
522560
};
523561

524-
525562
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
526563
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
527564

@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530567

531568
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
532569
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
533-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
534-
delete buft_ctx;
535-
delete rpc_ctx->buft;
536570
delete rpc_ctx;
537571
delete backend;
538572
}
539573

540574
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
541575
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
542-
return ctx->buft;
576+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
543577
}
544578

545579
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590624
std::vector<uint8_t> input;
591625
serialize_graph(cgraph, input);
592626
std::vector<uint8_t> output;
593-
bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
627+
auto sock = get_socket(rpc_ctx->endpoint);
628+
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
594629
GGML_ASSERT(status);
595630
GGML_ASSERT(output.size() == 1);
596631
return (enum ggml_status)output[0];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624659
/* .event_synchronize = */ NULL,
625660
};
626661

627-
static std::unordered_map<std::string, ggml_backend_t> instances;
628-
629662
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
630-
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
631-
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
632-
}
633-
634-
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
635-
std::string endpoint_str(endpoint);
636-
if (instances.find(endpoint_str) != instances.end()) {
637-
return instances[endpoint_str];
638-
}
639-
#ifdef _WIN32
640-
{
641-
WSADATA wsaData;
642-
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
643-
if (res != 0) {
644-
return nullptr;
645-
}
646-
}
647-
#endif
648-
fprintf(stderr, "Connecting to %s\n", endpoint);
649-
std::string host;
650-
int port;
651-
if (!parse_endpoint(endpoint, host, port)) {
652-
return nullptr;
653-
}
654-
auto sock = socket_connect(host.c_str(), port);
663+
static std::mutex mutex;
664+
std::lock_guard<std::mutex> lock(mutex);
665+
// NOTE: buffer types are allocated and never freed; this is by design
666+
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
667+
auto it = buft_map.find(endpoint);
668+
if (it != buft_map.end()) {
669+
return it->second;
670+
}
671+
auto sock = get_socket(endpoint);
655672
if (sock == nullptr) {
656673
return nullptr;
657674
}
658675
size_t alignment = get_alignment(sock);
659676
size_t max_size = get_max_size(sock);
660677
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661-
/* .sock = */ sock,
662-
/* .name = */ "RPC" + std::to_string(sock->fd),
678+
/* .endpoint = */ endpoint,
679+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
663680
/* .alignment = */ alignment,
664-
/* .max_size = */ max_size
681+
/* .max_size = */ max_size
665682
};
666683

667684
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
668685
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
669686
/* .context = */ buft_ctx
670687
};
688+
buft_map[endpoint] = buft;
689+
return buft;
690+
}
671691

692+
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
672693
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673-
/* .endpoint = */ endpoint,
674-
/* .name = */ "RPC" + std::to_string(sock->fd),
675-
/* .sock = */ sock,
676-
/* .buft = */ buft
694+
/* .endpoint = */ endpoint,
695+
/* .name = */ "RPC",
677696
};
678697

679-
instances[endpoint] = new ggml_backend {
698+
ggml_backend_t backend = new ggml_backend {
680699
/* .guid = */ ggml_backend_rpc_guid(),
681700
/* .interface = */ ggml_backend_rpc_interface,
682701
/* .context = */ ctx
683702
};
684-
685-
return instances[endpoint];
703+
return backend;
686704
}
687705

688706
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706724
}
707725

708726
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
709-
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
710-
if (backend == nullptr) {
727+
auto sock = get_socket(endpoint);
728+
if (sock == nullptr) {
711729
*free = 0;
712730
*total = 0;
713731
return;
714732
}
715-
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
716-
get_device_memory(ctx->sock, free, total);
733+
get_device_memory(sock, free, total);
717734
}
718735

719736
// RPC server-side implementation

0 commit comments

Comments
 (0)