Skip to content

Commit 457ef12

Browse files
authored
Merge pull request #118 from Avasam/patch-1
complete `scipy.fft.dct`, and add relevant type aliases
2 parents 228ca1b + eba7dec commit 457ef12

File tree

2 files changed

+46
-39
lines changed

2 files changed

+46
-39
lines changed

scipy-stubs/_typing.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ CorrelateMode: TypeAlias = Literal["valid", "same", "full"]
9393
# scipy literals
9494
NanPolicy: TypeAlias = Literal["raise", "propagate", "omit"]
9595
Alternative: TypeAlias = Literal["two-sided", "less", "greater"]
96+
DCTType: TypeAlias = Literal[1, 2, 3, 4]
97+
NormalizationMode: TypeAlias = Literal["backward", "ortho", "forward"]
9698

9799
# used in `scipy.linalg.blas` and `scipy.linalg.lapack`
98100
@type_check_only

scipy-stubs/fft/_realtransforms.pyi

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,88 @@
1-
from scipy._typing import Untyped
1+
import numpy.typing as npt
2+
from numpy._typing import _ArrayLikeNumber_co
3+
from scipy._typing import DCTType, NormalizationMode, Untyped
24

35
def dctn(
46
x: Untyped,
5-
type: int = 2,
7+
type: DCTType = 2,
68
s: Untyped | None = None,
79
axes: Untyped | None = None,
8-
norm: Untyped | None = None,
10+
norm: NormalizationMode | None = None,
911
overwrite_x: bool = False,
10-
workers: Untyped | None = None,
12+
workers: int | None = None,
1113
*,
12-
orthogonalize: Untyped | None = None,
14+
orthogonalize: bool | None = None,
1315
) -> Untyped: ...
1416
def idctn(
1517
x: Untyped,
16-
type: int = 2,
18+
type: DCTType = 2,
1719
s: Untyped | None = None,
1820
axes: Untyped | None = None,
19-
norm: Untyped | None = None,
21+
norm: NormalizationMode | None = None,
2022
overwrite_x: bool = False,
21-
workers: Untyped | None = None,
22-
orthogonalize: Untyped | None = None,
23+
workers: int | None = None,
24+
orthogonalize: bool | None = None,
2325
) -> Untyped: ...
2426
def dstn(
2527
x: Untyped,
26-
type: int = 2,
28+
type: DCTType = 2,
2729
s: Untyped | None = None,
2830
axes: Untyped | None = None,
29-
norm: Untyped | None = None,
31+
norm: NormalizationMode | None = None,
3032
overwrite_x: bool = False,
31-
workers: Untyped | None = None,
32-
orthogonalize: Untyped | None = None,
33+
workers: int | None = None,
34+
orthogonalize: bool | None = None,
3335
) -> Untyped: ...
3436
def idstn(
3537
x: Untyped,
36-
type: int = 2,
38+
type: DCTType = 2,
3739
s: Untyped | None = None,
3840
axes: Untyped | None = None,
39-
norm: Untyped | None = None,
41+
norm: NormalizationMode | None = None,
4042
overwrite_x: bool = False,
41-
workers: Untyped | None = None,
42-
orthogonalize: Untyped | None = None,
43+
workers: int | None = None,
44+
orthogonalize: bool | None = None,
4345
) -> Untyped: ...
46+
47+
# We could use overloads based on the type of x to get more accurate return type
48+
# see https://github.com/jorenham/scipy-stubs/pull/118#discussion_r1807957439
4449
def dct(
45-
x: Untyped,
46-
type: int = 2,
47-
n: Untyped | None = None,
50+
x: _ArrayLikeNumber_co,
51+
type: DCTType = 2,
52+
n: int | None = None,
4853
axis: int = -1,
49-
norm: Untyped | None = None,
54+
norm: NormalizationMode | None = None,
5055
overwrite_x: bool = False,
51-
workers: Untyped | None = None,
52-
orthogonalize: Untyped | None = None,
53-
) -> Untyped: ...
56+
workers: int | None = None,
57+
orthogonalize: bool | None = None,
58+
) -> npt.NDArray[Untyped]: ...
5459
def idct(
5560
x: Untyped,
56-
type: int = 2,
57-
n: Untyped | None = None,
61+
type: DCTType = 2,
62+
n: int | None = None,
5863
axis: int = -1,
59-
norm: Untyped | None = None,
64+
norm: NormalizationMode | None = None,
6065
overwrite_x: bool = False,
61-
workers: Untyped | None = None,
62-
orthogonalize: Untyped | None = None,
66+
workers: int | None = None,
67+
orthogonalize: bool | None = None,
6368
) -> Untyped: ...
6469
def dst(
6570
x: Untyped,
66-
type: int = 2,
67-
n: Untyped | None = None,
71+
type: DCTType = 2,
72+
n: int | None = None,
6873
axis: int = -1,
69-
norm: Untyped | None = None,
74+
norm: NormalizationMode | None = None,
7075
overwrite_x: bool = False,
71-
workers: Untyped | None = None,
72-
orthogonalize: Untyped | None = None,
76+
workers: int | None = None,
77+
orthogonalize: bool | None = None,
7378
) -> Untyped: ...
7479
def idst(
7580
x: Untyped,
76-
type: int = 2,
77-
n: Untyped | None = None,
81+
type: DCTType = 2,
82+
n: int | None = None,
7883
axis: int = -1,
79-
norm: Untyped | None = None,
84+
norm: NormalizationMode | None = None,
8085
overwrite_x: bool = False,
81-
workers: Untyped | None = None,
82-
orthogonalize: Untyped | None = None,
86+
workers: int | None = None,
87+
orthogonalize: bool | None = None,
8388
) -> Untyped: ...

0 commit comments

Comments
 (0)