Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
283bd5b
First draft (untested)
adrianlizarraga Sep 16, 2025
bc4df99
Properly discard invalid hw devices provided by the EP; Add hw device…
adrianlizarraga Sep 16, 2025
fb7441c
Add example code to example EP factory
adrianlizarraga Sep 16, 2025
5dafb50
Merge branch 'main' into adrianl/plugin-ep-specify-ort-hw-device
adrianlizarraga Oct 7, 2025
d23bd3a
Stub out default implementation of new factory function for internal eps
adrianlizarraga Oct 8, 2025
f86a4e9
Check for vendor too
adrianlizarraga Oct 9, 2025
f711df7
Formalize hardware metadata keys
adrianlizarraga Oct 9, 2025
02ced24
Remove discrete key
adrianlizarraga Oct 9, 2025
7d30fa7
Make global keys const
adrianlizarraga Oct 9, 2025
f9514a3
Add a new EP for testing
adrianlizarraga Oct 9, 2025
b553858
Rename class
adrianlizarraga Oct 9, 2025
24316d9
Add test for new test EP
adrianlizarraga Oct 10, 2025
4cae365
Merge branch 'main' into adrianl/plugin-ep-specify-ort-hw-device
adrianlizarraga Oct 22, 2025
a2cd0e8
Use ep registration name suffix to trigger virtual device creation. T…
adrianlizarraga Oct 23, 2025
66aee09
Add OrtEpFactory function to pass options from the ORT env
adrianlizarraga Oct 24, 2025
73f5a2b
Fix test
adrianlizarraga Oct 25, 2025
d87f8b8
Remove unnecessary code from test virtual EP. Add test for creating a…
adrianlizarraga Oct 27, 2025
d520595
Merge branch 'main' into adrianl/plugin-ep-specify-ort-hw-device
adrianlizarraga Oct 28, 2025
9314e03
Check OrtEpFactory.ort_version_supported before trying to call new fu…
adrianlizarraga Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1982,9 +1982,13 @@ endif()
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)

#
# example_plugin_ep
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
#
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.cc"
"${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)
Expand All @@ -1994,12 +1998,12 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG "-Xlinker -dead_strip")
elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX")
string(CONCAT ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.lds "
"-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep/example_plugin_ep_library.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
else()
set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG
"-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.def")
"-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep/example_plugin_ep_library.def")
endif()

set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS
Expand All @@ -2008,7 +2012,42 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest")
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src})

#
# example_plugin_ep_virt_gpu
#
set(onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src
"${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep.cc")
onnxruntime_add_shared_library_module(example_plugin_ep_virt_gpu ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src})
target_include_directories(example_plugin_ep_virt_gpu PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep_virt_gpu PRIVATE onnxruntime)

if(UNIX)
if (APPLE)
set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG "-Xlinker -dead_strip")
elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX")
string(CONCAT ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
else()
set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG
"-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def")
endif()

set_property(TARGET example_plugin_ep_virt_gpu APPEND_STRING PROPERTY LINK_FLAGS
${ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG})

set_target_properties(example_plugin_ep_virt_gpu PROPERTIES FOLDER "ONNXRuntimeTest")
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src})

#
# test library
#
file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h"
"${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc")

Expand Down Expand Up @@ -2041,7 +2080,7 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
TARGET onnxruntime_autoep_test
SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src}
LIBS ${onnxruntime_autoep_test_LIBS}
DEPENDS ${all_dependencies} example_plugin_ep
DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu
)
endif()

Expand Down
56 changes: 56 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,33 @@ struct OrtEpApi {
*/
ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream,
_In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream);

