Skip to content

[SYCL] Add implementation of kernel_bundle. Part 4 #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions sycl/include/CL/sycl/exception.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class __SYCL_EXPORT exception : public std::exception {
public:
exception() = default;

exception(std::error_code, const char *Msg)
: exception(Msg, PI_INVALID_VALUE) {}

exception(std::error_code, const std::string &Msg)
: exception(Msg, PI_INVALID_VALUE) {}

const char *what() const noexcept final;

bool has_context() const;
Expand Down Expand Up @@ -111,5 +117,39 @@ class feature_not_supported : public device_error {
using device_error::device_error;
};

enum class errc : unsigned int {
runtime = 0,
kernel = 1,
accessor = 2,
nd_range = 3,
event = 4,
kernel_argument = 5,
build = 6,
invalid = 7,
memory_allocation = 8,
platform = 9,
profiling = 10,
feature_not_supported = 11,
kernel_not_supported = 12,
backend_mismatch = 13,
};

/// Constructs an error code using e and sycl_category()
__SYCL_EXPORT std::error_code make_error_code(sycl::errc E) noexcept;

__SYCL_EXPORT const std::error_category &sycl_category() noexcept;

namespace detail {
class __SYCL_EXPORT SYCLCategory : public std::error_category {
public:
const char *name() const noexcept override { return "SYCL"; }
std::string message(int) const override { return "SYCL Error"; }
};
} // namespace detail

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)

