Skip to content

Commit feeacc1

Browse files
jcranmer-intelromanovvlad
authored andcommitted
[SYCL] [USM] Implement prefetch.
Signed-off-by: Joshua Cranmer <[email protected]>
1 parent 866d634 commit feeacc1

File tree

10 files changed

+190
-1
lines changed

10 files changed

+190
-1
lines changed

sycl/include/CL/sycl/detail/cg.hpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ class CG {
340340
UPDATE_HOST,
341341
RUN_ON_HOST_INTEL,
342342
COPY_USM,
343-
FILL_USM
343+
FILL_USM,
344+
PREFETCH_USM
344345
};
345346

346347
CG(CGTYPE Type, std::vector<std::vector<char>> ArgsStorage,
@@ -529,6 +530,26 @@ class CGFillUSM : public CG {
529530
int getFill() { return MPattern[0]; }
530531
};
531532

533+
// The class which represents "prefetch" command group for USM pointers.
534+
class CGPrefetchUSM : public CG {
535+
void *MDst;
536+
size_t MLength;
537+
538+
public:
539+
CGPrefetchUSM(void *DstPtr, size_t Length,
540+
std::vector<std::vector<char>> ArgsStorage,
541+
std::vector<detail::AccessorImplPtr> AccStorage,
542+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
543+
std::vector<Requirement *> Requirements,
544+
std::vector<detail::EventImplPtr> Events)
545+
: CG(PREFETCH_USM, std::move(ArgsStorage), std::move(AccStorage),
546+
std::move(SharedPtrStorage), std::move(Requirements),
547+
std::move(Events)),
548+
MDst(DstPtr), MLength(Length) {}
549+
void *getDst() { return MDst; }
550+
size_t getLength() { return MLength; }
551+
};
552+
532553
} // namespace detail
533554
} // namespace sycl
534555
} // namespace cl

sycl/include/CL/sycl/detail/memory_manager.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ class MemoryManager {
129129
int Pattern, std::vector<RT::PiEvent> DepEvents,
130130
RT::PiEvent &OutEvent);
131131

132+
static void prefetch_usm(void *Ptr, QueueImplPtr Queue, size_t Len,
133+
std::vector<RT::PiEvent> DepEvents,
134+
RT::PiEvent &OutEvent);
135+
132136
};
133137
} // namespace detail
134138
} // namespace sycl

sycl/include/CL/sycl/detail/usm_dispatch.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class USMDispatcher {
4949
void *ParamValue, size_t *ParamValueSizeRet);
5050
void memAdvise(pi_queue Queue, const void *Ptr, size_t Length, int Advice,
5151
pi_event *Event);
52+
pi_result enqueuePrefetch(pi_queue Queue, void *Ptr, size_t Size,
53+
pi_uint32 NumEventsInWaitList,
54+
const pi_event *EventWaitList, pi_event *Event);
5255

5356
private:
5457
bool mEmulated = false;

sycl/include/CL/sycl/handler.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,12 @@ class handler {
397397
std::move(MAccStorage), std::move(MSharedPtrStorage),
398398
std::move(MRequirements), std::move(MEvents)));
399399
break;
400+
case detail::CG::PREFETCH_USM:
401+
CommandGroup.reset(new detail::CGPrefetchUSM(
402+
MDstPtr, MLength, std::move(MArgsStorage),
403+
std::move(MAccStorage), std::move(MSharedPtrStorage),
404+
std::move(MRequirements), std::move(MEvents)));
405+
break;
400406
case detail::CG::NONE:
401407
throw runtime_error("Command group submitted without a kernel or a "
402408
"explicit memory operation.");
@@ -1163,6 +1169,13 @@ class handler {
11631169
MLength = Count;
11641170
MCGType = detail::CG::FILL_USM;
11651171
}
1172+
1173+
// Prefetch the memory pointed to by the pointer.
1174+
void prefetch(const void *Ptr, size_t Count) {
1175+
MDstPtr = const_cast<void *>(Ptr);
1176+
MLength = Count;
1177+
MCGType = detail::CG::PREFETCH_USM;
1178+
}
11661179
};
11671180
} // namespace sycl
11681181
} // namespace cl

sycl/include/CL/sycl/ordered_queue.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ class ordered_queue {
112112
return impl->memcpy(impl, dest, src, count);
113113
}
114114

115+
event prefetch(const void* Ptr, size_t Count) {
116+
return submit([=](handler &cgh) {
117+
cgh.prefetch(Ptr, Count);
118+
});
119+
}
120+
115121
private:
116122
std::shared_ptr<detail::queue_impl> impl;
117123
template <class Obj>

sycl/include/CL/sycl/queue.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ class queue {
116116
return impl->mem_advise(Ptr, Length, Advice);
117117
}
118118

119+
event prefetch(const void* Ptr, size_t Count) {
120+
return submit([=](handler &cgh) {
121+
cgh.prefetch(Ptr, Count);
122+
});
123+
}
124+
119125
private:
120126
std::shared_ptr<detail::queue_impl> impl;
121127
template <class Obj>

sycl/source/detail/memory_manager.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,21 @@ void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
500500
}
501501
}
502502

