Skip to content

[SYCL][EXT] Define a new device selector that filters based on string input #2163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions sycl/include/CL/sycl/device_selector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,26 @@ class __SYCL_EXPORT host_selector : public device_selector {
int operator()(const device &dev) const override;
};

namespace ext {
namespace oneapi {

/// Selects device based on string.
///
/// \sa device
///
/// \ingroup sycl_api_dev_sel
class __SYCL_EXPORT string_selector : public device_selector {
public:
string_selector(std::string filter);
int operator()(const device &dev) const override;

private:
std::vector<std::string> mPlatforms;
std::vector<std::string> mDeviceTypes;
default_selector mRanker;
};
} // namespace oneapi
} // namespace ext

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
141 changes: 141 additions & 0 deletions sycl/source/device_selector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <CL/sycl/stl.hpp>
#include <detail/device_impl.hpp>
#include <detail/force_device.hpp>

#include <algorithm>
#include <cctype>
// 4.6.1 Device selection class

__SYCL_INLINE_NAMESPACE(cl) {
Expand Down Expand Up @@ -157,5 +160,143 @@ int host_selector::operator()(const device &dev) const {
return Score;
}

namespace detail {
std::string trim_spaces(std::string input) {
size_t LStart = input.find_first_not_of(" ");
std::string LTrimmed =
(LStart == std::string::npos) ? "" : input.substr(LStart);

size_t REnd = LTrimmed.find_last_not_of(" ");
return (REnd == std::string::npos) ? "" : LTrimmed.substr(0, REnd + 1);
}

std::vector<std::string> tokenize(std::string filter, std::string delim) {
std::vector<std::string> Tokens;
size_t Pos = 0;
std::string Input = filter;
std::string Tok;

while ((Pos = Input.find(delim)) != std::string::npos) {
Tok = Input.substr(0, Pos);
Input.erase(0, Pos + delim.length());
// Erase leading and trailing WS
Tok = trim_spaces(Tok);

if (!Tok.empty())
Tokens.push_back(Tok);
}

if (!Input.empty())
Input = trim_spaces(Input);

// Add remainder
if (!Input.empty())
Tokens.push_back(Input);

return Tokens;
}

enum TokenType { TokPlatform, TokDeviceType, TokUnknown };

TokenType parse_kind(std::string token) {
TokenType Result = TokUnknown;

if (token == "platform")
Result = TokPlatform;
if (token == "type")
Result = TokDeviceType;

return Result;
}

std::string strip_kind(std::string token) {
std::string Prefix = "=";
size_t Loc = token.find(Prefix);

if (Loc == std::string::npos)
return token;

// move past the '='
Loc = Loc + 1;

return token.substr(Loc);
}

bool match(std::string input, std::string pattern) {
return (input.find(pattern) != std::string::npos);
}
} // namespace detail

namespace ext {
namespace oneapi {
string_selector::string_selector(std::string filter) {
std::transform(filter.begin(), filter.end(), filter.begin(),
[](unsigned char c) { return std::tolower(c); });
std::vector<std::string> Tokens = detail::tokenize(filter, ";");

for (auto Tok : Tokens) {
if (Tok.find("=") == std::string::npos)
continue;

std::vector<std::string> Req = detail::tokenize(Tok, "=");
if (Req.size() != 2) {
throw runtime_error("Invalid string_selector input! Please specify the "
"desired platform or device type after '='.",
PI_INVALID_VALUE);
}
detail::TokenType TTy = detail::parse_kind(Req[0]);
std::vector<std::string> SubTokens = detail::tokenize(Req[1], ",");

if (TTy == detail::TokPlatform) {
mPlatforms.insert(mPlatforms.end(), SubTokens.begin(), SubTokens.end());
} else if (TTy == detail::TokDeviceType) {
mDeviceTypes.insert(mDeviceTypes.end(), SubTokens.begin(),
SubTokens.end());
} else {
throw runtime_error("Invalid string_selector input! Please specify at "
"least one platform or device type filter.",
PI_INVALID_VALUE);
}
}
}

int string_selector::operator()(const device &dev) const {
int Score = REJECT_DEVICE_SCORE;

std::string CPU = "cpu";
std::string GPU = "gpu";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there plans to support accelerator(s) eventually as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or even FPGA :-), since it is a string and does not need anything predefined in SYCL specification...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this can be extended with more types in the future as necessary.

std::string PlatformName =
dev.get_platform().get_info<info::platform::name>();
std::transform(PlatformName.begin(), PlatformName.end(), PlatformName.begin(),
[](unsigned char c) { return std::tolower(c); });

if (mPlatforms.empty() && mDeviceTypes.empty()) {
Score = mRanker(dev);
} else if (!mPlatforms.empty() && mDeviceTypes.empty()) {
for (const auto &Plat : mPlatforms) {
if (detail::match(PlatformName, Plat))
Score = mRanker(dev);
}
} else if (mPlatforms.empty() && !mDeviceTypes.empty()) {
for (const auto &DT : mDeviceTypes) {
if ((detail::match(DT, CPU) && dev.is_cpu()) ||
(detail::match(DT, GPU) && dev.is_gpu()))
Score = mRanker(dev);
}
} else {
for (const auto &Plat : mPlatforms) {
for (const auto &DT : mDeviceTypes) {
if (detail::match(PlatformName, Plat) &&
((detail::match(DT, CPU) && dev.is_cpu()) ||
(detail::match(DT, GPU) && dev.is_gpu())))
Score = mRanker(dev);
}
}
}

return Score;
}
} // namespace oneapi
} // namespace ext
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
15 changes: 9 additions & 6 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3608,6 +3608,8 @@ _ZN2cl4sycl20aligned_alloc_deviceEmmRKNS0_5queueE
_ZN2cl4sycl20aligned_alloc_deviceEmmRKNS0_6deviceERKNS0_7contextE
_ZN2cl4sycl20aligned_alloc_sharedEmmRKNS0_5queueE
_ZN2cl4sycl20aligned_alloc_sharedEmmRKNS0_6deviceERKNS0_7contextE
_ZN2cl4sycl3ext6oneapi15string_selectorC1ENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
_ZN2cl4sycl3ext6oneapi15string_selectorC2ENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
_ZN2cl4sycl4freeEPvRKNS0_5queueE
_ZN2cl4sycl4freeEPvRKNS0_7contextE
_ZN2cl4sycl5event13get_wait_listEv
Expand All @@ -3621,8 +3623,8 @@ _ZN2cl4sycl5eventC1Ev
_ZN2cl4sycl5eventC2EP9_cl_eventRKNS0_7contextE
_ZN2cl4sycl5eventC2ESt10shared_ptrINS0_6detail10event_implEE
_ZN2cl4sycl5eventC2Ev
_ZN2cl4sycl5intel6detail17reduComputeWGSizeEmmRm
_ZN2cl4sycl5intel6detail16reduGetMaxWGSizeESt10shared_ptrINS0_6detail10queue_implEEm
_ZN2cl4sycl5intel6detail17reduComputeWGSizeEmmRm
_ZN2cl4sycl5queue10mem_adviseEPKvm14_pi_mem_advice
_ZN2cl4sycl5queue10wait_proxyERKNS0_6detail13code_locationE
_ZN2cl4sycl5queue11submit_implESt8functionIFvRNS0_7handlerEEERKNS0_6detail13code_locationE
Expand Down Expand Up @@ -3728,15 +3730,15 @@ _ZN2cl4sycl6detail12sampler_implD2Ev
_ZN2cl4sycl6detail12split_stringERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEc
_ZN2cl4sycl6detail13MemoryManager12prefetch_usmEPvSt10shared_ptrINS1_10queue_implEEmSt6vectorIP9_pi_eventSaIS9_EERS9_
_ZN2cl4sycl6detail13MemoryManager13releaseMemObjESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvS8_
_ZN2cl4sycl6detail13MemoryManager19allocateImageObjectESt10shared_ptrINS1_12context_implEEPvbRK14_pi_image_descRK16_pi_image_formatRKNS0_13property_listE
_ZN2cl4sycl6detail13MemoryManager16allocateMemImageESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvbmRK14_pi_image_descRK16_pi_image_formatRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event
_ZN2cl4sycl6detail13MemoryManager24allocateInteropMemObjectESt10shared_ptrINS1_12context_implEEPvRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event
_ZN2cl4sycl6detail13MemoryManager20allocateBufferObjectESt10shared_ptrINS1_12context_implEEPvbmRKNS0_13property_listE
_ZN2cl4sycl6detail13MemoryManager17allocateMemBufferESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvbmRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event
_ZN2cl4sycl6detail13MemoryManager18allocateHostMemoryEPNS1_11SYCLMemObjIEPvbmRKNS0_13property_listE
_ZN2cl4sycl6detail13MemoryManager19wrapIntoImageBufferESt10shared_ptrINS1_12context_implEEPvPNS1_11SYCLMemObjIE
_ZN2cl4sycl6detail13MemoryManager18releaseImageBufferESt10shared_ptrINS1_12context_implEEPv
_ZN2cl4sycl6detail13MemoryManager17allocateMemBufferESt10shared_ptrINS1_12context_implEEPNS1_11SYCLMemObjIEPvbmRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event
_ZN2cl4sycl6detail13MemoryManager19allocateImageObjectESt10shared_ptrINS1_12context_implEEPvbRK14_pi_image_descRK16_pi_image_formatRKNS0_13property_listE
_ZN2cl4sycl6detail13MemoryManager19wrapIntoImageBufferESt10shared_ptrINS1_12context_implEEPvPNS1_11SYCLMemObjIE
_ZN2cl4sycl6detail13MemoryManager20allocateBufferObjectESt10shared_ptrINS1_12context_implEEPvbmRKNS0_13property_listE
_ZN2cl4sycl6detail13MemoryManager20allocateMemSubBufferESt10shared_ptrINS1_12context_implEEPvmmNS0_5rangeILi3EEESt6vectorIS3_INS1_10event_implEESaISB_EERP9_pi_event
_ZN2cl4sycl6detail13MemoryManager24allocateInteropMemObjectESt10shared_ptrINS1_12context_implEEPvRKS3_INS1_10event_implEERKS5_RKNS0_13property_listERP9_pi_event
_ZN2cl4sycl6detail13MemoryManager3mapEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEENS0_6access4modeEjNS0_5rangeILi3EEESC_NS0_2idILi3EEEjSt6vectorIP9_pi_eventSaISH_EERSH_
_ZN2cl4sycl6detail13MemoryManager4copyEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEEjNS0_5rangeILi3EEESA_NS0_2idILi3EEEjS5_S8_jSA_SA_SC_jSt6vectorIP9_pi_eventSaISF_EERSF_
_ZN2cl4sycl6detail13MemoryManager4fillEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEEmPKcjNS0_5rangeILi3EEESC_NS0_2idILi3EEEjSt6vectorIP9_pi_eventSaISH_EERSH_
Expand Down Expand Up @@ -3863,6 +3865,7 @@ _ZNK2cl4sycl15interop_handler12GetNativeMemEPNS0_6detail16AccessorImplHostE
_ZNK2cl4sycl15interop_handler14GetNativeQueueEv
_ZNK2cl4sycl16default_selectorclERKNS0_6deviceE
_ZNK2cl4sycl20accelerator_selectorclERKNS0_6deviceE
_ZNK2cl4sycl3ext6oneapi15string_selectorclERKNS0_6deviceE
_ZNK2cl4sycl5event18get_profiling_infoILNS0_4info15event_profilingE4737EEENS3_12param_traitsIS4_XT_EE11return_typeEv
_ZNK2cl4sycl5event18get_profiling_infoILNS0_4info15event_profilingE4738EEENS3_12param_traitsIS4_XT_EE11return_typeEv
_ZNK2cl4sycl5event18get_profiling_infoILNS0_4info15event_profilingE4739EEENS3_12param_traitsIS4_XT_EE11return_typeEv
Expand Down
127 changes: 127 additions & 0 deletions sycl/test/string-selector/select.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t1.out
// RUN: %t1.out

