Skip to content

Commit 680bc91

Browse files
authored
fix: fft shift (#67)
1 parent ae9f586 commit 680bc91

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NeuralOperators"
22
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
33
authors = ["Avik Pal <[email protected]>"]
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/transform.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,30 @@ end
4343
Base.ndims(T::FourierTransform) = length(T.modes)
4444

4545
function transform(ft::FourierTransform, x::AbstractArray)
46-
res = Lux.Utils.eltype(x) <: Complex ? fft(x, 1:ndims(ft)) : rfft(x, 1:ndims(ft))
46+
complex_data = Lux.Utils.eltype(x) <: Complex
47+
res = complex_data ? fft(x, 1:ndims(ft)) : rfft(x, 1:ndims(ft))
4748
if ft.shift && ndims(ft) > 1
48-
res = fftshift(res, 1:ndims(ft))
49+
res = fftshift(res, (1 + !complex_data):ndims(ft))
4950
end
5051
return res
5152
end
5253

5354
function low_pass(ft::FourierTransform, x_fft::AbstractArray)
54-
return view(x_fft,(map(d -> 1:d, ft.modes)...),:,:)
55+
return view(x_fft, (map(d -> 1:d, ft.modes)...), :, :)
5556
end
5657

5758
truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft)
5859

5960
function inverse(
6061
ft::FourierTransform, x_fft::AbstractArray{T,N}, x::AbstractArray{T2,N}
6162
) where {T,T2,N}
63+
complex_data = Lux.Utils.eltype(x) <: Complex
64+
6265
if ft.shift && ndims(ft) > 1
63-
x_fft = fftshift(x_fft, 1:ndims(ft))
66+
x_fft = fftshift(x_fft, (1 + !complex_data):ndims(ft))
6467
end
6568

66-
if Lux.Utils.eltype(x) <: Complex
69+
if complex_data
6770
return ifft(x_fft, 1:ndims(ft))
6871
else
6972
return real(irfft(x_fft, size(x, 1), 1:ndims(ft)))

0 commit comments

Comments
 (0)