Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class __SYCL_EXPORT command_graph {
command_graph(const property_list &propList = {});

// Adding empty node with [0..n] predecessors:
node add(const std::vector<node> &dep = {});
node add(const std::vector<node> &dep = {}) { return add_impl(dep); }

// Adding device node:
template <typename T> node add(T cgf, const std::vector<node> &dep = {}) {
Expand Down Expand Up @@ -110,6 +110,8 @@ class __SYCL_EXPORT command_graph {
node add_impl(std::function<void(handler &)> cgf,
const std::vector<node> &dep);

node add_impl(const std::vector<node> &dep);

template <class Obj>
friend decltype(Obj::impl)
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
Expand Down
32 changes: 32 additions & 0 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ bool check_for_arg(const sycl::detail::ArgDesc &Arg,
return SuccessorAddedDep;
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<std::shared_ptr<node_impl>> &Dep) {
const std::shared_ptr<node_impl> &NodeImpl =
std::make_shared<node_impl>(Impl);

// TODO: Encapsulate in separate function to avoid duplication
if (!Dep.empty()) {
for (auto N : Dep) {
N->register_successor(NodeImpl); // register successor
this->remove_root(NodeImpl); // remove receiver from root node
// list
}
} else {
this->add_root(NodeImpl);
}

return NodeImpl;
}

std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
std::function<void(handler &)> CGF,
Expand Down Expand Up @@ -193,6 +213,18 @@ command_graph<graph_state::modifiable>::command_graph(
const sycl::property_list &)
: impl(std::make_shared<detail::graph_impl>()) {}

template <>
node command_graph<graph_state::modifiable>::add_impl(
const std::vector<node> &Deps) {
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

std::shared_ptr<detail::node_impl> NodeImpl = impl->add(impl, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

template <>
node command_graph<graph_state::modifiable>::add_impl(
std::function<void(handler &)> CGF, const std::vector<node> &Deps) {
Expand Down
10 changes: 9 additions & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ struct node_impl {

sycl::event get_event(void) const { return MEvent; }

node_impl(const std::shared_ptr<graph_impl> &Graph)
: MScheduled(false), MGraph(Graph) {}

node_impl(
const std::shared_ptr<graph_impl> &Graph,
std::shared_ptr<sycl::detail::kernel_impl> Kernel,
Expand Down Expand Up @@ -119,7 +122,8 @@ struct node_impl {
if (!Next->MScheduled)
Next->topology_sort(Schedule);
}
Schedule.push_front(std::shared_ptr<node_impl>(this));
if (MKernel != nullptr)
Schedule.push_front(std::shared_ptr<node_impl>(this));
}

bool has_arg(const sycl::detail::ArgDesc &Arg) {
Expand Down Expand Up @@ -169,6 +173,10 @@ struct graph_impl {
const std::vector<sycl::detail::ArgDesc> &Args,
const std::vector<std::shared_ptr<node_impl>> &Dep = {});

std::shared_ptr<node_impl>
add(const std::shared_ptr<graph_impl> &Impl,
const std::vector<std::shared_ptr<node_impl>> &Dep = {});

graph_impl() : MFirst(true) {}

/// Add a queue to the set of queues which are currently recording to this
Expand Down
44 changes: 44 additions & 0 deletions sycl/test/graph/graph-explicit-empty.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
#include <sycl/sycl.hpp>

#include <sycl/ext/oneapi/experimental/graph.hpp>

int main() {

sycl::property_list properties{
sycl::property::queue::in_order{},
sycl::ext::oneapi::property::queue::lazy_execution{}};

sycl::queue q{sycl::gpu_selector_v, properties};

sycl::ext::oneapi::experimental::command_graph g;

const size_t n = 10;
float *arr = sycl::malloc_device<float>(n, q);

auto start = g.add();

auto init = g.add([&](sycl::handler &h) {
h.parallel_for(sycl::range<1>{n}, [=](sycl::id<1> idx) {
size_t i = idx;
arr[i] = 0;
});
}, {start});

auto empty = g.add({init});

g.add([&](sycl::handler &h) {
h.parallel_for(sycl::range<1>{n}, [=](sycl::id<1> idx) {
size_t i = idx;
arr[i] = 1;
});
}, {empty});

auto executable_graph = g.finalize(q.get_context());

q.submit([&](sycl::handler &h) { h.ext_oneapi_graph(executable_graph); }).wait();

sycl::free(arr, q);

return 0;
}