/** \brief Create an OrtHardwareDevice.
*
* \note Called within OrtEpFactory::GetSupportedDevices to create a new hardware device (e.g., virtual).
*
* \param[in] type The hardware device type.
* \param[in] vendor_id The hardware device's vendor identifier.
* \param[in] device_id The hardware device's identifier.
* \param[in] vendor_name The hardware device's vendor name as a null-terminated string. Copied by ORT.
* \param[in] metadata Optional OrtKeyValuePairs instance for hardware device metadata that may be queried by
* applications via OrtApi::GetEpDevices().
* Refer to onnxruntime_ep_device_ep_metadata_keys.h for common OrtHardwareDevice metadata keys.
* \param[out] hardware_device Output parameter set to the new OrtHardwareDevice instance that is created.
* Must be release with ReleaseHardwareDevice().
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(CreateHardwareDevice, _In_ OrtHardwareDeviceType type,
_In_ uint32_t vendor_id,
_In_ uint32_t device_id,
_In_ const char* vendor_name,
_In_opt_ const OrtKeyValuePairs* metadata,
_Out_ OrtHardwareDevice** hardware_device);

ORT_CLASS_RELEASE(HardwareDevice);
};

/**
Expand Down Expand Up @@ -981,6 +1008,35 @@ struct OrtEpFactory {
_In_ const OrtMemoryDevice* memory_device,
_In_opt_ const OrtKeyValuePairs* stream_options,
_Outptr_ OrtSyncStreamImpl** stream);

/** \brief Set environment options on this EP factory.
*
* Environment options can be set by ORT after calling the library's 'CreateEpFactories' function to
* create EP factories.
*
* Supported options:
* "allow_virtual_devices": Allows EP factory to specify OrtEpDevice instances that use custom
* virtual OrtHardwareDevices, which can be created via OrtEpApi::CreateHardwareDevice().
*
* A virtual OrtHardwareDevice does not represent actual hardware on the device, and is identified
* via the metadata entry "is_virtual" with a value of "1".
* Refer to onnxruntime_ep_device_ep_metadata_keys.h for well-known OrtHardwareDevice metadata keys.
*
* Allowed values:
* -# "0": Default. Creation of virtual devices is not allowed.
* -# "1": Creation of virtual devices is allowed.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] options The configuration options.
*
* \note Implementation of this function is optional.
* An EP factory should only implement this if it needs to handle any environment options.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options);
};

#ifdef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

// This file contains well-known keys for OrtEpDevice EP metadata entries.
// This file contains well-known keys for OrtEpDevice and OrtHardwareDevice metadata entries.
// It does NOT specify all available metadata keys.

// Key for the execution provider version string. This should be available for all plugin EPs.
Expand All @@ -16,3 +16,10 @@ static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compa

// Key for the execution provider library path (for dynamically loaded EPs)
static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path";

// Optional metadata key to determine if a OrtHardwareDevice represents a virtual (non-hardware) device.
// Possible values:
// - "0": OrtHardwareDevice is not virtual (i.e., actual hardware device). This is the assumed default value
// if this metadata key is not present.
// - "1": OrtHardwareDevice is virtual.
static const char* const kOrtHardwareDevice_MetadataKey_IsVirtual = "is_virtual";
25 changes: 25 additions & 0 deletions onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,29 @@ std::vector<const OrtHardwareDevice*> SortDevicesByType() {

return sorted_devices;
}

bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) {
constexpr std::string_view suffix{".virtual"};

return lib_registration_name.size() >= suffix.size() &&
lib_registration_name.compare(lib_registration_name.size() - suffix.size(),
suffix.size(), suffix) == 0;
}

Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name) {
// OrtEpFactory::SetEnvironmentOptions was added in ORT 1.24
if (factory.ort_version_supported < 24 || factory.SetEnvironmentOptions == nullptr) {
return Status::OK();
}

// We only set one option now but this can be generalized if necessary.
OrtKeyValuePairs options;
options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0");

ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &options)));

return Status::OK();
}
} // namespace

Status Environment::EpInfo::Create(std::unique_ptr<EpLibrary> library_in, std::unique_ptr<EpInfo>& out,
Expand All @@ -772,6 +795,8 @@ Status Environment::EpInfo::Create(std::unique_ptr<EpLibrary> library_in, std::u

auto& factory = *factory_ptr;

ORT_RETURN_IF_ERROR(SetEpFactoryEnvironmentOptions(factory, instance.library->RegistrationName()));

std::array<OrtEpDevice*, 8> ep_devices{nullptr};
size_t num_ep_devices = 0;
ORT_RETURN_IF_ERROR(ToStatusAndRelease(
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "core/session/plugin_ep/ep_api.h"

#include <algorithm>
#include <memory>
#include <string>
#include <vector>

#include "core/common/semver.h"
Expand Down Expand Up @@ -205,6 +207,32 @@ ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* pr
return id;
}

ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type,
_In_ uint32_t vendor_id,
_In_ uint32_t device_id,
_In_ const char* vendor_name,
_In_opt_ const OrtKeyValuePairs* metadata,
_Out_ OrtHardwareDevice** hardware_device) {
API_IMPL_BEGIN
auto device = std::make_unique<OrtHardwareDevice>();
device->type = type;
device->vendor_id = vendor_id;
device->device_id = device_id;
device->vendor = std::string(vendor_name);

if (metadata) {
device->metadata = *metadata;
}

*hardware_device = device.release();
return nullptr;
API_IMPL_END
}

ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device) {
delete device;
}

static constexpr OrtEpApi ort_ep_api = {
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
// and no functions can be removed (the implementation needs to change to return an error).
Expand All @@ -231,6 +259,9 @@ static constexpr OrtEpApi ort_ep_api = {
&OrtExecutionProviderApi::SyncStream_GetSyncId,
&OrtExecutionProviderApi::GetSyncIdForLastWaitOnSyncStream,
// End of Version 23 - DO NOT MODIFY ABOVE

&OrtExecutionProviderApi::CreateHardwareDevice,
&OrtExecutionProviderApi::ReleaseHardwareDevice,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,11 @@ ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream*
ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream);
ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream,
_In_ const OrtSyncStream* consumer_stream);
ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type,
_In_ uint32_t vendor_id,
_In_ uint32_t device_id,
_In_ const char* vendor_name,
_In_opt_ const OrtKeyValuePairs* metadata,
_Out_ OrtHardwareDevice** hardware_device);
ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device);
} // namespace OrtExecutionProviderApi
6 changes: 5 additions & 1 deletion onnxruntime/core/session/plugin_ep/ep_factory_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "core/session/plugin_ep/ep_factory_internal.h"

#include <utility>

#include "core/framework/error_code_helper.h"
#include "core/session/abi_devices.h"
#include "core/session/abi_session_options_impl.h"
Expand All @@ -13,7 +15,8 @@ namespace onnxruntime {
using Forward = ForwardToFactoryImpl<EpFactoryInternal>;

EpFactoryInternal::EpFactoryInternal(std::unique_ptr<EpFactoryInternalImpl> impl)
: impl_{std::move(impl)} {
: OrtEpFactory{}, // Ensure optional functions are default initialized to nullptr
impl_{std::move(impl)} {
ort_version_supported = ORT_API_VERSION;

OrtEpFactory::GetName = Forward::GetFactoryName;
Expand All @@ -29,6 +32,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr<EpFactoryInternalImpl> impl
OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer;
OrtEpFactory::IsStreamAware = Forward::IsStreamAware;
OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice;
OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions;
}

InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory,
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class EpFactoryInternal : public OrtEpFactory {
return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility);
}

OrtStatus* SetEnvironmentOptions(_In_ const OrtKeyValuePairs* options) noexcept {
return impl_->SetEnvironmentOptions(options);
}

// Function ORT calls to release an EP instance.
void ReleaseEp(OrtEp* /*ep*/) noexcept {
// we never create an OrtEp so we should never be trying to release one
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class EpFactoryInternalImpl {
"CreateSyncStreamForDevice is not implemented for this EP factory.");
}

