diff --git a/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h b/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h index 0b45b847d651f..b54c8e17f46a6 100644 --- a/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h +++ b/include/onnxruntime/core/providers/webgpu/webgpu_provider_factory.h @@ -12,3 +12,53 @@ // The WebGPU EP can be enabled via the generic SessionOptionsAppendExecutionProvider method, so no direct usage of // the provider factory is required. + +#pragma once + +#include +#include + +// Define export macros without including onnxruntime_c_api.h to avoid conflicts +#if defined(_WIN32) +#ifdef ORT_DLL_EXPORT +#define ORT_WEBGPU_EXPORT __declspec(dllexport) +#else +#define ORT_WEBGPU_EXPORT __declspec(dllimport) +#endif +#define ORT_WEBGPU_API_CALL __stdcall +#elif defined(__APPLE__) || defined(__linux__) +#define ORT_WEBGPU_EXPORT __attribute__((visibility("default"))) +#define ORT_WEBGPU_API_CALL +#else +#define ORT_WEBGPU_EXPORT +#define ORT_WEBGPU_API_CALL +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \brief Get the Dawn proc table from WebGPU EP context + * \param context_id The WebGPU context ID (0 for default context) + * \return Pointer to the Dawn proc table, or nullptr if not available + */ +ORT_WEBGPU_EXPORT const void* ORT_WEBGPU_API_CALL OrtWebGpuGetDawnProcTable(int context_id); + +/** + * \brief Get the WebGPU instance from WebGPU EP context + * \param context_id The WebGPU context ID (0 for default context) + * \return Pointer to the WebGPU instance (WGPUInstance), or nullptr if not available + */ +ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetInstance(int context_id); + +/** + * \brief Get the WebGPU device from WebGPU EP context + * \param context_id The WebGPU context ID (0 for default context) + * \return Pointer to the WebGPU device (WGPUDevice), or nullptr if not available + */ +ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetDevice(int context_id); + +#ifdef __cplusplus +} +#endif diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 0e4004db35b10..de2084b74b95c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -28,6 +28,13 @@ #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/buffer_manager.h" + +// Define ORT_DLL_EXPORT before including webgpu_provider_factory.h to ensure +// ORT_WEBGPU_EXPORT becomes __declspec(dllexport) instead of __declspec(dllimport) +#ifndef ORT_DLL_EXPORT +#define ORT_DLL_EXPORT +#endif +#include "core/providers/webgpu/webgpu_provider_factory.h" // For ORT_WEBGPU_EXPORT macros #include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/program_cache_key.h" @@ -982,6 +989,9 @@ void WebGpuContextFactory::ReleaseContext(int context_id) { ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); if (--it->second.ref_count == 0 && !it->second.context->preserve_device_) { + // TODO: Investigate why memory leak is triggered if we don't explicitly destroy the device. + // It seems that memory leak deletection is triggered before the device is destroyed. + it->second.context->Device().Destroy(); contexts_.erase(it); } } @@ -1002,3 +1012,42 @@ WGPUDevice GetDevice(int context_id) { } // namespace webgpu } // namespace onnxruntime + +// C API functions for external access +extern "C" { + +ORT_WEBGPU_EXPORT const void* ORT_WEBGPU_API_CALL OrtWebGpuGetDawnProcTable(int /* context_id */) { +#if !defined(__wasm__) && !defined(BUILD_DAWN_SHARED_LIBRARY) && !defined(USE_EXTERNAL_DAWN) + try { + return &dawn::native::GetProcs(); + } catch (...) { + return nullptr; + } +#else + return nullptr; +#endif +} + +ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetInstance(int context_id) { +#if !defined(__wasm__) + try { + auto& context = onnxruntime::webgpu::WebGpuContextFactory::GetContext(context_id); + return context.Instance().Get(); + } catch (...) { + return nullptr; + } +#else + return nullptr; +#endif +} + +ORT_WEBGPU_EXPORT void* ORT_WEBGPU_API_CALL OrtWebGpuGetDevice(int context_id) { + try { + auto& context = onnxruntime::webgpu::WebGpuContextFactory::GetContext(context_id); + return context.Device().Get(); + } catch (...) { + return nullptr; + } +} + +} // extern "C" diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index f1bebc0d52738..067f6b8cffd9b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -85,6 +85,7 @@ class WebGpuContext final { Status Wait(wgpu::Future f); + const wgpu::Instance& Instance() const { return instance_; } const wgpu::Device& Device() const { return device_; } const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e3291cdce62c5..8c74f5a6db531 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -145,7 +145,7 @@ static bool HasControlflowNodes(const Graph& graph) { static bool HasMemcpyNodes(const Graph& graph) { for (const auto& node : graph.Nodes()) { - if (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost") { + if (node.OpType() == "MemcpyFromHost") { return true; } }