Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {

class handler;
class queue;
namespace ext {
namespace oneapi {
namespace experimental {
Expand Down Expand Up @@ -70,6 +71,40 @@ class __SYCL_EXPORT command_graph {
finalize(const sycl::context &syclContext,
const property_list &propList = {}) const;

/// Change the state of a queue to be recording and associate this graph with
/// it.
/// @param recordingQueue The queue to change state on and associate this
/// graph with.
/// @return True if the queue had its state changed from executing to
/// recording.
bool begin_recording(queue recordingQueue);

/// Change the state of multiple queues to be recording and associate this
/// graph with each of them.
/// @param recordingQueues The queues to change state on and associate this
/// graph with.
/// @return True if any queue had its state changed from executing to
/// recording.
bool begin_recording(const std::vector<queue> &recordingQueues);

/// Set all queues currently recording to this graph to the executing state.
/// @return True if any queue had its state changed from recording to
/// executing.
bool end_recording();

/// Set a queues currently recording to this graph to the executing state.
/// @param recordingQueue The queue to change state on.
/// @return True if the queue had its state changed from recording to
/// executing.
bool end_recording(queue recordingQueue);

/// Set multiple queues currently recording to this graph to the executing
/// state.
/// @param recordingQueue The queues to change state on.
/// @return True if any queue had its state changed from recording to
/// executing.
bool end_recording(const std::vector<queue> &recordingQueues);

private:
command_graph(detail::graph_ptr Impl) : impl(Impl) {}

Expand Down
10 changes: 10 additions & 0 deletions sycl/include/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ static event submitAssertCapture(queue &, event &, queue *,
#endif
} // namespace detail

namespace ext {
namespace oneapi {
namespace experimental {
// State of a queue with regards to graph recording,
// returned by info::queue::state
enum class queue_state { executing, recording };
} // namespace experimental
} // namespace oneapi
} // namespace ext

/// Encapsulates a single SYCL queue which schedules kernels on a SYCL device.
///
/// A SYCL queue can be used to submit command groups to be executed by the SYCL
Expand Down
123 changes: 119 additions & 4 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <detail/graph_impl.hpp>
#include <detail/queue_impl.hpp>
#include <sycl/queue.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -61,12 +62,56 @@ void graph_impl::remove_root(node_ptr n) {
MSchedule.clear();
}

// Recursive check if a graph node or its successors contains a given kernel
// argument.
//
// @param[in] arg The kernel argument to check for.
// @param[in] currentNode The current graph node being checked.
// @param[in,out] deps The unique list of dependencies which have been
// identified for this arg.
// @param[in] dereferencePtr if true arg comes direct from the handler in which
// case it will need to be deferenced to check actual value.
//
// @returns True if a dependency was added in this node of any of its
// successors.
bool check_for_arg(const sycl::detail::ArgDesc &arg, node_ptr currentNode,
std::set<node_ptr> &deps, bool dereferencePtr = false) {
bool successorAddedDep = false;
for (auto &successor : currentNode->MSuccessors) {
successorAddedDep |= check_for_arg(arg, successor, deps, dereferencePtr);
}

if (deps.find(currentNode) == deps.end() &&
currentNode->has_arg(arg, dereferencePtr) && !successorAddedDep) {
deps.insert(currentNode);
return true;
}
return successorAddedDep;
}

template <typename T>
node_ptr graph_impl::add(graph_ptr impl, T cgf,
const std::vector<sycl::detail::ArgDesc> &args,
const std::vector<node_ptr> &dep) {
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf);
if (!dep.empty()) {
for (auto n : dep) {
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf, args);
// Copy deps so we can modify them
auto deps = dep;
// A unique set of dependencies obtained by checking kernel arguments
std::set<node_ptr> uniqueDeps;
for (auto &arg : args) {
if (arg.MType != sycl::detail::kernel_param_kind_t::kind_pointer) {
continue;
}
// Look through the graph for nodes which share this argument
for (auto nodePtr : MRoots) {
check_for_arg(arg, nodePtr, uniqueDeps, true);
}
}

// Add any deps determined from arguments into the dependency list
deps.insert(deps.end(), uniqueDeps.begin(), uniqueDeps.end());
if (!deps.empty()) {
for (auto n : deps) {
n->register_successor(nodeImpl); // register successor
this->remove_root(nodeImpl); // remove receiver from root node
// list
Expand All @@ -77,6 +122,17 @@ node_ptr graph_impl::add(graph_ptr impl, T cgf,
return nodeImpl;
}

bool graph_impl::clear_queues() {
bool anyQueuesCleared = false;
for (auto &q : MRecordingQueues) {
q->setCommandGraph(nullptr);
anyQueuesCleared = true;
}
MRecordingQueues.clear();

return anyQueuesCleared;
}

void node_impl::exec(sycl::detail::queue_ptr q) {
std::vector<sycl::event> deps;
for (auto i : MPredecessors)
Expand All @@ -100,7 +156,7 @@ node command_graph<graph_state::modifiable>::add_impl(
depImpls.push_back(sycl::detail::getSyclObjImpl(d));
}

auto nodeImpl = impl->add(impl, cgf, depImpls);
auto nodeImpl = impl->add(impl, cgf, {}, depImpls);
return sycl::detail::createSyclObjFromImpl<node>(nodeImpl);
}

Expand All @@ -121,6 +177,65 @@ command_graph<graph_state::modifiable>::finalize(
return command_graph<graph_state::executable>{this->impl, ctx};
}

template <>
bool command_graph<graph_state::modifiable>::begin_recording(
queue recordingQueue) {
auto queueImpl = sycl::detail::getSyclObjImpl(recordingQueue);
if (queueImpl->getCommandGraph() == nullptr) {
queueImpl->setCommandGraph(impl);
impl->add_queue(queueImpl);
return true;
} else if (queueImpl->getCommandGraph() != impl) {
throw sycl::exception(make_error_code(errc::invalid),
"begin_recording called for a queue which is already "
"recording to a different graph.");
}

// Queue was already recording to this graph.
return false;
}

template <>
bool command_graph<graph_state::modifiable>::begin_recording(
const std::vector<queue> &recordingQueues) {
bool queueStateChanged = false;
for (auto &q : recordingQueues) {
queueStateChanged |= this->begin_recording(q);
}
return queueStateChanged;
}

template <> bool command_graph<graph_state::modifiable>::end_recording() {
return impl->clear_queues();
}

template <>
bool command_graph<graph_state::modifiable>::end_recording(
queue recordingQueue) {
auto queueImpl = sycl::detail::getSyclObjImpl(recordingQueue);
if (queueImpl->getCommandGraph() == impl) {
queueImpl->setCommandGraph(nullptr);
impl->remove_queue(queueImpl);
return true;
} else if (queueImpl->getCommandGraph() != nullptr) {
throw sycl::exception(make_error_code(errc::invalid),
"end_recording called for a queue which is recording "
"to a different graph.");
}

// Queue was not recording to a graph.
return false;
}
template <>
bool command_graph<graph_state::modifiable>::end_recording(
const std::vector<queue> &recordingQueues) {
bool queueStateChanged = false;
for (auto &q : recordingQueues) {
queueStateChanged |= this->end_recording(q);
}
return queueStateChanged;
}

} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
53 changes: 51 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <sycl/detail/cg_types.hpp>
#include <sycl/ext/oneapi/experimental/graph.hpp>
#include <sycl/handler.hpp>

Expand Down Expand Up @@ -53,6 +54,8 @@ struct node_impl {

std::function<void(sycl::handler &)> MBody;

std::vector<sycl::detail::ArgDesc> MArgs;

void exec(sycl::detail::queue_ptr q);

void register_successor(node_ptr n) {
Expand All @@ -65,7 +68,17 @@ struct node_impl {
sycl::event get_event(void) const { return MEvent; }

template <typename T>
node_impl(graph_ptr g, T cgf) : MScheduled(false), MGraph(g), MBody(cgf) {}
node_impl(graph_ptr g, T cgf, const std::vector<sycl::detail::ArgDesc> &args)
: MScheduled(false), MGraph(g), MBody(cgf), MArgs(args) {
for (size_t i = 0; i < MArgs.size(); i++) {
if (MArgs[i].MType == sycl::detail::kernel_param_kind_t::kind_pointer) {
// Make sure we are storing the actual USM pointer for comparison
// purposes, note we couldn't actually submit using these copies of the
// args if subsequent code expects a void**.
MArgs[i].MPtr = *(void **)(MArgs[i].MPtr);
}
}
}

// Recursively adding nodes to execution stack:
void topology_sort(std::list<node_ptr> &schedule) {
Expand All @@ -76,6 +89,20 @@ struct node_impl {
}
schedule.push_front(node_ptr(this));
}

bool has_arg(const sycl::detail::ArgDesc &arg, bool dereferencePtr = false) {
for (auto &nodeArg : MArgs) {
if (arg.MType == nodeArg.MType && arg.MSize == nodeArg.MSize) {
// Args coming directly from the handler will need to be dereferenced
// since they are actually void**
void *incomingPtr = dereferencePtr ? *(void **)arg.MPtr : arg.MPtr;
if (incomingPtr == nodeArg.MPtr) {
return true;
}
}
}
return false;
}
};

struct graph_impl {
Expand All @@ -93,9 +120,31 @@ struct graph_impl {
void remove_root(node_ptr n);

template <typename T>
node_ptr add(graph_ptr impl, T cgf, const std::vector<node_ptr> &dep = {});
node_ptr add(graph_ptr impl, T cgf,
const std::vector<sycl::detail::ArgDesc> &args,
const std::vector<node_ptr> &dep = {});

graph_impl() : MFirst(true) {}

/// Add a queue to the set of queues which are currently recording to this
/// graph.
void add_queue(sycl::detail::queue_ptr recordingQueue) {
MRecordingQueues.insert(recordingQueue);
}

/// Remove a queue from the set of queues which are currently recording to
/// this graph.
void remove_queue(sycl::detail::queue_ptr recordingQueue) {
MRecordingQueues.erase(recordingQueue);
}

/// Remove all queues which are recording to this graph, also sets all queues
/// cleared back to the executing state. \return True if any queues were
/// removed.
bool clear_queues();

private:
std::set<sycl::detail::queue_ptr> MRecordingQueues;
};

} // namespace detail
Expand Down
Loading