@@ -178,7 +178,6 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
178
178
const size_t shape_size,
179
179
const size_t input_size,
180
180
const size_t result_size,
181
- _Descriptor_type& desc,
182
181
size_t inverse,
183
182
const size_t norm)
184
183
{
@@ -187,14 +186,15 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
187
186
(void )input_size;
188
187
(void )result_size;
189
188
190
- if (!shape_size) {
189
+ if (!shape_size)
190
+ {
191
191
return ;
192
192
}
193
193
194
194
sycl::queue queue = *(reinterpret_cast <sycl::queue*>(q_ref));
195
195
196
- _DataType_input* array_1 = static_cast <_DataType_input *>(const_cast <void *>(array1_in));
197
- _DataType_output* result = static_cast <_DataType_output *>(result_out);
196
+ _DataType_input* array_1 = static_cast <_DataType_input*>(const_cast <void *>(array1_in));
197
+ _DataType_output* result = static_cast <_DataType_output*>(result_out);
198
198
199
199
const size_t n_iter =
200
200
std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
@@ -204,39 +204,49 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
204
204
double backward_scale = 1 .;
205
205
double forward_scale = 1 .;
206
206
207
- if (norm == 0 ) { // norm = "backward"
207
+ if (norm == 0 ) // norm = "backward"
208
+ {
208
209
backward_scale = 1 . / shift;
209
- } else if (norm == 1 ) { // norm = "forward"
210
+ }
211
+ else if (norm == 1 ) // norm = "forward"
212
+ {
210
213
forward_scale = 1 . / shift;
211
- } else { // norm = "ortho"
212
- if (inverse) {
214
+ }
215
+ else // norm = "ortho"
216
+ {
217
+ if (inverse)
218
+ {
213
219
backward_scale = 1 . / sqrt (shift);
214
- } else {
220
+ }
221
+ else
222
+ {
215
223
forward_scale = 1 . / sqrt (shift);
216
224
}
217
225
}
218
226
219
- desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
220
- desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
221
- // enum value from math library C interface
222
- // instead of mkl_dft::config_value::NOT_INPLACE
223
- desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
224
- desc.commit (queue);
225
-
226
- std::vector<sycl::event> fft_events;
227
- fft_events.reserve (n_iter);
228
-
229
- for (size_t i = 0 ; i < n_iter; ++i) {
230
- if (inverse) {
231
- fft_events.push_back (mkl_dft::compute_backward (desc, array_1 + i * shift, result + i * shift));
232
- } else {
233
- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift, result + i * shift));
227
+ std::vector<sycl::event> fft_events (n_iter);
228
+
229
+ for (size_t i = 0 ; i < n_iter; ++i)
230
+ {
231
+ std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(shift);
232
+ desc->set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
233
+ desc->set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
234
+ desc->set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
235
+ desc->commit (queue);
236
+
237
+ if (inverse)
238
+ {
239
+ fft_events[i] = mkl_dft::compute_backward<_Descriptor_type, _DataType_input, _DataType_output>(
240
+ *desc, array_1 + i * shift, result + i * shift);
241
+ }
242
+ else
243
+ {
244
+ fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
245
+ *desc, array_1 + i * shift, result + i * shift);
234
246
}
235
247
}
236
248
237
249
sycl::event::wait (fft_events);
238
-
239
- return ;
240
250
}
241
251
242
252
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
@@ -251,7 +261,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
251
261
const size_t shape_size,
252
262
const size_t input_size,
253
263
const size_t result_size,
254
- _Descriptor_type& desc,
255
264
size_t inverse,
256
265
const size_t norm,
257
266
const size_t real)
@@ -260,14 +269,15 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
260
269
(void )input_size;
261
270
262
271
DPCTLSyclEventRef event_ref = nullptr ;
263
- if (!shape_size) {
272
+ if (!shape_size)
273
+ {
264
274
return event_ref;
265
275
}
266
276
267
277
sycl::queue queue = *(reinterpret_cast <sycl::queue*>(q_ref));
268
278
269
- _DataType_input* array_1 = static_cast <_DataType_input *>(const_cast <void *>(array1_in));
270
- _DataType_output* result = static_cast <_DataType_output *>(result_out);
279
+ _DataType_input* array_1 = static_cast <_DataType_input*>(const_cast <void *>(array1_in));
280
+ _DataType_output* result = static_cast <_DataType_output*>(result_out);
271
281
272
282
const size_t n_iter =
273
283
std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
@@ -278,38 +288,52 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
278
288
double backward_scale = 1 .;
279
289
double forward_scale = 1 .;
280
290
281
- if (norm == 0 ) { // norm = "backward"
282
- if (inverse) {
291
+ if (norm == 0 ) // norm = "backward"
292
+ {
293
+ if (inverse)
294
+ {
283
295
forward_scale = 1 . / result_shift;
284
- } else {
296
+ }
297
+ else
298
+ {
285
299
backward_scale = 1 . / result_shift;
286
300
}
287
- } else if (norm == 1 ) { // norm = "forward"
288
- if (inverse) {
301
+ }
302
+ else if (norm == 1 ) // norm = "forward"
303
+ {
304
+ if (inverse)
305
+ {
289
306
backward_scale = 1 . / result_shift;
290
- } else {
307
+ }
308
+ else
309
+ {
291
310
forward_scale = 1 . / result_shift;
292
311
}
293
- } else { // norm = "ortho"
312
+ }
313
+ else // norm = "ortho"
314
+ {
294
315
forward_scale = 1 . / sqrt (result_shift);
295
316
}
296
317
297
- desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
298
- desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
299
- desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
318
+ std::vector<sycl::event> fft_events (n_iter);
300
319
301
- desc.commit (queue);
302
-
303
- std::vector<sycl::event> fft_events;
304
- fft_events.reserve (n_iter);
305
-
306
- for (size_t i = 0 ; i < n_iter; ++i) {
307
- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
320
+ for (size_t i = 0 ; i < n_iter; ++i)
321
+ {
322
+ std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(input_shift);
323
+ desc->set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
324
+ desc->set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
325
+ desc->set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
326
+ desc->commit (queue);
327
+
328
+ // real result_size = 2 * result_size, because real type of "result" is twice wider than '_DataType_output'
329
+ fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
330
+ *desc, array_1 + i * input_shift, result + i * result_shift * 2 );
308
331
}
309
332
310
333
sycl::event::wait (fft_events);
311
334
312
- if (real) { // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
335
+ if (real) // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
336
+ {
313
337
return event_ref;
314
338
}
315
339
@@ -325,19 +349,22 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
325
349
size_t j = global_id[1 ];
326
350
{
327
351
*(reinterpret_cast <std::complex<_DataType_output>*>(result) + result_shift * (i + 1 ) - (j + 1 )) =
328
- std::conj (*(reinterpret_cast <std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1 )));
352
+ std::conj (
353
+ *(reinterpret_cast <std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1 )));
329
354
}
330
355
}
331
356
};
332
357
333
358
auto kernel_func = [&](sycl::handler& cgh) {
334
- cgh.parallel_for <class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel <_DataType_input, _DataType_output, _Descriptor_type>>(
359
+ cgh.parallel_for <
360
+ class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel <_DataType_input, _DataType_output, _Descriptor_type>>(
335
361
gws, kernel_parallel_for_func);
336
362
};
337
363
338
364
event = queue.submit (kernel_func);
339
365
340
- if (inverse) {
366
+ if (inverse)
367
+ {
341
368
event.wait ();
342
369
event = oneapi::mkl::vm::conj (queue,
343
370
result_size,
@@ -346,7 +373,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
346
373
}
347
374
348
375
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
349
-
350
376
return DPCTLEvent_Copy (event_ref);
351
377
}
352
378
@@ -375,43 +401,35 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
375
401
const size_t input_size =
376
402
std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<shape_elem_type>());
377
403
378
- size_t dim = input_shape[shape_size - 1 ];
379
-
380
404
if constexpr (std::is_same<_DataType_output, std::complex<float >>::value ||
381
405
std::is_same<_DataType_output, std::complex<double >>::value)
382
406
{
383
407
if constexpr (std::is_same<_DataType_input, std::complex<double >>::value &&
384
408
std::is_same<_DataType_output, std::complex<double >>::value)
385
409
{
386
- desc_dp_cmplx_t desc (dim);
387
410
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
388
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
411
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
389
412
}
390
413
/* complex-to-complex, single precision */
391
414
else if constexpr (std::is_same<_DataType_input, std::complex<float >>::value &&
392
415
std::is_same<_DataType_output, std::complex<float >>::value)
393
416
{
394
- desc_sp_cmplx_t desc (dim);
395
417
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
396
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
418
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
397
419
}
398
420
/* real-to-complex, double precision */
399
421
else if constexpr (std::is_same<_DataType_input, double >::value &&
400
422
std::is_same<_DataType_output, std::complex<double >>::value)
401
423
{
402
- desc_dp_real_t desc (dim);
403
-
404
424
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
405
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0 );
425
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0 );
406
426
}
407
427
/* real-to-complex, single precision */
408
428
else if constexpr (std::is_same<_DataType_input, float >::value &&
409
429
std::is_same<_DataType_output, std::complex<float >>::value)
410
430
{
411
- desc_sp_real_t desc (dim); // try: 2 * result_size
412
-
413
431
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
414
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0 );
432
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0 );
415
433
}
416
434
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
417
435
std::is_same<_DataType_input, int64_t >::value)
@@ -428,9 +446,8 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
428
446
DPCTLEvent_WaitAndThrow (event_ref);
429
447
DPCTLEvent_Delete (event_ref);
430
448
431
- desc_dp_real_t desc (dim);
432
449
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
433
- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0 );
450
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0 );
434
451
435
452
DPCTLEvent_WaitAndThrow (event_ref);
436
453
DPCTLEvent_Delete (event_ref);
@@ -537,26 +554,21 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
537
554
const size_t input_size =
538
555
std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<shape_elem_type>());
539
556
540
- size_t dim = input_shape[shape_size - 1 ];
541
-
542
557
if constexpr (std::is_same<_DataType_output, std::complex<float >>::value ||
543
558
std::is_same<_DataType_output, std::complex<double >>::value)
544
559
{
545
560
if constexpr (std::is_same<_DataType_input, double >::value &&
546
- std::is_same<_DataType_output, std::complex<double >>::value)
561
+ std::is_same<_DataType_output, std::complex<double >>::value)
547
562
{
548
- desc_dp_real_t desc (dim);
549
-
550
563
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
551
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1 );
564
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1 );
552
565
}
553
566
/* real-to-complex, single precision */
554
567
else if constexpr (std::is_same<_DataType_input, float >::value &&
555
568
std::is_same<_DataType_output, std::complex<float >>::value)
556
569
{
557
- desc_sp_real_t desc (dim); // try: 2 * result_size
558
570
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
559
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1 );
571
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1 );
560
572
}
561
573
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
562
574
std::is_same<_DataType_input, int64_t >::value)
@@ -573,9 +585,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
573
585
DPCTLEvent_WaitAndThrow (event_ref);
574
586
DPCTLEvent_Delete (event_ref);
575
587
576
- desc_dp_real_t desc (dim);
577
588
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
578
- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1 );
589
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1 );
579
590
580
591
DPCTLEvent_WaitAndThrow (event_ref);
581
592
DPCTLEvent_Delete (event_ref);
0 commit comments