Skip to content

Commit 4e41318

Browse files
Merge pull request #1176 from IntelPython/revert-returning-host-tasks
Return events for comp. tasks, rather than temp-clean up host tasks
2 parents 193a8b1 + 80f0463 commit 4e41318

6 files changed

+8
-10
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -988,8 +988,7 @@ std::pair<sycl::event, sycl::event> py_nonzero(
988988
sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive(
989989
exec_q, {cumsum, indexes}, host_task_events);
990990

991-
return std::make_pair(py_obj_management_host_task_ev,
992-
temporaries_cleanup_ev);
991+
return std::make_pair(py_obj_management_host_task_ev, non_zero_indexes_ev);
993992
}
994993

995994
} // namespace py_internal

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
272272
host_task_events.push_back(temporaries_cleanup_ev);
273273

274274
return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events),
275-
temporaries_cleanup_ev);
275+
copy_and_cast_generic_ev);
276276
}
277277

278278
void init_copy_and_cast_usm_to_usm_dispatch_tables(void)

dpctl/tensor/libtensor/source/copy_for_reshape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
167167
host_task_events.push_back(temporaries_cleanup_ev);
168168

169169
return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events),
170-
temporaries_cleanup_ev);
170+
copy_for_reshape_event);
171171
}
172172

173173
void init_copy_for_reshape_dispatch_vectors(void)

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
540540
sycl::event arg_cleanup_ev =
541541
keep_args_alive(exec_q, {src, py_ind, dst}, host_task_events);
542542

543-
return std::make_pair(arg_cleanup_ev, temporaries_cleanup_ev);
543+
return std::make_pair(arg_cleanup_ev, take_generic_ev);
544544
}
545545

546546
std::pair<sycl::event, sycl::event>
@@ -854,7 +854,7 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
854854
sycl::event arg_cleanup_ev =
855855
keep_args_alive(exec_q, {dst, py_ind, val}, host_task_events);
856856

857-
return std::make_pair(arg_cleanup_ev, temporaries_cleanup_ev);
857+
return std::make_pair(arg_cleanup_ev, put_generic_ev);
858858
}
859859

860860
void init_advanced_indexing_dispatch_tables(void)

dpctl/tensor/libtensor/source/triul_ctor.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ usm_ndarray_triul(sycl::queue exec_q,
202202
}
203203

204204
auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
205-
cgh.depends_on({tri_ev});
205+
cgh.depends_on(tri_ev);
206206
auto ctx = exec_q.get_context();
207207
cgh.host_task(
208208
[shp_host_shape_and_strides, dev_shape_and_strides, ctx]() {
@@ -213,8 +213,7 @@ usm_ndarray_triul(sycl::queue exec_q,
213213
});
214214

215215
return std::make_pair(
216-
keep_args_alive(exec_q, {src, dst}, {temporaries_cleanup_ev}),
217-
temporaries_cleanup_ev);
216+
keep_args_alive(exec_q, {src, dst}, {temporaries_cleanup_ev}), tri_ev);
218217
}
219218

220219
void init_triul_ctor_dispatch_vectors(void)

dpctl/tensor/libtensor/source/where.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ py_where(dpctl::tensor::usm_ndarray condition,
244244
sycl::event arg_cleanup_ev =
245245
keep_args_alive(exec_q, {x1, x2, condition, dst}, host_task_events);
246246

247-
return std::make_pair(arg_cleanup_ev, temporaries_cleanup_ev);
247+
return std::make_pair(arg_cleanup_ev, where_ev);
248248
}
249249

250250
void init_where_dispatch_tables(void)

0 commit comments

Comments
 (0)