1111
1212#include < webgpu/webgpu_cpp.h>
1313
14+ #include < atomic>
1415#include < condition_variable>
1516#include < cstring>
1617#include < iostream>
6566# define WEBGPU_WAIT_ANY_TIMEOUT_MS UINT64_MAX
6667#else
6768# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8
68- # define WEBGPU_WAIT_ANY_TIMEOUT_MS 1
69+ # define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
6970#endif
7071
7172/* Constants */
7273
7374#define WEBGPU_MUL_MAT_WG_SIZE 256
7475#define WEBGPU_NUM_PARAM_BUFS 32
76+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
77+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
7578#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
7679#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
7780#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
@@ -107,6 +110,11 @@ struct webgpu_pool_bufs {
107110 wgpu::Buffer dev_buf;
108111};
109112
113+ // The futures to wait on for a single queue submission
114+ struct webgpu_submission_futures {
115+ std::vector<wgpu::FutureWaitInfo> futures;
116+ };
117+
110118// Holds a pool of parameter buffers for WebGPU operations
111119struct webgpu_buf_pool {
112120 std::vector<webgpu_pool_bufs> free;
@@ -243,6 +251,7 @@ struct webgpu_context_struct {
243251 uint32_t max_wg_size_x;
244252
245253 std::recursive_mutex mutex;
254+ std::atomic_int inflight_threads = 0 ;
246255
247256 webgpu_buf_pool param_buf_pool;
248257 webgpu_buf_pool set_rows_error_buf_pool;
@@ -365,12 +374,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
365374/* * WebGPU Actions */
366375
367376// Wait for the queue to finish processing all submitted work
368- static void ggml_backend_webgpu_wait (webgpu_context & ctx,
369- std::vector<std::vector<wgpu::FutureWaitInfo>> & futures,
370- uint64_t timeout_ms = UINT64_MAX) {
377+ static void ggml_backend_webgpu_wait (webgpu_context & ctx,
378+ std::vector<webgpu_submission_futures> & futures,
379+ uint64_t timeout_ms = UINT64_MAX) {
380+ // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
381+ // inflight_max may be 0, meaning that we must wait on all futures.
382+ int inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / ctx->inflight_threads ;
383+ while (futures.size () >= inflight_max && futures.size () > 0 ) {
384+ ctx->instance .WaitAny (futures[0 ].futures .size (), futures[0 ].futures .data (), UINT64_MAX);
385+ futures.erase (futures.begin ());
386+ }
371387 size_t i = 0 ;
372388 while (i < futures.size ()) {
373- auto waitStatus = ctx->instance .WaitAny (futures[i].size (), futures[i].data (), timeout_ms);
389+ auto waitStatus = ctx->instance .WaitAny (futures[i].futures . size (), futures[i]. futures .data (), timeout_ms);
374390 switch (waitStatus) {
375391 case wgpu::WaitStatus::Success:
376392 futures.erase (futures.begin () + i);
@@ -424,8 +440,7 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
424440}
425441#endif
426442
427- static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit (webgpu_context ctx,
428- std::vector<webgpu_command> commands) {
443+ static webgpu_submission_futures ggml_backend_webgpu_submit (webgpu_context ctx, std::vector<webgpu_command> commands) {
429444 std::vector<wgpu::CommandBuffer> command_buffers;
430445 std::vector<webgpu_pool_bufs> params_bufs;
431446 std::vector<webgpu_pool_bufs> set_rows_error_bufs;
@@ -484,9 +499,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
484499 if (status != wgpu::MapAsyncStatus::Success) {
485500 GGML_LOG_ERROR (" ggml_webgpu: Failed to map timestamp buffer: %s\n " , std::string (message).c_str ());
486501 } else {
487- const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf .GetConstMappedRange ();
502+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf .GetConstMappedRange ();
488503 // WebGPU timestamps are in ns; convert to ms
489- double elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
504+ double elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
490505 ctx->shader_gpu_time_ms [label] += elapsed_ms;
491506 // We can't unmap in here due to WebGPU reentrancy limitations.
492507 ctx->timestamp_query_buf_pool .free_bufs ({ ts_bufs });
@@ -495,7 +510,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
495510 futures.push_back ({ f });
496511 }
497512#endif
498- return futures;
513+ return { futures } ;
499514}
500515
501516static webgpu_command ggml_backend_webgpu_build (webgpu_context & ctx,
@@ -588,7 +603,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
588603 uint32_t wg_x = ((size + 3 ) + bytes_per_wg - 1 ) / bytes_per_wg;
589604
590605 webgpu_command command = ggml_backend_webgpu_build (ctx, ctx->memset_pipeline , params, entries, wg_x);
591- std::vector<std::vector<wgpu::FutureWaitInfo> > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
606+ std::vector<webgpu_submission_futures > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
592607 ggml_backend_webgpu_wait (ctx, futures);
593608}
594609
@@ -1255,25 +1270,31 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
12551270
12561271 WEBGPU_CPU_PROFILE_TOTAL_START (graph_compute);
12571272
1258- std::vector<webgpu_command> commands;
1259- std::vector<std::vector<wgpu::FutureWaitInfo>> futures;
1273+ ctx->inflight_threads ++;
1274+
1275+ std::vector<webgpu_command> commands;
1276+ std::vector<webgpu_submission_futures> futures;
12601277 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
12611278 if (auto cmd = ggml_webgpu_encode_node (ctx, cgraph->nodes [i])) {
12621279 commands.push_back (*cmd);
12631280 }
1264- if (commands.size () >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
1265- std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit (ctx, commands);
1266- // check if previous futures have finished
1281+ // compute the batch size based on the number of inflight threads
1282+ int batch_size = std::min (std::max (1 , WEBGPU_NUM_PARAM_BUFS / ctx->inflight_threads ),
1283+ WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1284+ if (commands.size () >= batch_size) {
1285+ futures.push_back (ggml_backend_webgpu_submit (ctx, commands));
1286+ // Process events and check for completed submissions
1287+ ctx->instance .ProcessEvents ();
12671288 ggml_backend_webgpu_wait (ctx, futures, WEBGPU_WAIT_ANY_TIMEOUT_MS);
1268- futures.push_back ({ new_futures });
12691289 commands.clear ();
12701290 }
12711291 }
12721292 if (!commands.empty ()) {
1273- std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit (ctx, commands);
1274- futures.push_back ({ new_futures } );
1293+ webgpu_submission_futures new_futures = ggml_backend_webgpu_submit (ctx, commands);
1294+ futures.push_back (new_futures);
12751295 }
12761296 ggml_backend_webgpu_wait (ctx, futures);
1297+ ctx->inflight_threads --;
12771298 WEBGPU_CPU_PROFILE_TOTAL_END (graph_compute, ctx);
12781299 return GGML_STATUS_SUCCESS;
12791300}
0 commit comments