// REQUIRES: cpu, gpu, opencl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks below are written in quite safe manner, i.e. the test should pass with even only one of CPU or GPU is available.
Can this line be: "// REQUIRES: opencl" ? or be simply removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to check that it can do different things in the same program.


//==------------------- select.cpp - string_selector test ------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <CL/sycl.hpp>

using namespace cl::sycl;
using namespace cl::sycl::ext::oneapi;

int main() {
queue q1(string_selector("platform=Intel"));
std::cout << q1.get_device().get_info<info::device::name>() << std::endl;
std::cout << q1.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q1.get_device().is_host()) {
assert(q1.get_device().get_platform().get_info<info::platform::name>().find(
"Intel") != std::string::npos &&
"Intel platform not found!");
}

queue q2(string_selector("type=cpu"));
std::cout << q2.get_device().get_info<info::device::name>() << std::endl;
std::cout << q2.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q2.get_device().is_host()) {
assert(q2.get_device().is_cpu() && "Device is not CPU!");
}

queue q3(string_selector("type=gpu"));
std::cout << q3.get_device().get_info<info::device::name>() << std::endl;
std::cout << q3.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q3.get_device().is_host()) {
assert(q3.get_device().is_gpu() && "Device is not GPU!");
}

queue q4(string_selector("type=cpu,gpu"));
std::cout << q4.get_device().get_info<info::device::name>() << std::endl;
std::cout << q4.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q4.get_device().is_host()) {
assert((q4.get_device().is_gpu() || q4.get_device().is_cpu()) &&
"Device is not GPU or CPU!");
}

