3636# pylint: disable=protected-access
3737# pylint: disable=no-name-in-module
3838
39+ from collections .abc import Sequence
40+
3941import dpctl
4042import dpctl .tensor ._tensor_impl as ti
4143import dpctl .utils as dpu
4244import numpy
43- from dpctl .tensor ._numpy_helper import normalize_axis_index
45+ from dpctl .tensor ._numpy_helper import (
46+ normalize_axis_index ,
47+ normalize_axis_tuple ,
48+ )
4449from dpctl .utils import ExecutionPlacementError
4550
4651import dpnp
5459
5560__all__ = [
5661 "dpnp_fft" ,
62+ "dpnp_fftn" ,
5763]
5864
5965
@@ -159,6 +165,37 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides):
159165 return result
160166
161167
168+ # TODO: c2r keyword is place holder for irfftn
169+ def _cook_nd_args (a , s = None , axes = None , c2r = False ):
170+ if s is None :
171+ shapeless = True
172+ if axes is None :
173+ s = list (a .shape )
174+ else :
175+ s = numpy .take (a .shape , axes )
176+ else :
177+ shapeless = False
178+
179+ for s_i in s :
180+ if s_i is not None and s_i < 1 and s_i != - 1 :
181+ raise ValueError (
182+ f"Invalid number of FFT data points ({ s_i } ) specified."
183+ )
184+
185+ if axes is None :
186+ axes = list (range (- len (s ), 0 ))
187+
188+ if len (s ) != len (axes ):
189+ raise ValueError ("Shape and axes have different lengths." )
190+
191+ s = list (s )
192+ if c2r and shapeless :
193+ s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
194+ # use the whole input array along axis `i` if `s[i] == -1`
195+ s = [a .shape [_a ] if _s == - 1 else _s for _s , _a in zip (s , axes )]
196+ return s , axes
197+
198+
162199def _copy_array (x , complex_input ):
163200 """
164201 Creating a C-contiguous copy of input array if input array has a negative
@@ -204,6 +241,80 @@ def _copy_array(x, complex_input):
204241 return x , copy_flag
205242
206243
244+ def _extract_axes_chunk (a , s , chunk_size = 3 ):
245+ """
246+ Classify the first input into a list of lists with each list containing
247+ only unique values in reverse order and its length is at most `chunk_size`.
248+ The second input is also classified into a list of lists with each list
249+ containing the corresponding values of the first input.
250+
251+ Parameters
252+ ----------
253+ a : list or tuple of ints
254+ The first input.
255+ s : list or tuple of ints
256+ The second input.
257+ chunk_size : int
258+ Maximum number of elements in each chunk.
259+
260+ Return
261+ ------
262+ out : a tuple of two lists
263+ The first element of output is a list of lists with each list
264+ containing only unique values in revere order and its length is
265+ at most `chunk_size`.
266+ The second element of output is a list of lists with each list
267+ containing the corresponding values of the first input.
268+
269+ Examples
270+ --------
271+ >>> axes = (0, 1, 2, 3, 4)
272+ >>> shape = (7, 8, 10, 9, 5)
273+ >>> _extract_axes_chunk(axes, shape, chunk_size=3)
274+ ([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])
275+
276+ >>> axes = (1, 0, 3, 2, 4, 4)
277+ >>> shape = (7, 8, 10, 5, 7, 6)
278+ >>> _extract_axes_chunk(axes, shape, chunk_size=3)
279+ ([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])
280+
281+ """
282+
283+ a_chunks = []
284+ a_current_chunk = []
285+ seen_elements = set ()
286+
287+ s_chunks = []
288+ s_current_chunk = []
289+
290+ for a_elem , s_elem in zip (a , s ):
291+ if a_elem in seen_elements :
292+ # If element is already seen, start a new chunk
293+ a_chunks .append (a_current_chunk [::- 1 ])
294+ s_chunks .append (s_current_chunk [::- 1 ])
295+ a_current_chunk = [a_elem ]
296+ s_current_chunk = [s_elem ]
297+ seen_elements = {a_elem }
298+ else :
299+ a_current_chunk .append (a_elem )
300+ s_current_chunk .append (s_elem )
301+ seen_elements .add (a_elem )
302+
303+ if len (a_current_chunk ) == chunk_size :
304+ a_chunks .append (a_current_chunk [::- 1 ])
305+ s_chunks .append (s_current_chunk [::- 1 ])
306+ a_current_chunk = []
307+ s_current_chunk = []
308+ seen_elements = set ()
309+
310+ # Add the last chunk if it's not empty
311+ if a_current_chunk :
312+ a_chunks .append (a_current_chunk [::- 1 ])
313+ s_chunks .append (s_current_chunk [::- 1 ])
314+
315+ return a_chunks [::- 1 ], s_chunks [::- 1 ]
316+
317+
207318def _fft (a , norm , out , forward , in_place , c2c , axes = None ):
208319 """Calculates FFT of the input array along the specified axes."""
209320
@@ -238,7 +349,11 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None):
238349
239350def _scale_result (res , a_shape , norm , forward , index ):
240351 """Scale the result of the FFT according to `norm`."""
241- scale = numpy .prod (a_shape [index :], dtype = res .real .dtype )
352+ if res .dtype in [dpnp .float32 , dpnp .complex64 ]:
353+ dtype = dpnp .float32
354+ else :
355+ dtype = dpnp .float64
356+ scale = numpy .prod (a_shape [index :], dtype = dtype )
242357 norm_factor = 1
243358 if norm == "ortho" :
244359 norm_factor = numpy .sqrt (scale )
@@ -293,7 +408,7 @@ def _truncate_or_pad(a, shape, axes):
293408 return a
294409
295410
296- def _validate_out_keyword (a , out , axis , c2r , r2c ):
411+ def _validate_out_keyword (a , out , s , axes , c2r , r2c ):
297412 """Validate out keyword argument."""
298413 if out is not None :
299414 dpnp .check_supported_arrays_type (out )
@@ -305,16 +420,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
305420 "Input and output allocation queues are not compatible"
306421 )
307422
308- # validate out shape
309- expected_shape = a .shape
423+ # validate out shape against the final shape,
424+ # intermediate shapes may vary
425+ expected_shape = list (a .shape )
426+ for s_i , axis in zip (s [::- 1 ], axes [::- 1 ]):
427+ expected_shape [axis ] = s_i
310428 if r2c :
311- expected_shape = list (a .shape )
312- expected_shape [axis ] = a .shape [axis ] // 2 + 1
313- expected_shape = tuple (expected_shape )
314- if out .shape != expected_shape :
429+ expected_shape [axes [- 1 ]] = expected_shape [axes [- 1 ]] // 2 + 1
430+
431+ if out .shape != tuple (expected_shape ):
315432 raise ValueError (
316433 "output array has incorrect shape, expected "
317- f"{ expected_shape } , got { out .shape } ."
434+ f"{ tuple ( expected_shape ) } , got { out .shape } ."
318435 )
319436
320437 # validate out data type
@@ -328,9 +445,33 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
328445 raise TypeError ("output array should have complex data type." )
329446
330447
448+ def _validate_s_axes (a , s , axes ):
449+ if axes is not None :
450+ # validate axes is a sequence and
451+ # each axis is an integer within the range
452+ normalize_axis_tuple (list (set (axes )), a .ndim , "axes" )
453+
454+ if s is not None :
455+ raise_error = False
456+ if isinstance (s , Sequence ):
457+ if any (not isinstance (s_i , int ) for s_i in s ):
458+ raise_error = True
459+ else :
460+ raise_error = True
461+
462+ if raise_error :
463+ raise TypeError ("`s` must be `None` or a sequence of integers." )
464+
465+ if axes is None :
466+ raise ValueError (
467+ "`axes` should not be `None` if `s` is not `None`."
468+ )
469+
470+
331471def dpnp_fft (a , forward , real , n = None , axis = - 1 , norm = None , out = None ):
332472 """Calculates 1-D FFT of the input array along axis"""
333473
474+ _check_norm (norm )
334475 a_ndim = a .ndim
335476 if a_ndim == 0 :
336477 raise ValueError ("Input array must be at least 1D" )
@@ -354,7 +495,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
354495
355496 _check_norm (norm )
356497 a = _truncate_or_pad (a , n , axis )
357- _validate_out_keyword (a , out , axis , c2r , r2c )
498+ _validate_out_keyword (a , out , ( n ,), ( axis ,) , c2r , r2c )
358499 # if input array is copied, in-place FFT can be used
359500 a , in_place = _copy_array (a , c2c or c2r )
360501 if not in_place and out is not None :
@@ -377,3 +518,71 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
377518 c2c = c2c ,
378519 axes = axis ,
379520 )
521+
522+
523+ def dpnp_fftn (a , forward , s = None , axes = None , norm = None , out = None ):
524+ """Calculates N-D FFT of the input array along axes"""
525+
526+ _check_norm (norm )
527+ if isinstance (axes , (list , tuple )) and len (axes ) == 0 :
528+ return a
529+
530+ if a .ndim == 0 :
531+ if axes is not None :
532+ raise IndexError (
533+ "Input array is 0-dimensional while axis is not `None`."
534+ )
535+
536+ return a
537+
538+ _validate_s_axes (a , s , axes )
539+ s , axes = _cook_nd_args (a , s , axes )
540+ # TODO: False and False are place holder for future development of
541+ # rfft2, irfft2, rfftn, irfftn
542+ _validate_out_keyword (a , out , s , axes , False , False )
543+ # TODO: True is place holder for future development of
544+ # rfft2, irfft2, rfftn, irfftn
545+ a , in_place = _copy_array (a , True )
546+
547+ len_axes = len (axes )
548+ # OneMKL supports up to 3-dimensional FFT on GPU
549+ # repeated axis in OneMKL FFT is not allowed
550+ if len_axes > 3 or len (set (axes )) < len_axes :
551+ axes_chunk , shape_chunk = _extract_axes_chunk (axes , s , chunk_size = 3 )
552+ for s_chunk , a_chunk in zip (shape_chunk , axes_chunk ):
553+ a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
554+ if out is not None and out .shape == a .shape :
555+ tmp_out = out
556+ else :
557+ tmp_out = None
558+ a = _fft (
559+ a ,
560+ norm = norm ,
561+ out = tmp_out ,
562+ forward = forward ,
563+ in_place = in_place ,
564+ # TODO: c2c=True is place holder for future development of
565+ # rfft2, irfft2, rfftn, irfftn
566+ c2c = True ,
567+ axes = a_chunk ,
568+ )
569+ return a
570+
571+ a = _truncate_or_pad (a , s , axes )
572+ if a .size == 0 :
573+ return dpnp .get_result_array (a , out = out , casting = "same_kind" )
574+ if a .ndim == len_axes :
575+ # non-batch FFT
576+ axes = None
577+
578+ return _fft (
579+ a ,
580+ norm = norm ,
581+ out = out ,
582+ forward = forward ,
583+ in_place = in_place ,
584+ # TODO: c2c=True is place holder for future development of
585+ # rfft2, irfft2, rfftn, irfftn
586+ c2c = True ,
587+ axes = axes ,
588+ )
0 commit comments