@@ -64,46 +64,16 @@ static sycl::event div_impl(sycl::queue exec_q,
64
64
{
65
65
type_utils::validate_type_for_device<T>(exec_q);
66
66
67
- std::cerr << " enter div_impl" << std::endl;
67
+ const T* a = reinterpret_cast <const T*>(in_a);
68
+ const T* b = reinterpret_cast <const T*>(in_b);
69
+ T* y = reinterpret_cast <T*>(out_y);
68
70
69
- const T* _a = reinterpret_cast <const T*>(in_a);
70
- const T* _b = reinterpret_cast <const T*>(in_b);
71
- T* _y = reinterpret_cast <T*>(out_y);
72
-
73
- std::cerr << " casting is done" << std::endl;
74
-
75
- T* a = sycl::malloc_device<T>(n, exec_q);
76
- T* b = sycl::malloc_device<T>(n, exec_q);
77
- T* y = sycl::malloc_device<T>(n, exec_q);
78
-
79
- std::cerr << " malloc is done" << std::endl;
80
-
81
- exec_q.copy (_a, a, n).wait ();
82
- exec_q.copy (_b, b, n).wait ();
83
- exec_q.copy (_y, y, n).wait ();
84
-
85
- std::cerr << " copy is done" << std::endl;
86
-
87
- sycl::event ev = mkl_vm::div (exec_q,
71
+ return mkl_vm::div (exec_q,
88
72
n, // number of elements to be calculated
89
73
a, // pointer `a` containing 1st input vector of size n
90
74
b, // pointer `b` containing 2nd input vector of size n
91
75
y, // pointer `y` to the output vector of size n
92
76
depends);
93
- ev.wait ();
94
-
95
- std::cerr << " div is done" << std::endl;
96
-
97
- exec_q.copy (y, _y, n).wait ();
98
-
99
- std::cerr << " copy is done" << std::endl;
100
-
101
- sycl::free (a, exec_q);
102
- sycl::free (b, exec_q);
103
- sycl::free (y, exec_q);
104
-
105
- std::cerr << " leaving div_impl" << std::endl;
106
- return sycl::event ();
107
77
}
108
78
109
79
std::pair<sycl::event, sycl::event> div (sycl::queue exec_q,
@@ -205,20 +175,9 @@ std::pair<sycl::event, sycl::event> div(sycl::queue exec_q,
205
175
throw py::value_error (" No div implementation defined" );
206
176
}
207
177
sycl::event sum_ev = div_fn (exec_q, src_nelems, src1_data, src2_data, dst_data, depends);
208
- // sum_ev.wait();
209
-
210
- // int* dummy = sycl::malloc_device<int>(1, exec_q);
211
- // sycl::event cleanup_ev = exec_q.submit([&](sycl::handler& cgh) {
212
- // // cgh.depends_on(sum_ev);
213
- // auto ctx = exec_q.get_context();
214
- // cgh.host_task([dummy, ctx]() {
215
- // // dummy host task to pass into keep_args_alive
216
- // sycl::free(dummy, ctx);
217
- // });
218
- // });
219
-
220
- // sycl::event ht_ev = dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, {sum_ev});
221
- // return std::make_pair(ht_ev, sum_ev);
178
+
179
+ sycl::event ht_ev = dpctl::utils::keep_args_alive (exec_q, {src1, src2, dst}, {sum_ev});
180
+ return std::make_pair (ht_ev, sum_ev);
222
181
return std::make_pair (sycl::event (), sycl::event ());
223
182
}
224
183
@@ -227,6 +186,7 @@ bool can_call_div(sycl::queue exec_q,
227
186
dpctl::tensor::usm_ndarray src2,
228
187
dpctl::tensor::usm_ndarray dst)
229
188
{
189
+ #if INTEL_MKL_VERSION >= 20230002
230
190
// check type_nums
231
191
int src1_typenum = src1.get_typenum ();
232
192
int src2_typenum = src2.get_typenum ();
@@ -325,6 +285,16 @@ bool can_call_div(sycl::queue exec_q,
325
285
return false ;
326
286
}
327
287
return true ;
288
+ #else
289
+ // In OneMKL 2023.1.0 the call of oneapi::mkl::vm::div() is going to dead lock
290
+ // inside ~usm_wrapper_to_host()->{...; q_->wait_and_throw(); ...}
291
+
292
+ (void )exec_q;
293
+ (void )src1;
294
+ (void )src2;
295
+ (void )dst;
296
+ return false ;
297
+ #endif // INTEL_MKL_VERSION >= 20230002
328
298
}
329
299
330
300
template <typename fnT, typename T>
0 commit comments