|
5 | 5 | import functools
|
6 | 6 | from test import support
|
7 | 7 | import unittest
|
| 8 | +import traceback |
8 | 9 |
|
9 | 10 | from test.test_contextlib import TestBaseExitStack
|
10 | 11 |
|
@@ -125,6 +126,62 @@ async def woohoo():
|
125 | 126 | raise ZeroDivisionError()
|
126 | 127 | self.assertEqual(state, [1, 42, 999])
|
127 | 128 |
|
| 129 | + @_async_test |
| 130 | + async def test_contextmanager_traceback(self): |
| 131 | + @asynccontextmanager |
| 132 | + async def f(): |
| 133 | + yield |
| 134 | + |
| 135 | + try: |
| 136 | + async with f(): |
| 137 | + 1/0 |
| 138 | + except ZeroDivisionError as e: |
| 139 | + frames = traceback.extract_tb(e.__traceback__) |
| 140 | + |
| 141 | + self.assertEqual(len(frames), 1) |
| 142 | + self.assertEqual(frames[0].name, 'test_contextmanager_traceback') |
| 143 | + self.assertEqual(frames[0].line, '1/0') |
| 144 | + |
| 145 | + # Repeat with RuntimeError (which goes through a different code path) |
| 146 | + class RuntimeErrorSubclass(RuntimeError): |
| 147 | + pass |
| 148 | + |
| 149 | + try: |
| 150 | + async with f(): |
| 151 | + raise RuntimeErrorSubclass(42) |
| 152 | + except RuntimeErrorSubclass as e: |
| 153 | + frames = traceback.extract_tb(e.__traceback__) |
| 154 | + |
| 155 | + self.assertEqual(len(frames), 1) |
| 156 | + self.assertEqual(frames[0].name, 'test_contextmanager_traceback') |
| 157 | + self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)') |
| 158 | + |
| 159 | + class StopIterationSubclass(StopIteration): |
| 160 | + pass |
| 161 | + |
| 162 | + class StopAsyncIterationSubclass(StopAsyncIteration): |
| 163 | + pass |
| 164 | + |
| 165 | + for stop_exc in ( |
| 166 | + StopIteration('spam'), |
| 167 | + StopAsyncIteration('ham'), |
| 168 | + StopIterationSubclass('spam'), |
| 169 | + StopAsyncIterationSubclass('spam') |
| 170 | + ): |
| 171 | + with self.subTest(type=type(stop_exc)): |
| 172 | + try: |
| 173 | + async with f(): |
| 174 | + raise stop_exc |
| 175 | + except type(stop_exc) as e: |
| 176 | + self.assertIs(e, stop_exc) |
| 177 | + frames = traceback.extract_tb(e.__traceback__) |
| 178 | + else: |
| 179 | + self.fail(f'{stop_exc} was suppressed') |
| 180 | + |
| 181 | + self.assertEqual(len(frames), 1) |
| 182 | + self.assertEqual(frames[0].name, 'test_contextmanager_traceback') |
| 183 | + self.assertEqual(frames[0].line, 'raise stop_exc') |
| 184 | + |
128 | 185 | @_async_test
|
129 | 186 | async def test_contextmanager_no_reraise(self):
|
130 | 187 | @asynccontextmanager
|
|
0 commit comments