Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,53 @@

// The WebGPU EP can be enabled via the generic SessionOptionsAppendExecutionProvider method, so no direct usage of
// the provider factory is required.

#pragma once

#include <stddef.h>
#include <stdint.h>

// Define export macros without including onnxruntime_c_api.h to avoid conflicts
#if defined(_WIN32)
#ifdef ORT_DLL_EXPORT
#define ORT_WEBGPU_EXPORT __declspec(dllexport)
#else
#define ORT_WEBGPU_EXPORT __declspec(dllimport)
#endif
#define ORT_WEBGPU_API_CALL __stdcall
#elif defined(__APPLE__) || defined(__linux__)
#define ORT_WEBGPU_EXPORT __attribute__((visibility("default")))
#define ORT_WEBGPU_API_CALL
#else
#define ORT_WEBGPU_EXPORT
#define ORT_WEBGPU_API_CALL
#endif

#ifdef __cplusplus
extern "C" {
#endif

/**
* \brief Get the Dawn proc table from WebGPU EP context
* \param context_id The WebGPU context ID (0 for default context)
* \return Pointer to the Dawn proc table, or nullptr if not available
*/
ORT_WEBGPU_EXPORT const void* ORT_WEBGPU_API_CALL OrtWebGpuGetDawnProcTable(int context_id);

/**
* \brief Get the WebGPU instance from WebGPU EP context
* \param context_id The WebGPU context ID (0 for default context)
* \return Pointer to the WebGPU instance (WGPUInstance), or nullptr if not available
*/
ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetInstance(int context_id);

/**
* \brief Get the WebGPU device from WebGPU EP context
* \param context_id The WebGPU context ID (0 for default context)
* \return Pointer to the WebGPU device (WGPUDevice), or nullptr if not available
*/
ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetDevice(int context_id);

#ifdef __cplusplus
}
#endif
49 changes: 49 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/webgpu_context.h"
#include "core/providers/webgpu/buffer_manager.h"

// Define ORT_DLL_EXPORT before including webgpu_provider_factory.h to ensure
// ORT_WEBGPU_EXPORT becomes __declspec(dllexport) instead of __declspec(dllimport)
#ifndef ORT_DLL_EXPORT
#define ORT_DLL_EXPORT
#endif
#include "core/providers/webgpu/webgpu_provider_factory.h" // For ORT_WEBGPU_EXPORT macros
#include "core/providers/webgpu/webgpu_execution_provider.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/program_cache_key.h"
Expand Down Expand Up @@ -982,6 +989,9 @@
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");

if (--it->second.ref_count == 0 && !it->second.context->preserve_device_) {
// TODO: Investigate why memory leak is triggered if we don't explicitly destroy the device.

Check warning on line 992 in onnxruntime/core/providers/webgpu/webgpu_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/webgpu_context.cc:992: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// It seems that memory leak deletection is triggered before the device is destroyed.
it->second.context->Device().Destroy();
contexts_.erase(it);
}
}
Expand All @@ -1002,3 +1012,42 @@

} // namespace webgpu
} // namespace onnxruntime

// C API functions for external access
extern "C" {

ORT_WEBGPU_EXPORT const void* ORT_WEBGPU_API_CALL OrtWebGpuGetDawnProcTable(int /* context_id */) {
#if !defined(__wasm__) && !defined(BUILD_DAWN_SHARED_LIBRARY) && !defined(USE_EXTERNAL_DAWN)
try {

Check failure on line 1021 in onnxruntime/core/providers/webgpu/webgpu_context.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

the following warning is treated as an error
return &dawn::native::GetProcs();
} catch (...) {
return nullptr;
}
#else
return nullptr;
#endif
}

ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetInstance(int context_id) {
#if !defined(__wasm__)
try {
auto& context = onnxruntime::webgpu::WebGpuContextFactory::GetContext(context_id);
return context.Instance().Get();
} catch (...) {
return nullptr;
}
#else
return nullptr;
#endif
}

ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetDevice(int context_id) {
try {
auto& context = onnxruntime::webgpu::WebGpuContextFactory::GetContext(context_id);
return context.Device().Get();
} catch (...) {
return nullptr;
}
}

} // extern "C"
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class WebGpuContext final {

Status Wait(wgpu::Future f);

const wgpu::Instance& Instance() const { return instance_; }
const wgpu::Device& Device() const { return device_; }

const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ static bool HasControlflowNodes(const Graph& graph) {

static bool HasMemcpyNodes(const Graph& graph) {
for (const auto& node : graph.Nodes()) {
if (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost") {
if (node.OpType() == "MemcpyFromHost") {
return true;
}
}
Expand Down
Loading