Skip to content

Commit ebf118a

Browse files
authored
Repeat Python bindings properly pass host task dependencies (#1430)
1d variant of repeat was not passed host task event dependency for allocating shapes and strides on the device. This caused sporadic segfaults, where the kernel would attempt to access unallocated device data.
1 parent d4cc465 commit ebf118a

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

dpctl/tensor/libtensor/source/repeat.cpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,18 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src,
253253
}
254254
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);
255255

256+
std::vector<sycl::event> all_deps;
257+
all_deps.reserve(depends.size() + 1);
258+
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
259+
all_deps.push_back(copy_shapes_strides_ev);
260+
261+
assert(all_deps.size() == depends.size() + 1);
262+
256263
repeat_ev =
257264
fn(exec_q, src_axis_nelems, src_data_p, dst_data_p, reps_data_p,
258265
cumsum_data_p, src_nd, packed_src_shape_strides,
259266
dst_shape_vec[0], dst_strides_vec[0], reps_shape_vec[0],
260-
reps_strides_vec[0], depends);
267+
reps_strides_vec[0], all_deps);
261268

262269
sycl::event cleanup_tmp_allocations_ev =
263270
exec_q.submit([&](sycl::handler &cgh) {
@@ -496,10 +503,10 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src,
496503

497504
assert(all_deps.size() == depends.size() + 1);
498505

499-
sycl::event repeat_ev =
500-
fn(exec_q, src_sz, src_data_p, dst_data_p, reps_data_p, cumsum_data_p,
501-
src_nd, packed_src_shapes_strides, dst_shape_vec[0],
502-
dst_strides_vec[0], reps_shape_vec[0], reps_strides_vec[0], depends);
506+
sycl::event repeat_ev = fn(
507+
exec_q, src_sz, src_data_p, dst_data_p, reps_data_p, cumsum_data_p,
508+
src_nd, packed_src_shapes_strides, dst_shape_vec[0], dst_strides_vec[0],
509+
reps_shape_vec[0], reps_strides_vec[0], all_deps);
503510

504511
sycl::event cleanup_tmp_allocations_ev =
505512
exec_q.submit([&](sycl::handler &cgh) {
@@ -652,7 +659,7 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src,
652659

653660
repeat_ev = fn(exec_q, dst_axis_nelems, src_data_p, dst_data_p, reps,
654661
src_nd, packed_src_shape_strides, dst_shape_vec[0],
655-
dst_strides_vec[0], depends);
662+
dst_strides_vec[0], all_deps);
656663

657664
sycl::event cleanup_tmp_allocations_ev =
658665
exec_q.submit([&](sycl::handler &cgh) {
@@ -856,7 +863,7 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src,
856863

857864
sycl::event repeat_ev = fn(exec_q, dst_sz, src_data_p, dst_data_p, reps,
858865
src_nd, packed_src_shape_strides,
859-
dst_shape_vec[0], dst_strides_vec[0], depends);
866+
dst_shape_vec[0], dst_strides_vec[0], all_deps);
860867

861868
sycl::event cleanup_tmp_allocations_ev =
862869
exec_q.submit([&](sycl::handler &cgh) {

0 commit comments

Comments
 (0)