Skip to content

bpo-45046: Support context managers in unittest #28045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 8, 2022
42 changes: 42 additions & 0 deletions Doc/library/unittest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,16 @@ Test cases
.. versionadded:: 3.1


.. method:: enterContext(cm)

Enter the supplied :term:`context manager`. If successful, also
add its :meth:`~object.__exit__` method as a cleanup function by
:meth:`addCleanup` and return the result of the
:meth:`~object.__enter__` method.

.. versionadded:: 3.11


.. method:: doCleanups()

This method is called unconditionally after :meth:`tearDown`, or
Expand All @@ -1510,6 +1520,7 @@ Test cases

.. versionadded:: 3.1


.. classmethod:: addClassCleanup(function, /, *args, **kwargs)

Add a function to be called after :meth:`tearDownClass` to cleanup
Expand All @@ -1524,6 +1535,16 @@ Test cases
.. versionadded:: 3.8


.. classmethod:: enterClassContext(cm)

Enter the supplied :term:`context manager`. If successful, also
add its :meth:`~object.__exit__` method as a cleanup function by
:meth:`addClassCleanup` and return the result of the
:meth:`~object.__enter__` method.

.. versionadded:: 3.11


.. classmethod:: doClassCleanups()

This method is called unconditionally after :meth:`tearDownClass`, or
Expand Down Expand Up @@ -1571,6 +1592,16 @@ Test cases

This method accepts a coroutine that can be used as a cleanup function.

.. coroutinemethod:: enterAsyncContext(cm)

Enter the supplied :term:`asynchronous context manager`. If successful,
also add its :meth:`~object.__aexit__` method as a cleanup function by
:meth:`addAsyncCleanup` and return the result of the
:meth:`~object.__aenter__` method.

.. versionadded:: 3.11


.. method:: run(result=None)

Sets up a new event loop to run the test, collecting the result into
Expand Down Expand Up @@ -2465,6 +2496,16 @@ To add cleanup code that must be run even in the case of an exception, use
.. versionadded:: 3.8


.. classmethod:: enterModuleContext(cm)

Enter the supplied :term:`context manager`. If successful, also
add its :meth:`~object.__exit__` method as a cleanup function by
:func:`addModuleCleanup` and return the result of the
:meth:`~object.__enter__` method.

.. versionadded:: 3.11


.. function:: doModuleCleanups()

This function is called unconditionally after :func:`tearDownModule`, or
Expand All @@ -2480,6 +2521,7 @@ To add cleanup code that must be run even in the case of an exception, use

.. versionadded:: 3.8


Signal Handling
---------------

Expand Down
12 changes: 12 additions & 0 deletions Doc/whatsnew/3.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,18 @@ unicodedata
* The Unicode database has been updated to version 14.0.0. (:issue:`45190`).


unittest
--------

* Added methods :meth:`~unittest.TestCase.enterContext` and
:meth:`~unittest.TestCase.enterClassContext` of class
:class:`~unittest.TestCase`, method
:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of
class :class:`~unittest.IsolatedAsyncioTestCase` and function
:func:`unittest.enterModuleContext`.
(Contributed by Serhiy Storchaka in :issue:`45046`.)


venv
----

Expand Down
4 changes: 1 addition & 3 deletions Lib/distutils/tests/test_build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def setUp(self):
# bpo-30132: On Windows, a .pdb file may be created in the current
# working directory. Create a temporary working directory to cleanup
# everything at the end of the test.
change_cwd = os_helper.change_cwd(self.tmp_dir)
change_cwd.__enter__()
self.addCleanup(change_cwd.__exit__, None, None, None)
self.enterContext(os_helper.change_cwd(self.tmp_dir))

def tearDown(self):
import site
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test__osx_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def setUp(self):
self.maxDiff = None
self.prog_name = 'bogus_program_xxxx'
self.temp_path_dir = os.path.abspath(os.getcwd())
self.env = os_helper.EnvironmentVarGuard()
self.addCleanup(self.env.__exit__)
self.env = self.enterContext(os_helper.EnvironmentVarGuard())
for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS',
'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC',
'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS',
Expand Down
6 changes: 2 additions & 4 deletions Lib/test/test_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def setUp(self):
# The tests assume that line wrapping occurs at 80 columns, but this
# behaviour can be overridden by setting the COLUMNS environment
# variable. To ensure that this width is used, set COLUMNS to 80.
env = os_helper.EnvironmentVarGuard()
env = self.enterContext(os_helper.EnvironmentVarGuard())
env['COLUMNS'] = '80'
self.addCleanup(env.__exit__)


