@@ -414,6 +414,9 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
414414 const size_t norm,
415415 const DPCTLEventVectorRef dep_event_vec_ref)
416416{
417+ static_assert (sycl::detail::is_complex<_DataType_output>::value,
418+ " Output data type must be a complex type." );
419+
417420 DPCTLSyclEventRef event_ref = nullptr ;
418421
419422 if (!shape_size || !array1_in || !result_out) {
@@ -476,8 +479,10 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
476479 else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
477480 std::is_same<_DataType_input, int64_t >::value)
478481 {
479- double *array1_copy = reinterpret_cast <double *>(
480- dpnp_memory_alloc_c (q_ref, input_size * sizeof (double )));
482+ using CastType = typename _DataType_output::value_type;
483+
484+ CastType *array1_copy = reinterpret_cast <CastType *>(
485+ dpnp_memory_alloc_c (q_ref, input_size * sizeof (CastType)));
481486
482487 shape_elem_type *copy_strides = reinterpret_cast <shape_elem_type *>(
483488 dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
@@ -486,15 +491,17 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
486491 dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
487492 *copy_shape = input_size;
488493 shape_elem_type copy_shape_size = 1 ;
489- event_ref = dpnp_copyto_c<_DataType_input, double >(
494+ event_ref = dpnp_copyto_c<_DataType_input, CastType >(
490495 q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
491496 copy_strides, array1_in, input_size, copy_shape_size,
492497 copy_shape, copy_strides, NULL , dep_event_vec_ref);
493498 DPCTLEvent_WaitAndThrow (event_ref);
494499 DPCTLEvent_Delete (event_ref);
495500
496- event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double ,
497- desc_dp_real_t >(
501+ event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
502+ CastType, CastType,
503+ std::conditional_t <std::is_same<CastType, double >::value,
504+ desc_dp_real_t , desc_sp_real_t >>(
498505 q_ref, array1_copy, result_out, input_shape, result_shape,
499506 shape_size, input_size, result_size, inverse, norm, 0 );
500507
@@ -577,6 +584,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
577584 const size_t norm,
578585 const DPCTLEventVectorRef dep_event_vec_ref)
579586{
587+ static_assert (sycl::detail::is_complex<_DataType_output>::value,
588+ " Output data type must be a complex type." );
580589 DPCTLSyclEventRef event_ref = nullptr ;
581590
582591 if (!shape_size || !array1_in || !result_out) {
@@ -617,8 +626,10 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
617626 else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
618627 std::is_same<_DataType_input, int64_t >::value)
619628 {
620- double *array1_copy = reinterpret_cast <double *>(
621- dpnp_memory_alloc_c (q_ref, input_size * sizeof (double )));
629+ using CastType = typename _DataType_output::value_type;
630+
631+ CastType *array1_copy = reinterpret_cast <CastType *>(
632+ dpnp_memory_alloc_c (q_ref, input_size * sizeof (CastType)));
622633
623634 shape_elem_type *copy_strides = reinterpret_cast <shape_elem_type *>(
624635 dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
@@ -627,15 +638,17 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
627638 dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
628639 *copy_shape = input_size;
629640 shape_elem_type copy_shape_size = 1 ;
630- event_ref = dpnp_copyto_c<_DataType_input, double >(
641+ event_ref = dpnp_copyto_c<_DataType_input, CastType >(
631642 q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
632643 copy_strides, array1_in, input_size, copy_shape_size,
633644 copy_shape, copy_strides, NULL , dep_event_vec_ref);
634645 DPCTLEvent_WaitAndThrow (event_ref);
635646 DPCTLEvent_Delete (event_ref);
636647
637- event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double ,
638- desc_dp_real_t >(
648+ event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
649+ CastType, CastType,
650+ std::conditional_t <std::is_same<CastType, double >::value,
651+ desc_dp_real_t , desc_sp_real_t >>(
639652 q_ref, array1_copy, result_out, input_shape, result_shape,
640653 shape_size, input_size, result_size, inverse, norm, 1 );
641654
@@ -721,9 +734,11 @@ void func_map_init_fft_func(func_map_t &fmap)
721734 dpnp_fft_fft_default_c<std::complex <double >, std::complex <double >>};
722735
723736 fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_INT][eft_INT] = {
724- eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t , std::complex <double >>};
737+ eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t , std::complex <double >>,
738+ eft_C64, (void *)dpnp_fft_fft_ext_c<int32_t , std::complex <float >>};
725739 fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_LNG][eft_LNG] = {
726- eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t , std::complex <double >>};
740+ eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t , std::complex <double >>,
741+ eft_C64, (void *)dpnp_fft_fft_ext_c<int64_t , std::complex <float >>};
727742 fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_FLT][eft_FLT] = {
728743 eft_C64, (void *)dpnp_fft_fft_ext_c<float , std::complex <float >>};
729744 fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_DBL][eft_DBL] = {
@@ -748,9 +763,11 @@ void func_map_init_fft_func(func_map_t &fmap)
748763 (void *)dpnp_fft_rfft_default_c<double , std::complex <double >>};
749764
750765 fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_INT][eft_INT] = {
751- eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t , std::complex <double >>};
766+ eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t , std::complex <double >>,
767+ eft_C64, (void *)dpnp_fft_rfft_ext_c<int32_t , std::complex <float >>};
752768 fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_LNG][eft_LNG] = {
753- eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t , std::complex <double >>};
769+ eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t , std::complex <double >>,
770+ eft_C64, (void *)dpnp_fft_rfft_ext_c<int64_t , std::complex <float >>};
754771 fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_FLT][eft_FLT] = {
755772 eft_C64, (void *)dpnp_fft_rfft_ext_c<float , std::complex <float >>};
756773 fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_DBL][eft_DBL] = {
0 commit comments