queue q5(string_selector("platform=OpenCL;type=cpu"));
std::cout << q5.get_device().get_info<info::device::name>() << std::endl;
std::cout << q5.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q5.get_device().is_host()) {
assert(q5.get_device().get_platform().get_info<info::platform::name>().find(
"OpenCL") != std::string::npos &&
"OpenCL platform not found!");
assert(q5.get_device().is_cpu() && "Device is not CPU!");
}

queue q6(string_selector("platform=OpenCL;type=cpu;"));
std::cout << q6.get_device().get_info<info::device::name>() << std::endl;
std::cout << q6.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q6.get_device().is_host()) {
assert(q6.get_device().get_platform().get_info<info::platform::name>().find(
"OpenCL") != std::string::npos &&
"OpenCL platform not found!");
assert(q6.get_device().is_cpu() && "Device is not CPU!");
}

queue q7(string_selector(";platform=OpenCL;type=cpu;"));
std::cout << q7.get_device().get_info<info::device::name>() << std::endl;
std::cout << q7.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q7.get_device().is_host()) {
assert(q7.get_device().get_platform().get_info<info::platform::name>().find(
"OpenCL") != std::string::npos &&
"OpenCL platform not found!");
assert(q7.get_device().is_cpu() && "Device is not CPU!");
}

queue q8(string_selector("; platform = OpenCL ; type= cpu,;"));
std::cout << q8.get_device().get_info<info::device::name>() << std::endl;
std::cout << q8.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;
if (!q8.get_device().is_host()) {
assert(q8.get_device().get_platform().get_info<info::platform::name>().find(
"OpenCL") != std::string::npos &&
"OpenCL platform not found!");
assert(q8.get_device().is_cpu() && "Device is not CPU!");
}

queue q9(string_selector("; ,,,, ; "));
std::cout << q9.get_device().get_info<info::device::name>() << std::endl;
std::cout << q9.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;

queue q10(string_selector(" "));
std::cout << q10.get_device().get_info<info::device::name>() << std::endl;
std::cout << q10.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;

queue q11(string_selector(""));
std::cout << q11.get_device().get_info<info::device::name>() << std::endl;
std::cout << q11.get_device().get_platform().get_info<info::platform::name>()
<< std::endl;

try {
queue q12(string_selector("plat=Intel"));
} catch (runtime_error e) {
std::cout << "TEST PASS: " << e.what() << std::endl;
}

try {
queue q13(string_selector("plat_type=Foo"));
} catch (runtime_error e) {
std::cout << "TEST PASS: " << e.what() << std::endl;
}

return 0;
}