From 78ad25bffb927df688ee4855104b4d18215fc74b Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Thu, 19 Aug 2021 13:38:12 +0300 Subject: [PATCH 1/4] bpo-45046: Support context managers in unittest Add methods enterContext() and enterClassContext() in TestCase. Add method enterAsyncContext() in IsolatedAsyncioTestCase. Add function enterModuleContext(). --- Doc/library/unittest.rst | 42 +++++++ Doc/whatsnew/3.11.rst | 12 ++ Lib/distutils/tests/test_build_ext.py | 4 +- Lib/test/test__osx_support.py | 3 +- Lib/test/test_argparse.py | 6 +- Lib/test/test_getopt.py | 6 +- Lib/test/test_gettext.py | 7 +- Lib/test/test_global.py | 7 +- Lib/test/test_importlib/source/test_finder.py | 19 +-- .../test_importlib/test_namespace_pkgs.py | 7 +- Lib/test/test_nntplib.py | 2 +- Lib/test/test_peg_generator/test_c_parser.py | 4 +- Lib/test/test_poll.py | 3 +- Lib/test/test_posix.py | 12 +- Lib/test/test_set.py | 6 +- Lib/test/test_ssl.py | 6 +- Lib/test/test_tempfile.py | 6 +- Lib/test/test_urllib.py | 7 +- Lib/unittest/__init__.py | 4 +- Lib/unittest/async_case.py | 20 ++++ Lib/unittest/case.py | 32 +++++ Lib/unittest/test/test_async_case.py | 51 ++++++++ Lib/unittest/test/test_runner.py | 110 ++++++++++++++++++ .../2021-08-29-19-59-16.bpo-45046.eGq0NC.rst | 7 ++ 24 files changed, 301 insertions(+), 82 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst index f0fba94677a917..9c7d9ed5cfa8ad 100644 --- a/Doc/library/unittest.rst +++ b/Doc/library/unittest.rst @@ -1471,6 +1471,16 @@ Test cases .. versionadded:: 3.1 + .. method:: enterContext(cm) + + Enter the supplied :term:`context manager`. If successful, also + add its :meth:`~object.__exit__` method as a cleanup function by + :meth:`addCleanup` and return the result of the + :meth:`~object.__enter__` method. + + .. versionadded:: 3.11 + + .. method:: doCleanups() This method is called unconditionally after :meth:`tearDown`, or @@ -1486,6 +1496,7 @@ Test cases .. versionadded:: 3.1 + .. classmethod:: addClassCleanup(function, /, *args, **kwargs) Add a function to be called after :meth:`tearDownClass` to cleanup @@ -1500,6 +1511,16 @@ Test cases .. versionadded:: 3.8 + .. classmethod:: enterClassContext(cm) + + Enter the supplied :term:`context manager`. If successful, also + add its :meth:`~object.__exit__` method as a cleanup function by + :meth:`addClassCleanup` and return the result of the + :meth:`~object.__enter__` method. + + .. versionadded:: 3.11 + + .. classmethod:: doClassCleanups() This method is called unconditionally after :meth:`tearDownClass`, or @@ -1547,6 +1568,16 @@ Test cases This method accepts a coroutine that can be used as a cleanup function. + .. coroutinemethod:: enterAsyncContext(cm) + + Enter the supplied :term:`asynchronous context manager`. If successful, + also add its :meth:`~object.__aexit__` method as a cleanup function by + :meth:`addAsyncCleanup` and return the result of the + :meth:`~object.__aenter__` method. + + .. versionadded:: 3.11 + + .. method:: run(result=None) Sets up a new event loop to run the test, collecting the result into @@ -2437,6 +2468,16 @@ To add cleanup code that must be run even in the case of an exception, use .. versionadded:: 3.8 +.. classmethod:: enterModuleContext(cm) + + Enter the supplied :term:`context manager`. If successful, also + add its :meth:`~object.__exit__` method as a cleanup function by + :func:`addModuleCleanup` and return the result of the + :meth:`~object.__enter__` method. + + .. versionadded:: 3.11 + + .. function:: doModuleCleanups() This function is called unconditionally after :func:`tearDownModule`, or @@ -2452,6 +2493,7 @@ To add cleanup code that must be run even in the case of an exception, use .. versionadded:: 3.8 + Signal Handling --------------- diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index 306385c2a90aaf..5ec8ea0a468f02 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -227,6 +227,18 @@ sqlite3 (Contributed by Erlend E. Aasland in :issue:`44688`.) +unittest +-------- + +* Added methods :meth:`~unittest.TestCase.enterContext` and + :meth:`~unittest.TestCase.enterClassContext` of class + :class:`~unittest.TestCase`, method + :meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of + class :class:`~unittest.IsolatedAsyncioTestCase` and function + :func:`unittest.enterModuleContext`. + (Contributed by Serhiy Storchaka in :issue:`XXXXX`.) + + Removed ======= * :class:`smtpd.MailmanProxy` is now removed as it is unusable without diff --git a/Lib/distutils/tests/test_build_ext.py b/Lib/distutils/tests/test_build_ext.py index 8e7364d2a2cb5f..456186b7683ece 100644 --- a/Lib/distutils/tests/test_build_ext.py +++ b/Lib/distutils/tests/test_build_ext.py @@ -40,9 +40,7 @@ def setUp(self): # bpo-30132: On Windows, a .pdb file may be created in the current # working directory. Create a temporary working directory to cleanup # everything at the end of the test. - change_cwd = os_helper.change_cwd(self.tmp_dir) - change_cwd.__enter__() - self.addCleanup(change_cwd.__exit__, None, None, None) + self.enterContext(os_helper.change_cwd(self.tmp_dir)) def tearDown(self): import site diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py index 907ae27d529b50..4a14cb352138ef 100644 --- a/Lib/test/test__osx_support.py +++ b/Lib/test/test__osx_support.py @@ -19,8 +19,7 @@ def setUp(self): self.maxDiff = None self.prog_name = 'bogus_program_xxxx' self.temp_path_dir = os.path.abspath(os.getcwd()) - self.env = os_helper.EnvironmentVarGuard() - self.addCleanup(self.env.__exit__) + self.env = self.enterContext(os_helper.EnvironmentVarGuard()) for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS', 'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC', 'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS', diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index d369d0fb28f4a3..685f16fb038243 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -24,9 +24,8 @@ def setUp(self): # The tests assume that line wrapping occurs at 80 columns, but this # behaviour can be overridden by setting the COLUMNS environment # variable. To ensure that this width is used, set COLUMNS to 80. - env = os_helper.EnvironmentVarGuard() + env = self.enterContext(os_helper.EnvironmentVarGuard()) env['COLUMNS'] = '80' - self.addCleanup(env.__exit__) class TempDirMixin(object): @@ -3290,9 +3289,8 @@ class TestShortColumns(HelpTestCase): but we don't want any exceptions thrown in such cases. Only ugly representation. ''' def setUp(self): - env = os_helper.EnvironmentVarGuard() + env = self.enterContext(os_helper.EnvironmentVarGuard()) env.set("COLUMNS", '15') - self.addCleanup(env.__exit__) parser_signature = TestHelpBiggerOptionals.parser_signature argument_signatures = TestHelpBiggerOptionals.argument_signatures diff --git a/Lib/test/test_getopt.py b/Lib/test/test_getopt.py index 9261276ebb9726..64b9ce01e05ea2 100644 --- a/Lib/test/test_getopt.py +++ b/Lib/test/test_getopt.py @@ -11,14 +11,10 @@ class GetoptTests(unittest.TestCase): def setUp(self): - self.env = EnvironmentVarGuard() + self.env = self.enterContext(EnvironmentVarGuard()) if "POSIXLY_CORRECT" in self.env: del self.env["POSIXLY_CORRECT"] - def tearDown(self): - self.env.__exit__() - del self.env - def assertError(self, *args, **kwargs): self.assertRaises(getopt.GetoptError, *args, **kwargs) diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py index 467652a41f0cd6..1608d1b18e98fb 100644 --- a/Lib/test/test_gettext.py +++ b/Lib/test/test_gettext.py @@ -117,6 +117,7 @@ class GettextBaseTest(unittest.TestCase): def setUp(self): + self.addCleanup(os_helper.rmtree, os.path.split(LOCALEDIR)[0]) if not os.path.isdir(LOCALEDIR): os.makedirs(LOCALEDIR) with open(MOFILE, 'wb') as fp: @@ -129,14 +130,10 @@ def setUp(self): fp.write(base64.decodebytes(UMO_DATA)) with open(MMOFILE, 'wb') as fp: fp.write(base64.decodebytes(MMO_DATA)) - self.env = os_helper.EnvironmentVarGuard() + self.env = self.enterContext(os_helper.EnvironmentVarGuard()) self.env['LANGUAGE'] = 'xx' gettext._translations.clear() - def tearDown(self): - self.env.__exit__() - del self.env - os_helper.rmtree(os.path.split(LOCALEDIR)[0]) GNU_MO_DATA_ISSUE_17898 = b'''\ 3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py index c71d055297e0c9..71dd55be2d448a 100644 --- a/Lib/test/test_global.py +++ b/Lib/test/test_global.py @@ -9,14 +9,9 @@ class GlobalTests(unittest.TestCase): def setUp(self): - self._warnings_manager = check_warnings() - self._warnings_manager.__enter__() + self.enterContext(check_warnings()) warnings.filterwarnings("error", module="") - def tearDown(self): - self._warnings_manager.__exit__(None, None, None) - - def test1(self): prog_text_1 = """\ def wrong1(): diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py index 80e930cc6a1f28..3b98e3bf62381b 100644 --- a/Lib/test/test_importlib/source/test_finder.py +++ b/Lib/test/test_importlib/source/test_finder.py @@ -158,21 +158,12 @@ def test_dir_removal_handling(self): def test_no_read_directory(self): # Issue #16730 tempdir = tempfile.TemporaryDirectory() + self.enterContext(tempdir) + # Since we muck with the permissions, we want to set them back to + # their original values to make sure the directory can be properly + # cleaned up. original_mode = os.stat(tempdir.name).st_mode - def cleanup(tempdir): - """Cleanup function for the temporary directory. - - Since we muck with the permissions, we want to set them back to - their original values to make sure the directory can be properly - cleaned up. - - """ - os.chmod(tempdir.name, original_mode) - # If this is not explicitly called then the __del__ method is used, - # but since already mucking around might as well explicitly clean - # up. - tempdir.__exit__(None, None, None) - self.addCleanup(cleanup, tempdir) + self.addCleanup(os.chmod, tempdir.name, original_mode) os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR) finder = self.get_finder(tempdir.name) found = self._find(finder, 'doesnotexist') diff --git a/Lib/test/test_importlib/test_namespace_pkgs.py b/Lib/test/test_importlib/test_namespace_pkgs.py index 3fe3ddc5898448..39acd7bf613a00 100644 --- a/Lib/test/test_importlib/test_namespace_pkgs.py +++ b/Lib/test/test_importlib/test_namespace_pkgs.py @@ -62,12 +62,7 @@ def setUp(self): self.resolved_paths = [ os.path.join(self.root, path) for path in self.paths ] - self.ctx = namespace_tree_context(path=self.resolved_paths) - self.ctx.__enter__() - - def tearDown(self): - # TODO: will we ever want to pass exc_info to __exit__? - self.ctx.__exit__(None, None, None) + self.enterContext(namespace_tree_context(path=self.resolved_paths)) class SingleNamespacePackage(NamespacePackageTest): diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index 4f0592188f8443..ad1c293ac22c1c 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -1594,7 +1594,7 @@ def setUp(self): self.addCleanup(self.background.join) self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__() - self.addCleanup(self.nntp.__exit__, None, None, None) + self.enterContext(self.nntp) def run_server(self, sock): # Could be generalized to handle more commands in separate methods diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py index b761bd493f52c7..e220c7bda715df 100644 --- a/Lib/test/test_peg_generator/test_c_parser.py +++ b/Lib/test/test_peg_generator/test_c_parser.py @@ -78,9 +78,7 @@ def setUp(self): self.skipTest("The %r command is not found" % cmd) self.old_cwd = os.getcwd() self.tmp_path = tempfile.mkdtemp() - change_cwd = os_helper.change_cwd(self.tmp_path) - change_cwd.__enter__() - self.addCleanup(change_cwd.__exit__, None, None, None) + self.enterContext(os_helper.change_cwd(self.tmp_path)) def tearDown(self): os.chdir(self.old_cwd) diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py index de62350696a920..e496925ad8f16c 100644 --- a/Lib/test/test_poll.py +++ b/Lib/test/test_poll.py @@ -124,8 +124,7 @@ def test_poll2(self): cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, bufsize=0) - proc.__enter__() - self.addCleanup(proc.__exit__, None, None, None) + self.enterContext(proc) p = proc.stdout pollster = select.poll() pollster.register( p, select.POLLIN ) diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py index e4666884ce06a1..cc8116a209287f 100644 --- a/Lib/test/test_posix.py +++ b/Lib/test/test_posix.py @@ -45,19 +45,13 @@ class PosixTester(unittest.TestCase): def setUp(self): # create empty file + self.addCleanup(os_helper.unlink, os_helper.TESTFN) with open(os_helper.TESTFN, "wb"): pass - self.teardown_files = [ os_helper.TESTFN ] - self._warnings_manager = warnings_helper.check_warnings() - self._warnings_manager.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.filterwarnings('ignore', '.* potential security risk .*', RuntimeWarning) - def tearDown(self): - for teardown_file in self.teardown_files: - os_helper.unlink(teardown_file) - self._warnings_manager.__exit__(None, None, None) - def testNoArgFunctions(self): # test posix functions which take no arguments and have # no side-effects which we need to cleanup (e.g., fork, wait, abort) @@ -958,8 +952,8 @@ def test_lchflags_symlink(self): self.assertTrue(hasattr(testfn_st, 'st_flags')) + self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK) os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK) - self.teardown_files.append(_DUMMY_SYMLINK) dummy_symlink_st = os.lstat(_DUMMY_SYMLINK) def chflags_nofollow(path, flags): diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 29bb39df76c8a5..fa93cc20eeeb58 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -955,8 +955,7 @@ def test_repr(self): class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): def setUp(self): - self._warning_filters = warnings_helper.check_warnings() - self._warning_filters.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.simplefilter('ignore', BytesWarning) self.case = "string and bytes set" self.values = ["a", "b", b"a", b"b"] @@ -964,9 +963,6 @@ def setUp(self): self.dup = set(self.values) self.length = 4 - def tearDown(self): - self._warning_filters.__exit__(None, None, None) - def test_repr(self): self.check_repr_against_values() diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index d9e184365ce082..76c1b70fafce3e 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1995,9 +1995,8 @@ def setUp(self): self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) self.server_context.load_cert_chain(SIGNED_CERTFILE) server = ThreadedEchoServer(context=self.server_context) + self.enterContext(server) self.server_addr = (HOST, server.port) - server.__enter__() - self.addCleanup(server.__exit__, None, None, None) def test_connect(self): with test_wrap_socket(socket.socket(socket.AF_INET), @@ -3708,8 +3707,7 @@ def _recvfrom_into(): def test_recv_zero(self): server = ThreadedEchoServer(CERTFILE) - server.__enter__() - self.addCleanup(server.__exit__, None, None) + self.enterContext(server) s = socket.create_connection((HOST, server.port)) self.addCleanup(s.close) s = test_wrap_socket(s, suppress_ragged_eofs=False) diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py index f1d483733e2675..bd85c8030dd6da 100644 --- a/Lib/test/test_tempfile.py +++ b/Lib/test/test_tempfile.py @@ -71,14 +71,10 @@ class BaseTestCase(unittest.TestCase): b_check = re.compile(br"^[a-z0-9_-]{8}$") def setUp(self): - self._warnings_manager = warnings_helper.check_warnings() - self._warnings_manager.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.filterwarnings("ignore", category=RuntimeWarning, message="mktemp", module=__name__) - def tearDown(self): - self._warnings_manager.__exit__(None, None, None) - def nameCheck(self, name, dir, pre, suf): (ndir, nbase) = os.path.split(name) npre = nbase[:len(pre)] diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py index 82f1d9dc2e7bb3..bc6e74c291ac1c 100644 --- a/Lib/test/test_urllib.py +++ b/Lib/test/test_urllib.py @@ -232,17 +232,12 @@ class ProxyTests(unittest.TestCase): def setUp(self): # Records changes to env vars - self.env = os_helper.EnvironmentVarGuard() + self.env = self.enterContext(os_helper.EnvironmentVarGuard()) # Delete all proxy related env vars for k in list(os.environ): if 'proxy' in k.lower(): self.env.unset(k) - def tearDown(self): - # Restore all proxy related env vars - self.env.__exit__() - del self.env - def test_getproxies_environment_keep_no_proxies(self): self.env.set('NO_PROXY', 'localhost') proxies = urllib.request.getproxies_environment() diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index 348dc471f4c3d4..d54af80d9c4cd2 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -49,7 +49,7 @@ def testMultiply(self): 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 'expectedFailure', 'TextTestResult', 'installHandler', 'registerResult', 'removeResult', 'removeHandler', - 'addModuleCleanup'] + 'addModuleCleanup', 'enterModuleContext'] # Expose obsolete functions for backwards compatibility __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) @@ -58,7 +58,7 @@ def testMultiply(self): from .result import TestResult from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip, - skipIf, skipUnless, expectedFailure) + skipIf, skipUnless, expectedFailure, enterModuleContext) from .suite import BaseTestSuite, TestSuite from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, findTestCases) diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py index bfc68a76e84d93..8880e3894e8116 100644 --- a/Lib/unittest/async_case.py +++ b/Lib/unittest/async_case.py @@ -58,6 +58,26 @@ def addAsyncCleanup(self, func, /, *args, **kwargs): # 3. Regular "def func()" that returns awaitable object self.addCleanup(*(func, *args), **kwargs) + async def enterAsyncContext(self, cm): + """Enters the supplied asynchronous context manager. + + If successful, also adds its __aexit__ method as a cleanup + function and returns the result of the __aenter__ method. + """ + # We look up the special methods on the type to match the with + # statement. + cls = type(cm) + try: + enter = cls.__aenter__ + exit = cls.__aexit__ + except AttributeError: + raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " + f"not support the asynchronous context manager protocol" + ) from None + result = await enter(cm) + self.addAsyncCleanup(exit, cm, None, None, None) + return result + def _callSetUp(self): self.setUp() self._callAsync(self.asyncSetUp) diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 8775ba9e241c44..173d81e0620717 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -85,12 +85,31 @@ def _id(obj): return obj +def _enter_context(cm, addcleanup): + # We look up the special methods on the type to match the with + # statement. + cls = type(cm) + try: + enter = cls.__enter__ + exit = cls.__exit__ + except AttributeError: + raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " + f"not support the context manager protocol") from None + result = enter(cm) + addcleanup(exit, cm, None, None, None) + return result + + _module_cleanups = [] def addModuleCleanup(function, /, *args, **kwargs): """Same as addCleanup, except the cleanup items are called even if setUpModule fails (unlike tearDownModule).""" _module_cleanups.append((function, args, kwargs)) +def enterModuleContext(cm): + """Same as enterContext, but module-wide.""" + return _enter_context(cm, addModuleCleanup) + def doModuleCleanups(): """Execute all module cleanup functions. Normally called for you after @@ -409,12 +428,25 @@ def addCleanup(self, function, /, *args, **kwargs): Cleanup items are called even if setUp fails (unlike tearDown).""" self._cleanups.append((function, args, kwargs)) + def enterContext(self, cm): + """Enters the supplied context manager. + + If successful, also adds its __exit__ method as a cleanup + function and returns the result of the __enter__ method. + """ + return _enter_context(cm, self.addCleanup) + @classmethod def addClassCleanup(cls, function, /, *args, **kwargs): """Same as addCleanup, except the cleanup items are called even if setUpClass fails (unlike tearDownClass).""" cls._class_cleanups.append((function, args, kwargs)) + @classmethod + def enterClassContext(cls, cm): + """Same as enterContext, but class-wide.""" + return _enter_context(cm, cls.addClassCleanup) + def setUp(self): "Hook method for setting up the test fixture before exercising it." pass diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py index 93ef1997e0c99f..7992b3888ebbfc 100644 --- a/Lib/unittest/test/test_async_case.py +++ b/Lib/unittest/test/test_async_case.py @@ -6,6 +6,29 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +class TestCM: + def __init__(self, ordering, enter_result=None): + self.ordering = ordering + self.enter_result = enter_result + + async def __aenter__(self): + self.ordering.append('enter') + return self.enter_result + + async def __aexit__(self, *exc_info): + self.ordering.append('exit') + + +class LacksEnterAndExit: + pass +class LacksEnter: + async def __aexit__(self, *exc_info): + pass +class LacksExit: + async def __aenter__(self): + pass + + class TestAsyncCase(unittest.TestCase): def test_full_cycle(self): events = [] @@ -255,7 +278,35 @@ async def coro(): output = test.run() self.assertTrue(cancelled) + def test_enterAsyncContext(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(slf): + slf.addAsyncCleanup(events.append, 'cleanup1') + cm = TestCM(events, 42) + self.assertEqual(await slf.enterAsyncContext(cm), 42) + slf.addAsyncCleanup(events.append, 'cleanup2') + events.append('test') + test = Test('test_func') + output = test.run() + self.assertTrue(output.wasSuccessful(), output) + self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterAsyncContext_arg_errors(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(slf): + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksExit()) + + test = Test('test_func') + output = test.run() + self.assertTrue(output.wasSuccessful()) if __name__ == "__main__": diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py index dd9a1b6d9aeddf..4c41721a353394 100644 --- a/Lib/unittest/test/test_runner.py +++ b/Lib/unittest/test/test_runner.py @@ -45,6 +45,29 @@ def cleanup(ordering, blowUp=False): raise Exception('CleanUpExc') +class TestCM: + def __init__(self, ordering, enter_result=None): + self.ordering = ordering + self.enter_result = enter_result + + def __enter__(self): + self.ordering.append('enter') + return self.enter_result + + def __exit__(self, *exc_info): + self.ordering.append('exit') + + +class LacksEnterAndExit: + pass +class LacksEnter: + def __exit__(self, *exc_info): + pass +class LacksExit: + def __enter__(self): + pass + + class TestCleanUp(unittest.TestCase): def testCleanUp(self): class TestableTest(unittest.TestCase): @@ -168,6 +191,39 @@ def cleanup2(): self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2']) + def test_enterContext(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + cleanups = [] + + test.addCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(test.enterContext(cm), 42) + test.addCleanup(cleanups.append, 'cleanup2') + + self.assertTrue(test.doCleanups()) + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksExit()) + + self.assertEqual(test._cleanups, []) + + class TestClassCleanup(unittest.TestCase): def test_addClassCleanUp(self): class TestableTest(unittest.TestCase): @@ -355,6 +411,35 @@ def tearDownClass(cls): ['setUpClass', 'setUp', 'tearDownClass', 'cleanup_exc']) + def test_enterClassContext(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + cleanups = [] + + TestableTest.addClassCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(TestableTest.enterClassContext(cm), 42) + TestableTest.addClassCleanup(cleanups.append, 'cleanup2') + + TestableTest.doClassCleanups() + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterClassContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksExit()) + + self.assertEqual(TestableTest._class_cleanups, []) + class TestModuleCleanUp(unittest.TestCase): def test_add_and_do_ModuleCleanup(self): @@ -794,6 +879,31 @@ def tearDown(self): 'cleanup2', 'setUp2', 'test2', 'tearDown2', 'cleanup3', 'tearDownModule', 'cleanup1']) + def test_enterModuleContext(self): + cleanups = [] + + unittest.addModuleCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(unittest.enterModuleContext(cm), 42) + unittest.addModuleCleanup(cleanups.append, 'cleanup2') + + unittest.case.doModuleCleanups() + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterModuleContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksExit()) + + self.assertEqual(unittest.case._module_cleanups, []) + class Test_TextTestRunner(unittest.TestCase): """Tests for TextTestRunner.""" diff --git a/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst new file mode 100644 index 00000000000000..8072afaf445c50 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst @@ -0,0 +1,7 @@ +Add support of context managers in :mod:`unittest`: methods +:meth:`~unittest.TestCase.enterContext` and +:meth:`~unittest.TestCase.enterClassContext` of class +:class:`~unittest.TestCase`, method +:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of class +:class:`~unittest.IsolatedAsyncioTestCase` and function +:func:`unittest.enterModuleContext`. From 249a43434d069ca39995b71469044b23a0c2581a Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Mon, 13 Sep 2021 20:16:09 +0300 Subject: [PATCH 2/4] Update the issue number. --- Doc/whatsnew/3.11.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index bd42c2659a37f7..e6de18d24d2ffa 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -249,7 +249,7 @@ unittest :meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of class :class:`~unittest.IsolatedAsyncioTestCase` and function :func:`unittest.enterModuleContext`. - (Contributed by Serhiy Storchaka in :issue:`XXXXX`.) + (Contributed by Serhiy Storchaka in :issue:`45046`.) Removed From c3026c8b0c53b397e4806a128e6feb30f0e08bfa Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Thu, 16 Sep 2021 22:36:31 +0300 Subject: [PATCH 3/4] Minor fix in test_nntplib. --- Lib/test/test_nntplib.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index ad1c293ac22c1c..58ce10c3f68967 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -1593,8 +1593,7 @@ def setUp(self): self.background.start() self.addCleanup(self.background.join) - self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__() - self.enterContext(self.nntp) + self.nntp = self.enterContext(NNTP(socket_helper.HOST, port, usenetrc=False)) def run_server(self, sock): # Could be generalized to handle more commands in separate methods From 86c3e6ea76b2960cfa465d44113e9e91793c2423 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sun, 19 Sep 2021 16:42:58 +0300 Subject: [PATCH 4/4] Use in more tests. --- Lib/test/test_global.py | 4 +--- Lib/test/test_logging.py | 4 +--- Lib/test/test_socket.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py index 9b3c1f78345c6a..f5b38c25ea0728 100644 --- a/Lib/test/test_global.py +++ b/Lib/test/test_global.py @@ -49,9 +49,7 @@ def test4(self): def setUpModule(): - cm = warnings.catch_warnings() - cm.__enter__() - unittest.addModuleCleanup(cm.__exit__, None, None, None) + unittest.enterModuleContext(warnings.catch_warnings()) warnings.filterwarnings("error", module="") diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 211fe4bbd7bac0..8436386f266560 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -5515,9 +5515,7 @@ def test__all__(self): # why the test does this, but in any case we save the current locale # first and restore it at the end. def setUpModule(): - cm = support.run_with_locale('LC_ALL', '') - cm.__enter__() - unittest.addModuleCleanup(cm.__exit__, None, None, None) + unittest.enterModuleContext(support.run_with_locale('LC_ALL', '')) if __name__ == "__main__": diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index eeb8e8c98a1494..698938d7b66658 100755 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -336,9 +336,7 @@ def serverExplicitReady(self): self.server_ready.set() def _setUp(self): - self.wait_threads = threading_helper.wait_threads_exit() - self.wait_threads.__enter__() - self.addCleanup(self.wait_threads.__exit__, None, None, None) + self.enterContext(threading_helper.wait_threads_exit()) self.server_ready = threading.Event() self.client_ready = threading.Event()