Skip to content

ggml: Add basic SET_ROWS support in WebGPU #15137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ jobs:
- name: Test
id: cmake_test
run: |
export LLAMA_SET_ROWS=0
cd build
ctest -L main --verbose --timeout 900

Expand Down Expand Up @@ -438,7 +437,6 @@ jobs:
- name: Test
id: cmake_test
run: |
export LLAMA_SET_ROWS=0
cd build
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 3600
Expand Down
237 changes: 204 additions & 33 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@
#include <vector>

#ifdef GGML_WEBGPU_DEBUG
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_DEBUG_BUF_ELEMS 32
#else
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG

/* Constants */

#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4

/* End Constants */

Expand All @@ -54,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::BufferUsage usage,
const char * label);

struct webgpu_param_bufs {
struct webgpu_pool_bufs {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
};

// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_param_buf_pool {
std::vector<webgpu_param_bufs> free;
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;

std::mutex mutex;

std::condition_variable cv;

void init(wgpu::Device device) {
for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) {
void init(wgpu::Device device,
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
wgpu::BufferUsage host_buf_usage) {
for (int i = 0; i < num_bufs; i++) {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device,
host_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
"ggml_webgpu_host_params_buf");
ggml_webgpu_create_buffer(device,
dev_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
"ggml_webgpu_dev_params_buf");
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
free.push_back({ host_buf, dev_buf });
}
}

webgpu_param_bufs alloc_bufs() {
webgpu_pool_bufs alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this] { return !free.empty(); });
webgpu_param_bufs bufs = free.back();
webgpu_pool_bufs bufs = free.back();
free.pop_back();
return bufs;
}

void free_bufs(std::vector<webgpu_param_bufs> bufs) {
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
std::lock_guard<std::mutex> lock(mutex);
free.insert(free.end(), bufs.begin(), bufs.end());
cv.notify_all();
Expand Down Expand Up @@ -121,10 +120,12 @@ struct webgpu_context_struct {

bool device_init = false;

webgpu_param_buf_pool param_buf_pool;
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;

wgpu::ComputePipeline memset_pipeline;
wgpu::ComputePipeline mul_mat_pipeline;
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline cpy_pipeline;

size_t memset_bytes_per_thread;
Expand All @@ -136,9 +137,16 @@ struct webgpu_context_struct {
std::vector<wgpu::CommandBuffer> staged_command_bufs;

// Parameter buffers associated with the staged command buffers
std::vector<webgpu_param_bufs> staged_param_bufs;
std::vector<webgpu_pool_bufs> staged_param_bufs;
// Buffers associated with set_rows operations, used to store potential errors
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;

std::vector<wgpu::FutureWaitInfo> callback_futures;

#ifdef GGML_WEBGPU_DEBUG
wgpu::Buffer debug_host_buf;
wgpu::Buffer debug_dev_buf;
#endif
};

typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
Expand Down Expand Up @@ -249,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
return;
}
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());

// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
if (ctx->staged_set_row_error_bufs.size() > 0) {
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
// Copy the error buffer to the host buffer
encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
}
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
}

ctx->staged_command_bufs.clear();
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);

// Free the staged parameter buffers once the submission completes
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
}
// Free the staged parameter buffers
// Free the staged buffers
ctx->param_buf_pool.free_bufs(staged_param_bufs);
});
ctx->callback_futures.push_back({ f });
ctx->callback_futures.push_back({ p_f });

// Check for errrors in SET_ROWS operations
for (auto & error_bufs : staged_set_row_error_bufs) {
wgpu::Future f = error_bufs.host_buf.MapAsync(
wgpu::MapMode::Read,
0,
error_bufs.host_buf.GetSize(),
wgpu::CallbackMode::AllowSpontaneous,
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data);
} else {
const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
}
});
ctx->callback_futures.push_back({ f });
}
}

static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
Expand All @@ -283,13 +326,34 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
UINT64_MAX);
}

#ifdef GGML_WEBGPU_DEBUG
// This function adds debugging information to shaders, as WebGPU does not support printing directly.
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);

ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug data:";
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
std::cout << " " << i << ": " << debug_data[i];
}
std::cout << "\n";
ctx->debug_host_buf.Unmap();
}
#endif

static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
wgpu::ComputePipeline & pipeline,
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_and_wait = false) {
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();

ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
Expand Down Expand Up @@ -429,6 +493,76 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
}

static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
// For set rows specifically, we need to check if src and idx are empty tensors.
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
return;
}

webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
error_bufs.host_buf.Unmap();
}

size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
// assumes power of 2 offset alignment
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx);
size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);

std::vector<uint32_t> params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)),
(uint32_t) (idx_misalignment / ggml_type_size(idx->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2]) };

std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_backend_webgpu_tensor_buf(src),
.offset = ggml_backend_webgpu_tensor_offset(src),
.size = ggml_nbytes(src) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(idx),
.offset = ggml_backend_webgpu_tensor_offset(idx),
.size = ggml_nbytes(idx) },
{ .binding = 2,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = ggml_backend_webgpu_tensor_offset(dst),
.size = ggml_nbytes(dst) },
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
};

size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;

std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ctx->staged_set_row_error_bufs.push_back(error_bufs);

ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
}

static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
std::vector<uint32_t> params = {
(uint32_t) dst->ne[1], // number of rows in result (M)
Expand Down Expand Up @@ -487,6 +621,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
ggml_webgpu_cpy(ctx, src0, node);
break;
}
case GGML_OP_SET_ROWS:
{
ggml_webgpu_set_rows(ctx, src0, src1, node);
break;
}
case GGML_OP_MUL_MAT:
{
ggml_webgpu_mul_mat(ctx, src0, src1, node);
Expand Down Expand Up @@ -771,6 +910,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
}

static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
}

static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
Expand Down Expand Up @@ -827,11 +974,35 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();

// Create buffer pool for shader parameters
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device);
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
WEBGPU_NUM_PARAM_BUFS,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);

ggml_webgpu_init_memset_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_set_rows_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);

#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(webgpu_ctx->device,
webgpu_ctx->debug_host_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"debug_host_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device,
webgpu_ctx->debug_dev_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
"debug_dev_buf");
#endif
webgpu_ctx->device_init = true;
}

Expand Down Expand Up @@ -882,7 +1053,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
case GGML_OP_CPY:
case GGML_OP_CPY | GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_MUL_MAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
Expand Down
Loading
Loading