Skip to content

Commit 7ed9b40

Browse files
authored
Allocate a separate FFT descriptor per FFT compute function (#1322)
1 parent 1c7b85f commit 7ed9b40

File tree

1 file changed

+87
-76
lines changed

1 file changed

+87
-76
lines changed

dpnp/backend/kernels/dpnp_krnl_fft.cpp

+87-76
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
178178
const size_t shape_size,
179179
const size_t input_size,
180180
const size_t result_size,
181-
_Descriptor_type& desc,
182181
size_t inverse,
183182
const size_t norm)
184183
{
@@ -187,14 +186,15 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
187186
(void)input_size;
188187
(void)result_size;
189188

190-
if (!shape_size) {
189+
if (!shape_size)
190+
{
191191
return;
192192
}
193193

194194
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
195195

196-
_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
197-
_DataType_output* result = static_cast<_DataType_output *>(result_out);
196+
_DataType_input* array_1 = static_cast<_DataType_input*>(const_cast<void*>(array1_in));
197+
_DataType_output* result = static_cast<_DataType_output*>(result_out);
198198

199199
const size_t n_iter =
200200
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
@@ -204,39 +204,49 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
204204
double backward_scale = 1.;
205205
double forward_scale = 1.;
206206

207-
if (norm == 0) { // norm = "backward"
207+
if (norm == 0) // norm = "backward"
208+
{
208209
backward_scale = 1. / shift;
209-
} else if (norm == 1) { // norm = "forward"
210+
}
211+
else if (norm == 1) // norm = "forward"
212+
{
210213
forward_scale = 1. / shift;
211-
} else { // norm = "ortho"
212-
if (inverse) {
214+
}
215+
else // norm = "ortho"
216+
{
217+
if (inverse)
218+
{
213219
backward_scale = 1. / sqrt(shift);
214-
} else {
220+
}
221+
else
222+
{
215223
forward_scale = 1. / sqrt(shift);
216224
}
217225
}
218226

219-
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
220-
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
221-
// enum value from math library C interface
222-
// instead of mkl_dft::config_value::NOT_INPLACE
223-
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
224-
desc.commit(queue);
225-
226-
std::vector<sycl::event> fft_events;
227-
fft_events.reserve(n_iter);
228-
229-
for (size_t i = 0; i < n_iter; ++i) {
230-
if (inverse) {
231-
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * shift, result + i * shift));
232-
} else {
233-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
227+
std::vector<sycl::event> fft_events(n_iter);
228+
229+
for (size_t i = 0; i < n_iter; ++i)
230+
{
231+
std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(shift);
232+
desc->set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
233+
desc->set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
234+
desc->set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
235+
desc->commit(queue);
236+
237+
if (inverse)
238+
{
239+
fft_events[i] = mkl_dft::compute_backward<_Descriptor_type, _DataType_input, _DataType_output>(
240+
*desc, array_1 + i * shift, result + i * shift);
241+
}
242+
else
243+
{
244+
fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
245+
*desc, array_1 + i * shift, result + i * shift);
234246
}
235247
}
236248

237249
sycl::event::wait(fft_events);
238-
239-
return;
240250
}
241251

242252
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
@@ -251,7 +261,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
251261
const size_t shape_size,
252262
const size_t input_size,
253263
const size_t result_size,
254-
_Descriptor_type& desc,
255264
size_t inverse,
256265
const size_t norm,
257266
const size_t real)
@@ -260,14 +269,15 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
260269
(void)input_size;
261270

262271
DPCTLSyclEventRef event_ref = nullptr;
263-
if (!shape_size) {
272+
if (!shape_size)
273+
{
264274
return event_ref;
265275
}
266276

267277
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
268278

269-
_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
270-
_DataType_output* result = static_cast<_DataType_output *>(result_out);
279+
_DataType_input* array_1 = static_cast<_DataType_input*>(const_cast<void*>(array1_in));
280+
_DataType_output* result = static_cast<_DataType_output*>(result_out);
271281

272282
const size_t n_iter =
273283
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
@@ -278,38 +288,52 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
278288
double backward_scale = 1.;
279289
double forward_scale = 1.;
280290

281-
if (norm == 0) { // norm = "backward"
282-
if (inverse) {
291+
if (norm == 0) // norm = "backward"
292+
{
293+
if (inverse)
294+
{
283295
forward_scale = 1. / result_shift;
284-
} else {
296+
}
297+
else
298+
{
285299
backward_scale = 1. / result_shift;
286300
}
287-
} else if (norm == 1) { // norm = "forward"
288-
if (inverse) {
301+
}
302+
else if (norm == 1) // norm = "forward"
303+
{
304+
if (inverse)
305+
{
289306
backward_scale = 1. / result_shift;
290-
} else {
307+
}
308+
else
309+
{
291310
forward_scale = 1. / result_shift;
292311
}
293-
} else { // norm = "ortho"
312+
}
313+
else // norm = "ortho"
314+
{
294315
forward_scale = 1. / sqrt(result_shift);
295316
}
296317

297-
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
298-
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
299-
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
318+
std::vector<sycl::event> fft_events(n_iter);
300319

301-
desc.commit(queue);
302-
303-
std::vector<sycl::event> fft_events;
304-
fft_events.reserve(n_iter);
305-
306-
for (size_t i = 0; i < n_iter; ++i) {
307-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
320+
for (size_t i = 0; i < n_iter; ++i)
321+
{
322+
std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(input_shift);
323+
desc->set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
324+
desc->set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
325+
desc->set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
326+
desc->commit(queue);
327+
328+
// real result_size = 2 * result_size, because real type of "result" is twice wider than '_DataType_output'
329+
fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
330+
*desc, array_1 + i * input_shift, result + i * result_shift * 2);
308331
}
309332

310333
sycl::event::wait(fft_events);
311334

312-
if (real) { // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
335+
if (real) // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
336+
{
313337
return event_ref;
314338
}
315339

@@ -325,19 +349,22 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
325349
size_t j = global_id[1];
326350
{
327351
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * (i + 1) - (j + 1)) =
328-
std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
352+
std::conj(
353+
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
329354
}
330355
}
331356
};
332357

333358
auto kernel_func = [&](sycl::handler& cgh) {
334-
cgh.parallel_for<class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
359+
cgh.parallel_for<
360+
class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
335361
gws, kernel_parallel_for_func);
336362
};
337363

338364
event = queue.submit(kernel_func);
339365

340-
if (inverse) {
366+
if (inverse)
367+
{
341368
event.wait();
342369
event = oneapi::mkl::vm::conj(queue,
343370
result_size,
@@ -346,7 +373,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
346373
}
347374

348375
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
349-
350376
return DPCTLEvent_Copy(event_ref);
351377
}
352378

@@ -375,43 +401,35 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
375401
const size_t input_size =
376402
std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<shape_elem_type>());
377403

378-
size_t dim = input_shape[shape_size - 1];
379-
380404
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
381405
std::is_same<_DataType_output, std::complex<double>>::value)
382406
{
383407
if constexpr (std::is_same<_DataType_input, std::complex<double>>::value &&
384408
std::is_same<_DataType_output, std::complex<double>>::value)
385409
{
386-
desc_dp_cmplx_t desc(dim);
387410
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
388-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
411+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
389412
}
390413
/* complex-to-complex, single precision */
391414
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
392415
std::is_same<_DataType_output, std::complex<float>>::value)
393416
{
394-
desc_sp_cmplx_t desc(dim);
395417
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
396-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
418+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
397419
}
398420
/* real-to-complex, double precision */
399421
else if constexpr (std::is_same<_DataType_input, double>::value &&
400422
std::is_same<_DataType_output, std::complex<double>>::value)
401423
{
402-
desc_dp_real_t desc(dim);
403-
404424
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
405-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
425+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0);
406426
}
407427
/* real-to-complex, single precision */
408428
else if constexpr (std::is_same<_DataType_input, float>::value &&
409429
std::is_same<_DataType_output, std::complex<float>>::value)
410430
{
411-
desc_sp_real_t desc(dim); // try: 2 * result_size
412-
413431
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
414-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
432+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0);
415433
}
416434
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
417435
std::is_same<_DataType_input, int64_t>::value)
@@ -428,9 +446,8 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
428446
DPCTLEvent_WaitAndThrow(event_ref);
429447
DPCTLEvent_Delete(event_ref);
430448