class TempDirMixin(object):
Expand Down Expand Up @@ -3428,9 +3427,8 @@ class TestShortColumns(HelpTestCase):
but we don't want any exceptions thrown in such cases. Only ugly representation.
'''
def setUp(self):
env = os_helper.EnvironmentVarGuard()
env = self.enterContext(os_helper.EnvironmentVarGuard())
env.set("COLUMNS", '15')
self.addCleanup(env.__exit__)

parser_signature = TestHelpBiggerOptionals.parser_signature
argument_signatures = TestHelpBiggerOptionals.argument_signatures
Expand Down
6 changes: 1 addition & 5 deletions Lib/test/test_getopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@

class GetoptTests(unittest.TestCase):
def setUp(self):
self.env = EnvironmentVarGuard()
self.env = self.enterContext(EnvironmentVarGuard())
if "POSIXLY_CORRECT" in self.env:
del self.env["POSIXLY_CORRECT"]

def tearDown(self):
self.env.__exit__()
del self.env

def assertError(self, *args, **kwargs):
self.assertRaises(getopt.GetoptError, *args, **kwargs)

Expand Down
7 changes: 2 additions & 5 deletions Lib/test/test_gettext.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@

class GettextBaseTest(unittest.TestCase):
def setUp(self):
self.addCleanup(os_helper.rmtree, os.path.split(LOCALEDIR)[0])
if not os.path.isdir(LOCALEDIR):
os.makedirs(LOCALEDIR)
with open(MOFILE, 'wb') as fp:
Expand All @@ -129,14 +130,10 @@ def setUp(self):
fp.write(base64.decodebytes(UMO_DATA))
with open(MMOFILE, 'wb') as fp:
fp.write(base64.decodebytes(MMO_DATA))
self.env = os_helper.EnvironmentVarGuard()
self.env = self.enterContext(os_helper.EnvironmentVarGuard())
self.env['LANGUAGE'] = 'xx'
gettext._translations.clear()

def tearDown(self):
self.env.__exit__()
del self.env
os_helper.rmtree(os.path.split(LOCALEDIR)[0])

GNU_MO_DATA_ISSUE_17898 = b'''\
3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z
Expand Down
11 changes: 2 additions & 9 deletions Lib/test/test_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@
class GlobalTests(unittest.TestCase):

def setUp(self):
self._warnings_manager = check_warnings()
self._warnings_manager.__enter__()
self.enterContext(check_warnings())
warnings.filterwarnings("error", module="<test string>")

def tearDown(self):
self._warnings_manager.__exit__(None, None, None)


def test1(self):
prog_text_1 = """\
def wrong1():
Expand Down Expand Up @@ -54,9 +49,7 @@ def test4(self):


def setUpModule():
cm = warnings.catch_warnings()
cm.__enter__()
unittest.addModuleCleanup(cm.__exit__, None, None, None)
unittest.enterModuleContext(warnings.catch_warnings())
warnings.filterwarnings("error", module="<test string>")


Expand Down
19 changes: 5 additions & 14 deletions Lib/test/test_importlib/source/test_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,12 @@ def test_dir_removal_handling(self):
def test_no_read_directory(self):
# Issue #16730
tempdir = tempfile.TemporaryDirectory()
self.enterContext(tempdir)
# Since we muck with the permissions, we want to set them back to
# their original values to make sure the directory can be properly
# cleaned up.
original_mode = os.stat(tempdir.name).st_mode
def cleanup(tempdir):
"""Cleanup function for the temporary directory.

Since we muck with the permissions, we want to set them back to
their original values to make sure the directory can be properly
cleaned up.

