Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/windows_x64_release_build_x64_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,34 @@ jobs:
ALLOW_RELEASED_ONNX_OPSET_ONLY: '0'
DocUpdateNeeded: 'false'

- name: Run onnxruntime_provider_test with example_plugin_ep
shell: pwsh
run: |
# Note on skipped tests:
# The skipped tests are either:
# - relying on CPU EP fallback for BFloat16 which is not supported
# - testing the LayerNormalization contrib op with mixed input/output types (only supported by a few EPs)
# Some other hardcoded EP types are skipped in these tests. For a plugin EP, we skip these tests by
# specifying them in the dynamic plugin EP configuration.

$env:ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON = @"
{
"ep_library_registration_name": "example_ep",
"ep_library_path": "./example_plugin_ep.dll",
"selected_ep_name": "example_ep",
"tests_to_skip": [
"LayerNormTest.LayerNorm_BFloat16Input",
"LayerNormTest.LayerNorm_Scale_Float16Input",
"LayerNormTest.LayerNorm_Scale_Float16ScaleOutput",
"LayerNormTest.LayerNorm_Scale_Bias_Float16Input",
"LayerNormTest.LayerNorm_Scale_Bias_Float16ScaleBiasOutput"
]
}
"@

.\onnxruntime_provider_test.exe
working-directory: ${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo

- name: Validate C# native delegates
shell: cmd
run: python tools\ValidateNativeDelegateAttributes.py
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,11 @@ static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap pr

bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const {
const auto& node_provider_type = node.GetExecutionProviderType();
ORT_ENFORCE(!node_provider_type.empty(),
"Provider type for ", node.OpType(), " node with name '", node.Name(), "' is not set.");
const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type);
ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type);
ORT_ENFORCE(node_provider != nullptr,
"Unable to get provider associated with provider type '", node_provider_type, "'.");

// Same provider?
if (node_provider->Type() == provider_.Type()) {
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/test/unittest_main/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#include <algorithm>
#include <cstdlib>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#ifdef _WIN32
#include <iostream>
#include <locale>
Expand Down Expand Up @@ -35,6 +37,7 @@

#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE)
#include "test/unittest_util/test_dynamic_plugin_ep.h"
#include "test/util/include/skipping_test_listener.h"
#endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE)

std::unique_ptr<Ort::Env> ort_env;
Expand Down Expand Up @@ -107,6 +110,19 @@ extern "C" void ortenv_teardown() {
ort_env.reset();
}

static std::vector<std::unique_ptr<::testing::TestEventListener>> MakeTestEventListeners() {
std::vector<std::unique_ptr<::testing::TestEventListener>> result{};
#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE)
{
namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra;
const auto tests_to_skip = dynamic_plugin_ep_infra::GetTestsToSkip();
auto skipping_test_listener = std::make_unique<onnxruntime::test::SkippingTestListener>(tests_to_skip);
result.emplace_back(std::move(skipping_test_listener));
}
#endif
return result;
}

#ifdef USE_TENSORRT

#if defined(_MSC_VER)
Expand Down Expand Up @@ -152,6 +168,14 @@ int TEST_MAIN(int argc, char** argv) {
ortenv_setup();
::testing::InitGoogleTest(&argc, argv);

{
auto& test_listeners = ::testing::UnitTest::GetInstance()->listeners();
auto test_listeners_to_add = MakeTestEventListeners();
for (auto& test_listener_to_add : test_listeners_to_add) {
test_listeners.Append(test_listener_to_add.release());
}
}

status = RUN_ALL_TESTS();
}
ORT_CATCH(const std::exception& ex) {
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Status ParseInitializationConfig(std::string_view json_str, InitializationConfig
config.selected_ep_name = parsed_json.value<decltype(config.selected_ep_name)>("selected_ep_name", {});
config.selected_ep_device_indices =
parsed_json.value<decltype(config.selected_ep_device_indices)>("selected_ep_device_indices", {});
config.tests_to_skip = parsed_json.value<decltype(config.tests_to_skip)>("tests_to_skip", {});

config_out = std::move(config);
return Status::OK();
Expand Down Expand Up @@ -198,4 +199,12 @@ std::optional<std::string> GetEpName() {
return g_plugin_ep_infrastructure_state->ep_name;
}

std::vector<std::string> GetTestsToSkip() {
if (!IsInitialized()) {
return {};
}

return g_plugin_ep_infrastructure_state->config.tests_to_skip;
}

} // namespace onnxruntime::test::dynamic_plugin_ep_infra
7 changes: 7 additions & 0 deletions onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ struct InitializationConfig {
std::vector<size_t> selected_ep_device_indices{};

std::map<std::string, std::string> default_ep_options{};

// Specifies any tests to skip.
// Tests should be specified by full name, i.e., "<test suite name>.<test name>".
std::vector<std::string> tests_to_skip{};
};

// Parses `InitializationConfig` from JSON.
Expand Down Expand Up @@ -75,6 +79,9 @@ std::unique_ptr<IExecutionProvider> MakeEp(const logging::Logger* logger = nullp
// Gets the dynamic plugin EP name, or `std::nullopt` if uninitialized.
std::optional<std::string> GetEpName();

// Gets the list of tests to skip, or an empty list if uninitialized.
std::vector<std::string> GetTestsToSkip();

} // namespace dynamic_plugin_ep_infra
} // namespace test
} // namespace onnxruntime
33 changes: 33 additions & 0 deletions onnxruntime/test/util/include/skipping_test_listener.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>
#include <unordered_set>

#include "gsl/gsl"

#include "gtest/gtest.h"

namespace onnxruntime::test {

// A test event listener that skips the specified tests.
class SkippingTestListener : public ::testing::EmptyTestEventListener {
public:
explicit SkippingTestListener(gsl::span<const std::string> tests_to_skip)
: tests_to_skip_(tests_to_skip.begin(), tests_to_skip.end()) {
}

private:
void OnTestStart(const ::testing::TestInfo& test_info) override {
const auto full_test_name = std::string(test_info.test_suite_name()) + "." + test_info.name();
if (tests_to_skip_.find(full_test_name) != tests_to_skip_.end()) {
GTEST_SKIP();
}
}

std::unordered_set<std::string> tests_to_skip_;
};

} // namespace onnxruntime::test
Loading