Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 20ddd20

Browse files
committed
Add generate framework
1 parent f32978e commit 20ddd20

File tree

3 files changed

+373
-0
lines changed

3 files changed

+373
-0
lines changed

generate/random.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import os
2+
from typing import Iterator, List
3+
import itertools
4+
5+
from generate.utils import Arg, Func, SELF, STUB_DIR
6+
7+
SIZE = Arg(name="size", ty="_Size", optional=False)
8+
NOSIZE = Arg(name="size", ty="None", optional=True)
9+
10+
11+
def distribution(name: str, return_type: str, *args: Arg) -> Iterator[Func]:
12+
def func(return_type: str, args: List[Arg], size: bool) -> Func:
13+
args = [SELF] + list(args)
14+
if size:
15+
args.append(SIZE)
16+
else:
17+
args.append(NOSIZE)
18+
return Func(name, return_type, args, overload=True)
19+
20+
yield func(return_type, args, size=False)
21+
yield func(
22+
"ndarray",
23+
[Arg(arg.name, f"_ScalarOrArray[{arg.ty}]") for arg in args],
24+
size=True,
25+
)
26+
27+
for i in range(len(args)):
28+
for combo in itertools.combinations(args, i + 1):
29+
yield func(
30+
"ndarray",
31+
[
32+
Arg(
33+
arg.name,
34+
f"_ArrayLike[{arg.ty}]" if arg in combo else arg.ty,
35+
optional=False,
36+
)
37+
for arg in args
38+
],
39+
size=False,
40+
)
41+
42+
43+
functions = list(
44+
distribution("beta", "float", Arg("a", "float"), Arg("b", "float"))
45+
)
46+
functions.append(Func("bytes", "builtins.bytes", [SELF, Arg("length", "int")]))
47+
functions += list(
48+
distribution("binomial", "int", Arg("n", "int"), Arg("p", "float"))
49+
)
50+
functions += list(distribution("chisquared", "float", Arg("df", "int")))
51+
52+
53+
def choice(return_type: str, ty: str, size: bool) -> Func:
54+
return Func(
55+
"choice",
56+
return_type,
57+
[
58+
SELF,
59+
Arg("a", ty),
60+
SIZE if size else NOSIZE,
61+
Arg("replace", "bool", True),
62+
Arg("p", "Optional[_ArrayLike[float]]", True),
63+
],
64+
overload=True,
65+
)
66+
67+
68+
functions += [
69+
choice("int", "int", False),
70+
choice("_T", "Sequence[_T]", False),
71+
choice("Any", "ndarray", False),
72+
choice("ndarray", "Union[int, ndarray]", True),
73+
]
74+
functions.append(
75+
Func(
76+
"dirichlet",
77+
"ndarray",
78+
[
79+
SELF,
80+
Arg("alpha", "_ArrayLike[float]"),
81+
Arg("size", "Optional[_Size]", True),
82+
],
83+
)
84+
)
85+
functions += list(
86+
distribution("exponential", "float", Arg("scale", "float", True))
87+
)
88+
functions += list(
89+
distribution("f", "float", Arg("dfnum", "float"), Arg("dfden", "float"))
90+
)
91+
functions += list(
92+
distribution(
93+
"gamma", "float", Arg("shape", "float"), Arg("scale", "float", True)
94+
)
95+
)
96+
functions += list(distribution("geometric", "float", Arg("p", "float")))
97+
functions.append(
98+
Func("get_state", "Tuple[str, ndarray, int, int, float]", [SELF])
99+
)
100+
functions += list(
101+
distribution(
102+
"gumbel",
103+
"float",
104+
Arg("loc", "float", True),
105+
Arg("scale", "float", True),
106+
)
107+
)
108+
functions += list(
109+
distribution(
110+
"hypergeometric",
111+
"int",
112+
Arg("ngood", "int"),
113+
Arg("nbad", "int"),
114+
Arg("nsample", "int"),
115+
)
116+
)
117+
functions += list(
118+
distribution(
119+
"laplace",
120+
"float",
121+
Arg("loc", "float", True),
122+
Arg("scale", "float", True),
123+
)
124+
)
125+
functions += list(
126+
distribution(
127+
"logistic",
128+
"float",
129+
Arg("loc", "float", True),
130+
Arg("scale", "float", True),
131+
)
132+
)
133+
functions += list(
134+
distribution(
135+
"lognormal",
136+
"float",
137+
Arg("mean", "float", True),
138+
Arg("sigma", "float", True),
139+
)
140+
)
141+
functions += list(distribution("logseries", "int", Arg("p", "float")))
142+
functions.append(
143+
Func(
144+
"multinomial",
145+
"ndarray",
146+
[SELF, Arg("n", "int"), Arg("size", "Optional[_Size]", True)],
147+
)
148+
)
149+
functions.append(
150+
Func(
151+
"multivariate_normal",
152+
"ndarray",
153+
[
154+
SELF,
155+
Arg("mean", "ndarray"),
156+
Arg("cov", "ndarray"),
157+
Arg("size", "Optional[_Size]", True),
158+
Arg("check_valid", "str", True),
159+
Arg("tol", "float", True),
160+
],
161+
)
162+
)
163+
functions += list(
164+
distribution(
165+
"negative_binomial", "int", Arg("n", "int"), Arg("p", "float")
166+
)
167+
)
168+
functions += list(
169+
distribution(
170+
"noncentral_chisquare",
171+
"float",
172+
Arg("df", "float"),
173+
Arg("nonc", "float"),
174+
)
175+
)
176+
functions += list(
177+
distribution(
178+
"noncentral_f",
179+
"float",
180+
Arg("dfnum", "float"),
181+
Arg("dfden", "float"),
182+
Arg("nonc", "float"),
183+
)
184+
)
185+
functions += list(
186+
distribution(
187+
"normal",
188+
"float",
189+
Arg("loc", "float", True),
190+
Arg("scale", "float", True),
191+
)
192+
)
193+
functions += list(distribution("pareto", "float", Arg("a", "float")))
194+
functions.append(
195+
Func("permutation", "ndarray", [SELF, Arg("x", "Union[int, ndarray]")])
196+
)
197+
functions += list(distribution("poisson", "float", Arg("lam", "float", True)))
198+
functions += list(distribution("power", "float", Arg("a", "float")))
199+
functions += [
200+
Func("rand", "float", [SELF], overload=True),
201+
Func(
202+
"rand",
203+
"ndarray",
204+
[SELF, Arg("d0", "int"), Arg("*dn", "int")],
205+
overload=True,
206+
),
207+
]
208+
# TODO(alan): dtype parameter
209+
functions += list(
210+
distribution("randint", "int", Arg("low", "int"), Arg("high", "int", True))
211+
)
212+
functions += [
213+
Func("randn", "float", [SELF], overload=True),
214+
Func(
215+
"randn",
216+
"ndarray",
217+
[SELF, Arg("d0", "int"), Arg("*dn", "int")],
218+
overload=True,
219+
),
220+
]
221+
functions += list(
222+
distribution(
223+
"random_integers", "int", Arg("low", "int"), Arg("high", "int", True)
224+
)
225+
)
226+
functions += list(distribution("random_sample", "int"))
227+
functions += list(distribution("rayleigh", "float", Arg("scale", "float")))
228+
functions += [
229+
Func(
230+
"seed",
231+
"None",
232+
[SELF, Arg("seed", "Union[None, int, Tuple[int], List[int]]", True)],
233+
),
234+
Func(
235+
"set_state",
236+
"None",
237+
[SELF, Arg("state", "Tuple[str, ndarray, int, int, float]")],
238+
),
239+
Func("shuffle", "None", [SELF, Arg("x", "_ArrayLike[Any]")]),
240+
]
241+
functions += list(distribution("standard_cauchy", "float"))
242+
functions += list(distribution("standard_exponential", "float"))
243+
functions += list(distribution("standard_gamma", "float"))
244+
functions += list(distribution("standard_normal", "float"))
245+
functions += list(distribution("standard_t", "float", Arg("t", "int")))
246+
functions += list(distribution("tomaxint", "int", Arg("t", "int")))
247+
functions += list(
248+
distribution(
249+
"triangular",
250+
"float",
251+
Arg("left", "float"),
252+
Arg("mode", "float"),
253+
Arg("right", "float"),
254+
)
255+
)
256+
functions += list(
257+
distribution(
258+
"uniform",
259+
"float",
260+
Arg("low", "float", True),
261+
Arg("high", "float", True),
262+
)
263+
)
264+
functions += list(
265+
distribution(
266+
"vonmises", "float", Arg("mu", "float"), Arg("kappa", "float")
267+
)
268+
)
269+
functions += list(
270+
distribution("wald", "float", Arg("mean", "float"), Arg("scale", "float"))
271+
)
272+
functions += list(distribution("weibull", "float", Arg("a", "float")))
273+
functions += list(distribution("zipf", "int", Arg("a", "float")))
274+
275+
276+
imports = """\
277+
import builtins
278+
from typing import (
279+
Any, List, overload, Optional, Sequence, Tuple, TypeVar, Union
280+
)
281+
from numpy import ndarray
282+
"""
283+
284+
typevars = """\
285+
_Size = Union[int, Sequence[int]]
286+
_T = TypeVar("_T")
287+
_ArrayLike = Union[Sequence[_T], ndarray]
288+
_ScalarOrArray = Union[_T, Sequence[_T], ndarray]
289+
"""
290+
291+
with open(os.path.join(STUB_DIR, "random", "mtrand.pyi"), "w") as fout:
292+
fout.write(imports)
293+
fout.write(typevars)
294+
fout.write(
295+
"""
296+
class RandomState:
297+
def __init__(
298+
self, state: Union[None, int, List[int], Tuple[int]] = ...
299+
) -> None: ...
300+
"""
301+
)
302+
prev = None
303+
for func in functions:
304+
if func.name != prev:
305+
fout.write("\n")
306+
prev = func.name
307+
fout.write(func.render(indent=4))
308+
309+
with open(os.path.join(STUB_DIR, "random", "__init__.pyi"), "w") as fout:
310+
fout.write(imports)
311+
fout.write("from . import mtrand\n")
312+
fout.write(typevars)
313+
fout.write(
314+
"""\
315+
RandomState = mtrand.RandomState
316+
"""
317+
)
318+
319+
prev = None
320+
for func in functions:
321+
if func.name != prev:
322+
fout.write("\n")
323+
prev = func.name
324+
325+
# Leave out first argument (SELF)
326+
assert func.args[0] == SELF
327+
func = Func(func.name, func.return_type, func.args[1:], func.overload)
328+
fout.write(func.render())

