6
6
#include < string>
7
7
#include < vector>
8
8
#include < memory>
9
+ #include < mutex>
9
10
#include < unordered_map>
10
11
#include < unordered_set>
11
12
#ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
47
48
sockfd_t fd;
48
49
socket_t (sockfd_t fd) : fd(fd) {}
49
50
~socket_t () {
51
+ GGML_PRINT_DEBUG (" [%s] closing socket %d\n " , __func__, this ->fd );
50
52
#ifdef _WIN32
51
53
closesocket (this ->fd );
52
54
#else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
97
99
}
98
100
99
101
struct ggml_backend_rpc_buffer_type_context {
100
- std::shared_ptr< socket_t > sock ;
102
+ std::string endpoint ;
101
103
std::string name;
102
104
size_t alignment;
103
105
size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
106
108
struct ggml_backend_rpc_context {
107
109
std::string endpoint;
108
110
std::string name;
109
- std::shared_ptr<socket_t > sock;
110
- ggml_backend_buffer_type_t buft;
111
111
};
112
112
113
113
struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231
231
return true ;
232
232
}
233
233
234
- static bool parse_endpoint (const char * endpoint, std::string & host, int & port) {
235
- std::string str (endpoint);
236
- size_t pos = str.find (' :' );
234
+ static bool parse_endpoint (const std::string & endpoint, std::string & host, int & port) {
235
+ size_t pos = endpoint.find (' :' );
237
236
if (pos == std::string::npos) {
238
237
return false ;
239
238
}
240
- host = str .substr (0 , pos);
241
- port = std::stoi (str .substr (pos + 1 ));
239
+ host = endpoint .substr (0 , pos);
240
+ port = std::stoi (endpoint .substr (pos + 1 ));
242
241
return true ;
243
242
}
244
243
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273
272
274
273
// RPC client-side implementation
275
274
275
+ static std::shared_ptr<socket_t > get_socket (const std::string & endpoint) {
276
+ static std::mutex mutex;
277
+ std::lock_guard<std::mutex> lock (mutex);
278
+ static std::unordered_map<std::string, std::weak_ptr<socket_t >> sockets;
279
+ static bool initialized = false ;
280
+
281
+ auto it = sockets.find (endpoint);
282
+ if (it != sockets.end ()) {
283
+ if (auto sock = it->second .lock ()) {
284
+ return sock;
285
+ }
286
+ }
287
+ std::string host;
288
+ int port;
289
+ if (!parse_endpoint (endpoint, host, port)) {
290
+ return nullptr ;
291
+ }
292
+ #ifdef _WIN32
293
+ if (!initialized) {
294
+ WSADATA wsaData;
295
+ int res = WSAStartup (MAKEWORD (2 , 2 ), &wsaData);
296
+ if (res != 0 ) {
297
+ return nullptr ;
298
+ }
299
+ initialized = true ;
300
+ }
301
+ #else
302
+ UNUSED (initialized);
303
+ #endif
304
+ auto sock = socket_connect (host.c_str (), port);
305
+ if (sock == nullptr ) {
306
+ return nullptr ;
307
+ }
308
+ GGML_PRINT_DEBUG (" [%s] connected to %s, sockfd=%d\n " , __func__, endpoint.c_str (), sock->fd );
309
+ sockets[endpoint] = sock;
310
+ return sock;
311
+ }
312
+
276
313
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name (ggml_backend_buffer_t buffer) {
277
314
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
278
315
return ctx->name .c_str ();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442
479
std::vector<uint8_t > input (input_size, 0 );
443
480
memcpy (input.data (), &size, sizeof (size));
444
481
std::vector<uint8_t > output;
445
- bool status = send_rpc_cmd (buft_ctx->sock , ALLOC_BUFFER, input, output);
482
+ auto sock = get_socket (buft_ctx->endpoint );
483
+ bool status = send_rpc_cmd (sock, ALLOC_BUFFER, input, output);
446
484
GGML_ASSERT (status);
447
485
GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
448
486
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453
491
if (remote_ptr != 0 ) {
454
492
ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
455
493
ggml_backend_rpc_buffer_interface,
456
- new ggml_backend_rpc_buffer_context{buft_ctx-> sock , {}, remote_ptr, " RPC" },
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC" },
457
495
remote_size);
458
496
return buffer;
459
497
} else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508
546
}
509
547
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
510
548
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
511
- return buft_ctx->sock == rpc_ctx->sock ;
549
+ return buft_ctx->endpoint == rpc_ctx->endpoint ;
512
550
}
513
551
514
552
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521
559
/* .is_host = */ NULL ,
522
560
};
523
561
524
-
525
562
GGML_CALL static const char * ggml_backend_rpc_name (ggml_backend_t backend) {
526
563
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
527
564
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530
567
531
568
GGML_CALL static void ggml_backend_rpc_free (ggml_backend_t backend) {
532
569
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
533
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft ->context ;
534
- delete buft_ctx;
535
- delete rpc_ctx->buft ;
536
570
delete rpc_ctx;
537
571
delete backend;
538
572
}
539
573
540
574
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type (ggml_backend_t backend) {
541
575
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context ;
542
- return ctx->buft ;
576
+ return ggml_backend_rpc_buffer_type ( ctx->endpoint . c_str ()) ;
543
577
}
544
578
545
579
GGML_CALL static void ggml_backend_rpc_synchronize (ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590
624
std::vector<uint8_t > input;
591
625
serialize_graph (cgraph, input);
592
626
std::vector<uint8_t > output;
593
- bool status = send_rpc_cmd (rpc_ctx->sock , GRAPH_COMPUTE, input, output);
627
+ auto sock = get_socket (rpc_ctx->endpoint );
628
+ bool status = send_rpc_cmd (sock, GRAPH_COMPUTE, input, output);
594
629
GGML_ASSERT (status);
595
630
GGML_ASSERT (output.size () == 1 );
596
631
return (enum ggml_status)output[0 ];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624
659
/* .event_synchronize = */ NULL ,
625
660
};
626
661
627
- static std::unordered_map<std::string, ggml_backend_t > instances;
628
-
629
662
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type (const char * endpoint) {
630
- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
631
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type (backend) : nullptr ;
632
- }
633
-
634
- GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
635
- std::string endpoint_str (endpoint);
636
- if (instances.find (endpoint_str) != instances.end ()) {
637
- return instances[endpoint_str];
638
- }
639
- #ifdef _WIN32
640
- {
641
- WSADATA wsaData;
642
- int res = WSAStartup (MAKEWORD (2 , 2 ), &wsaData);
643
- if (res != 0 ) {
644
- return nullptr ;
645
- }
646
- }
647
- #endif
648
- fprintf (stderr, " Connecting to %s\n " , endpoint);
649
- std::string host;
650
- int port;
651
- if (!parse_endpoint (endpoint, host, port)) {
652
- return nullptr ;
653
- }
654
- auto sock = socket_connect (host.c_str (), port);
663
+ static std::mutex mutex;
664
+ std::lock_guard<std::mutex> lock (mutex);
665
+ // NOTE: buffer types are allocated and never freed; this is by design
666
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t > buft_map;
667
+ auto it = buft_map.find (endpoint);
668
+ if (it != buft_map.end ()) {
669
+ return it->second ;
670
+ }
671
+ auto sock = get_socket (endpoint);
655
672
if (sock == nullptr ) {
656
673
return nullptr ;
657
674
}
658
675
size_t alignment = get_alignment (sock);
659
676
size_t max_size = get_max_size (sock);
660
677
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661
- /* .sock = */ sock ,
662
- /* .name = */ " RPC" + std::to_string (sock-> fd ) ,
678
+ /* .endpoint = */ endpoint ,
679
+ /* .name = */ " RPC[ " + std::string (endpoint) + " ] " ,
663
680
/* .alignment = */ alignment,
664
- /* .max_size = */ max_size
681
+ /* .max_size = */ max_size
665
682
};
666
683
667
684
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
668
685
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
669
686
/* .context = */ buft_ctx
670
687
};
688
+ buft_map[endpoint] = buft;
689
+ return buft;
690
+ }
671
691
692
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
672
693
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673
- /* .endpoint = */ endpoint,
674
- /* .name = */ " RPC" + std::to_string (sock->fd ),
675
- /* .sock = */ sock,
676
- /* .buft = */ buft
694
+ /* .endpoint = */ endpoint,
695
+ /* .name = */ " RPC" ,
677
696
};
678
697
679
- instances[endpoint] = new ggml_backend {
698
+ ggml_backend_t backend = new ggml_backend {
680
699
/* .guid = */ ggml_backend_rpc_guid (),
681
700
/* .interface = */ ggml_backend_rpc_interface,
682
701
/* .context = */ ctx
683
702
};
684
-
685
- return instances[endpoint];
703
+ return backend;
686
704
}
687
705
688
706
GGML_API GGML_CALL bool ggml_backend_is_rpc (ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706
724
}
707
725
708
726
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory (const char * endpoint, size_t * free, size_t * total) {
709
- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
710
- if (backend == nullptr ) {
727
+ auto sock = get_socket (endpoint);
728
+ if (sock == nullptr ) {
711
729
*free = 0 ;
712
730
*total = 0 ;
713
731
return ;
714
732
}
715
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context ;
716
- get_device_memory (ctx->sock , free, total);
733
+ get_device_memory (sock, free, total);
717
734
}
718
735
719
736
// RPC server-side implementation
0 commit comments