namespace std {
template <> struct is_error_condition_enum<cl::sycl::errc> : true_type {};
} // namespace std
29 changes: 28 additions & 1 deletion sycl/include/CL/sycl/kernel_bundle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ compile(const kernel_bundle<bundle_state::input> &InputBundle,
/////////////////////////

namespace detail {
std::vector<sycl::device> find_device_intersection(
__SYCL_EXPORT std::vector<sycl::device> find_device_intersection(
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles);

__SYCL_EXPORT std::shared_ptr<detail::kernel_bundle_impl>
Expand Down Expand Up @@ -628,3 +628,30 @@ build(const kernel_bundle<bundle_state::input> &InputBundle,

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)

namespace std {
template <> struct hash<cl::sycl::kernel_id> {
size_t operator()(const cl::sycl::kernel_id &KernelID) const {
return hash<cl::sycl::shared_ptr_class<cl::sycl::detail::kernel_id_impl>>()(
cl::sycl::detail::getSyclObjImpl(KernelID));
}
};

template <cl::sycl::bundle_state State>
struct hash<cl::sycl::device_image<State>> {
size_t operator()(const cl::sycl::device_image<State> &DeviceImage) const {
return hash<
cl::sycl::shared_ptr_class<cl::sycl::detail::device_image_impl>>()(
cl::sycl::detail::getSyclObjImpl(DeviceImage));
}
};

template <cl::sycl::bundle_state State>
struct hash<cl::sycl::kernel_bundle<State>> {
size_t operator()(const cl::sycl::kernel_bundle<State> &KernelBundle) const {
return hash<
cl::sycl::shared_ptr_class<cl::sycl::detail::kernel_bundle_impl>>()(
cl::sycl::detail::getSyclObjImpl(KernelBundle));
}
};
} // namespace std
132 changes: 116 additions & 16 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,58 @@ namespace sycl {
namespace detail {

template <class T> struct LessByHash {
bool operator()(const T &LHS, const T &RHS) {
bool operator()(const T &LHS, const T &RHS) const {
return getSyclObjImpl(LHS) < getSyclObjImpl(RHS);
}
};

static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
const context &Context) {
const std::vector<device> &ContextDevices = Context.get_devices();
return std::all_of(
Devices.begin(), Devices.end(), [&ContextDevices](const device &Dev) {
return ContextDevices.end() !=
std::find(ContextDevices.begin(), ContextDevices.end(), Dev);
});
}

static bool checkAllDevicesHaveAspect(const std::vector<device> &Devices,
aspect Aspect) {
return std::all_of(Devices.begin(), Devices.end(),
[&Aspect](const device &Dev) { return Dev.has(Aspect); });
}

// The class is an impl counterpart of the sycl::kernel_bundle.
// It provides an access and utilities to manage set of sycl::device_images
// objects.
class kernel_bundle_impl {

void common_ctor_checks(bundle_state State) {
const bool AllDevicesInTheContext =
checkAllDevicesAreInContext(MDevices, MContext);
if (MDevices.empty() || !AllDevicesInTheContext)
throw sycl::exception(
make_error_code(errc::invalid),
"Not all devices are associated with the context or "
"vector of devices is empty");

if (bundle_state::input == State &&
!checkAllDevicesHaveAspect(MDevices, aspect::online_compiler))
throw sycl::exception(make_error_code(errc::invalid),
"Not all devices have aspect::online_compiler");

if (bundle_state::object == State &&
!checkAllDevicesHaveAspect(MDevices, aspect::online_linker))
throw sycl::exception(make_error_code(errc::invalid),
"Not all devices have aspect::online_linker");
}

public:
kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {

common_ctor_checks(State);

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, State);
}
Expand All @@ -54,6 +92,21 @@ class kernel_bundle_impl {
bundle_state TargetState)
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)) {

const std::vector<device> &InputBundleDevices =
getSyclObjImpl(InputBundle)->get_devices();
const bool AllDevsAssociatedWithInputBundle =
std::all_of(MDevices.begin(), MDevices.end(),
[&InputBundleDevices](const device &Dev) {
return InputBundleDevices.end() !=
std::find(InputBundleDevices.begin(),
InputBundleDevices.end(), Dev);
});
if (MDevices.empty() || !AllDevsAssociatedWithInputBundle)
throw sycl::exception(
make_error_code(errc::invalid),
"Not all devices are in the set of associated "
"devices for input bundle or vector of devices is empty");

for (const device_image_plain &DeviceImage : InputBundle) {
// Skip images which are not compatible with devices provided
if (std::none_of(
Expand Down Expand Up @@ -85,7 +138,39 @@ class kernel_bundle_impl {
kernel_bundle_impl(
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
std::vector<device> Devs, const property_list &PropList)
: MContext(ObjectBundles[0].get_context()), MDevices(std::move(Devs)) {
: MDevices(std::move(Devs)) {

if (ObjectBundles.empty())
return;

MContext = ObjectBundles[0].get_context();
for (size_t I = 1; I < ObjectBundles.size(); ++I) {
if (ObjectBundles[I].get_context() != MContext)
throw sycl::exception(
make_error_code(errc::invalid),
"Not all input bundles have the same associated context");
}

// Check if any of the devices in devs are not in the set of associated
// devices for any of the bundles in ObjectBundles
const bool AllDevsAssociatedWithInputBundles = std::all_of(
MDevices.begin(), MDevices.end(), [&ObjectBundles](const device &Dev) {
// Number of devices is expected to be small
return std::all_of(
ObjectBundles.begin(), ObjectBundles.end(),
[&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
const std::vector<device> &BundleDevices =
getSyclObjImpl(KernelBundle)->get_devices();
return BundleDevices.end() != std::find(BundleDevices.begin(),
BundleDevices.end(),
Dev);
});
});
if (MDevices.empty() || !AllDevsAssociatedWithInputBundles)
throw sycl::exception(
make_error_code(errc::invalid),
"Not all devices are in the set of associated "
"devices for input bundles or vector of devices is empty");

// TODO: Unify with c'tor for sycl::comile and sycl::build by calling
// sycl::join on vector of kernel_bundles
Expand Down Expand Up @@ -116,6 +201,10 @@ class kernel_bundle_impl {
bundle_state State)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {

// TODO: Add a check that all kernel ids are compatible with at least one
// device in Devs
common_ctor_checks(State);

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, KernelIDs, State);
}
Expand All @@ -124,24 +213,36 @@ class kernel_bundle_impl {
const DevImgSelectorImpl &Selector, bundle_state State)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {

common_ctor_checks(State);

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, Selector, State);
}

// C'tor matches sycl::join API
kernel_bundle_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles) {
if (Bundles.empty())
return;

MContext = Bundles[0]->MContext;
MDevices = Bundles[0]->MDevices;
for (size_t I = 1; I < Bundles.size(); ++I) {
if (Bundles[I]->MContext != MContext)
throw sycl::exception(
make_error_code(errc::invalid),
"Not all input bundles have the same associated context.");
if (Bundles[I]->MDevices != MDevices)
throw sycl::exception(
make_error_code(errc::invalid),
"Not all input bundles have the same set of associated devices.");
}

for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
MDevices.insert(MDevices.end(), Bundle->MDevices.begin(),
Bundle->MDevices.end());

MDeviceImages.insert(MDeviceImages.end(), Bundle->MDeviceImages.begin(),
Bundle->MDeviceImages.end());
}

std::sort(MDevices.begin(), MDevices.end(), LessByHash<device>{});
const auto DevIt = std::unique(MDevices.begin(), MDevices.end());
MDevices.erase(DevIt, MDevices.end());

std::sort(MDeviceImages.begin(), MDeviceImages.end(),
LessByHash<device_image_plain>{});
const auto DevImgIt =
Expand Down Expand Up @@ -171,14 +272,7 @@ class kernel_bundle_impl {
}
std::sort(Result.begin(), Result.end(), LessByNameComp{});

auto NewIt =
std::unique(Result.begin(), Result.end(),
[](const sycl::kernel_id &LHS, const sycl::kernel_id &RHS) {
return strcmp(LHS.get_name(), RHS.get_name()) == 0;
}

);

auto NewIt = std::unique(Result.begin(), Result.end(), EqualByNameComp{});
Result.erase(NewIt, Result.end());

return Result;
Expand All @@ -192,6 +286,12 @@ class kernel_bundle_impl {
[&KernelID](const device_image_plain &DeviceImage) {
return DeviceImage.has_kernel(KernelID);
});

if (MDeviceImages.end() == It)
throw sycl::exception(make_error_code(errc::invalid),
"The kernel bundle does not contain the kernel "
"identified by kernelId.");

const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
detail::getSyclObjImpl(*It);

Expand Down
10 changes: 9 additions & 1 deletion sycl/source/detail/kernel_id_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@ namespace detail {

// Used for sorting vector of kernel_id's
struct LessByNameComp {
bool operator()(const sycl::kernel_id &LHS, const sycl::kernel_id &RHS) {
bool operator()(const sycl::kernel_id &LHS,
const sycl::kernel_id &RHS) const {
return std::strcmp(LHS.get_name(), RHS.get_name()) < 0;
}
};

struct EqualByNameComp {
bool operator()(const sycl::kernel_id &LHS,
const sycl::kernel_id &RHS) const {
return strcmp(LHS.get_name(), RHS.get_name()) == 0;
}
};

// The class is impl counterpart for sycl::kernel_id which represent a kernel
// identificator
class kernel_id_impl {
Expand Down
22 changes: 16 additions & 6 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,13 +1394,18 @@ ProgramManager::compile(const device_image_plain &DeviceImage,
// TODO: Set spec constatns here.

// TODO: Handle zero sized Device list.
Plugin.call<PiApiKind::piProgramCompile>(
RT::PiResult Error = Plugin.call_nocheck<PiApiKind::piProgramCompile>(
ObjectImpl->get_program_ref(), /*num devices=*/Devs.size(),
PIDevices.data(),
/*options=*/nullptr,
/*num_input_headers=*/0, /*input_headers=*/nullptr,
/*header_include_names=*/nullptr,
/*pfn_notify=*/nullptr, /*user_data*/ nullptr);
if (Error != PI_SUCCESS)
throw sycl::exception(
make_error_code(errc::build),
getProgramBuildLog(ObjectImpl->get_program_ref(),
getSyclObjImpl(ObjectImpl->get_context())));

return createSyclObjFromImpl<device_image_plain>(ObjectImpl);
}
Expand All @@ -1422,19 +1427,23 @@ ProgramManager::link(const std::vector<device_image_plain> &DeviceImages,
PIDevices.push_back(getSyclObjImpl(Dev)->getHandleRef());

const context &Context = getSyclObjImpl(DeviceImages[0])->get_context();
const ContextImplPtr ContextImpl = getSyclObjImpl(Context);

const detail::plugin &Plugin = getSyclObjImpl(Context)->getPlugin();
const detail::plugin &Plugin = ContextImpl->getPlugin();

RT::PiProgram LinkedProg = nullptr;
RT::PiResult Error = Plugin.call_nocheck<PiApiKind::piProgramLink>(
getSyclObjImpl(Context)->getHandleRef(), PIDevices.size(),
PIDevices.data(),
ContextImpl->getHandleRef(), PIDevices.size(), PIDevices.data(),
/*options=*/nullptr, PIPrograms.size(), PIPrograms.data(),
/*pfn_notify=*/nullptr,
/*user_data=*/nullptr, &LinkedProg);

(void)Error;
// TODO: Add error handling
if (Error != PI_SUCCESS) {
const string_class ErrorMsg =
LinkedProg ? getProgramBuildLog(LinkedProg, ContextImpl)
: "Online link operation failed";
throw sycl::exception(make_error_code(errc::build), ErrorMsg);
}

std::vector<kernel_id> KernelIDs;
for (const device_image_plain &DeviceImage : DeviceImages) {
Expand Down Expand Up @@ -1582,6 +1591,7 @@ device_image_plain ProgramManager::build(const device_image_plain &DeviceImage,
SerializedObj SpecConsts = InputImpl->get_spec_const_blob_ref();

const RT::PiDevice PiDevice = getRawSyclObjImpl(Devs[0])->getHandleRef();
// TODO: Throw SYCL2020 style exception
auto BuildResult = getOrBuild<PiProgramT, compile_program_error>(
Cache,
std::make_pair(std::make_pair(std::move(SpecConsts), (size_t)ImgPtr),
Expand Down
9 changes: 9 additions & 0 deletions sycl/source/exception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,14 @@ context exception::get_context() const {

cl_int exception::get_cl_code() const { return MCLErr; }

const std::error_category &sycl_category() noexcept {
static const detail::SYCLCategory SYCLCategoryObj;
return SYCLCategoryObj;
}

std::error_code make_error_code(sycl::errc Err) noexcept {
return {static_cast<int>(Err), sycl_category()};
}

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
Loading