generate/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
from dataclasses import dataclass, field
4+
from typing import List, Optional
5+
6+
import textwrap
7+
8+
STUB_DIR = os.path.abspath(os.path.join(__file__, "..", "..", "numpy-stubs"))
9+
10+
11+
@dataclass(frozen=True)
12+
class Arg:
13+
name: str
14+
ty: Optional[str] = None
15+
optional: bool = False
16+
17+
def render(self) -> str:
18+
s = self.name
19+
if self.ty is not None:
20+
s += f": {self.ty}"
21+
if self.optional:
22+
s += " = ..."
23+
return s
24+
25+
26+
@dataclass(frozen=True)
27+
class Func:
28+
name: str
29+
return_type: str
30+
31+
args: List[Arg] = field(default_factory=list)
32+
overload: bool = False
33+
34+
def render(self, indent=0) -> str:
35+
s = f"def {self.name}("
36+
s += ", ".join(arg.render() for arg in self.args)
37+
s += f") -> {self.return_type}: ...\n"
38+
39+
if self.overload:
40+
s = "@overload\n" + s
41+
return textwrap.indent(s, " " * indent)
42+
43+
44+
SELF = Arg("self")

test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
dataclasses==0.6
12
flake8==3.3.0
23
flake8-pyi==17.3.0
34
pytest==3.4.2

0 commit comments

Comments
 (0)