diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 460736ff8506e..57ad1e597a205 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1982,9 +1982,13 @@ endif() if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) + + # # example_plugin_ep - file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h" - "${TEST_SRC_DIR}/autoep/library/*.cc") + # + file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.cc" + "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep PRIVATE onnxruntime) @@ -1994,12 +1998,12 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG "-Xlinker -dead_strip") elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") string(CONCAT ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG - "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.lds " + "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep/example_plugin_ep_library.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() else() set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG - "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.def") + "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep/example_plugin_ep_library.def") endif() set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS @@ -2008,7 +2012,42 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src}) + # + # example_plugin_ep_virt_gpu + # + set(onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src + "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep.cc") + onnxruntime_add_shared_library_module(example_plugin_ep_virt_gpu ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src}) + target_include_directories(example_plugin_ep_virt_gpu PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + target_link_libraries(example_plugin_ep_virt_gpu PRIVATE onnxruntime) + + if(UNIX) + if (APPLE) + set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG "-Xlinker -dead_strip") + elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + string(CONCAT ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG + "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds " + "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") + endif() + else() + set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG + "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def") + endif() + + set_property(TARGET example_plugin_ep_virt_gpu APPEND_STRING PROPERTY LINK_FLAGS + ${ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG}) + + set_target_properties(example_plugin_ep_virt_gpu PROPERTIES FOLDER "ONNXRuntimeTest") + source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src}) + + # # test library + # file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h" "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc") @@ -2041,7 +2080,7 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND TARGET onnxruntime_autoep_test SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src} LIBS ${onnxruntime_autoep_test_LIBS} - DEPENDS ${all_dependencies} example_plugin_ep + DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu ) endif() diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 975f6b453a88d..a8e94d690177f 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -465,6 +465,33 @@ struct OrtEpApi { */ ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); + + /** \brief Create an OrtHardwareDevice. + * + * \note Called within OrtEpFactory::GetSupportedDevices to create a new hardware device (e.g., virtual). + * + * \param[in] type The hardware device type. + * \param[in] vendor_id The hardware device's vendor identifier. + * \param[in] device_id The hardware device's identifier. + * \param[in] vendor_name The hardware device's vendor name as a null-terminated string. Copied by ORT. + * \param[in] metadata Optional OrtKeyValuePairs instance for hardware device metadata that may be queried by + * applications via OrtApi::GetEpDevices(). + * Refer to onnxruntime_ep_device_ep_metadata_keys.h for common OrtHardwareDevice metadata keys. + * \param[out] hardware_device Output parameter set to the new OrtHardwareDevice instance that is created. + * Must be release with ReleaseHardwareDevice(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, + _In_ uint32_t vendor_id, + _In_ uint32_t device_id, + _In_ const char* vendor_name, + _In_opt_ const OrtKeyValuePairs* metadata, + _Out_ OrtHardwareDevice** hardware_device); + + ORT_CLASS_RELEASE(HardwareDevice); }; /** @@ -981,6 +1008,35 @@ struct OrtEpFactory { _In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStreamImpl** stream); + + /** \brief Set environment options on this EP factory. + * + * Environment options can be set by ORT after calling the library's 'CreateEpFactories' function to + * create EP factories. + * + * Supported options: + * "allow_virtual_devices": Allows EP factory to specify OrtEpDevice instances that use custom + * virtual OrtHardwareDevices, which can be created via OrtEpApi::CreateHardwareDevice(). + * + * A virtual OrtHardwareDevice does not represent actual hardware on the device, and is identified + * via the metadata entry "is_virtual" with a value of "1". + * Refer to onnxruntime_ep_device_ep_metadata_keys.h for well-known OrtHardwareDevice metadata keys. + * + * Allowed values: + * -# "0": Default. Creation of virtual devices is not allowed. + * -# "1": Creation of virtual devices is allowed. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] options The configuration options. + * + * \note Implementation of this function is optional. + * An EP factory should only implement this if it needs to handle any environment options. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); }; #ifdef __cplusplus diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index bbd6a43bb7a41..f6afd84dabc5e 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -3,7 +3,7 @@ #pragma once -// This file contains well-known keys for OrtEpDevice EP metadata entries. +// This file contains well-known keys for OrtEpDevice and OrtHardwareDevice metadata entries. // It does NOT specify all available metadata keys. // Key for the execution provider version string. This should be available for all plugin EPs. @@ -16,3 +16,10 @@ static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compa // Key for the execution provider library path (for dynamically loaded EPs) static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; + +// Optional metadata key to determine if a OrtHardwareDevice represents a virtual (non-hardware) device. +// Possible values: +// - "0": OrtHardwareDevice is not virtual (i.e., actual hardware device). This is the assumed default value +// if this metadata key is not present. +// - "1": OrtHardwareDevice is virtual. +static const char* const kOrtHardwareDevice_MetadataKey_IsVirtual = "is_virtual"; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 9c40eb75780ee..cde77eeed8aa5 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -746,6 +746,29 @@ std::vector SortDevicesByType() { return sorted_devices; } + +bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { + constexpr std::string_view suffix{".virtual"}; + + return lib_registration_name.size() >= suffix.size() && + lib_registration_name.compare(lib_registration_name.size() - suffix.size(), + suffix.size(), suffix) == 0; +} + +Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name) { + // OrtEpFactory::SetEnvironmentOptions was added in ORT 1.24 + if (factory.ort_version_supported < 24 || factory.SetEnvironmentOptions == nullptr) { + return Status::OK(); + } + + // We only set one option now but this can be generalized if necessary. + OrtKeyValuePairs options; + options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0"); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &options))); + + return Status::OK(); +} } // namespace Status Environment::EpInfo::Create(std::unique_ptr library_in, std::unique_ptr& out, @@ -772,6 +795,8 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u auto& factory = *factory_ptr; + ORT_RETURN_IF_ERROR(SetEpFactoryEnvironmentOptions(factory, instance.library->RegistrationName())); + std::array ep_devices{nullptr}; size_t num_ep_devices = 0; ORT_RETURN_IF_ERROR(ToStatusAndRelease( diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 366f96e585918..7efb0a68c735d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -4,6 +4,8 @@ #include "core/session/plugin_ep/ep_api.h" #include +#include +#include #include #include "core/common/semver.h" @@ -205,6 +207,32 @@ ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* pr return id; } +ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, + _In_ uint32_t vendor_id, + _In_ uint32_t device_id, + _In_ const char* vendor_name, + _In_opt_ const OrtKeyValuePairs* metadata, + _Out_ OrtHardwareDevice** hardware_device) { + API_IMPL_BEGIN + auto device = std::make_unique(); + device->type = type; + device->vendor_id = vendor_id; + device->device_id = device_id; + device->vendor = std::string(vendor_name); + + if (metadata) { + device->metadata = *metadata; + } + + *hardware_device = device.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device) { + delete device; +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -231,6 +259,9 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::SyncStream_GetSyncId, &OrtExecutionProviderApi::GetSyncIdForLastWaitOnSyncStream, // End of Version 23 - DO NOT MODIFY ABOVE + + &OrtExecutionProviderApi::CreateHardwareDevice, + &OrtExecutionProviderApi::ReleaseHardwareDevice, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index c0dc79f3fb333..129230be4f618 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -40,4 +40,11 @@ ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); +ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, + _In_ uint32_t vendor_id, + _In_ uint32_t device_id, + _In_ const char* vendor_name, + _In_opt_ const OrtKeyValuePairs* metadata, + _Out_ OrtHardwareDevice** hardware_device); +ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index f3e30caf07e81..364bab471ddbe 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -3,6 +3,8 @@ #include "core/session/plugin_ep/ep_factory_internal.h" +#include + #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" @@ -13,7 +15,8 @@ namespace onnxruntime { using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) - : impl_{std::move(impl)} { + : OrtEpFactory{}, // Ensure optional functions are default initialized to nullptr + impl_{std::move(impl)} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; @@ -29,6 +32,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; OrtEpFactory::IsStreamAware = Forward::IsStreamAware; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; + OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 093bfce462d32..6eb83a117fb63 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -87,6 +87,10 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); } + OrtStatus* SetEnvironmentOptions(_In_ const OrtKeyValuePairs* options) noexcept { + return impl_->SetEnvironmentOptions(options); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index f29154d19c53c..de9e2d44431bf 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -83,6 +83,11 @@ class EpFactoryInternalImpl { "CreateSyncStreamForDevice is not implemented for this EP factory."); } + virtual OrtStatus* SetEnvironmentOptions(const OrtKeyValuePairs* /*options*/) noexcept { + // Default implementation does not handle any options. + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 2cceb1d08d536..65c396181f0a7 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -82,6 +82,11 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->CreateSyncStreamForDevice(memory_device, stream_options, stream); } + static OrtStatus* ORT_API_CALL SetEnvironmentOptions(_In_ OrtEpFactory* this_ptr, + _In_ const OrtKeyValuePairs* options) noexcept { + return static_cast(this_ptr)->SetEnvironmentOptions(options); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep.cc diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h similarity index 98% rename from onnxruntime/test/autoep/library/ep.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 279925a7ec3e1..7e96a523cf285 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -5,7 +5,7 @@ #include -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" class ExampleEpFactory; struct MulKernel; diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_allocator.h similarity index 99% rename from onnxruntime/test/autoep/library/ep_allocator.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_allocator.h index e46c03dfc8f14..febf8c7dbd8c1 100644 --- a/onnxruntime/test/autoep/library/ep_allocator.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_allocator.h @@ -3,7 +3,7 @@ #pragma once -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" #include diff --git a/onnxruntime/test/autoep/library/ep_arena.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_arena.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.cc diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.h similarity index 99% rename from onnxruntime/test/autoep/library/ep_arena.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.h index caa2c61db835f..c8fd1db5dc007 100644 --- a/onnxruntime/test/autoep/library/ep_arena.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.h @@ -26,7 +26,7 @@ limitations under the License. #undef ORT_API_MANUAL_INIT #include "ep_allocator.h" -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" #if defined(PLATFORM_WINDOWS) #include diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_data_transfer.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.cc diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.h similarity index 97% rename from onnxruntime/test/autoep/library/ep_data_transfer.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.h index da74d42b4affe..f1dad784ff84b 100644 --- a/onnxruntime/test/autoep/library/ep_data_transfer.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" struct ExampleDataTransfer : OrtDataTransferImpl, ApiPtrs { ExampleDataTransfer(ApiPtrs api_ptrs, diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc similarity index 98% rename from onnxruntime/test/autoep/library/ep_factory.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 4da7d722a5e0b..a2a4848ed060f 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -12,7 +12,7 @@ #include "ep_stream_support.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger) - : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { + : OrtEpFactory{}, ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; @@ -190,8 +190,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, // Create EP configuration from session options, if needed. // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. std::string ep_context_enable; - RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(factory->ort_api, *session_options, - "ep.context_enable", "0", ep_context_enable)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, "ep.context_enable", "0", ep_context_enable)); ExampleEp::Config config = {}; config.enable_ep_context = ep_context_enable == "1"; diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h similarity index 99% rename from onnxruntime/test/autoep/library/ep_factory.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 088deda1fe9d2..6f6aaf2aaa9fc 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -7,7 +7,7 @@ #include "ep_arena.h" #include "ep_data_transfer.h" -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_stream_support.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.cc diff --git a/onnxruntime/test/autoep/library/ep_stream_support.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.h similarity index 98% rename from onnxruntime/test/autoep/library/ep_stream_support.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.h index a825e5afd2250..847ed708c5ca7 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.h @@ -5,7 +5,7 @@ #include "onnxruntime_c_api.h" #include "ep_factory.h" -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" class ExampleEpFactory; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep.cc similarity index 100% rename from onnxruntime/test/autoep/library/example_plugin_ep.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep.cc diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_library.def b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.def similarity index 100% rename from onnxruntime/test/autoep/library/example_plugin_ep_library.def rename to onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.def diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_library.lds b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.lds similarity index 100% rename from onnxruntime/test/autoep/library/example_plugin_ep_library.lds rename to onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.lds diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc deleted file mode 100644 index 8b36f5f4e9a13..0000000000000 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "example_plugin_ep_utils.h" - -#include - -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& /* ort_api */, const OrtSessionOptions& session_options, - const char* config_key, const std::string& default_val, - /*out*/ std::string& config_val) { - try { - Ort::ConstSessionOptions sess_opt{&session_options}; - config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } - - return nullptr; -} - -void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { - result = false; - - auto type_info = value_info.TypeInfo(); - ONNXType onnx_type = type_info.GetONNXType(); - if (onnx_type != ONNX_TYPE_TENSOR) { - return; - } - - auto type_shape = type_info.GetTensorTypeAndShapeInfo(); - ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return; - } - result = true; -} - -std::optional> GetTensorShape(Ort::ConstValueInfo value_info) { - const auto type_info = value_info.TypeInfo(); - const auto onnx_type = type_info.GetONNXType(); - if (onnx_type != ONNX_TYPE_TENSOR) { - return std::nullopt; - } - - const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); - return type_shape.GetShape(); -} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc new file mode 100644 index 0000000000000..99a580f5577f7 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ep_factory.h" +#include "../plugin_ep_utils.h" + +/// +/// Empty (not implemented) computation functor for a compiled Add. +/// This EP only supports a virtual GPU device that cannot run inference, but can create a compiled model. +/// +struct VirtualCompiledAdd { + VirtualCompiledAdd(const OrtApi& ort_api, const OrtLogger& logger) : ort_api(ort_api), logger(logger) {} + + OrtStatus* Compute(OrtKernelContext* /*kernel_ctx*/) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "VirtualCompiledAdd::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + return ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "EP only supports a virtual GPU that cannot run ops."); + } + + const OrtApi& ort_api; + const OrtLogger& logger; +}; + +/// +/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +struct AddNodeComputeInfo : OrtNodeComputeInfo { + explicit AddNodeComputeInfo(EpVirtualGpu& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + EpVirtualGpu& ep; +}; + +EpVirtualGpu::EpVirtualGpu(EpFactoryVirtualGpu& factory, const EpVirtualGpu::Config& config, const OrtLogger& logger) + : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized + factory_{factory}, + config_{config}, + ort_api_{factory.GetOrtApi()}, + ep_api_{factory.GetEpApi()}, + model_editor_api_{factory.GetModelEditorApi()}, + name_{factory.GetEpName()}, + logger_{logger} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + + auto status = ort_api_.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("EpVirtualGpu has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + // ignore status for now + (void)status; +} + +EpVirtualGpu::~EpVirtualGpu() = default; + +/*static*/ +const char* ORT_API_CALL EpVirtualGpu ::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL EpVirtualGpu::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + try { + EpVirtualGpu* ep = static_cast(this_ptr); + + Ort::ConstGraph graph{ort_graph}; + std::vector nodes = graph.GetNodes(); + + if (nodes.empty()) { + return nullptr; // No nodes to process + } + + std::vector supported_nodes; + + for (const auto& node : nodes) { + auto op_type = node.GetOperatorType(); + + if (op_type == "Add") { + supported_nodes.push_back(node); // Only support a single Add for now. + break; + } + } + + if (supported_nodes.empty()) { + return nullptr; + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + node_fusion_options.drop_constant_initializers = false; + + RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpVirtualGpu::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** ort_graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { + try { + if (count != 1) { + Ort::Status status("Expected to compile a single graph", ORT_EP_FAIL); + return status.release(); + } + + EpVirtualGpu* ep = static_cast(this_ptr); + + Ort::ConstGraph graph{ort_graphs[0]}; + + std::vector nodes = graph.GetNodes(); + if (nodes.size() != 1) { + Ort::Status status("Expected to compile a single Add node", ORT_EP_FAIL); + return status.release(); + } + + auto node_op_type = nodes[0].GetOperatorType(); + if (node_op_type != "Add") { + Ort::Status status("Expected to compile a single Add node", ORT_EP_FAIL); + return status.release(); + } + + // Now we know we're compiling a single Add node. Create a computation kernel. + Ort::ConstNode fused_node{fused_nodes[0]}; + auto ep_name = fused_node.GetEpName(); + if (ep_name != ep->name_) { + Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); + return status.release(); + } + + // Associate the name of the fused node with our VirtualCompiledAdd. + auto fused_node_name = fused_node.GetName(); + ep->compiled_subgraphs_.emplace(std::move(fused_node_name), + std::make_unique(ep->ort_api_, ep->logger_)); + + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); + + // Create EpContext nodes for the fused nodes we compiled (if enabled by user via session options). + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +// Creates EPContext nodes from the given fused nodes. +// This is an example implementation that can be used to generate an EPContext model. However, this example EP +// cannot currently run the EPContext model. +OrtStatus* EpVirtualGpu::CreateEpContextNodes(gsl::span fused_nodes, + /*out*/ gsl::span ep_context_nodes) { + try { + assert(fused_nodes.size() == ep_context_nodes.size()); + + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](gsl::span value_infos, + std::vector& result) { + std::vector value_names; + value_names.reserve(value_infos.size()); + + for (const auto vi : value_infos) { + value_names.push_back(vi.GetName()); + } + + result = std::move(value_names); + }; + + // Create an "EPContext" node for every fused node. + for (size_t i = 0; i < fused_nodes.size(); ++i) { + Ort::ConstNode fused_node{fused_nodes[i]}; + auto fused_node_name = fused_node.GetName(); + + std::vector fused_node_inputs = fused_node.GetInputs(); + std::vector fused_node_outputs = fused_node.GetOutputs(); + + std::vector input_names; + std::vector output_names; + + collect_input_output_names(fused_node_inputs, /*out*/ input_names); + collect_input_output_names(fused_node_outputs, /*out*/ output_names); + + int64_t is_main_context = (i == 0); + int64_t embed_mode = 1; + + // Create node attributes. The CreateNode() function copies the attributes. + std::array attributes = {}; + std::string ep_ctx = "binary_data"; + attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), + ORT_OP_ATTR_STRING); + + attributes[1] = Ort::OpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT); + attributes[2] = Ort::OpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT); + attributes[3] = Ort::OpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING); + attributes[4] = Ort::OpAttr("partition_name", fused_node_name.data(), static_cast(fused_node_name.size()), + ORT_OP_ATTR_STRING); + + attributes[5] = Ort::OpAttr("source", this->name_.data(), static_cast(this->name_.size()), + ORT_OP_ATTR_STRING); + + std::vector c_input_names; + std::transform(input_names.begin(), input_names.end(), std::back_inserter(c_input_names), + [](const std::string& s) { return s.c_str(); }); + std::vector c_output_names; + std::transform(output_names.begin(), output_names.end(), std::back_inserter(c_output_names), + [](const std::string& s) { return s.c_str(); }); + + OrtOpAttr** op_attrs = reinterpret_cast(attributes.data()); + RETURN_IF_ERROR(model_editor_api_.CreateNode("EPContext", "com.microsoft", fused_node_name.c_str(), + c_input_names.data(), c_input_names.size(), + c_output_names.data(), c_output_names.size(), + op_attrs, attributes.size(), + &ep_context_nodes[i])); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL EpVirtualGpu::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) noexcept { + (void)this_ptr; + for (size_t i = 0; i < num_node_compute_infos; i++) { + delete node_compute_infos[i]; + } +} + +// +// Implementation of AddNodeComputeInfo +// +AddNodeComputeInfo::AddNodeComputeInfo(EpVirtualGpu& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* AddNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + EpVirtualGpu& ep = node_compute_info->ep; + + std::string fused_node_name = ep.GetEpApi().NodeComputeContext_NodeName(compute_context); + auto subgraph_it = ep.GetCompiledSubgraphs().find(fused_node_name); + if (subgraph_it == ep.GetCompiledSubgraphs().end()) { + std::string message = "Unable to get compiled subgraph for fused node with name " + fused_node_name; + return ep.GetOrtApi().CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + VirtualCompiledAdd& add_impl = *subgraph_it->second; + *compute_state = &add_impl; + return nullptr; +} + +OrtStatus* AddNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + (void)this_ptr; + VirtualCompiledAdd& add_impl = *reinterpret_cast(compute_state); + return add_impl.Compute(kernel_context); +} + +void AddNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + VirtualCompiledAdd& add_impl = *reinterpret_cast(compute_state); + (void)add_impl; + // Do nothing for this example. +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.h new file mode 100644 index 0000000000000..da7bb05d79e62 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +class EpFactoryVirtualGpu; +struct VirtualCompiledAdd; + +/// +/// Example EP for a virtual GPU OrtHardwareDevice that was created by the EP factory itself (not ORT). +/// Can only compile/execute a single Add node. Only used to test that an EP can provide additional hardware devices. +/// +class EpVirtualGpu : public OrtEp { + public: + struct Config { + bool enable_ep_context = false; + // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) + }; + + EpVirtualGpu(EpFactoryVirtualGpu& factory, const Config& config, const OrtLogger& logger); + ~EpVirtualGpu(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + + std::unordered_map>& GetCompiledSubgraphs() { + return compiled_subgraphs_; + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) noexcept; + + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + /*out*/ gsl::span ep_context_nodes); + + EpFactoryVirtualGpu& factory_; + Config config_{}; + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtModelEditorApi& model_editor_api_; + std::string name_; + const OrtLogger& logger_; + std::unordered_map> compiled_subgraphs_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc new file mode 100644 index 0000000000000..da8c2b3ed7c6b --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_factory.h" + +#include + +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + +#include "ep.h" +#include "../plugin_ep_utils.h" + +EpFactoryVirtualGpu::EpFactoryVirtualGpu(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtModelEditorApi& model_editor_api, const OrtLogger& default_logger) + : OrtEpFactory{}, + ort_api_(ort_api), + ep_api_(ep_api), + model_editor_api_(model_editor_api), + allow_virtual_devices_{false}, + default_logger_{default_logger} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + + SetEnvironmentOptions = SetEnvironmentOptionsImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; +} + +EpFactoryVirtualGpu::~EpFactoryVirtualGpu() { + if (virtual_hw_device_ != nullptr) { + ep_api_.ReleaseHardwareDevice(virtual_hw_device_); + } +} + +/*static*/ +const char* ORT_API_CALL EpFactoryVirtualGpu::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_name_.c_str(); +} + +/*static*/ +const char* ORT_API_CALL EpFactoryVirtualGpu::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_.c_str(); +} + +/*static*/ +uint32_t ORT_API_CALL EpFactoryVirtualGpu::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id_; +} + +/*static*/ +const char* ORT_API_CALL EpFactoryVirtualGpu::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::SetEnvironmentOptionsImpl(OrtEpFactory* this_ptr, + const OrtKeyValuePairs* options) noexcept { + auto* factory = static_cast(this_ptr); + const char* value = factory->ort_api_.GetKeyValue(options, "allow_virtual_devices"); + + if (value != nullptr) { + factory->allow_virtual_devices_ = strcmp(value, "1") == 0; + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + size_t /*num_devices*/, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + num_ep_devices = 0; + + // Create a virtual OrtHardwareDevice if application indicated it is allowed (e.g., for cross-compiling). + // This example EP creates a virtual GPU OrtHardwareDevice and adds a new OrtEpDevice that uses the virtual GPU. + if (factory->allow_virtual_devices_ && num_ep_devices < max_ep_devices) { + OrtKeyValuePairs* hw_metadata = nullptr; + factory->ort_api_.CreateKeyValuePairs(&hw_metadata); + factory->ort_api_.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_IsVirtual, "1"); + + auto* status = factory->ep_api_.CreateHardwareDevice(OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, + factory->vendor_id_, + /*device_id*/ 0, + factory->vendor_.c_str(), + hw_metadata, + &factory->virtual_hw_device_); + factory->ort_api_.ReleaseKeyValuePairs(hw_metadata); // Release since ORT makes a copy. + + if (status != nullptr) { + return status; + } + + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api_.CreateKeyValuePairs(&ep_metadata); + factory->ort_api_.CreateKeyValuePairs(&ep_options); + + // made up example metadata values. + factory->ort_api_.AddKeyValuePair(ep_metadata, "some_metadata", "1"); + factory->ort_api_.AddKeyValuePair(ep_options, "compile_optimization", "O3"); + + OrtEpDevice* virtual_ep_device = nullptr; + status = factory->ort_api_.GetEpApi()->CreateEpDevice(factory, factory->virtual_hw_device_, ep_metadata, + ep_options, &virtual_ep_device); + + factory->ort_api_.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api_.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + + ep_devices[num_ep_devices++] = virtual_ep_device; + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + // we only registered for GPU and only expected to be selected for one GPU + return factory->ort_api_.CreateStatus(ORT_INVALID_ARGUMENT, + "EpFactoryVirtualGpu only supports selection for one device."); + } + + std::string ep_context_enable; + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, "ep.context_enable", "0", ep_context_enable)); + + EpVirtualGpu::Config config = {}; + config.enable_ep_context = ep_context_enable == "1"; + + auto actual_ep = std::make_unique(*factory, config, *logger); + + *ep = actual_ep.release(); + return nullptr; +} + +/*static*/ +void ORT_API_CALL EpFactoryVirtualGpu::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + EpVirtualGpu* dummy_ep = static_cast(ep); + delete dummy_ep; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateAllocatorImpl(OrtEpFactory* /*this_ptr*/, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + // Don't support custom allocators in this example for simplicity. A GPU EP would normally support allocators. + *allocator = nullptr; + return nullptr; +} + +/*static*/ +void ORT_API_CALL EpFactoryVirtualGpu::ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, + OrtAllocator* /*allocator*/) noexcept { + // Do nothing. +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + // Don't support data transfer in this example for simplicity. A GPU EP would normally support it. + *data_transfer = nullptr; + return nullptr; +} + +/*static*/ +bool ORT_API_CALL EpFactoryVirtualGpu::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateSyncStreamForDeviceImpl(OrtEpFactory* /*this_ptr*/, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + // Don't support sync streams in this example. A GPU EP would normally support it. + *stream = nullptr; + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h new file mode 100644 index 0000000000000..2357b3676aa79 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +/// +/// EP factory that creates an OrtEp instance that supports a virtual GPU OrtHardwareDevice +/// created by the factory itself (not ORT). +/// +class EpFactoryVirtualGpu : public OrtEpFactory { + public: + EpFactoryVirtualGpu(const OrtApi& ort_api, const OrtEpApi& ep_api, const OrtModelEditorApi& model_editor_api, + const OrtLogger& default_logger); + ~EpFactoryVirtualGpu(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + const OrtModelEditorApi& GetModelEditorApi() const { return model_editor_api_; } + const std::string& GetEpName() const { return ep_name_; } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + static OrtStatus* ORT_API_CALL SetEnvironmentOptionsImpl(OrtEpFactory* this_ptr, + const OrtKeyValuePairs* options) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtModelEditorApi& model_editor_api_; + bool allow_virtual_devices_{false}; + const OrtLogger& default_logger_; + OrtHardwareDevice* virtual_hw_device_{}; + const std::string ep_name_{"EpVirtualGpu"}; + const std::string vendor_{"Contoso2"}; // EP vendor name + const uint32_t vendor_id_{0xB358}; // EP vendor ID + const std::string ep_version_{"0.1.0"}; // EP version +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def new file mode 100644 index 0000000000000..e9481d0d60b28 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def @@ -0,0 +1,5 @@ +LIBRARY "example_plugin_ep_virt_gpu.dll" +EXPORTS + CreateEpFactories @1 + ReleaseEpFactory @2 + diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds new file mode 100644 index 0000000000000..a6d2ef09a7b16 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds @@ -0,0 +1,7 @@ +VERS_1.0.0 { + global: + CreateEpFactories; + ReleaseEpFactory; + local: + *; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc new file mode 100644 index 0000000000000..1e438e156828d --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "ep_factory.h" + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + + // Manual init for the C++ API + Ort::InitApi(ort_api); + + std::unique_ptr factory = std::make_unique(*ort_api, *ep_api, *model_editor_api, + *default_logger); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} + +} // extern "C" diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/plugin_ep_utils.h similarity index 75% rename from onnxruntime/test/autoep/library/example_plugin_ep_utils.h rename to onnxruntime/test/autoep/library/plugin_ep_utils.h index decc89251dc7b..2024c5185b0d6 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/plugin_ep_utils.h @@ -104,12 +104,46 @@ struct FloatInitializer { }; // Returns an entry in the session option configurations, or a default value if not present. -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, - const char* config_key, const std::string& default_val, - /*out*/ std::string& config_val); +inline OrtStatus* GetSessionConfigEntryOrDefault(const OrtSessionOptions& session_options, + const char* config_key, const std::string& default_val, + /*out*/ std::string& config_val) { + try { + Ort::ConstSessionOptions sess_opt{&session_options}; + config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } + + return nullptr; +} // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. -void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); +inline void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { + result = false; + + auto type_info = value_info.TypeInfo(); + ONNXType onnx_type = type_info.GetONNXType(); + if (onnx_type != ONNX_TYPE_TENSOR) { + return; + } + + auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + return; + } + result = true; +} // Gets the tensor shape from `value_info`. Returns std::nullopt if `value_info` is not a tensor. -std::optional> GetTensorShape(Ort::ConstValueInfo value_info); +inline std::optional> GetTensorShape(Ort::ConstValueInfo value_info) { + const auto type_info = value_info.TypeInfo(); + const auto onnx_type = type_info.GetONNXType(); + if (onnx_type != ONNX_TYPE_TENSOR) { + return std::nullopt; + } + + const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + return type_shape.GetShape(); +} diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 88b522eb10dca..3c73237708828 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -61,7 +61,7 @@ struct DummyAllocator : OrtAllocator { // validate CreateSharedAllocator allows adding an arena to the shared allocator TEST(SharedAllocators, AddArenaToSharedAllocator) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); Ort::ConstEpDevice example_ep_device{example_ep.get()}; diff --git a/onnxruntime/test/autoep/test_autoep_utils.cc b/onnxruntime/test/autoep/test_autoep_utils.cc index 7045ccca2f576..d8404e5161e0a 100644 --- a/onnxruntime/test/autoep/test_autoep_utils.cc +++ b/onnxruntime/test/autoep/test_autoep_utils.cc @@ -15,7 +15,28 @@ namespace onnxruntime { namespace test { -Utils::ExamplePluginInfo Utils::example_ep_info; +Utils::ExamplePluginInfo::ExamplePluginInfo(const ORTCHAR_T* lib_path, const char* reg_name, const char* ep_name) + : library_path(lib_path), registration_name(reg_name), ep_name(ep_name) {} + +const Utils::ExamplePluginInfo Utils::example_ep_info( +#if _WIN32 + ORT_TSTR("example_plugin_ep.dll"), +#else + ORT_TSTR("libexample_plugin_ep.so"), +#endif + // The example_plugin_ep always uses the registration name as the EP name. + "example_ep", + "example_ep"); + +const Utils::ExamplePluginInfo Utils::example_ep_virt_gpu_info( +#if _WIN32 + ORT_TSTR("example_plugin_ep_virt_gpu.dll"), +#else + "libexample_plugin_ep_virt_gpu.so", +#endif + "example_plugin_ep_virt_gpu.virtual", // Ends in ".virtual" to allow creation of virtual devices. + // This EP's name is hardcoded to the following + "EpVirtualGpu"); void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { const OrtApi& c_api = Ort::GetApi(); @@ -36,18 +57,19 @@ void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& } } -void Utils::RegisterAndGetExampleEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& registered_ep) { +void Utils::RegisterAndGetExampleEp(Ort::Env& env, const ExamplePluginInfo& ep_info, + RegisteredEpDeviceUniquePtr& registered_ep) { const OrtApi& c_api = Ort::GetApi(); // this should load the library and create OrtEpDevice ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(env, - example_ep_info.registration_name.c_str(), - example_ep_info.library_path.c_str())); + ep_info.registration_name.c_str(), + ep_info.library_path.c_str())); const OrtEpDevice* example_ep = nullptr; - GetEp(env, example_ep_info.registration_name, example_ep); + GetEp(env, ep_info.ep_name, example_ep); ASSERT_NE(example_ep, nullptr); - registered_ep = RegisteredEpDeviceUniquePtr(example_ep, [&env, c_api](const OrtEpDevice* /*ep*/) { - c_api.UnregisterExecutionProviderLibrary(env, example_ep_info.registration_name.c_str()); + registered_ep = RegisteredEpDeviceUniquePtr(example_ep, [&env, &ep_info, c_api](const OrtEpDevice* /*ep*/) { + c_api.UnregisterExecutionProviderLibrary(env, ep_info.registration_name.c_str()); }); } diff --git a/onnxruntime/test/autoep/test_autoep_utils.h b/onnxruntime/test/autoep/test_autoep_utils.h index 2dd7b5f0428e2..f6b5e3623505f 100644 --- a/onnxruntime/test/autoep/test_autoep_utils.h +++ b/onnxruntime/test/autoep/test_autoep_utils.h @@ -15,23 +15,23 @@ using RegisteredEpDeviceUniquePtr = std::unique_ptr ep_options; + + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + // Compile the model. + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + // Make sure the compiled model was generated. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc index 88c2e320990e1..9a0219b57092c 100644 --- a/onnxruntime/test/autoep/test_registration.cc +++ b/onnxruntime/test/autoep/test_registration.cc @@ -23,6 +23,7 @@ namespace test { TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { const std::filesystem::path& library_path = Utils::example_ep_info.library_path; const std::string& registration_name = Utils::example_ep_info.registration_name; + const std::string& ep_name = Utils::example_ep_info.ep_name; const OrtApi* c_api = &Ort::GetApi(); // this should load the library and create OrtEpDevice @@ -35,10 +36,8 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); // should be one device for the example EP auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, - [®istration_name, &c_api](const OrtEpDevice* device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return c_api->EpDevice_EpName(device) == registration_name; + [&ep_name, &c_api](const OrtEpDevice* device) { + return c_api->EpDevice_EpName(device) == ep_name; }); ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; @@ -50,6 +49,7 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { const std::filesystem::path& library_path = Utils::example_ep_info.library_path; const std::string& registration_name = Utils::example_ep_info.registration_name; + const std::string& ep_name = Utils::example_ep_info.ep_name; // this should load the library and create OrtEpDevice ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); @@ -58,14 +58,13 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { // should be one device for the example EP auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), - [®istration_name](Ort::ConstEpDevice& device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return device.EpName() == registration_name; + [&ep_name](Ort::ConstEpDevice& device) { + return device.EpName() == ep_name; }); ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; - // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + // test all the C++ getters. + // expected values are from \onnxruntime\test\autoep\library\example_plugin_ep\*.cc ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); auto metadata = test_ep_device->EpMetadata(); @@ -89,6 +88,86 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); } +// Test loading example_plugin_ep_virt_gpu and its associated OrtEpDevice/OrtHardwareDevice. +// This EP creates a new OrtHardwareDevice instance that represents a virtual GPU and gives to ORT. +TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { + const std::filesystem::path& library_path = Utils::example_ep_virt_gpu_info.library_path; + const std::string& registration_name = "example_plugin_ep_virt_gpu"; + const std::string& ep_name = Utils::example_ep_virt_gpu_info.ep_name; + + auto get_plugin_ep_devices = [&ep_name]() -> std::vector { + std::vector all_ep_devices = ort_env->GetEpDevices(); + std::vector ep_devices; + + std::copy_if(all_ep_devices.begin(), all_ep_devices.end(), std::back_inserter(ep_devices), + [&ep_name](Ort::ConstEpDevice& device) { + return device.EpName() == ep_name; + }); + + return ep_devices; + }; + + auto is_hw_device_virtual = [](Ort::ConstHardwareDevice hw_device) -> bool { + std::unordered_map metadata_entries = hw_device.Metadata().GetKeyValuePairs(); + auto iter = metadata_entries.find(kOrtHardwareDevice_MetadataKey_IsVirtual); + + if (iter == metadata_entries.end()) { + return false; + } + + return iter->second == "1"; + }; + + // Test getting EP's supported OrtEpDevices. Do not allow virtual devices. + // The EP should not return any OrtEpDevice instances. + { + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + // Find ep devices for this EP. Should not get any. + std::vector ep_devices = get_plugin_ep_devices(); + ASSERT_EQ(ep_devices.size(), 0); + + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + + // Test getting EP's supported OrtEpDevices, but ALLOW virtual devices. + // The EP should return a OrtEpDevice for a virtual GPU. + { + // Use a registration name ending with ".virtual" to indicate to the EP library (factory) that creating virtual + // devices is allowed. + std::string registration_name_for_virtual_devices = registration_name + ".virtual"; + ort_env->RegisterExecutionProviderLibrary(registration_name_for_virtual_devices.c_str(), library_path.c_str()); + + // Find ep devices for this EP. Should get a virtual gpu. + std::vector ep_devices = get_plugin_ep_devices(); + ASSERT_EQ(ep_devices.size(), 1); + + auto virt_gpu_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [](Ort::ConstEpDevice& ep_device) { + return ep_device.Device().Type() == OrtHardwareDeviceType_GPU; + }); + + ASSERT_TRUE(is_hw_device_virtual(virt_gpu_ep_device->Device())); + + // test metadata and provider options attached to the virtual OrtEpDevice. + // expected values are from \onnxruntime\test\autoep\library\example_plugin_ep_virt_gpu\*.cc + ASSERT_STREQ(virt_gpu_ep_device->EpVendor(), "Contoso2"); + + auto metadata = virt_gpu_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); + ASSERT_STREQ(metadata.GetValue("some_metadata"), "1"); + + auto options = virt_gpu_ep_device->EpOptions(); + ASSERT_STREQ(options.GetValue("compile_optimization"), "O3"); + + // Check the virtual GPU hw device info. + ASSERT_EQ(virt_gpu_ep_device->Device().VendorId(), 0xB358); + ASSERT_EQ(virt_gpu_ep_device->Device().DeviceId(), 0); + ASSERT_STREQ(virt_gpu_ep_device->Device().Vendor(), virt_gpu_ep_device->EpVendor()); + + ort_env->UnregisterExecutionProviderLibrary(registration_name_for_virtual_devices.c_str()); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/autoep/test_selection.cc b/onnxruntime/test/autoep/test_selection.cc index 72f39be917f90..106a9d474c2b2 100644 --- a/onnxruntime/test/autoep/test_selection.cc +++ b/onnxruntime/test/autoep/test_selection.cc @@ -13,6 +13,7 @@ #include "core/session/abi_key_value_pairs.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test_allocator.h" #include "test/autoep/test_autoep_utils.h" @@ -35,6 +36,12 @@ void DefaultDeviceSelection(const std::string& ep_name, std::vectorEpDevice_EpName(device) == ep_name) { + const auto* hw_device = c_api->EpDevice_Device(device); + const OrtKeyValuePairs* hw_kvps = c_api->HardwareDevice_Metadata(hw_device); + + const char* is_virtual = c_api->GetKeyValue(hw_kvps, kOrtHardwareDevice_MetadataKey_IsVirtual); + ASSERT_TRUE(is_virtual == nullptr || strcmp(is_virtual, "0") == 0); + devices.push_back(device); break; } @@ -193,6 +200,9 @@ TEST(AutoEpSelection, DmlEP) { const auto* device = c_api->EpDevice_Device(ep_device); const OrtKeyValuePairs* kvps = c_api->HardwareDevice_Metadata(device); + const char* is_virtual = c_api->GetKeyValue(kvps, kOrtHardwareDevice_MetadataKey_IsVirtual); + ASSERT_TRUE(is_virtual == nullptr || strcmp(is_virtual, "0") == 0); + if (devices.empty()) { // add the first device devices.push_back(ep_device);