503+
void MemoryManager::prefetch_usm(void *Mem, QueueImplPtr Queue, size_t Length,
504+
std::vector<RT::PiEvent> DepEvents,
505+
RT::PiEvent &OutEvent) {
506+
sycl::context Context = Queue->get_context();
507+
508+
if (Context.is_host()) {
509+
// TODO: Potentially implement prefetch on the host.
510+
} else {
511+
std::shared_ptr<usm::USMDispatcher> USMDispatch =
512+
getSyclObjImpl(Context)->getUSMDispatch();
513+
PI_CHECK(USMDispatch->enqueuePrefetch(Queue->getHandleRef(),
514+
Mem, Length, DepEvents.size(), &DepEvents[0], &OutEvent));
515+
}
516+
}
517+
503518
} // namespace detail
504519
} // namespace sycl
505520
} // namespace cl

sycl/source/detail/scheduler/commands.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
486486
case detail::CG::FILL_USM:
487487
Stream << "CG type: fill usm\\n";
488488
break;
489+
case detail::CG::PREFETCH_USM:
490+
Stream << "CG type: prefetch usm\\n";
491+
break;
489492
default:
490493
Stream << "CG type: unknown\\n";
491494
break;
@@ -785,6 +788,12 @@ cl_int ExecCGCommand::enqueueImp() {
785788
Fill->getFill(), std::move(RawEvents), Event);
786789
return CL_SUCCESS;
787790
}
791+
case CG::CGTYPE::PREFETCH_USM: {
792+
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
793+
MemoryManager::prefetch_usm(Prefetch->getDst(), MQueue,
794+
Prefetch->getLength(), std::move(RawEvents), Event);
795+
return CL_SUCCESS;
796+
}
788797
case CG::CGTYPE::NONE:
789798
default:
790799
throw runtime_error("CG type not implemented.");

sycl/source/detail/usm/usm_dispatch.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,29 @@ void USMDispatcher::memAdvise(pi_queue Queue, const void *Ptr, size_t Length,
364364
}
365365
}
366366
}
367+
368+
pi_result USMDispatcher::enqueuePrefetch(pi_queue Queue, void *Ptr, size_t Size,
369+
pi_uint32 NumEventsInWaitList,
370+
const pi_event *EventWaitList,
371+
pi_event *Event) {
372+
pi_result RetVal = PI_INVALID_OPERATION;
373+
374+
if (pi::useBackend(pi::Backend::SYCL_BE_PI_OPENCL)) {
375+
if (mEmulated) {
376+
// Prefetch is a hint, so ignoring it is always safe.
377+
RetVal = PI_CALL_RESULT(RT::piEnqueueEventsWait(
378+
Queue, NumEventsInWaitList, EventWaitList, Event));
379+
} else {
380+
// TODO: Replace this with real prefetch support when the driver enables
381+
// it.
382+
RetVal = PI_CALL_RESULT(RT::piEnqueueEventsWait(
383+
Queue, NumEventsInWaitList, EventWaitList, Event));
384+
}
385+
}
386+
387+
return RetVal;
388+
}
389+
367390
} // namespace usm
368391
} // namespace detail
369392
} // namespace sycl

sycl/test/usm/prefetch.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//==---- prefetch.cpp - USM prefetch test ----------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// RUN: %clangxx -fsycl %s -o %t1.out -lOpenCL
9+
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
10+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
11+
12+
#include <CL/sycl.hpp>
13+
14+
using namespace cl::sycl;
15+
16+
static constexpr int count = 100;
17+
18+
int main() {
19+
queue q([](exception_list el) {
20+
for (auto &e : el)
21+
throw e;
22+
});
23+
float *src = (float*)malloc_shared(sizeof(float) * count, q.get_device(),
24+
q.get_context());
25+
float *dest = (float*)malloc_shared(sizeof(float) * count, q.get_device(),
26+
q.get_context());
27+
for (int i = 0; i < count; i++)
28+
src[i] = i;
29+
30+
// Test handler::prefetch
31+
{
32+
event init_prefetch = q.submit([&](handler &cgh) {
33+
cgh.prefetch(src, sizeof(float) * count);
34+
});
35+
36+
q.submit([&](handler &cgh) {
37+
cgh.depends_on(init_prefetch);
38+
cgh.single_task<class double_dest>([=]() {
39+
for (int i = 0; i < count; i++)
40+
dest[i] = 2 * src[i];
41+
});
42+
});
43+
q.wait_and_throw();
44+
45+
for (int i = 0; i < count; i++) {
46+
assert(dest[i] == i * 2);
47+
}
48+
}
49+
50+
// Test queue::prefetch
51+
{
52+
event init_prefetch = q.prefetch(src, sizeof(float) * count);
53+
54+
q.submit([&](handler &cgh) {
55+
cgh.depends_on(init_prefetch);
56+
cgh.single_task<class double_dest3>([=]() {
57+
for (int i = 0; i < count; i++)
58+
dest[i] = 3 * src[i];
59+
});
60+
});
61+
q.wait_and_throw();
62+
63+
for (int i = 0; i < count; i++) {
64+
assert(dest[i] == i * 3);
65+
}
66+
}
67+
68+
// Test ordered_queue::prefetch
69+
{
70+
ordered_queue oq([](exception_list el) {
71+
for (auto &e : el)
72+
throw e;
73+
});
74+
event init_prefetch = oq.prefetch(src, sizeof(float) * count);
75+
76+
oq.submit([&](handler &cgh) {
77+
cgh.depends_on(init_prefetch);
78+
cgh.single_task<class double_dest4>([=]() {
79+
for (int i = 0; i < count; i++)
80+
dest[i] = 4 * src[i];
81+
});
82+
});
83+
oq.wait_and_throw();
84+
85+
for (int i = 0; i < count; i++) {
86+
assert(dest[i] == i * 4);
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)