2
2
//
3
3
// Data Parallel Control (dpctl)
4
4
//
5
- // Copyright 2020-2022 Intel Corporation
5
+ // Copyright 2020-2023 Intel Corporation
6
6
//
7
7
// Licensed under the Apache License, Version 2.0 (the "License");
8
8
// you may not use this file except in compliance with the License.
29
29
// /
30
30
// ===----------------------------------------------------------------------===//
31
31
32
+ #pragma once
32
33
#include " Python.h"
33
34
#include " syclinterface/dpctl_data_types.h"
35
+ #include " syclinterface/dpctl_sycl_type_casters.hpp"
34
36
#include < CL/sycl.hpp>
35
37
36
- int async_dec_ref (DPCTLSyclQueueRef QRef,
37
- PyObject **obj_array,
38
- size_t obj_array_size,
39
- DPCTLSyclEventRef *ERefs,
40
- size_t nERefs)
38
+ DPCTLSyclEventRef async_dec_ref (DPCTLSyclQueueRef QRef,
39
+ PyObject **obj_array,
40
+ size_t obj_array_size,
41
+ DPCTLSyclEventRef *depERefs,
42
+ size_t nDepERefs,
43
+ int *status)
41
44
{
45
+ using dpctl::syclinterface::unwrap;
46
+ using dpctl::syclinterface::wrap;
42
47
43
- sycl::queue *q = reinterpret_cast <sycl::queue * >(QRef);
48
+ sycl::queue *q = unwrap <sycl::queue>(QRef);
44
49
45
- std::vector<PyObject *> obj_vec;
46
- obj_vec.reserve (obj_array_size);
47
- for (size_t obj_id = 0 ; obj_id < obj_array_size; ++obj_id) {
48
- obj_vec.push_back (obj_array[obj_id]);
49
- }
50
+ std::vector<PyObject *> obj_vec (obj_array, obj_array + obj_array_size);
50
51
51
52
try {
52
- q->submit ([&](sycl::handler &cgh) {
53
- for (size_t ev_id = 0 ; ev_id < nERefs; ++ev_id) {
54
- cgh.depends_on (
55
- *(reinterpret_cast <sycl::event *>(ERefs[ev_id])));
53
+ sycl::event ht_ev = q->submit ([&](sycl::handler &cgh) {
54
+ for (size_t ev_id = 0 ; ev_id < nDepERefs; ++ev_id) {
55
+ cgh.depends_on (*(unwrap<sycl::event>(depERefs[ev_id])));
56
56
}
57
57
cgh.host_task ([obj_array_size, obj_vec]() {
58
58
// if the main thread has not finilized the interpreter yet
@@ -66,9 +66,21 @@ int async_dec_ref(DPCTLSyclQueueRef QRef,
66
66
}
67
67
});
68
68
});
69
+
70
+ constexpr int result_ok = 0 ;
71
+
72
+ *status = result_ok;
73
+ auto e_ptr = new sycl::event (ht_ev);
74
+ return wrap<sycl::event>(e_ptr);
69
75
} catch (const std::exception &e) {
70
- return 1 ;
76
+ constexpr int result_std_exception = 1 ;
77
+
78
+ *status = result_std_exception;
79
+ return nullptr ;
71
80
}
72
81
73
- return 0 ;
82
+ constexpr int result_other_abnormal = 2 ;
83
+
84
+ *status = result_other_abnormal;
85
+ return nullptr ;
74
86
}
0 commit comments