@@ -28,20 +28,58 @@ namespace sycl {
28
28
namespace detail {
29
29
30
30
template <class T > struct LessByHash {
31
- bool operator ()(const T &LHS, const T &RHS) {
31
+ bool operator ()(const T &LHS, const T &RHS) const {
32
32
return getSyclObjImpl (LHS) < getSyclObjImpl (RHS);
33
33
}
34
34
};
35
35
36
+ static bool checkAllDevicesAreInContext (const std::vector<device> &Devices,
37
+ const context &Context) {
38
+ const std::vector<device> &ContextDevices = Context.get_devices ();
39
+ return std::all_of (
40
+ Devices.begin (), Devices.end (), [&ContextDevices](const device &Dev) {
41
+ return ContextDevices.end () !=
42
+ std::find (ContextDevices.begin (), ContextDevices.end (), Dev);
43
+ });
44
+ }
45
+
46
+ static bool checkAllDevicesHaveAspect (const std::vector<device> &Devices,
47
+ aspect Aspect) {
48
+ return std::all_of (Devices.begin (), Devices.end (),
49
+ [&Aspect](const device &Dev) { return Dev.has (Aspect); });
50
+ }
51
+
36
52
// The class is an impl counterpart of the sycl::kernel_bundle.
37
53
// It provides an access and utilities to manage set of sycl::device_images
38
54
// objects.
39
55
class kernel_bundle_impl {
40
56
57
+ void common_ctor_checks (bundle_state State) {
58
+ const bool AllDevicesInTheContext =
59
+ checkAllDevicesAreInContext (MDevices, MContext);
60
+ if (MDevices.empty () || !AllDevicesInTheContext)
61
+ throw sycl::exception (
62
+ make_error_code (errc::invalid),
63
+ " Not all devices are associated with the context or "
64
+ " vector of devices is empty" );
65
+
66
+ if (bundle_state::input == State &&
67
+ !checkAllDevicesHaveAspect (MDevices, aspect::online_compiler))
68
+ throw sycl::exception (make_error_code (errc::invalid),
69
+ " Not all devices have aspect::online_compiler" );
70
+
71
+ if (bundle_state::object == State &&
72
+ !checkAllDevicesHaveAspect (MDevices, aspect::online_linker))
73
+ throw sycl::exception (make_error_code (errc::invalid),
74
+ " Not all devices have aspect::online_linker" );
75
+ }
76
+
41
77
public:
42
78
kernel_bundle_impl (context Ctx, std::vector<device> Devs, bundle_state State)
43
79
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {
44
80
81
+ common_ctor_checks (State);
82
+
45
83
MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
46
84
MContext, MDevices, State);
47
85
}
@@ -54,6 +92,21 @@ class kernel_bundle_impl {
54
92
bundle_state TargetState)
55
93
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)) {
56
94
95
+ const std::vector<device> &InputBundleDevices =
96
+ getSyclObjImpl (InputBundle)->get_devices ();
97
+ const bool AllDevsAssociatedWithInputBundle =
98
+ std::all_of (MDevices.begin (), MDevices.end (),
99
+ [&InputBundleDevices](const device &Dev) {
100
+ return InputBundleDevices.end () !=
101
+ std::find (InputBundleDevices.begin (),
102
+ InputBundleDevices.end (), Dev);
103
+ });
104
+ if (MDevices.empty () || !AllDevsAssociatedWithInputBundle)
105
+ throw sycl::exception (
106
+ make_error_code (errc::invalid),
107
+ " Not all devices are in the set of associated "
108
+ " devices for input bundle or vector of devices is empty" );
109
+
57
110
for (const device_image_plain &DeviceImage : InputBundle) {
58
111
// Skip images which are not compatible with devices provided
59
112
if (std::none_of (
@@ -85,7 +138,39 @@ class kernel_bundle_impl {
85
138
kernel_bundle_impl (
86
139
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
87
140
std::vector<device> Devs, const property_list &PropList)
88
- : MContext(ObjectBundles[0 ].get_context()), MDevices(std::move(Devs)) {
141
+ : MDevices(std::move(Devs)) {
142
+
143
+ if (ObjectBundles.empty ())
144
+ return ;
145
+
146
+ MContext = ObjectBundles[0 ].get_context ();
147
+ for (size_t I = 1 ; I < ObjectBundles.size (); ++I) {
148
+ if (ObjectBundles[I].get_context () != MContext)
149
+ throw sycl::exception (
150
+ make_error_code (errc::invalid),
151
+ " Not all input bundles have the same associated context" );
152
+ }
153
+
154
+ // Check if any of the devices in devs are not in the set of associated
155
+ // devices for any of the bundles in ObjectBundles
156
+ const bool AllDevsAssociatedWithInputBundles = std::all_of (
157
+ MDevices.begin (), MDevices.end (), [&ObjectBundles](const device &Dev) {
158
+ // Number of devices is expected to be small
159
+ return std::all_of (
160
+ ObjectBundles.begin (), ObjectBundles.end (),
161
+ [&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
162
+ const std::vector<device> &BundleDevices =
163
+ getSyclObjImpl (KernelBundle)->get_devices ();
164
+ return BundleDevices.end () != std::find (BundleDevices.begin (),
165
+ BundleDevices.end (),
166
+ Dev);
167
+ });
168
+ });
169
+ if (MDevices.empty () || !AllDevsAssociatedWithInputBundles)
170
+ throw sycl::exception (
171
+ make_error_code (errc::invalid),
172
+ " Not all devices are in the set of associated "
173
+ " devices for input bundles or vector of devices is empty" );
89
174
90
175
// TODO: Unify with c'tor for sycl::comile and sycl::build by calling
91
176
// sycl::join on vector of kernel_bundles
@@ -116,6 +201,10 @@ class kernel_bundle_impl {
116
201
bundle_state State)
117
202
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {
118
203
204
+ // TODO: Add a check that all kernel ids are compatible with at least one
205
+ // device in Devs
206
+ common_ctor_checks (State);
207
+
119
208
MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
120
209
MContext, MDevices, KernelIDs, State);
121
210
}
@@ -124,24 +213,36 @@ class kernel_bundle_impl {
124
213
const DevImgSelectorImpl &Selector, bundle_state State)
125
214
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {
126
215
216
+ common_ctor_checks (State);
217
+
127
218
MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
128
219
MContext, MDevices, Selector, State);
129
220
}
130
221
131
222
// C'tor matches sycl::join API
132
223
kernel_bundle_impl (const std::vector<detail::KernelBundleImplPtr> &Bundles) {
224
+ if (Bundles.empty ())
225
+ return ;
226
+
133
227
MContext = Bundles[0 ]->MContext ;
228
+ MDevices = Bundles[0 ]->MDevices ;
229
+ for (size_t I = 1 ; I < Bundles.size (); ++I) {
230
+ if (Bundles[I]->MContext != MContext)
231
+ throw sycl::exception (
232
+ make_error_code (errc::invalid),
233
+ " Not all input bundles have the same associated context." );
234
+ if (Bundles[I]->MDevices != MDevices)
235
+ throw sycl::exception (
236
+ make_error_code (errc::invalid),
237
+ " Not all input bundles have the same set of associated devices." );
238
+ }
239
+
134
240
for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
135
- MDevices.insert (MDevices.end (), Bundle->MDevices .begin (),
136
- Bundle->MDevices .end ());
241
+
137
242
MDeviceImages.insert (MDeviceImages.end (), Bundle->MDeviceImages .begin (),
138
243
Bundle->MDeviceImages .end ());
139
244
}
140
245
141
- std::sort (MDevices.begin (), MDevices.end (), LessByHash<device>{});
142
- const auto DevIt = std::unique (MDevices.begin (), MDevices.end ());
143
- MDevices.erase (DevIt, MDevices.end ());
144
-
145
246
std::sort (MDeviceImages.begin (), MDeviceImages.end (),
146
247
LessByHash<device_image_plain>{});
147
248
const auto DevImgIt =
@@ -171,14 +272,7 @@ class kernel_bundle_impl {
171
272
}
172
273
std::sort (Result.begin (), Result.end (), LessByNameComp{});
173
274
174
- auto NewIt =
175
- std::unique (Result.begin (), Result.end (),
176
- [](const sycl::kernel_id &LHS, const sycl::kernel_id &RHS) {
177
- return strcmp (LHS.get_name (), RHS.get_name ()) == 0 ;
178
- }
179
-
180
- );
181
-
275
+ auto NewIt = std::unique (Result.begin (), Result.end (), EqualByNameComp{});
182
276
Result.erase (NewIt, Result.end ());
183
277
184
278
return Result;
@@ -192,6 +286,12 @@ class kernel_bundle_impl {
192
286
[&KernelID](const device_image_plain &DeviceImage) {
193
287
return DeviceImage.has_kernel (KernelID);
194
288
});
289
+
290
+ if (MDeviceImages.end () == It)
291
+ throw sycl::exception (make_error_code (errc::invalid),
292
+ " The kernel bundle does not contain the kernel "
293
+ " identified by kernelId." );
294
+
195
295
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
196
296
detail::getSyclObjImpl (*It);
197
297
0 commit comments