virtual OrtStatus* SetEnvironmentOptions(const OrtKeyValuePairs* /*options*/) noexcept {
// Default implementation does not handle any options.
return nullptr;
}

// Function ORT calls to release an EP instance.
void ReleaseEp(OrtEp* ep);

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ struct ForwardToFactoryImpl {
return static_cast<TFactory*>(this_ptr)->CreateSyncStreamForDevice(memory_device, stream_options, stream);
}

static OrtStatus* ORT_API_CALL SetEnvironmentOptions(_In_ OrtEpFactory* this_ptr,
_In_ const OrtKeyValuePairs* options) noexcept {
return static_cast<TFactory*>(this_ptr)->SetEnvironmentOptions(options);
}

static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept {
static_cast<TFactory*>(this_ptr)->ReleaseEp(ep);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <gsl/span>

#include "example_plugin_ep_utils.h"
#include "../plugin_ep_utils.h"

class ExampleEpFactory;
struct MulKernel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include "example_plugin_ep_utils.h"
#include "../plugin_ep_utils.h"

#include <sstream>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License.
#undef ORT_API_MANUAL_INIT

#include "ep_allocator.h"
#include "example_plugin_ep_utils.h"
#include "../plugin_ep_utils.h"

#if defined(PLATFORM_WINDOWS)
#include <intrin.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include "example_plugin_ep_utils.h"
#include "../plugin_ep_utils.h"

struct ExampleDataTransfer : OrtDataTransferImpl, ApiPtrs {
ExampleDataTransfer(ApiPtrs api_ptrs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "ep_stream_support.h"

ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger)
: ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} {
: OrtEpFactory{}, ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} {
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with.
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
Expand Down Expand Up @@ -190,8 +190,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr,
// Create EP configuration from session options, if needed.
// Note: should not store a direct reference to the session options object as its lifespan is not guaranteed.
std::string ep_context_enable;
RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(factory->ort_api, *session_options,
"ep.context_enable", "0", ep_context_enable));
RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, "ep.context_enable", "0", ep_context_enable));

ExampleEp::Config config = {};
config.enable_ep_context = ep_context_enable == "1";
Expand Down
Loading
Loading