Skip to content

Commit faebf3c

Browse files
AnOctopusGuido van Rossum
authored and
Guido van Rossum
committed
Fix sys.stdout overriding in mypy.api (#6750)
Overriding sys.stdout and sys.stderr in mypy.api is not threadsafe. This causes problems sometimes when using the api in pyls for example. This started with the changes made by @elkhadiy to fix #6125. Fixes #6125.
1 parent a0e9a99 commit faebf3c

File tree

8 files changed

+135
-85
lines changed

8 files changed

+135
-85
lines changed

mypy/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Mypy type checker command line tool."""
22

3+
import sys
34
from mypy.main import main
45

56

67
def console_entry() -> None:
7-
main(None)
8+
main(None, sys.stdout, sys.stderr)
89

910

1011
if __name__ == '__main__':
11-
main(None)
12+
main(None, sys.stdout, sys.stderr)

mypy/api.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,36 +44,31 @@
4444

4545
import sys
4646
from io import StringIO
47-
from typing import List, Tuple, Callable
47+
from typing import List, Tuple, Union, TextIO, Callable
48+
from mypy_extensions import DefaultArg
4849

4950

50-
def _run(f: Callable[[], None]) -> Tuple[str, str, int]:
51-
old_stdout = sys.stdout
52-
new_stdout = StringIO()
53-
sys.stdout = new_stdout
51+
def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> Tuple[str, str, int]:
5452

55-
old_stderr = sys.stderr
56-
new_stderr = StringIO()
57-
sys.stderr = new_stderr
53+
stdout = StringIO()
54+
stderr = StringIO()
5855

5956
try:
60-
f()
57+
main_wrapper(stdout, stderr)
6158
exit_status = 0
6259
except SystemExit as system_exit:
6360
exit_status = system_exit.code
64-
finally:
65-
sys.stdout = old_stdout
66-
sys.stderr = old_stderr
6761

68-
return new_stdout.getvalue(), new_stderr.getvalue(), exit_status
62+
return stdout.getvalue(), stderr.getvalue(), exit_status
6963

7064

7165
def run(args: List[str]) -> Tuple[str, str, int]:
7266
# Lazy import to avoid needing to import all of mypy to call run_dmypy
7367
from mypy.main import main
74-
return _run(lambda: main(None, args=args))
68+
return _run(lambda stdout, stderr: main(None, args=args,
69+
stdout=stdout, stderr=stderr))
7570

7671

7772
def run_dmypy(args: List[str]) -> Tuple[str, str, int]:
7873
from mypy.dmypy import main
79-
return _run(lambda: main(args))
74+
return _run(lambda stdout, stderr: main(args))

mypy/build.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import types
2626

2727
from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List,
28-
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable)
28+
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable, TextIO)
2929
MYPY = False
3030
if MYPY:
3131
from typing import ClassVar
@@ -128,6 +128,8 @@ def build(sources: List[BuildSource],
128128
alt_lib_path: Optional[str] = None,
129129
flush_errors: Optional[Callable[[List[str], bool], None]] = None,
130130
fscache: Optional[FileSystemCache] = None,
131+
stdout: Optional[TextIO] = None,
132+
stderr: Optional[TextIO] = None,
131133
) -> BuildResult:
132134
"""Analyze a program.
133135
@@ -159,9 +161,11 @@ def default_flush_errors(new_messages: List[str], is_serious: bool) -> None:
159161
messages.extend(new_messages)
160162

161163
flush_errors = flush_errors or default_flush_errors
164+
stdout = stdout or sys.stdout
165+
stderr = stderr or sys.stderr
162166

163167
try:
164-
result = _build(sources, options, alt_lib_path, flush_errors, fscache)
168+
result = _build(sources, options, alt_lib_path, flush_errors, fscache, stdout, stderr)
165169
result.errors = messages
166170
return result
167171
except CompileError as e:
@@ -180,6 +184,8 @@ def _build(sources: List[BuildSource],
180184
alt_lib_path: Optional[str],
181185
flush_errors: Callable[[List[str], bool], None],
182186
fscache: Optional[FileSystemCache],
187+
stdout: TextIO,
188+
stderr: TextIO,
183189
) -> BuildResult:
184190
# This seems the most reasonable place to tune garbage collection.
185191
gc.set_threshold(150 * 1000)
@@ -197,7 +203,7 @@ def _build(sources: List[BuildSource],
197203

198204
source_set = BuildSourceSet(sources)
199205
errors = Errors(options.show_error_context, options.show_column_numbers)
200-
plugin, snapshot = load_plugins(options, errors)
206+
plugin, snapshot = load_plugins(options, errors, stdout)
201207

202208
# Construct a build manager object to hold state during the build.
203209
#
@@ -212,12 +218,14 @@ def _build(sources: List[BuildSource],
212218
plugins_snapshot=snapshot,
213219
errors=errors,
214220
flush_errors=flush_errors,
215-
fscache=fscache)
221+
fscache=fscache,
222+
stdout=stdout,
223+
stderr=stderr)
216224
manager.trace(repr(options))
217225

218226
reset_global_state()
219227
try:
220-
graph = dispatch(sources, manager)
228+
graph = dispatch(sources, manager, stdout)
221229
if not options.fine_grained_incremental:
222230
TypeState.reset_all_subtype_caches()
223231
return BuildResult(manager, graph)
@@ -319,7 +327,10 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
319327
return toplevel_priority
320328

321329

322-
def load_plugins(options: Options, errors: Errors) -> Tuple[Plugin, Dict[str, str]]:
330+
def load_plugins(options: Options,
331+
errors: Errors,
332+
stdout: TextIO,
333+
) -> Tuple[Plugin, Dict[str, str]]:
323334
"""Load all configured plugins.
324335
325336
Return a plugin that encapsulates all plugins chained together. Always
@@ -383,7 +394,8 @@ def plugin_error(message: str) -> None:
383394
try:
384395
plugin_type = getattr(module, func_name)(__version__)
385396
except Exception:
386-
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path))
397+
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path),
398+
file=stdout)
387399
raise # Propagate to display traceback
388400

389401
if not isinstance(plugin_type, type):
@@ -398,7 +410,8 @@ def plugin_error(message: str) -> None:
398410
custom_plugins.append(plugin_type(options))
399411
snapshot[module_name] = take_module_snapshot(module)
400412
except Exception:
401-
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
413+
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__),
414+
file=stdout)
402415
raise # Propagate to display traceback
403416
# Custom plugins take precedence over the default plugin.
404417
return ChainedPlugin(options, custom_plugins + [default_plugin]), snapshot
@@ -496,8 +509,10 @@ def __init__(self, data_dir: str,
496509
errors: Errors,
497510
flush_errors: Callable[[List[str], bool], None],
498511
fscache: FileSystemCache,
512+
stdout: TextIO,
513+
stderr: TextIO,
499514
) -> None:
500-
super().__init__()
515+
super().__init__(stdout, stderr)
501516
self.start_time = time.time()
502517
self.data_dir = data_dir
503518
self.errors = errors
@@ -558,7 +573,7 @@ def __init__(self, data_dir: str,
558573
self.plugin = plugin
559574
self.plugins_snapshot = plugins_snapshot
560575
self.old_plugins_snapshot = read_plugins_snapshot(self)
561-
self.quickstart_state = read_quickstart_file(options)
576+
self.quickstart_state = read_quickstart_file(options, self.stdout)
562577

563578
def dump_stats(self) -> None:
564579
self.log("Stats:")
@@ -904,7 +919,9 @@ def read_plugins_snapshot(manager: BuildManager) -> Optional[Dict[str, str]]:
904919
return snapshot
905920

906921

907-
def read_quickstart_file(options: Options) -> Optional[Dict[str, Tuple[float, int, str]]]:
922+
def read_quickstart_file(options: Options,
923+
stdout: TextIO,
924+
) -> Optional[Dict[str, Tuple[float, int, str]]]:
908925
quickstart = None # type: Optional[Dict[str, Tuple[float, int, str]]]
909926
if options.quickstart_file:
910927
# This is very "best effort". If the file is missing or malformed,
@@ -918,7 +935,7 @@ def read_quickstart_file(options: Options) -> Optional[Dict[str, Tuple[float, in
918935
for file, (x, y, z) in raw_quickstart.items():
919936
quickstart[file] = (x, y, z)
920937
except Exception as e:
921-
print("Warning: Failed to load quickstart file: {}\n".format(str(e)))
938+
print("Warning: Failed to load quickstart file: {}\n".format(str(e)), file=stdout)
922939
return quickstart
923940

924941

@@ -1769,7 +1786,8 @@ def wrap_context(self, check_blockers: bool = True) -> Iterator[None]:
17691786
except CompileError:
17701787
raise
17711788
except Exception as err:
1772-
report_internal_error(err, self.path, 0, self.manager.errors, self.options)
1789+
report_internal_error(err, self.path, 0, self.manager.errors,
1790+
self.options, self.manager.stdout, self.manager.stderr)
17731791
self.manager.errors.set_import_context(save_import_context)
17741792
# TODO: Move this away once we've removed the old semantic analyzer?
17751793
if check_blockers:
@@ -2429,7 +2447,10 @@ def log_configuration(manager: BuildManager) -> None:
24292447
# The driver
24302448

24312449

2432-
def dispatch(sources: List[BuildSource], manager: BuildManager) -> Graph:
2450+
def dispatch(sources: List[BuildSource],
2451+
manager: BuildManager,
2452+
stdout: TextIO,
2453+
) -> Graph:
24332454
log_configuration(manager)
24342455

24352456
t0 = time.time()
@@ -2454,11 +2475,11 @@ def dispatch(sources: List[BuildSource], manager: BuildManager) -> Graph:
24542475
fm_cache_size=len(manager.find_module_cache.results),
24552476
)
24562477
if not graph:
2457-
print("Nothing to do?!")
2478+
print("Nothing to do?!", file=stdout)
24582479
return graph
24592480
manager.log("Loaded graph with %d nodes (%.3f sec)" % (len(graph), t1 - t0))
24602481
if manager.options.dump_graph:
2461-
dump_graph(graph)
2482+
dump_graph(graph, stdout)
24622483
return graph
24632484

24642485
# Fine grained dependencies that didn't have an associated module in the build
@@ -2480,7 +2501,7 @@ def dispatch(sources: List[BuildSource], manager: BuildManager) -> Graph:
24802501
manager.log("Error reading fine-grained dependencies cache -- aborting cache load")
24812502
manager.cache_enabled = False
24822503
manager.log("Falling back to full run -- reloading graph...")
2483-
return dispatch(sources, manager)
2504+
return dispatch(sources, manager, stdout)
24842505

24852506
# If we are loading a fine-grained incremental mode cache, we
24862507
# don't want to do a real incremental reprocess of the
@@ -2528,7 +2549,7 @@ def dumps(self) -> str:
25282549
json.dumps(self.deps))
25292550

25302551

2531-
def dump_graph(graph: Graph) -> None:
2552+
def dump_graph(graph: Graph, stdout: TextIO = sys.stdout) -> None:
25322553
"""Dump the graph as a JSON string to stdout.
25332554
25342555
This copies some of the work by process_graph()
@@ -2562,7 +2583,7 @@ def dump_graph(graph: Graph) -> None:
25622583
if (dep_id != node.node_id and
25632584
(dep_id not in node.deps or pri < node.deps[dep_id])):
25642585
node.deps[dep_id] = pri
2565-
print("[" + ",\n ".join(node.dumps() for node in nodes) + "\n]")
2586+
print("[" + ",\n ".join(node.dumps() for node in nodes) + "\n]", file=stdout)
25662587

25672588

25682589
def load_graph(sources: List[BuildSource], manager: BuildManager,

mypy/dmypy_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def cmd_run(self, version: str, args: Sequence[str]) -> Dict[str, object]:
308308
if self.fine_grained_manager:
309309
manager = self.fine_grained_manager.manager
310310
start_plugins_snapshot = manager.plugins_snapshot
311-
_, current_plugins_snapshot = mypy.build.load_plugins(options, manager.errors)
311+
_, current_plugins_snapshot = mypy.build.load_plugins(options,
312+
manager.errors,
313+
sys.stdout)
312314
if current_plugins_snapshot != start_plugins_snapshot:
313315
return {'restart': 'plugins changed'}
314316
except InvalidSourceList as err:

mypy/errors.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import traceback
44
from collections import OrderedDict, defaultdict
55

6-
from typing import Tuple, List, TypeVar, Set, Dict, Optional
6+
from typing import Tuple, List, TypeVar, Set, Dict, Optional, TextIO
77

88
from mypy.scope import Scope
99
from mypy.options import Options
@@ -585,19 +585,27 @@ def remove_path_prefix(path: str, prefix: Optional[str]) -> str:
585585
return path
586586

587587

588-
def report_internal_error(err: Exception, file: Optional[str], line: int,
589-
errors: Errors, options: Options) -> None:
588+
def report_internal_error(err: Exception,
589+
file: Optional[str],
590+
line: int,
591+
errors: Errors,
592+
options: Options,
593+
stdout: Optional[TextIO] = None,
594+
stderr: Optional[TextIO] = None,
595+
) -> None:
590596
"""Report internal error and exit.
591597
592598
This optionally starts pdb or shows a traceback.
593599
"""
600+
stdout = (stdout or sys.stdout)
601+
stderr = (stderr or sys.stderr)
594602
# Dump out errors so far, they often provide a clue.
595603
# But catch unexpected errors rendering them.
596604
try:
597605
for msg in errors.new_messages():
598606
print(msg)
599607
except Exception as e:
600-
print("Failed to dump errors:", repr(e), file=sys.stderr)
608+
print("Failed to dump errors:", repr(e), file=stderr)
601609

602610
# Compute file:line prefix for official-looking error messages.
603611
if file:
@@ -612,11 +620,11 @@ def report_internal_error(err: Exception, file: Optional[str], line: int,
612620
print('{}error: INTERNAL ERROR --'.format(prefix),
613621
'please report a bug at https://github.com/python/mypy/issues',
614622
'version: {}'.format(mypy_version),
615-
file=sys.stderr)
623+
file=stderr)
616624

617625
# If requested, drop into pdb. This overrides show_tb.
618626
if options.pdb:
619-
print('Dropping into pdb', file=sys.stderr)
627+
print('Dropping into pdb', file=stderr)
620628
import pdb
621629
pdb.post_mortem(sys.exc_info()[2])
622630

@@ -627,15 +635,15 @@ def report_internal_error(err: Exception, file: Optional[str], line: int,
627635
if not options.pdb:
628636
print('{}: note: please use --show-traceback to print a traceback '
629637
'when reporting a bug'.format(prefix),
630-
file=sys.stderr)
638+
file=stderr)
631639
else:
632640
tb = traceback.extract_stack()[:-2]
633641
tb2 = traceback.extract_tb(sys.exc_info()[2])
634642
print('Traceback (most recent call last):')
635643
for s in traceback.format_list(tb + tb2):
636644
print(s.rstrip('\n'))
637-
print('{}: {}'.format(type(err).__name__, err))
638-
print('{}: note: use --pdb to drop into pdb'.format(prefix), file=sys.stderr)
645+
print('{}: {}'.format(type(err).__name__, err), file=stdout)
646+
print('{}: note: use --pdb to drop into pdb'.format(prefix), file=stderr)
639647

640648
# Exit. The caller has nothing more to say.
641649
# We use exit code 2 to signal that this is no ordinary error.

0 commit comments

Comments
 (0)