"""
os.chmod(tempdir.name, original_mode)
# If this is not explicitly called then the __del__ method is used,
# but since already mucking around might as well explicitly clean
# up.
tempdir.__exit__(None, None, None)
self.addCleanup(cleanup, tempdir)
self.addCleanup(os.chmod, tempdir.name, original_mode)
os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR)
finder = self.get_finder(tempdir.name)
found = self._find(finder, 'doesnotexist')
Expand Down
7 changes: 1 addition & 6 deletions Lib/test/test_importlib/test_namespace_pkgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ def setUp(self):
self.resolved_paths = [
os.path.join(self.root, path) for path in self.paths
]
self.ctx = namespace_tree_context(path=self.resolved_paths)
self.ctx.__enter__()

def tearDown(self):
# TODO: will we ever want to pass exc_info to __exit__?
self.ctx.__exit__(None, None, None)
self.enterContext(namespace_tree_context(path=self.resolved_paths))


class SingleNamespacePackage(NamespacePackageTest):
Expand Down
4 changes: 1 addition & 3 deletions Lib/test/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5650,9 +5650,7 @@ def test__all__(self):
# why the test does this, but in any case we save the current locale
# first and restore it at the end.
def setUpModule():
cm = support.run_with_locale('LC_ALL', '')
cm.__enter__()
unittest.addModuleCleanup(cm.__exit__, None, None, None)
unittest.enterModuleContext(support.run_with_locale('LC_ALL', ''))


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test_nntplib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,8 +1593,7 @@ def setUp(self):
self.background.start()
self.addCleanup(self.background.join)

self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__()
self.addCleanup(self.nntp.__exit__, None, None, None)
self.nntp = self.enterContext(NNTP(socket_helper.HOST, port, usenetrc=False))

def run_server(self, sock):
# Could be generalized to handle more commands in separate methods
Expand Down
4 changes: 1 addition & 3 deletions Lib/test/test_peg_generator/test_c_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def setUp(self):
self.skipTest("The %r command is not found" % cmd)
self.old_cwd = os.getcwd()
self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base)
change_cwd = os_helper.change_cwd(self.tmp_path)
change_cwd.__enter__()
self.addCleanup(change_cwd.__exit__, None, None, None)
self.enterContext(os_helper.change_cwd(self.tmp_path))

def tearDown(self):
os.chdir(self.old_cwd)
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def test_poll2(self):
cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
bufsize=0)
proc.__enter__()
self.addCleanup(proc.__exit__, None, None, None)
self.enterContext(proc)
p = proc.stdout
pollster = select.poll()
pollster.register( p, select.POLLIN )
Expand Down
12 changes: 3 additions & 9 deletions Lib/test/test_posix.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,13 @@ class PosixTester(unittest.TestCase):

def setUp(self):
# create empty file
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
with open(os_helper.TESTFN, "wb"):
pass
self.teardown_files = [ os_helper.TESTFN ]
self._warnings_manager = warnings_helper.check_warnings()
self._warnings_manager.__enter__()
self.enterContext(warnings_helper.check_warnings())
warnings.filterwarnings('ignore', '.* potential security risk .*',
RuntimeWarning)

def tearDown(self):
for teardown_file in self.teardown_files:
os_helper.unlink(teardown_file)
self._warnings_manager.__exit__(None, None, None)

def testNoArgFunctions(self):
# test posix functions which take no arguments and have
# no side-effects which we need to cleanup (e.g., fork, wait, abort)
Expand Down Expand Up @@ -973,8 +967,8 @@ def test_lchflags_symlink(self):

self.assertTrue(hasattr(testfn_st, 'st_flags'))

self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK)
os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK)
self.teardown_files.append(_DUMMY_SYMLINK)
dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)

def chflags_nofollow(path, flags):
Expand Down
6 changes: 1 addition & 5 deletions Lib/test/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,18 +1022,14 @@ def test_repr(self):

class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
def setUp(self):
self._warning_filters = warnings_helper.check_warnings()
self._warning_filters.__enter__()
self.enterContext(warnings_helper.check_warnings())
warnings.simplefilter('ignore', BytesWarning)
self.case = "string and bytes set"
self.values = ["a", "b", b"a", b"b"]
self.set = set(self.values)
self.dup = set(self.values)
self.length = 4

def tearDown(self):
self._warning_filters.__exit__(None, None, None)

def test_repr(self):
self.check_repr_against_values()

Expand Down
4 changes: 1 addition & 3 deletions Lib/test/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,7 @@ def serverExplicitReady(self):
self.server_ready.set()

def _setUp(self):
self.wait_threads = threading_helper.wait_threads_exit()
self.wait_threads.__enter__()
self.addCleanup(self.wait_threads.__exit__, None, None, None)
self.enterContext(threading_helper.wait_threads_exit())

self.server_ready = threading.Event()
self.client_ready = threading.Event()
Expand Down
6 changes: 2 additions & 4 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,9 +1999,8 @@ def setUp(self):
self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.server_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=self.server_context)
self.enterContext(server)
self.server_addr = (HOST, server.port)
server.__enter__()
self.addCleanup(server.__exit__, None, None, None)

def test_connect(self):
with test_wrap_socket(socket.socket(socket.AF_INET),
Expand Down Expand Up @@ -3713,8 +3712,7 @@ def _recvfrom_into():

def test_recv_zero(self):
server = ThreadedEchoServer(CERTFILE)
server.__enter__()
self.addCleanup(server.__exit__, None, None)
self.enterContext(server)
s = socket.create_connection((HOST, server.port))
self.addCleanup(s.close)
s = test_wrap_socket(s, suppress_ragged_eofs=False)
Expand Down
6 changes: 1 addition & 5 deletions Lib/test/test_tempfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,10 @@ class BaseTestCase(unittest.TestCase):
b_check = re.compile(br"^[a-z0-9_-]{8}$")

def setUp(self):
self._warnings_manager = warnings_helper.check_warnings()
self._warnings_manager.__enter__()
self.enterContext(warnings_helper.check_warnings())
warnings.filterwarnings("ignore", category=RuntimeWarning,
message="mktemp", module=__name__)

def tearDown(self):
self._warnings_manager.__exit__(None, None, None)

def nameCheck(self, name, dir, pre, suf):
(ndir, nbase) = os.path.split(name)
npre = nbase[:len(pre)]
Expand Down
Loading