Skip to content

[SYCL] Support image dependencies in kernel bundles #16228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 4, 2024
Merged
1 change: 1 addition & 0 deletions sycl/source/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class RTDeviceBinaryImage {
ConstIterator begin() const { return ConstIterator(Begin); }
ConstIterator end() const { return ConstIterator(End); }
size_t size() const { return std::distance(begin(), end()); }
bool empty() const { return begin() == end(); }
friend class RTDeviceBinaryImage;
bool isAvailable() const { return !(Begin == nullptr); }

Expand Down
141 changes: 87 additions & 54 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class kernel_bundle_impl {

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, State);
fillUniqueDeviceImages();
}

// Interop constructor used by make_kernel
Expand All @@ -103,7 +104,8 @@ class kernel_bundle_impl {
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
device_image_plain &DevImage)
: kernel_bundle_impl(Ctx, Devs) {
MDeviceImages.push_back(DevImage);
MDeviceImages.emplace_back(DevImage);
MUniqueDeviceImages.emplace_back(DevImage);
}

// Matches sycl::build and sycl::compile
Expand All @@ -115,10 +117,12 @@ class kernel_bundle_impl {
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
MState(TargetState) {

MSpecConstValues = getSyclObjImpl(InputBundle)->get_spec_const_map_ref();
const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
getSyclObjImpl(InputBundle);
MSpecConstValues = InputBundleImpl->get_spec_const_map_ref();

const std::vector<device> &InputBundleDevices =
getSyclObjImpl(InputBundle)->get_devices();
InputBundleImpl->get_devices();
const bool AllDevsAssociatedWithInputBundle =
std::all_of(MDevices.begin(), MDevices.end(),
[&InputBundleDevices](const device &Dev) {
Expand All @@ -132,24 +136,37 @@ class kernel_bundle_impl {
"Not all devices are in the set of associated "
"devices for input bundle or vector of devices is empty");

for (const device_image_plain &DeviceImage : InputBundle) {
for (const DevImgPlainWithDeps &DevImgWithDeps :
InputBundleImpl->MDeviceImages) {
// Skip images which are not compatible with devices provided
if (std::none_of(
MDevices.begin(), MDevices.end(),
[&DeviceImage](const device &Dev) {
return getSyclObjImpl(DeviceImage)->compatible_with_device(Dev);
}))
if (std::none_of(MDevices.begin(), MDevices.end(),
[&DevImgWithDeps](const device &Dev) {
return getSyclObjImpl(DevImgWithDeps.getMain())
->compatible_with_device(Dev);
}))
continue;

switch (TargetState) {
case bundle_state::object:
MDeviceImages.push_back(detail::ProgramManager::getInstance().compile(
DeviceImage, MDevices, PropList));
case bundle_state::object: {
DevImgPlainWithDeps CompiledImgWithDeps =
detail::ProgramManager::getInstance().compile(DevImgWithDeps,
MDevices, PropList);

MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
CompiledImgWithDeps.begin(),
CompiledImgWithDeps.end());
MDeviceImages.push_back(std::move(CompiledImgWithDeps));
break;
case bundle_state::executable:
MDeviceImages.push_back(detail::ProgramManager::getInstance().build(
DeviceImage, MDevices, PropList));
}

case bundle_state::executable: {
device_image_plain BuiltImg =
detail::ProgramManager::getInstance().build(DevImgWithDeps,
MDevices, PropList);
MDeviceImages.emplace_back(BuiltImg);
MUniqueDeviceImages.push_back(BuiltImg);
break;
}
case bundle_state::input:
case bundle_state::ext_oneapi_source:
throw exception(make_error_code(errc::runtime),
Expand All @@ -158,6 +175,7 @@ class kernel_bundle_impl {
break;
}
}
removeDuplicateImages();
}

// Matches sycl::link
Expand Down Expand Up @@ -201,7 +219,7 @@ class kernel_bundle_impl {
"Not all devices are in the set of associated "
"devices for input bundles");

// TODO: Unify with c'tor for sycl::comile and sycl::build by calling
// TODO: Unify with c'tor for sycl::compile and sycl::build by calling
// sycl::join on vector of kernel_bundles

// The loop below just links each device image separately, not linking any
Expand All @@ -213,23 +231,27 @@ class kernel_bundle_impl {
// undefined symbols, then the logic in this loop will need to be changed.
for (const kernel_bundle<bundle_state::object> &ObjectBundle :
ObjectBundles) {
for (const device_image_plain &DeviceImage : ObjectBundle) {
for (const DevImgPlainWithDeps &DeviceImageWithDeps :
getSyclObjImpl(ObjectBundle)->MDeviceImages) {

// Skip images which are not compatible with devices provided
if (std::none_of(MDevices.begin(), MDevices.end(),
[&DeviceImage](const device &Dev) {
return getSyclObjImpl(DeviceImage)
[&DeviceImageWithDeps](const device &Dev) {
return getSyclObjImpl(DeviceImageWithDeps.getMain())
->compatible_with_device(Dev);
}))
continue;

std::vector<device_image_plain> LinkedResults =
detail::ProgramManager::getInstance().link(DeviceImage, MDevices,
PropList);
detail::ProgramManager::getInstance().link(DeviceImageWithDeps,
MDevices, PropList);
MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(),
LinkedResults.end());
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
LinkedResults.begin(), LinkedResults.end());
}
}
removeDuplicateImages();

for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
const KernelBundleImplPtr BundlePtr = getSyclObjImpl(Bundle);
Expand All @@ -249,6 +271,7 @@ class kernel_bundle_impl {

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, KernelIDs, State);
fillUniqueDeviceImages();
}

kernel_bundle_impl(context Ctx, std::vector<device> Devs,
Expand All @@ -259,6 +282,7 @@ class kernel_bundle_impl {

MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, MDevices, Selector, State);
fillUniqueDeviceImages();
}

// C'tor matches sycl::join API
Expand Down Expand Up @@ -287,11 +311,10 @@ class kernel_bundle_impl {
Bundle->MDeviceImages.end());
}

std::sort(MDeviceImages.begin(), MDeviceImages.end(),
LessByHash<device_image_plain>{});
fillUniqueDeviceImages();

if (get_bundle_state() == bundle_state::input) {
// Copy spec constants values from the device images to be removed.
// Copy spec constants values from the device images.
auto MergeSpecConstants = [this](const device_image_plain &Img) {
const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl(Img);
const std::map<std::string,
Expand All @@ -310,16 +333,9 @@ class kernel_bundle_impl {
SpecConst.second.back().Size);
}
};
std::for_each(MDeviceImages.begin(), MDeviceImages.end(),
MergeSpecConstants);
std::for_each(begin(), end(), MergeSpecConstants);
}

const auto DevImgIt =
std::unique(MDeviceImages.begin(), MDeviceImages.end());

// Remove duplicate device images.
MDeviceImages.erase(DevImgIt, MDeviceImages.end());

for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
for (const std::pair<const std::string, std::vector<unsigned char>>
&SpecConst : Bundle->MSpecConstValues) {
Expand Down Expand Up @@ -605,7 +621,7 @@ class kernel_bundle_impl {

assert(MDeviceImages.size() > 0);
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
detail::getSyclObjImpl(MDeviceImages[0]);
detail::getSyclObjImpl(MDeviceImages[0].getMain());
ur_program_handle_t UrProgram = DeviceImageImpl->get_ur_program_ref();
ContextImplPtr ContextImpl = getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();
Expand Down Expand Up @@ -634,7 +650,7 @@ class kernel_bundle_impl {
// Collect kernel ids from all device images, then remove duplicates

std::vector<kernel_id> Result;
for (const device_image_plain &DeviceImage : MDeviceImages) {
for (const device_image_plain &DeviceImage : MUniqueDeviceImages) {
const std::vector<kernel_id> &KernelIDs =
getSyclObjImpl(DeviceImage)->get_kernel_ids();

Expand Down Expand Up @@ -662,8 +678,9 @@ class kernel_bundle_impl {
// Used to track if any of the candidate images has specialization values
// set.
bool SpecConstsSet = false;
for (auto &DeviceImage : MDeviceImages) {
if (!DeviceImage.has_kernel(KernelID))
for (const DevImgPlainWithDeps &DeviceImageWithDeps : MDeviceImages) {
const device_image_plain &DeviceImage = DeviceImageWithDeps.getMain();
if (!DeviceImageWithDeps.getMain().has_kernel(KernelID))
continue;

const auto DeviceImageImpl = detail::getSyclObjImpl(DeviceImage);
Expand Down Expand Up @@ -718,39 +735,38 @@ class kernel_bundle_impl {
}

bool has_kernel(const kernel_id &KernelID) const noexcept {
return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
return std::any_of(begin(), end(),
[&KernelID](const device_image_plain &DeviceImage) {
return DeviceImage.has_kernel(KernelID);
});
}

bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept {
return std::any_of(
MDeviceImages.begin(), MDeviceImages.end(),
begin(), end(),
[&KernelID, &Dev](const device_image_plain &DeviceImage) {
return DeviceImage.has_kernel(KernelID, Dev);
});
}

bool contains_specialization_constants() const noexcept {
return std::any_of(
MDeviceImages.begin(), MDeviceImages.end(),
[](const device_image_plain &DeviceImage) {
begin(), end(), [](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)->has_specialization_constants();
});
}

bool native_specialization_constant() const noexcept {
return contains_specialization_constants() &&
std::all_of(MDeviceImages.begin(), MDeviceImages.end(),
std::all_of(begin(), end(),
[](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->all_specialization_constant_native();
});
}

bool has_specialization_constant(const char *SpecName) const noexcept {
return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
return std::any_of(begin(), end(),
[SpecName](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->has_specialization_constant(SpecName);
Expand All @@ -761,7 +777,7 @@ class kernel_bundle_impl {
const void *Value,
size_t Size) noexcept {
if (has_specialization_constant(SpecName))
for (const device_image_plain &DeviceImage : MDeviceImages)
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
getSyclObjImpl(DeviceImage)
->set_specialization_constant_raw_value(SpecName, Value);
else {
Expand All @@ -773,7 +789,7 @@ class kernel_bundle_impl {

void get_specialization_constant_raw_value(const char *SpecName,
void *ValueRet) const noexcept {
for (const device_image_plain &DeviceImage : MDeviceImages)
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
if (getSyclObjImpl(DeviceImage)->has_specialization_constant(SpecName)) {
getSyclObjImpl(DeviceImage)
->get_specialization_constant_raw_value(SpecName, ValueRet);
Expand All @@ -796,21 +812,21 @@ class kernel_bundle_impl {

bool is_specialization_constant_set(const char *SpecName) const noexcept {
bool SetInDevImg =
std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
std::any_of(begin(), end(),
[SpecName](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->is_specialization_constant_set(SpecName);
});
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
}

const device_image_plain *begin() const { return MDeviceImages.data(); }
const device_image_plain *begin() const { return MUniqueDeviceImages.data(); }

const device_image_plain *end() const {
return MDeviceImages.data() + MDeviceImages.size();
return MUniqueDeviceImages.data() + MUniqueDeviceImages.size();
}

size_t size() const noexcept { return MDeviceImages.size(); }
size_t size() const noexcept { return MUniqueDeviceImages.size(); }

bundle_state get_bundle_state() const { return MState; }

Expand All @@ -827,7 +843,7 @@ class kernel_bundle_impl {

// First try and get images in current bundle state
const bundle_state BundleState = get_bundle_state();
std::vector<device_image_plain> NewDevImgs =
std::vector<DevImgPlainWithDeps> NewDevImgs =
detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, {Dev}, {KernelID}, BundleState);

Expand All @@ -836,21 +852,38 @@ class kernel_bundle_impl {
return false;

// Propagate already set specialization constants to the new images
for (device_image_plain &DevImg : NewDevImgs)
for (auto SpecConst : MSpecConstValues)
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
SpecConst.first.c_str(), SpecConst.second.data());
for (DevImgPlainWithDeps &DevImgWithDeps : NewDevImgs)
for (device_image_plain &DevImg : DevImgWithDeps)
for (auto SpecConst : MSpecConstValues)
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
SpecConst.first.c_str(), SpecConst.second.data());

// Add the images to the collection
MDeviceImages.insert(MDeviceImages.end(), NewDevImgs.begin(),
NewDevImgs.end());
removeDuplicateImages();
return true;
}

private:
void fillUniqueDeviceImages() {
assert(MUniqueDeviceImages.empty());
for (const DevImgPlainWithDeps &Imgs : MDeviceImages)
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(), Imgs.begin(),
Imgs.end());
removeDuplicateImages();
}
void removeDuplicateImages() {
std::sort(MUniqueDeviceImages.begin(), MUniqueDeviceImages.end(),
LessByHash<device_image_plain>{});
const auto It =
std::unique(MUniqueDeviceImages.begin(), MUniqueDeviceImages.end());
MUniqueDeviceImages.erase(It, MUniqueDeviceImages.end());
}
context MContext;
std::vector<device> MDevices;
std::vector<device_image_plain> MDeviceImages;
std::vector<DevImgPlainWithDeps> MDeviceImages;
std::vector<device_image_plain> MUniqueDeviceImages;
// This map stores values for specialization constants, that are missing
// from any device image.
SpecConstMapT MSpecConstValues;
Expand Down
Loading
Loading