Skip to content

Commit 26b9c92

Browse files
Merge pull request #1395 from IntelPython/async-ref-count-increment
Added submit_keep_args_alive
2 parents b437c47 + 10722d4 commit 26b9c92

11 files changed

+524
-131
lines changed

dpctl/_backend.pxd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,13 @@ cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
403403
void *Dest,
404404
const void *Src,
405405
size_t Count)
406+
cdef DPCTLSyclEventRef DPCTLQueue_MemcpyWithEvents(
407+
const DPCTLSyclQueueRef Q,
408+
void *Dest,
409+
const void *Src,
410+
size_t Count,
411+
const DPCTLSyclEventRef *depEvents,
412+
size_t depEventsCount)
406413
cdef DPCTLSyclEventRef DPCTLQueue_Memset(
407414
const DPCTLSyclQueueRef Q,
408415
void *Dest,

dpctl/_host_task_util.hpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// Data Parallel Control (dpctl)
44
//
5-
// Copyright 2020-2022 Intel Corporation
5+
// Copyright 2020-2023 Intel Corporation
66
//
77
// Licensed under the Apache License, Version 2.0 (the "License");
88
// you may not use this file except in compliance with the License.
@@ -29,30 +29,30 @@
2929
///
3030
//===----------------------------------------------------------------------===//
3131

32+
#pragma once
3233
#include "Python.h"
3334
#include "syclinterface/dpctl_data_types.h"
35+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
3436
#include <CL/sycl.hpp>
3537

36-
int async_dec_ref(DPCTLSyclQueueRef QRef,
37-
PyObject **obj_array,
38-
size_t obj_array_size,
39-
DPCTLSyclEventRef *ERefs,
40-
size_t nERefs)
38+
DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef QRef,
39+
PyObject **obj_array,
40+
size_t obj_array_size,
41+
DPCTLSyclEventRef *depERefs,
42+
size_t nDepERefs,
43+
int *status)
4144
{
45+
using dpctl::syclinterface::unwrap;
46+
using dpctl::syclinterface::wrap;
4247

43-
sycl::queue *q = reinterpret_cast<sycl::queue *>(QRef);
48+
sycl::queue *q = unwrap<sycl::queue>(QRef);
4449

45-
std::vector<PyObject *> obj_vec;
46-
obj_vec.reserve(obj_array_size);
47-
for (size_t obj_id = 0; obj_id < obj_array_size; ++obj_id) {
48-
obj_vec.push_back(obj_array[obj_id]);
49-
}
50+
std::vector<PyObject *> obj_vec(obj_array, obj_array + obj_array_size);
5051

5152
try {
52-
q->submit([&](sycl::handler &cgh) {
53-
for (size_t ev_id = 0; ev_id < nERefs; ++ev_id) {
54-
cgh.depends_on(
55-
*(reinterpret_cast<sycl::event *>(ERefs[ev_id])));
53+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
54+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
55+
cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
5656
}
5757
cgh.host_task([obj_array_size, obj_vec]() {
5858
// if the main thread has not finilized the interpreter yet
@@ -66,9 +66,21 @@ int async_dec_ref(DPCTLSyclQueueRef QRef,
6666
}
6767
});
6868
});
69+
70+
constexpr int result_ok = 0;
71+
72+
*status = result_ok;
73+
auto e_ptr = new sycl::event(ht_ev);
74+
return wrap<sycl::event>(e_ptr);
6975
} catch (const std::exception &e) {
70-
return 1;
76+
constexpr int result_std_exception = 1;
77+
78+
*status = result_std_exception;
79+
return nullptr;
7180
}
7281

73-
return 0;
82+
constexpr int result_other_abnormal = 2;
83+
84+
*status = result_other_abnormal;
85+
return nullptr;
7486
}

dpctl/_sycl_queue.pxd

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ cdef public api class SyclQueue (_SyclQueue) [
7070
cpdef SyclContext get_sycl_context(self)
7171
cpdef SyclDevice get_sycl_device(self)
7272
cdef DPCTLSyclQueueRef get_queue_ref(self)
73+
cpdef SyclEvent _submit_keep_args_alive(
74+
self,
75+
object args,
76+
list dEvents
77+
)
78+
cpdef SyclEvent submit_async(
79+
self,
80+
SyclKernel kernel,
81+
list args,
82+
list gS,
83+
list lS=*,
84+
list dEvents=*
85+
)
7386
cpdef SyclEvent submit(
7487
self,
7588
SyclKernel kernel,
@@ -81,6 +94,7 @@ cdef public api class SyclQueue (_SyclQueue) [
8194
cpdef void wait(self)
8295
cdef DPCTLSyclQueueRef get_queue_ref(self)
8396
cpdef memcpy(self, dest, src, size_t count)
97+
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=*)
8498
cpdef prefetch(self, ptr, size_t count=*)
8599
cpdef mem_advise(self, ptr, size_t count, int mem)
86100
cpdef SyclEvent submit_barrier(self, dependent_events=*)

0 commit comments

Comments
 (0)