Skip to content

Commit 0c247be

Browse files
committed
Add a few missing type annotations in _pytest._code
These are more "dirty" than the previous batch (that's why they were left out). The trouble is that `compile` can return either a code object or an AST depending on a flag, so we need to add an overload to make the common case Union free. But it's still worthwhile.
1 parent 3e6f0f3 commit 0c247be

File tree

3 files changed

+86
-11
lines changed

3 files changed

+86
-11
lines changed

src/_pytest/_code/code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __ne__(self, other):
6767
return not self == other
6868

6969
@property
70-
def path(self):
70+
def path(self) -> Union[py.path.local, str]:
7171
""" return a path object pointing to source code (note that it
7272
might not point to an actually existing file). """
7373
try:
@@ -335,7 +335,7 @@ def cut(
335335
(path is None or codepath == path)
336336
and (
337337
excludepath is None
338-
or not hasattr(codepath, "relto")
338+
or not isinstance(codepath, py.path.local)
339339
or not codepath.relto(excludepath)
340340
)
341341
and (lineno is None or x.lineno == lineno)

src/_pytest/_code/source.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tokenize
77
import warnings
88
from bisect import bisect_right
9+
from types import CodeType
910
from types import FrameType
1011
from typing import Iterator
1112
from typing import List
@@ -17,6 +18,10 @@
1718
import py
1819

1920
from _pytest.compat import overload
21+
from _pytest.compat import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from typing_extensions import Literal
2025

2126

2227
class Source:
@@ -120,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source":
120125
start, end = self.getstatementrange(lineno)
121126
return self[start:end]
122127

123-
def getstatementrange(self, lineno: int):
128+
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
124129
""" return (start, end) tuple which spans the minimal
125130
statement region which containing the given lineno.
126131
"""
@@ -158,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool:
158163
def __str__(self) -> str:
159164
return "\n".join(self.lines)
160165

166+
@overload
161167
def compile(
162168
self,
163-
filename=None,
164-
mode="exec",
169+
filename: Optional[str] = ...,
170+
mode: str = ...,
171+
flag: "Literal[0]" = ...,
172+
dont_inherit: int = ...,
173+
_genframe: Optional[FrameType] = ...,
174+
) -> CodeType:
175+
raise NotImplementedError()
176+
177+
@overload # noqa: F811
178+
def compile( # noqa: F811
179+
self,
180+
filename: Optional[str] = ...,
181+
mode: str = ...,
182+
flag: int = ...,
183+
dont_inherit: int = ...,
184+
_genframe: Optional[FrameType] = ...,
185+
) -> Union[CodeType, ast.AST]:
186+
raise NotImplementedError()
187+
188+
def compile( # noqa: F811
189+
self,
190+
filename: Optional[str] = None,
191+
mode: str = "exec",
165192
flag: int = 0,
166193
dont_inherit: int = 0,
167194
_genframe: Optional[FrameType] = None,
168-
):
195+
) -> Union[CodeType, ast.AST]:
169196
""" return compiled code object. if filename is None
170197
invent an artificial filename which displays
171198
the source/line position of the caller frame.
@@ -196,7 +223,9 @@ def compile(
196223
raise newex
197224
else:
198225
if flag & ast.PyCF_ONLY_AST:
226+
assert isinstance(co, ast.AST)
199227
return co
228+
assert isinstance(co, CodeType)
200229
lines = [(x + "\n") for x in self.lines]
201230
# Type ignored because linecache.cache is private.
202231
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
@@ -208,22 +237,52 @@ def compile(
208237
#
209238

210239

211-
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
240+
@overload
241+
def compile_(
242+
source: Union[str, bytes, ast.mod, ast.AST],
243+
filename: Optional[str] = ...,
244+
mode: str = ...,
245+
flags: "Literal[0]" = ...,
246+
dont_inherit: int = ...,
247+
) -> CodeType:
248+
raise NotImplementedError()
249+
250+
251+
@overload # noqa: F811
252+
def compile_( # noqa: F811
253+
source: Union[str, bytes, ast.mod, ast.AST],
254+
filename: Optional[str] = ...,
255+
mode: str = ...,
256+
flags: int = ...,
257+
dont_inherit: int = ...,
258+
) -> Union[CodeType, ast.AST]:
259+
raise NotImplementedError()
260+
261+
262+
def compile_( # noqa: F811
263+
source: Union[str, bytes, ast.mod, ast.AST],
264+
filename: Optional[str] = None,
265+
mode: str = "exec",
266+
flags: int = 0,
267+
dont_inherit: int = 0,
268+
) -> Union[CodeType, ast.AST]:
212269
""" compile the given source to a raw code object,
213270
and maintain an internal cache which allows later
214271
retrieval of the source code for the code object
215272
and any recursively created code objects.
216273
"""
217274
if isinstance(source, ast.AST):
218275
# XXX should Source support having AST?
219-
return compile(source, filename, mode, flags, dont_inherit)
276+
assert filename is not None
277+
co = compile(source, filename, mode, flags, dont_inherit)
278+
assert isinstance(co, (CodeType, ast.AST))
279+
return co
220280
_genframe = sys._getframe(1) # the caller
221281
s = Source(source)
222-
co = s.compile(filename, mode, flags, _genframe=_genframe)
223-
return co
282+
return s.compile(filename, mode, flags, _genframe=_genframe)
224283

225284

226-
def getfslineno(obj):
285+
def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
227286
""" Return source location (path, lineno) for the given object.
228287
If the source cannot be determined return ("", -1).
229288

testing/code/test_source.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import ast
55
import inspect
66
import sys
7+
from types import CodeType
78
from typing import Any
89
from typing import Dict
910
from typing import Optional
1011

12+
import py
13+
1114
import _pytest._code
1215
import pytest
1316
from _pytest._code import Source
@@ -147,6 +150,10 @@ def test_getrange(self) -> None:
147150
assert len(x.lines) == 2
148151
assert str(x) == "def f(x):\n pass"
149152

153+
def test_getrange_step_not_supported(self) -> None:
154+
with pytest.raises(IndexError, match=r"step"):
155+
self.source[::2]
156+
150157
def test_getline(self) -> None:
151158
x = self.source[0]
152159
assert x == "def f(x):"
@@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
449456
assert src == expected
450457

451458

459+
def test_compile_ast() -> None:
460+
# We don't necessarily want to support this.
461+
# This test was added just for coverage.
462+
stmt = ast.parse("def x(): pass")
463+
co = _pytest._code.compile(stmt, filename="foo.py")
464+
assert isinstance(co, CodeType)
465+
466+
452467
def test_findsource_fallback() -> None:
453468
from _pytest._code.source import findsource
454469

@@ -488,6 +503,7 @@ def f(x) -> None:
488503

489504
fspath, lineno = getfslineno(f)
490505

506+
assert isinstance(fspath, py.path.local)
491507
assert fspath.basename == "test_source.py"
492508
assert lineno == f.__code__.co_firstlineno - 1 # see findsource
493509

0 commit comments

Comments
 (0)