1
1
import importlib
2
2
import os
3
3
import sys
4
- import re
5
4
import tempfile
6
5
import time
7
6
import traceback
8
7
9
- from typing import List , Tuple , Any , Callable , Union , cast , Optional
10
- from types import TracebackType
8
+ from typing import List , Tuple , Any , Callable , Union , cast , Optional , Iterable
9
+ from types import TracebackType , MethodType
11
10
12
11
13
12
# TODO remove global state
@@ -105,30 +104,23 @@ def fail() -> None:
105
104
raise AssertionFailure ()
106
105
107
106
108
- class TestCase :
109
- def __init__ (self , name : str , suite : 'Optional[Suite]' = None ,
110
- func : Optional [Callable [[], None ]] = None ) -> None :
111
- self .func = func
107
+ class BaseTestCase :
108
+ """Common base class for _MyUnitTestCase and DataDrivenTestCase.
109
+
110
+ Handles temporary folder creation and deletion.
111
+ """
112
+ def __init__ (self , name : str ) -> None :
112
113
self .name = name
113
- self .suite = suite
114
114
self .old_cwd = None # type: Optional[str]
115
115
self .tmpdir = None # type: Optional[tempfile.TemporaryDirectory[str]]
116
116
117
- def run (self ) -> None :
118
- if self .func :
119
- self .func ()
120
-
121
- def set_up (self ) -> None :
117
+ def setup (self ) -> None :
122
118
self .old_cwd = os .getcwd ()
123
119
self .tmpdir = tempfile .TemporaryDirectory (prefix = 'mypy-test-' )
124
120
os .chdir (self .tmpdir .name )
125
121
os .mkdir ('tmp' )
126
- if self .suite :
127
- self .suite .set_up ()
128
122
129
- def tear_down (self ) -> None :
130
- if self .suite :
131
- self .suite .tear_down ()
123
+ def teardown (self ) -> None :
132
124
assert self .old_cwd is not None and self .tmpdir is not None , \
133
125
"test was not properly set up"
134
126
os .chdir (self .old_cwd )
@@ -140,35 +132,51 @@ def tear_down(self) -> None:
140
132
self .tmpdir = None
141
133
142
134
135
+ class _MyUnitTestCase (BaseTestCase ):
136
+ """A concrete, myunit-specific test case, a wrapper around a method to run."""
137
+
138
+ def __init__ (self , name : str , suite : 'Suite' , run : Callable [[], None ]) -> None :
139
+ super ().__init__ (name )
140
+ self .run = run
141
+ self .suite = suite
142
+
143
+ def setup (self ) -> None :
144
+ super ().setup ()
145
+ self .suite .setup ()
146
+
147
+ def teardown (self ) -> None :
148
+ self .suite .teardown () # No-op
149
+ super ().teardown ()
150
+
151
+
143
152
class Suite :
144
- def __init__ (self ) -> None :
145
- self .prefix = typename (type (self )) + '.'
146
- # Each test case is either a TestCase object or (str, function).
147
- self ._test_cases = [] # type: List[Any]
148
- self .init ()
153
+ """Abstract class for myunit test suites - node in the tree whose leaves are _MyUnitTestCases.
149
154
150
- def set_up (self ) -> None :
151
- pass
155
+ The children `cases` are looked up during __init__, looking for attributes named test_*
156
+ they are either no-arg methods or of a pair (name, Suite).
157
+ """
152
158
153
- def tear_down (self ) -> None :
154
- pass
159
+ cases = None # type: Iterable[Union[_MyUnitTestCase, Tuple[str, Suite]]]
155
160
156
- def init (self ) -> None :
161
+ def __init__ (self ) -> None :
162
+ self .prefix = typename (type (self )) + '.'
163
+ self .cases = []
157
164
for m in dir (self ):
158
- if m .startswith ('test ' ):
165
+ if m .startswith ('test_ ' ):
159
166
t = getattr (self , m )
160
167
if isinstance (t , Suite ):
161
- self .add_test ((m + '.' , t ))
168
+ self .cases . append ((m + '.' , t ))
162
169
else :
163
- self .add_test (TestCase (m , self , getattr (self , m )))
170
+ assert isinstance (t , MethodType )
171
+ self .cases .append (_MyUnitTestCase (m , self , t ))
164
172
165
- def add_test (self , test : Union [TestCase ,
166
- Tuple [str , Callable [[], None ]],
167
- Tuple [str , 'Suite' ]]) -> None :
168
- self ._test_cases .append (test )
173
+ def setup (self ) -> None :
174
+ """Set up fixtures"""
175
+ pass
169
176
170
- def cases (self ) -> List [Any ]:
171
- return self ._test_cases [:]
177
+ def teardown (self ) -> None :
178
+ # This method is not overridden in practice
179
+ pass
172
180
173
181
def skip (self ) -> None :
174
182
raise SkipTestCaseException ()
@@ -250,10 +258,11 @@ def main(args: Optional[List[str]] = None) -> None:
250
258
sys .exit (1 )
251
259
252
260
253
- def run_test_recursive (test : Any , num_total : int , num_fail : int , num_skip : int ,
261
+ def run_test_recursive (test : Union [_MyUnitTestCase , Tuple [str , Suite ], ListSuite ],
262
+ num_total : int , num_fail : int , num_skip : int ,
254
263
prefix : str , depth : int ) -> Tuple [int , int , int ]:
255
- """The first argument may be TestCase , Suite or (str, Suite)."""
256
- if isinstance (test , TestCase ):
264
+ """The first argument may be _MyUnitTestCase , Suite or (str, Suite)."""
265
+ if isinstance (test , _MyUnitTestCase ):
257
266
name = prefix + test .name
258
267
for pattern in patterns :
259
268
if match_pattern (name , pattern ):
@@ -275,7 +284,7 @@ def run_test_recursive(test: Any, num_total: int, num_fail: int, num_skip: int,
275
284
suite = test
276
285
suite_prefix = test .prefix
277
286
278
- for stest in suite .cases () :
287
+ for stest in suite .cases :
279
288
new_prefix = prefix
280
289
if depth > 0 :
281
290
new_prefix = prefix + suite_prefix
@@ -284,22 +293,22 @@ def run_test_recursive(test: Any, num_total: int, num_fail: int, num_skip: int,
284
293
return num_total , num_fail , num_skip
285
294
286
295
287
- def run_single_test (name : str , test : Any ) -> Tuple [bool , bool ]:
296
+ def run_single_test (name : str , test : _MyUnitTestCase ) -> Tuple [bool , bool ]:
288
297
if is_verbose :
289
298
sys .stderr .write (name )
290
299
sys .stderr .flush ()
291
300
292
301
time0 = time .time ()
293
- test .set_up () # FIX: check exceptions
294
- exc_traceback = None # type: Any
302
+ test .setup () # FIX: check exceptions
303
+ exc_traceback = None # type: Optional[TracebackType]
295
304
try :
296
305
test .run ()
297
306
except BaseException as e :
298
307
if isinstance (e , KeyboardInterrupt ):
299
308
raise
300
309
exc_type , exc_value , exc_traceback = sys .exc_info ()
301
310
finally :
302
- test .tear_down ()
311
+ test .teardown ()
303
312
times .append ((time .time () - time0 , name ))
304
313
305
314
if exc_traceback :
0 commit comments