|
| 1 | +import enum |
| 2 | +import io |
1 | 3 | import re
|
2 | 4 | import sys
|
3 | 5 | import os
|
| 6 | +import tokenize |
4 | 7 |
|
5 | 8 | from typing import Optional, Tuple, Sequence, MutableSequence, List, MutableMapping, IO, NamedTuple
|
6 | 9 | from types import ModuleType
|
@@ -120,64 +123,112 @@ def write_header(file: IO[str], module_name: Optional[str] = None,
|
120 | 123 | '# NOTE: This dynamically typed stub was automatically generated by stubgen.\n\n')
|
121 | 124 |
|
122 | 125 |
|
| 126 | +class State(enum.Enum): |
| 127 | + INIT = 1 |
| 128 | + FUNCTION_NAME = 2 |
| 129 | + ARGUMENT_LIST = 3 |
| 130 | + ARGUMENT_TYPE = 4 |
| 131 | + ARGUMENT_DEFAULT = 5 |
| 132 | + RETURN_VALUE = 6 |
| 133 | + OPEN_BRACKET = 7 |
| 134 | + |
| 135 | + |
123 | 136 | def infer_sig_from_docstring(docstr: str, name: str) -> Optional[List[TypedFunctionSig]]:
|
124 | 137 | if not docstr:
|
125 | 138 | return None
|
126 |
| - docstr = docstr.lstrip() |
127 |
| - is_overloaded = any(('Overloaded function.' == x.strip()) for x in docstr.split('\n')) |
128 |
| - # look for function signature, which is any string of the format |
129 |
| - # <function_name>(<signature>) -> <return type> |
130 |
| - # or perhaps without the return type |
131 |
| - |
132 |
| - # in the signature, we allow the following characters: |
133 |
| - # colon/equal: to match default values, like "a: int=1" |
134 |
| - # comma/space/brackets: for type hints like "a: Tuple[int, float]" |
135 |
| - # dot: for classes annotating using full path, like "a: foo.bar.baz" |
136 |
| - # to capture return type, |
137 |
| - sig_str = r'\([a-zA-Z0-9_=:, \[\]\.]*\)' |
138 |
| - sig_match = r'%s(%s)' % (name, sig_str) |
139 |
| - sig_match_ret = sig_match + ' -> ([a-zA-Z].*)$' |
140 |
| - |
141 |
| - if is_overloaded: |
142 |
| - def find_sig_ret() -> List[Tuple[str, str]]: |
143 |
| - return re.findall(sig_match_ret, docstr, re.MULTILINE) |
144 |
| - |
145 |
| - def find_sig() -> List[str]: |
146 |
| - return re.findall(sig_match, docstr, re.MULTILINE) |
147 |
| - else: |
148 |
| - def find_sig_ret() -> List[Tuple[str, str]]: |
149 |
| - m = re.match(sig_match_ret, docstr, re.MULTILINE) |
150 |
| - if m: |
151 |
| - return [(m.group(1), m.group(2))] |
152 |
| - return [] |
153 |
| - |
154 |
| - def find_sig() -> List[str]: |
155 |
| - m = re.match(sig_match, docstr) |
156 |
| - if m: |
157 |
| - return [m.group(1)] |
158 |
| - return [] |
159 |
| - |
160 |
| - sig_match_ret_res = find_sig_ret() |
161 |
| - if sig_match_ret_res: |
162 |
| - ret = [] |
163 |
| - for match_ret in sig_match_ret_res: |
164 |
| - ret.append(TypedFunctionSig( |
165 |
| - name=name, |
166 |
| - args=infer_arg_sig_from_docstring(match_ret[0]), |
167 |
| - ret_type=match_ret[1].rstrip() |
168 |
| - )) |
169 |
| - return ret |
170 |
| - sig_match_res = find_sig() |
171 |
| - if sig_match_res: |
172 |
| - ret = [] |
173 |
| - for match in sig_match_res: |
174 |
| - ret.append(TypedFunctionSig( |
175 |
| - name=name, |
176 |
| - args=infer_arg_sig_from_docstring(match), |
177 |
| - ret_type='Any' |
178 |
| - )) |
179 |
| - return ret |
180 |
| - return None |
| 139 | + |
| 140 | + state = [State.INIT, ] |
| 141 | + accumulator = "" |
| 142 | + arg_type = None |
| 143 | + arg_name = "" |
| 144 | + arg_default = None |
| 145 | + ret_type = "Any" |
| 146 | + found = False |
| 147 | + args = [] # type: List[TypedArgSig] |
| 148 | + signatures = [] # type: List[TypedFunctionSig] |
| 149 | + try: |
| 150 | + for token in tokenize.tokenize(io.BytesIO(docstr.encode('utf-8')).readline): |
| 151 | + if token.type == tokenize.NAME and token.string == name and state[-1] == State.INIT: |
| 152 | + state.append(State.FUNCTION_NAME) |
| 153 | + |
| 154 | + elif token.type == tokenize.OP and token.string == '(' and state[-1] == \ |
| 155 | + State.FUNCTION_NAME: |
| 156 | + state.pop() |
| 157 | + accumulator = "" |
| 158 | + found = True |
| 159 | + state.append(State.ARGUMENT_LIST) |
| 160 | + |
| 161 | + elif state[-1] == State.FUNCTION_NAME: |
| 162 | + # reset state, function name not followed by '(' |
| 163 | + state.pop() |
| 164 | + |
| 165 | + elif token.type == tokenize.OP and token.string in ('[', '(', '{'): |
| 166 | + accumulator += token.string |
| 167 | + state.append(State.OPEN_BRACKET) |
| 168 | + |
| 169 | + elif token.type == tokenize.OP and token.string in (']', ')', '}') and \ |
| 170 | + state[-1] == State.OPEN_BRACKET: |
| 171 | + accumulator += token.string |
| 172 | + state.pop() |
| 173 | + |
| 174 | + elif token.type == tokenize.OP and token.string == ':' and \ |
| 175 | + state[-1] == State.ARGUMENT_LIST: |
| 176 | + arg_name = accumulator |
| 177 | + accumulator = "" |
| 178 | + state.append(State.ARGUMENT_TYPE) |
| 179 | + |
| 180 | + elif token.type == tokenize.OP and token.string == '=' and state[-1] in ( |
| 181 | + State.ARGUMENT_LIST, State.ARGUMENT_TYPE): |
| 182 | + if state[-1] == State.ARGUMENT_TYPE: |
| 183 | + arg_type = accumulator |
| 184 | + state.pop() |
| 185 | + else: |
| 186 | + arg_name = accumulator |
| 187 | + accumulator = "" |
| 188 | + state.append(State.ARGUMENT_DEFAULT) |
| 189 | + |
| 190 | + elif token.type == tokenize.OP and token.string in (',', ')') and state[-1] in ( |
| 191 | + State.ARGUMENT_LIST, State.ARGUMENT_DEFAULT, State.ARGUMENT_TYPE): |
| 192 | + if state[-1] == State.ARGUMENT_DEFAULT: |
| 193 | + arg_default = accumulator |
| 194 | + state.pop() |
| 195 | + elif state[-1] == State.ARGUMENT_TYPE: |
| 196 | + arg_type = accumulator |
| 197 | + state.pop() |
| 198 | + elif state[-1] == State.ARGUMENT_LIST: |
| 199 | + arg_name = accumulator |
| 200 | + |
| 201 | + if token.string == ')': |
| 202 | + state.pop() |
| 203 | + args.append(TypedArgSig(name=arg_name, type=arg_type, default=arg_default)) |
| 204 | + arg_name = "" |
| 205 | + arg_type = None |
| 206 | + arg_default = None |
| 207 | + accumulator = "" |
| 208 | + |
| 209 | + elif token.type == tokenize.OP and token.string == '->': |
| 210 | + accumulator = "" |
| 211 | + state.append(State.RETURN_VALUE) |
| 212 | + |
| 213 | + elif token.type == tokenize.NEWLINE and state[-1] in (State.INIT, State.RETURN_VALUE): |
| 214 | + if state[-1] == State.RETURN_VALUE: |
| 215 | + ret_type = accumulator |
| 216 | + accumulator = "" |
| 217 | + state.pop() |
| 218 | + |
| 219 | + if found: |
| 220 | + signatures.append(TypedFunctionSig(name=name, args=args, ret_type=ret_type)) |
| 221 | + found = False |
| 222 | + args = [] |
| 223 | + ret_type = 'Any' |
| 224 | + # leave state as INIT |
| 225 | + else: |
| 226 | + accumulator += token.string |
| 227 | + |
| 228 | + return signatures |
| 229 | + except tokenize.TokenError: |
| 230 | + # return as much as collected |
| 231 | + return signatures |
181 | 232 |
|
182 | 233 |
|
183 | 234 | def infer_arg_sig_from_docstring(docstr: str) -> ArgList:
|
|
0 commit comments