@@ -253,11 +253,18 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src,
253
253
}
254
254
sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
255
255
256
+ std::vector<sycl::event> all_deps;
257
+ all_deps.reserve (depends.size () + 1 );
258
+ all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
259
+ all_deps.push_back (copy_shapes_strides_ev);
260
+
261
+ assert (all_deps.size () == depends.size () + 1 );
262
+
256
263
repeat_ev =
257
264
fn (exec_q, src_axis_nelems, src_data_p, dst_data_p, reps_data_p,
258
265
cumsum_data_p, src_nd, packed_src_shape_strides,
259
266
dst_shape_vec[0 ], dst_strides_vec[0 ], reps_shape_vec[0 ],
260
- reps_strides_vec[0 ], depends );
267
+ reps_strides_vec[0 ], all_deps );
261
268
262
269
sycl::event cleanup_tmp_allocations_ev =
263
270
exec_q.submit ([&](sycl::handler &cgh) {
@@ -496,10 +503,10 @@ py_repeat_by_sequence(const dpctl::tensor::usm_ndarray &src,
496
503
497
504
assert (all_deps.size () == depends.size () + 1 );
498
505
499
- sycl::event repeat_ev =
500
- fn ( exec_q, src_sz, src_data_p, dst_data_p, reps_data_p, cumsum_data_p,
501
- src_nd, packed_src_shapes_strides, dst_shape_vec[0 ],
502
- dst_strides_vec[ 0 ], reps_shape_vec[0 ], reps_strides_vec[0 ], depends );
506
+ sycl::event repeat_ev = fn (
507
+ exec_q, src_sz, src_data_p, dst_data_p, reps_data_p, cumsum_data_p,
508
+ src_nd, packed_src_shapes_strides, dst_shape_vec[ 0 ], dst_strides_vec [0 ],
509
+ reps_shape_vec[0 ], reps_strides_vec[0 ], all_deps );
503
510
504
511
sycl::event cleanup_tmp_allocations_ev =
505
512
exec_q.submit ([&](sycl::handler &cgh) {
@@ -652,7 +659,7 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src,
652
659
653
660
repeat_ev = fn (exec_q, dst_axis_nelems, src_data_p, dst_data_p, reps,
654
661
src_nd, packed_src_shape_strides, dst_shape_vec[0 ],
655
- dst_strides_vec[0 ], depends );
662
+ dst_strides_vec[0 ], all_deps );
656
663
657
664
sycl::event cleanup_tmp_allocations_ev =
658
665
exec_q.submit ([&](sycl::handler &cgh) {
@@ -856,7 +863,7 @@ py_repeat_by_scalar(const dpctl::tensor::usm_ndarray &src,
856
863
857
864
sycl::event repeat_ev = fn (exec_q, dst_sz, src_data_p, dst_data_p, reps,
858
865
src_nd, packed_src_shape_strides,
859
- dst_shape_vec[0 ], dst_strides_vec[0 ], depends );
866
+ dst_shape_vec[0 ], dst_strides_vec[0 ], all_deps );
860
867
861
868
sycl::event cleanup_tmp_allocations_ev =
862
869
exec_q.submit ([&](sycl::handler &cgh) {
0 commit comments