|
| 1 | +import os |
| 2 | +from typing import Iterator, List |
| 3 | +import itertools |
| 4 | + |
| 5 | +from generate.utils import Arg, Func, SELF, STUB_DIR |
| 6 | + |
| 7 | +SIZE = Arg(name="size", ty="_Size", optional=False) |
| 8 | +NOSIZE = Arg(name="size", ty="None", optional=True) |
| 9 | + |
| 10 | + |
| 11 | +def distribution(name: str, return_type: str, *args: Arg) -> Iterator[Func]: |
| 12 | + def func(return_type: str, args: List[Arg], size: bool) -> Func: |
| 13 | + args = [SELF] + list(args) |
| 14 | + if size: |
| 15 | + args.append(SIZE) |
| 16 | + else: |
| 17 | + args.append(NOSIZE) |
| 18 | + return Func(name, return_type, args, overload=True) |
| 19 | + |
| 20 | + yield func(return_type, args, size=False) |
| 21 | + yield func( |
| 22 | + "ndarray", |
| 23 | + [Arg(arg.name, f"_ScalarOrArray[{arg.ty}]") for arg in args], |
| 24 | + size=True, |
| 25 | + ) |
| 26 | + |
| 27 | + for i in range(len(args)): |
| 28 | + for combo in itertools.combinations(args, i + 1): |
| 29 | + yield func( |
| 30 | + "ndarray", |
| 31 | + [ |
| 32 | + Arg( |
| 33 | + arg.name, |
| 34 | + f"_ArrayLike[{arg.ty}]" if arg in combo else arg.ty, |
| 35 | + optional=False, |
| 36 | + ) |
| 37 | + for arg in args |
| 38 | + ], |
| 39 | + size=False, |
| 40 | + ) |
| 41 | + |
| 42 | + |
| 43 | +functions = list( |
| 44 | + distribution("beta", "float", Arg("a", "float"), Arg("b", "float")) |
| 45 | +) |
| 46 | +functions.append(Func("bytes", "builtins.bytes", [SELF, Arg("length", "int")])) |
| 47 | +functions += list( |
| 48 | + distribution("binomial", "int", Arg("n", "int"), Arg("p", "float")) |
| 49 | +) |
| 50 | +functions += list(distribution("chisquared", "float", Arg("df", "int"))) |
| 51 | + |
| 52 | + |
| 53 | +def choice(return_type: str, ty: str, size: bool) -> Func: |
| 54 | + return Func( |
| 55 | + "choice", |
| 56 | + return_type, |
| 57 | + [ |
| 58 | + SELF, |
| 59 | + Arg("a", ty), |
| 60 | + SIZE if size else NOSIZE, |
| 61 | + Arg("replace", "bool", True), |
| 62 | + Arg("p", "Optional[_ArrayLike[float]]", True), |
| 63 | + ], |
| 64 | + overload=True, |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | +functions += [ |
| 69 | + choice("int", "int", False), |
| 70 | + choice("_T", "Sequence[_T]", False), |
| 71 | + choice("Any", "ndarray", False), |
| 72 | + choice("ndarray", "Union[int, ndarray]", True), |
| 73 | +] |
| 74 | +functions.append( |
| 75 | + Func( |
| 76 | + "dirichlet", |
| 77 | + "ndarray", |
| 78 | + [ |
| 79 | + SELF, |
| 80 | + Arg("alpha", "_ArrayLike[float]"), |
| 81 | + Arg("size", "Optional[_Size]", True), |
| 82 | + ], |
| 83 | + ) |
| 84 | +) |
| 85 | +functions += list( |
| 86 | + distribution("exponential", "float", Arg("scale", "float", True)) |
| 87 | +) |
| 88 | +functions += list( |
| 89 | + distribution("f", "float", Arg("dfnum", "float"), Arg("dfden", "float")) |
| 90 | +) |
| 91 | +functions += list( |
| 92 | + distribution( |
| 93 | + "gamma", "float", Arg("shape", "float"), Arg("scale", "float", True) |
| 94 | + ) |
| 95 | +) |
| 96 | +functions += list(distribution("geometric", "float", Arg("p", "float"))) |
| 97 | +functions.append( |
| 98 | + Func("get_state", "Tuple[str, ndarray, int, int, float]", [SELF]) |
| 99 | +) |
| 100 | +functions += list( |
| 101 | + distribution( |
| 102 | + "gumbel", |
| 103 | + "float", |
| 104 | + Arg("loc", "float", True), |
| 105 | + Arg("scale", "float", True), |
| 106 | + ) |
| 107 | +) |
| 108 | +functions += list( |
| 109 | + distribution( |
| 110 | + "hypergeometric", |
| 111 | + "int", |
| 112 | + Arg("ngood", "int"), |
| 113 | + Arg("nbad", "int"), |
| 114 | + Arg("nsample", "int"), |
| 115 | + ) |
| 116 | +) |
| 117 | +functions += list( |
| 118 | + distribution( |
| 119 | + "laplace", |
| 120 | + "float", |
| 121 | + Arg("loc", "float", True), |
| 122 | + Arg("scale", "float", True), |
| 123 | + ) |
| 124 | +) |
| 125 | +functions += list( |
| 126 | + distribution( |
| 127 | + "logistic", |
| 128 | + "float", |
| 129 | + Arg("loc", "float", True), |
| 130 | + Arg("scale", "float", True), |
| 131 | + ) |
| 132 | +) |
| 133 | +functions += list( |
| 134 | + distribution( |
| 135 | + "lognormal", |
| 136 | + "float", |
| 137 | + Arg("mean", "float", True), |
| 138 | + Arg("sigma", "float", True), |
| 139 | + ) |
| 140 | +) |
| 141 | +functions += list(distribution("logseries", "int", Arg("p", "float"))) |
| 142 | +functions.append( |
| 143 | + Func( |
| 144 | + "multinomial", |
| 145 | + "ndarray", |
| 146 | + [SELF, Arg("n", "int"), Arg("size", "Optional[_Size]", True)], |
| 147 | + ) |
| 148 | +) |
| 149 | +functions.append( |
| 150 | + Func( |
| 151 | + "multivariate_normal", |
| 152 | + "ndarray", |
| 153 | + [ |
| 154 | + SELF, |
| 155 | + Arg("mean", "ndarray"), |
| 156 | + Arg("cov", "ndarray"), |
| 157 | + Arg("size", "Optional[_Size]", True), |
| 158 | + Arg("check_valid", "str", True), |
| 159 | + Arg("tol", "float", True), |
| 160 | + ], |
| 161 | + ) |
| 162 | +) |
| 163 | +functions += list( |
| 164 | + distribution( |
| 165 | + "negative_binomial", "int", Arg("n", "int"), Arg("p", "float") |
| 166 | + ) |
| 167 | +) |
| 168 | +functions += list( |
| 169 | + distribution( |
| 170 | + "noncentral_chisquare", |
| 171 | + "float", |
| 172 | + Arg("df", "float"), |
| 173 | + Arg("nonc", "float"), |
| 174 | + ) |
| 175 | +) |
| 176 | +functions += list( |
| 177 | + distribution( |
| 178 | + "noncentral_f", |
| 179 | + "float", |
| 180 | + Arg("dfnum", "float"), |
| 181 | + Arg("dfden", "float"), |
| 182 | + Arg("nonc", "float"), |
| 183 | + ) |
| 184 | +) |
| 185 | +functions += list( |
| 186 | + distribution( |
| 187 | + "normal", |
| 188 | + "float", |
| 189 | + Arg("loc", "float", True), |
| 190 | + Arg("scale", "float", True), |
| 191 | + ) |
| 192 | +) |
| 193 | +functions += list(distribution("pareto", "float", Arg("a", "float"))) |
| 194 | +functions.append( |
| 195 | + Func("permutation", "ndarray", [SELF, Arg("x", "Union[int, ndarray]")]) |
| 196 | +) |
| 197 | +functions += list(distribution("poisson", "float", Arg("lam", "float", True))) |
| 198 | +functions += list(distribution("power", "float", Arg("a", "float"))) |
| 199 | +functions += [ |
| 200 | + Func("rand", "float", [SELF], overload=True), |
| 201 | + Func( |
| 202 | + "rand", |
| 203 | + "ndarray", |
| 204 | + [SELF, Arg("d0", "int"), Arg("*dn", "int")], |
| 205 | + overload=True, |
| 206 | + ), |
| 207 | +] |
| 208 | +# TODO(alan): dtype parameter |
| 209 | +functions += list( |
| 210 | + distribution("randint", "int", Arg("low", "int"), Arg("high", "int", True)) |
| 211 | +) |
| 212 | +functions += [ |
| 213 | + Func("randn", "float", [SELF], overload=True), |
| 214 | + Func( |
| 215 | + "randn", |
| 216 | + "ndarray", |
| 217 | + [SELF, Arg("d0", "int"), Arg("*dn", "int")], |
| 218 | + overload=True, |
| 219 | + ), |
| 220 | +] |
| 221 | +functions += list( |
| 222 | + distribution( |
| 223 | + "random_integers", "int", Arg("low", "int"), Arg("high", "int", True) |
| 224 | + ) |
| 225 | +) |
| 226 | +functions += list(distribution("random_sample", "int")) |
| 227 | +functions += list(distribution("rayleigh", "float", Arg("scale", "float"))) |
| 228 | +functions += [ |
| 229 | + Func( |
| 230 | + "seed", |
| 231 | + "None", |
| 232 | + [SELF, Arg("seed", "Union[None, int, Tuple[int], List[int]]", True)], |
| 233 | + ), |
| 234 | + Func( |
| 235 | + "set_state", |
| 236 | + "None", |
| 237 | + [SELF, Arg("state", "Tuple[str, ndarray, int, int, float]")], |
| 238 | + ), |
| 239 | + Func("shuffle", "None", [SELF, Arg("x", "_ArrayLike[Any]")]), |
| 240 | +] |
| 241 | +functions += list(distribution("standard_cauchy", "float")) |
| 242 | +functions += list(distribution("standard_exponential", "float")) |
| 243 | +functions += list(distribution("standard_gamma", "float")) |
| 244 | +functions += list(distribution("standard_normal", "float")) |
| 245 | +functions += list(distribution("standard_t", "float", Arg("t", "int"))) |
| 246 | +functions += list(distribution("tomaxint", "int", Arg("t", "int"))) |
| 247 | +functions += list( |
| 248 | + distribution( |
| 249 | + "triangular", |
| 250 | + "float", |
| 251 | + Arg("left", "float"), |
| 252 | + Arg("mode", "float"), |
| 253 | + Arg("right", "float"), |
| 254 | + ) |
| 255 | +) |
| 256 | +functions += list( |
| 257 | + distribution( |
| 258 | + "uniform", |
| 259 | + "float", |
| 260 | + Arg("low", "float", True), |
| 261 | + Arg("high", "float", True), |
| 262 | + ) |
| 263 | +) |
| 264 | +functions += list( |
| 265 | + distribution( |
| 266 | + "vonmises", "float", Arg("mu", "float"), Arg("kappa", "float") |
| 267 | + ) |
| 268 | +) |
| 269 | +functions += list( |
| 270 | + distribution("wald", "float", Arg("mean", "float"), Arg("scale", "float")) |
| 271 | +) |
| 272 | +functions += list(distribution("weibull", "float", Arg("a", "float"))) |
| 273 | +functions += list(distribution("zipf", "int", Arg("a", "float"))) |
| 274 | + |
| 275 | + |
| 276 | +imports = """\ |
| 277 | +import builtins |
| 278 | +from typing import ( |
| 279 | + Any, List, overload, Optional, Sequence, Tuple, TypeVar, Union |
| 280 | +) |
| 281 | +from numpy import ndarray |
| 282 | +""" |
| 283 | + |
| 284 | +typevars = """\ |
| 285 | +_Size = Union[int, Sequence[int]] |
| 286 | +_T = TypeVar("_T") |
| 287 | +_ArrayLike = Union[Sequence[_T], ndarray] |
| 288 | +_ScalarOrArray = Union[_T, Sequence[_T], ndarray] |
| 289 | +""" |
| 290 | + |
| 291 | +with open(os.path.join(STUB_DIR, "random", "mtrand.pyi"), "w") as fout: |
| 292 | + fout.write(imports) |
| 293 | + fout.write(typevars) |
| 294 | + fout.write( |
| 295 | + """ |
| 296 | +class RandomState: |
| 297 | + def __init__( |
| 298 | + self, state: Union[None, int, List[int], Tuple[int]] = ... |
| 299 | + ) -> None: ... |
| 300 | +""" |
| 301 | + ) |
| 302 | + prev = None |
| 303 | + for func in functions: |
| 304 | + if func.name != prev: |
| 305 | + fout.write("\n") |
| 306 | + prev = func.name |
| 307 | + fout.write(func.render(indent=4)) |
| 308 | + |
| 309 | +with open(os.path.join(STUB_DIR, "random", "__init__.pyi"), "w") as fout: |
| 310 | + fout.write(imports) |
| 311 | + fout.write("from . import mtrand\n") |
| 312 | + fout.write(typevars) |
| 313 | + fout.write( |
| 314 | + """\ |
| 315 | +RandomState = mtrand.RandomState |
| 316 | +""" |
| 317 | + ) |
| 318 | + |
| 319 | + prev = None |
| 320 | + for func in functions: |
| 321 | + if func.name != prev: |
| 322 | + fout.write("\n") |
| 323 | + prev = func.name |
| 324 | + |
| 325 | + # Leave out first argument (SELF) |
| 326 | + assert func.args[0] == SELF |
| 327 | + func = Func(func.name, func.return_type, func.args[1:], func.overload) |
| 328 | + fout.write(func.render()) |
0 commit comments