6565# define WEBGPU_WAIT_ANY_TIMEOUT_MS UINT64_MAX
6666#else
6767# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8
68- # define WEBGPU_WAIT_ANY_TIMEOUT_MS 1
68+ # define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
6969#endif
7070
7171/* Constants */
7272
7373#define WEBGPU_MUL_MAT_WG_SIZE 256
7474#define WEBGPU_NUM_PARAM_BUFS 32
75+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
76+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
7577#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
7678#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
7779#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
@@ -107,6 +109,11 @@ struct webgpu_pool_bufs {
107109 wgpu::Buffer dev_buf;
108110};
109111
112+ // The futures to wait on for a single queue submission
113+ struct webgpu_submission_futures {
114+ std::vector<wgpu::FutureWaitInfo> futures;
115+ };
116+
110117// Holds a pool of parameter buffers for WebGPU operations
111118struct webgpu_buf_pool {
112119 std::vector<webgpu_pool_bufs> free;
@@ -243,6 +250,7 @@ struct webgpu_context_struct {
243250 uint32_t max_wg_size_x;
244251
245252 std::recursive_mutex mutex;
253+ std::atomic_int inflight_threads = 0 ;
246254
247255 webgpu_buf_pool param_buf_pool;
248256 webgpu_buf_pool set_rows_error_buf_pool;
@@ -365,12 +373,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
365373/* * WebGPU Actions */
366374
367375// Wait for the queue to finish processing all submitted work
368- static void ggml_backend_webgpu_wait (webgpu_context & ctx,
369- std::vector<std::vector<wgpu::FutureWaitInfo>> & futures,
370- uint64_t timeout_ms = UINT64_MAX) {
376+ static void ggml_backend_webgpu_wait (webgpu_context & ctx,
377+ std::vector<webgpu_submission_futures> & futures,
378+ uint64_t timeout_ms = UINT64_MAX) {
379+ // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
380+ // inflight_max may be 0, meaning that we must wait on all futures.
381+ int inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / ctx->inflight_threads ;
382+ while (futures.size () >= inflight_max && futures.size () > 0 ) {
383+ ctx->instance .WaitAny (futures[0 ].futures .size (), futures[0 ].futures .data (), UINT64_MAX);
384+ futures.erase (futures.begin ());
385+ }
371386 size_t i = 0 ;
372387 while (i < futures.size ()) {
373- auto waitStatus = ctx->instance .WaitAny (futures[i].size (), futures[i].data (), timeout_ms);
388+ auto waitStatus = ctx->instance .WaitAny (futures[i].futures . size (), futures[i]. futures .data (), timeout_ms);
374389 switch (waitStatus) {
375390 case wgpu::WaitStatus::Success:
376391 futures.erase (futures.begin () + i);
@@ -424,8 +439,7 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
424439}
425440#endif
426441
427- static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit (webgpu_context ctx,
428- std::vector<webgpu_command> commands) {
442+ static webgpu_submission_futures ggml_backend_webgpu_submit (webgpu_context ctx, std::vector<webgpu_command> commands) {
429443 std::vector<wgpu::CommandBuffer> command_buffers;
430444 std::vector<webgpu_pool_bufs> params_bufs;
431445 std::vector<webgpu_pool_bufs> set_rows_error_bufs;
@@ -484,9 +498,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
484498 if (status != wgpu::MapAsyncStatus::Success) {
485499 GGML_LOG_ERROR (" ggml_webgpu: Failed to map timestamp buffer: %s\n " , std::string (message).c_str ());
486500 } else {
487- const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf .GetConstMappedRange ();
501+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf .GetConstMappedRange ();
488502 // WebGPU timestamps are in ns; convert to ms
489- double elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
503+ double elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
490504 ctx->shader_gpu_time_ms [label] += elapsed_ms;
491505 // We can't unmap in here due to WebGPU reentrancy limitations.
492506 ctx->timestamp_query_buf_pool .free_bufs ({ ts_bufs });
@@ -495,7 +509,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
495509 futures.push_back ({ f });
496510 }
497511#endif
498- return futures;
512+ return { futures } ;
499513}
500514
501515static webgpu_command ggml_backend_webgpu_build (webgpu_context & ctx,
@@ -588,7 +602,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
588602 uint32_t wg_x = ((size + 3 ) + bytes_per_wg - 1 ) / bytes_per_wg;
589603
590604 webgpu_command command = ggml_backend_webgpu_build (ctx, ctx->memset_pipeline , params, entries, wg_x);
591- std::vector<std::vector<wgpu::FutureWaitInfo> > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
605+ std::vector<webgpu_submission_futures > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
592606 ggml_backend_webgpu_wait (ctx, futures);
593607}
594608
@@ -1255,25 +1269,31 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
12551269
12561270 WEBGPU_CPU_PROFILE_TOTAL_START (graph_compute);
12571271
1258- std::vector<webgpu_command> commands;
1259- std::vector<std::vector<wgpu::FutureWaitInfo>> futures;
1272+ ctx->inflight_threads ++;
1273+
1274+ std::vector<webgpu_command> commands;
1275+ std::vector<webgpu_submission_futures> futures;
12601276 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
12611277 if (auto cmd = ggml_webgpu_encode_node (ctx, cgraph->nodes [i])) {
12621278 commands.push_back (*cmd);
12631279 }
1264- if (commands.size () >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
1265- std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit (ctx, commands);
1266- // check if previous futures have finished
1280+ // compute the batch size based on the number of inflight threads
1281+ int batch_size = std::min (std::max (1 , WEBGPU_NUM_PARAM_BUFS / ctx->inflight_threads ),
1282+ WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1283+ if (commands.size () >= batch_size) {
1284+ futures.push_back (ggml_backend_webgpu_submit (ctx, commands));
1285+ // Process events and check for completed submissions
1286+ ctx->instance .ProcessEvents ();
12671287 ggml_backend_webgpu_wait (ctx, futures, WEBGPU_WAIT_ANY_TIMEOUT_MS);
1268- futures.push_back ({ new_futures });
12691288 commands.clear ();
12701289 }
12711290 }
12721291 if (!commands.empty ()) {
1273- std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit (ctx, commands);
1274- futures.push_back ({ new_futures } );
1292+ webgpu_submission_futures new_futures = ggml_backend_webgpu_submit (ctx, commands);
1293+ futures.push_back (new_futures);
12751294 }
12761295 ggml_backend_webgpu_wait (ctx, futures);
1296+ ctx->inflight_threads --;
12771297 WEBGPU_CPU_PROFILE_TOTAL_END (graph_compute, ctx);
12781298 return GGML_STATUS_SUCCESS;
12791299}
0 commit comments