431-
desc_dp_real_t desc(dim);
432449
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
433-
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
450+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0);
434451

435452
DPCTLEvent_WaitAndThrow(event_ref);
436453
DPCTLEvent_Delete(event_ref);
@@ -537,26 +554,21 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
537554
const size_t input_size =
538555
std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<shape_elem_type>());
539556

540-
size_t dim = input_shape[shape_size - 1];
541-
542557
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
543558
std::is_same<_DataType_output, std::complex<double>>::value)
544559
{
545560
if constexpr (std::is_same<_DataType_input, double>::value &&
546-
std::is_same<_DataType_output, std::complex<double>>::value)
561+
std::is_same<_DataType_output, std::complex<double>>::value)
547562
{
548-
desc_dp_real_t desc(dim);
549-
550563
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
551-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
564+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1);
552565
}
553566
/* real-to-complex, single precision */
554567
else if constexpr (std::is_same<_DataType_input, float>::value &&
555568
std::is_same<_DataType_output, std::complex<float>>::value)
556569
{
557-
desc_sp_real_t desc(dim); // try: 2 * result_size
558570
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
559-
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
571+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1);
560572
}
561573
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
562574
std::is_same<_DataType_input, int64_t>::value)
@@ -573,9 +585,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
573585
DPCTLEvent_WaitAndThrow(event_ref);
574586
DPCTLEvent_Delete(event_ref);
575587

576-
desc_dp_real_t desc(dim);
577588
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
578-
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
589+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1);
579590

580591
DPCTLEvent_WaitAndThrow(event_ref);
581592
DPCTLEvent_Delete(event_ref);

0 commit comments

Comments
 (0)