diff --git a/sycl/include/CL/sycl/device_selector.hpp b/sycl/include/CL/sycl/device_selector.hpp index 76d0e9489ce84..53782b37d6976 100644 --- a/sycl/include/CL/sycl/device_selector.hpp +++ b/sycl/include/CL/sycl/device_selector.hpp @@ -87,5 +87,26 @@ class __SYCL_EXPORT host_selector : public device_selector { int operator()(const device &dev) const override; }; +namespace ext { +namespace oneapi { + +/// Selects device based on string. +/// +/// \sa device +/// +/// \ingroup sycl_api_dev_sel +class __SYCL_EXPORT string_selector : public device_selector { +public: + string_selector(std::string filter); + int operator()(const device &dev) const override; + +private: + std::vector mPlatforms; + std::vector mDeviceTypes; + default_selector mRanker; +}; +} // namespace oneapi +} // namespace ext + } // namespace sycl } // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/source/device_selector.cpp b/sycl/source/device_selector.cpp index 22b3a613467ec..6e1b292f0cfe5 100644 --- a/sycl/source/device_selector.cpp +++ b/sycl/source/device_selector.cpp @@ -13,6 +13,9 @@ #include #include #include + +#include +#include // 4.6.1 Device selection class __SYCL_INLINE_NAMESPACE(cl) { @@ -157,5 +160,143 @@ int host_selector::operator()(const device &dev) const { return Score; } +namespace detail { +std::string trim_spaces(std::string input) { + size_t LStart = input.find_first_not_of(" "); + std::string LTrimmed = + (LStart == std::string::npos) ? "" : input.substr(LStart); + + size_t REnd = LTrimmed.find_last_not_of(" "); + return (REnd == std::string::npos) ? "" : LTrimmed.substr(0, REnd + 1); +} + +std::vector tokenize(std::string filter, std::string delim) { + std::vector Tokens; + size_t Pos = 0; + std::string Input = filter; + std::string Tok; + + while ((Pos = Input.find(delim)) != std::string::npos) { + Tok = Input.substr(0, Pos); + Input.erase(0, Pos + delim.length()); + // Erase leading and trailing WS + Tok = trim_spaces(Tok); + + if (!Tok.empty()) + Tokens.push_back(Tok); + } + + if (!Input.empty()) + Input = trim_spaces(Input); + + // Add remainder + if (!Input.empty()) + Tokens.push_back(Input); + + return Tokens; +} + +enum TokenType { TokPlatform, TokDeviceType, TokUnknown }; + +TokenType parse_kind(std::string token) { + TokenType Result = TokUnknown; + + if (token == "platform") + Result = TokPlatform; + if (token == "type") + Result = TokDeviceType; + + return Result; +} + +std::string strip_kind(std::string token) { + std::string Prefix = "="; + size_t Loc = token.find(Prefix); + + if (Loc == std::string::npos) + return token; + + // move past the '=' + Loc = Loc + 1; + + return token.substr(Loc); +} + +bool match(std::string input, std::string pattern) { + return (input.find(pattern) != std::string::npos); +} +} // namespace detail + +namespace ext { +namespace oneapi { +string_selector::string_selector(std::string filter) { + std::transform(filter.begin(), filter.end(), filter.begin(), + [](unsigned char c) { return std::tolower(c); }); + std::vector Tokens = detail::tokenize(filter, ";"); + + for (auto Tok : Tokens) { + if (Tok.find("=") == std::string::npos) + continue; + + std::vector Req = detail::tokenize(Tok, "="); + if (Req.size() != 2) { + throw runtime_error("Invalid string_selector input! Please specify the " + "desired platform or device type after '='.", + PI_INVALID_VALUE); + } + detail::TokenType TTy = detail::parse_kind(Req[0]); + std::vector SubTokens = detail::tokenize(Req[1], ","); + + if (TTy == detail::TokPlatform) { + mPlatforms.insert(mPlatforms.end(), SubTokens.begin(), SubTokens.end()); + } else if (TTy == detail::TokDeviceType) { + mDeviceTypes.insert(mDeviceTypes.end(), SubTokens.begin(), + SubTokens.end()); + } else { + throw runtime_error("Invalid string_selector input! Please specify at " + "least one platform or device type filter.", + PI_INVALID_VALUE); + } + } +} + +int string_selector::operator()(const device &dev) const { + int Score = REJECT_DEVICE_SCORE; + + std::string CPU = "cpu"; + std::string GPU = "gpu"; + std::string PlatformName = + dev.get_platform().get_info(); + std::transform(PlatformName.begin(), PlatformName.end(), PlatformName.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if (mPlatforms.empty() && mDeviceTypes.empty()) { + Score = mRanker(dev); + } else if (!mPlatforms.empty() && mDeviceTypes.empty()) { + for (const auto &Plat : mPlatforms) { + if (detail::match(PlatformName, Plat)) + Score = mRanker(dev); + } + } else if (mPlatforms.empty() && !mDeviceTypes.empty()) { + for (const auto &DT : mDeviceTypes) { + if ((detail::match(DT, CPU) && dev.is_cpu()) || + (detail::match(DT, GPU) && dev.is_gpu())) + Score = mRanker(dev); + } + } else { + for (const auto &Plat : mPlatforms) { + for (const auto &DT : mDeviceTypes) { + if (detail::match(PlatformName, Plat) && + ((detail::match(DT, CPU) && dev.is_cpu()) || + (detail::match(DT, GPU) && dev.is_gpu()))) + Score = mRanker(dev); + } + } + } + + return Score; +} +} // namespace oneapi +} // namespace ext } // namespace sycl } // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index 9ff2e1195cac9..5b0d0e20c6611 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3608,6 +3608,8 @@ _ZN2cl4sycl20aligned_alloc_deviceEmmRKNS0_5queueE _ZN2cl4sycl20aligned_alloc_deviceEmmRKNS0_6deviceERKNS0_7contextE _ZN2cl4sycl20aligned_alloc_sharedEmmRKNS0_5queueE _ZN2cl4sycl20aligned_alloc_sharedEmmRKNS0_6deviceERKNS0_7contextE +_ZN2cl4sycl3ext6oneapi15string_selectorC1ENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE +_ZN2cl4sycl3ext6oneapi15string_selectorC2ENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE _ZN2cl4sycl4freeEPvRKNS0_5queueE _ZN2cl4sycl4freeEPvRKNS0_7contextE _ZN2cl4sycl5event13get_wait_listEv @@ -3621,8 +3623,8 @@ _ZN2cl4sycl5eventC1Ev _ZN2cl4sycl5eventC2EP9_cl_eventRKNS0_7contextE _ZN2cl4sycl5eventC2ESt10shared_ptrINS0_6detail10event_implEE _ZN2cl4sycl5eventC2Ev -_ZN2cl4sycl5intel6detail17reduComputeWGSizeEmmRm _ZN2cl4sycl5intel6detail16reduGetMaxWGSizeESt10shared_ptrINS0_6detail10queue_implEEm +_ZN2cl4sycl5intel6detail17reduComputeWGSizeEmmRm _ZN2cl4sycl5queue10mem_adviseEPKvm14_pi_mem_advice _ZN2cl4sycl5queue10wait_proxyERKNS0_6detail13code_locationE _ZN2cl4sycl5queue11submit_implESt8functionIFvRNS0_7handlerEEERKNS0_6detail13code_locationE @@ -3728,15 +3730,15 @@ _ZN2cl4sycl6detail12sampler_implD2Ev _ZN2cl4sycl6detail12split_stringERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEc _ZN2cl4sycl6detail13MemoryManager12prefetch_usmEPvSt10shared_ptrINS1_10queue_implEEmSt6vectorIP9_pi_eventSaIS9_EERS9_ _ZN2cl4sycl6detail13MemoryManager13releaseMemObjESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvS8_ -_ZN2cl4sycl6detail13MemoryManager19allocateImageObjectESt10shared_ptrINS1_12context_implEEPvbRK14_pi_image_descRK16_pi_image_formatRKNS0_13property_listE _ZN2cl4sycl6detail13MemoryManager16allocateMemImageESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvbmRK14_pi_image_descRK16_pi_image_formatRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event -_ZN2cl4sycl6detail13MemoryManager24allocateInteropMemObjectESt10shared_ptrINS1_12context_implEEPvRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event -_ZN2cl4sycl6detail13MemoryManager20allocateBufferObjectESt10shared_ptrINS1_12context_implEEPvbmRKNS0_13property_listE +_ZN2cl4sycl6detail13MemoryManager17allocateMemBufferESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvbmRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event _ZN2cl4sycl6detail13MemoryManager18allocateHostMemoryEPNS1_11SYCLMemObjIEPvbmRKNS0_13property_listE -_ZN2cl4sycl6detail13MemoryManager19wrapIntoImageBufferESt10shared_ptrINS1_12context_implEEPvPNS1_11SYCLMemObjIE _ZN2cl4sycl6detail13MemoryManager18releaseImageBufferESt10shared_ptrINS1_12context_implEEPv -_ZN2cl4sycl6detail13MemoryManager17allocateMemBufferESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvbmRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event +_ZN2cl4sycl6detail13MemoryManager19allocateImageObjectESt10shared_ptrINS1_12context_implEEPvbRK14_pi_image_descRK16_pi_image_formatRKNS0_13property_listE +_ZN2cl4sycl6detail13MemoryManager19wrapIntoImageBufferESt10shared_ptrINS1_12context_implEEPvPNS1_11SYCLMemObjIE +_ZN2cl4sycl6detail13MemoryManager20allocateBufferObjectESt10shared_ptrINS1_12context_implEEPvbmRKNS0_13property_listE _ZN2cl4sycl6detail13MemoryManager20allocateMemSubBufferESt10shared_ptrINS1_12context_implEEPvmmNS0_5rangeILi3EEESt6vectorIS3_INS1_10event_implEESaISB_EERP9_pi_event +_ZN2cl4sycl6detail13MemoryManager24allocateInteropMemObjectESt10shared_ptrINS1_12context_implEEPvRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event _ZN2cl4sycl6detail13MemoryManager3mapEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEENS0_6access4modeEjNS0_5rangeILi3EEESC_NS0_2idILi3EEEjSt6vectorIP9_pi_eventSaISH_EERSH_ _ZN2cl4sycl6detail13MemoryManager4copyEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEEjNS0_5rangeILi3EEESA_NS0_2idILi3EEEjS5_S8_jSA_SA_SC_jSt6vectorIP9_pi_eventSaISF_EERSF_ _ZN2cl4sycl6detail13MemoryManager4fillEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEEmPKcjNS0_5rangeILi3EEESC_NS0_2idILi3EEEjSt6vectorIP9_pi_eventSaISH_EERSH_ @@ -3863,6 +3865,7 @@ _ZNK2cl4sycl15interop_handler12GetNativeMemEPNS0_6detail16AccessorImplHostE _ZNK2cl4sycl15interop_handler14GetNativeQueueEv _ZNK2cl4sycl16default_selectorclERKNS0_6deviceE _ZNK2cl4sycl20accelerator_selectorclERKNS0_6deviceE +_ZNK2cl4sycl3ext6oneapi15string_selectorclERKNS0_6deviceE _ZNK2cl4sycl5event18get_profiling_infoILNS0_4info15event_profilingE4737EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl5event18get_profiling_infoILNS0_4info15event_profilingE4738EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl5event18get_profiling_infoILNS0_4info15event_profilingE4739EEENS3_12param_traitsIS4_XT_EE11return_typeEv diff --git a/sycl/test/string-selector/select.cpp b/sycl/test/string-selector/select.cpp new file mode 100644 index 0000000000000..f8372a7708b38 --- /dev/null +++ b/sycl/test/string-selector/select.cpp @@ -0,0 +1,127 @@ +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t1.out +// RUN: %t1.out + +// REQUIRES: cpu, gpu, opencl + +//==------------------- select.cpp - string_selector test ------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +using namespace cl::sycl; +using namespace cl::sycl::ext::oneapi; + +int main() { + queue q1(string_selector("platform=Intel")); + std::cout << q1.get_device().get_info() << std::endl; + std::cout << q1.get_device().get_platform().get_info() + << std::endl; + if (!q1.get_device().is_host()) { + assert(q1.get_device().get_platform().get_info().find( + "Intel") != std::string::npos && + "Intel platform not found!"); + } + + queue q2(string_selector("type=cpu")); + std::cout << q2.get_device().get_info() << std::endl; + std::cout << q2.get_device().get_platform().get_info() + << std::endl; + if (!q2.get_device().is_host()) { + assert(q2.get_device().is_cpu() && "Device is not CPU!"); + } + + queue q3(string_selector("type=gpu")); + std::cout << q3.get_device().get_info() << std::endl; + std::cout << q3.get_device().get_platform().get_info() + << std::endl; + if (!q3.get_device().is_host()) { + assert(q3.get_device().is_gpu() && "Device is not GPU!"); + } + + queue q4(string_selector("type=cpu,gpu")); + std::cout << q4.get_device().get_info() << std::endl; + std::cout << q4.get_device().get_platform().get_info() + << std::endl; + if (!q4.get_device().is_host()) { + assert((q4.get_device().is_gpu() || q4.get_device().is_cpu()) && + "Device is not GPU or CPU!"); + } + + queue q5(string_selector("platform=OpenCL;type=cpu")); + std::cout << q5.get_device().get_info() << std::endl; + std::cout << q5.get_device().get_platform().get_info() + << std::endl; + if (!q5.get_device().is_host()) { + assert(q5.get_device().get_platform().get_info().find( + "OpenCL") != std::string::npos && + "OpenCL platform not found!"); + assert(q5.get_device().is_cpu() && "Device is not CPU!"); + } + + queue q6(string_selector("platform=OpenCL;type=cpu;")); + std::cout << q6.get_device().get_info() << std::endl; + std::cout << q6.get_device().get_platform().get_info() + << std::endl; + if (!q6.get_device().is_host()) { + assert(q6.get_device().get_platform().get_info().find( + "OpenCL") != std::string::npos && + "OpenCL platform not found!"); + assert(q6.get_device().is_cpu() && "Device is not CPU!"); + } + + queue q7(string_selector(";platform=OpenCL;type=cpu;")); + std::cout << q7.get_device().get_info() << std::endl; + std::cout << q7.get_device().get_platform().get_info() + << std::endl; + if (!q7.get_device().is_host()) { + assert(q7.get_device().get_platform().get_info().find( + "OpenCL") != std::string::npos && + "OpenCL platform not found!"); + assert(q7.get_device().is_cpu() && "Device is not CPU!"); + } + + queue q8(string_selector("; platform = OpenCL ; type= cpu,;")); + std::cout << q8.get_device().get_info() << std::endl; + std::cout << q8.get_device().get_platform().get_info() + << std::endl; + if (!q8.get_device().is_host()) { + assert(q8.get_device().get_platform().get_info().find( + "OpenCL") != std::string::npos && + "OpenCL platform not found!"); + assert(q8.get_device().is_cpu() && "Device is not CPU!"); + } + + queue q9(string_selector("; ,,,, ; ")); + std::cout << q9.get_device().get_info() << std::endl; + std::cout << q9.get_device().get_platform().get_info() + << std::endl; + + queue q10(string_selector(" ")); + std::cout << q10.get_device().get_info() << std::endl; + std::cout << q10.get_device().get_platform().get_info() + << std::endl; + + queue q11(string_selector("")); + std::cout << q11.get_device().get_info() << std::endl; + std::cout << q11.get_device().get_platform().get_info() + << std::endl; + + try { + queue q12(string_selector("plat=Intel")); + } catch (runtime_error e) { + std::cout << "TEST PASS: " << e.what() << std::endl; + } + + try { + queue q13(string_selector("plat_type=Foo")); + } catch (runtime_error e) { + std::cout << "TEST PASS: " << e.what() << std::endl; + } + + return 0; +}