|
43 | 43 | Base.ndims(T::FourierTransform) = length(T.modes) |
44 | 44 |
|
45 | 45 | function transform(ft::FourierTransform, x::AbstractArray) |
46 | | - res = Lux.Utils.eltype(x) <: Complex ? fft(x, 1:ndims(ft)) : rfft(x, 1:ndims(ft)) |
| 46 | + complex_data = Lux.Utils.eltype(x) <: Complex |
| 47 | + res = complex_data ? fft(x, 1:ndims(ft)) : rfft(x, 1:ndims(ft)) |
47 | 48 | if ft.shift && ndims(ft) > 1 |
48 | | - res = fftshift(res, 1:ndims(ft)) |
| 49 | + res = fftshift(res, (1 + !complex_data):ndims(ft)) |
49 | 50 | end |
50 | 51 | return res |
51 | 52 | end |
52 | 53 |
|
53 | 54 | function low_pass(ft::FourierTransform, x_fft::AbstractArray) |
54 | | - return view(x_fft,(map(d -> 1:d, ft.modes)...),:,:) |
| 55 | + return view(x_fft, (map(d -> 1:d, ft.modes)...), :, :) |
55 | 56 | end |
56 | 57 |
|
57 | 58 | truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft) |
58 | 59 |
|
59 | 60 | function inverse( |
60 | 61 | ft::FourierTransform, x_fft::AbstractArray{T,N}, x::AbstractArray{T2,N} |
61 | 62 | ) where {T,T2,N} |
| 63 | + complex_data = Lux.Utils.eltype(x) <: Complex |
| 64 | + |
62 | 65 | if ft.shift && ndims(ft) > 1 |
63 | | - x_fft = fftshift(x_fft, 1:ndims(ft)) |
| 66 | + x_fft = fftshift(x_fft, (1 + !complex_data):ndims(ft)) |
64 | 67 | end |
65 | 68 |
|
66 | | - if Lux.Utils.eltype(x) <: Complex |
| 69 | + if complex_data |
67 | 70 | return ifft(x_fft, 1:ndims(ft)) |
68 | 71 | else |
69 | 72 | return real(irfft(x_fft, size(x, 1), 1:ndims(ft))) |
|
0 commit comments