@@ -221,6 +221,37 @@ def pytest_configure(config):
221
221
reference_dir = reference_dir ,
222
222
generate_dir = generate_dir ,
223
223
default_format = default_format ))
224
+ else :
225
+ config .pluginmanager .register (ArrayInterceptor (config ))
226
+
227
+
228
+ def generate_test_name (item ):
229
+ """
230
+ Generate a unique name for this test.
231
+ """
232
+ if item .cls is not None :
233
+ name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
234
+ else :
235
+ name = f"{ item .module .__name__ } .{ item .name } "
236
+ return name
237
+
238
+
239
+ def wrap_array_interceptor (plugin , item ):
240
+ """
241
+ Intercept and store arrays returned by test functions.
242
+ """
243
+ # Only intercept array on marked array tests
244
+ if item .get_closest_marker ('array_compare' ) is not None :
245
+
246
+ # Use the full test name as a key to ensure correct array is being retrieved
247
+ test_name = generate_test_name (item )
248
+
249
+ def array_interceptor (store , obj ):
250
+ def wrapper (* args , ** kwargs ):
251
+ store .return_value [test_name ] = obj (* args , ** kwargs )
252
+ return wrapper
253
+
254
+ item .obj = array_interceptor (plugin , item .obj )
224
255
225
256
226
257
class ArrayComparison (object ):
@@ -230,12 +261,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format
230
261
self .reference_dir = reference_dir
231
262
self .generate_dir = generate_dir
232
263
self .default_format = default_format
264
+ self .return_value = {}
233
265
234
- def pytest_runtest_setup (self , item ):
266
+ @pytest .hookimpl (hookwrapper = True )
267
+ def pytest_runtest_call (self , item ):
235
268
236
269
compare = item .get_closest_marker ('array_compare' )
237
270
238
271
if compare is None :
272
+ yield
239
273
return
240
274
241
275
file_format = compare .kwargs .get ('file_format' , self .default_format )
@@ -255,85 +289,95 @@ def pytest_runtest_setup(self, item):
255
289
256
290
write_kwargs = compare .kwargs .get ('write_kwargs' , {})
257
291
258
- original = item .function
292
+ reference_dir = compare .kwargs .get ('reference_dir' , None )
293
+ if reference_dir is None :
294
+ if self .reference_dir is None :
295
+ reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), 'reference' )
296
+ else :
297
+ reference_dir = self .reference_dir
298
+ else :
299
+ if not reference_dir .startswith (('http://' , 'https://' )):
300
+ reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir )
259
301
260
- @wraps (item .function )
261
- def item_function_wrapper (* args , ** kwargs ):
302
+ baseline_remote = reference_dir .startswith ('http' )
262
303
263
- reference_dir = compare .kwargs .get ('reference_dir' , None )
264
- if reference_dir is None :
265
- if self .reference_dir is None :
266
- reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), 'reference' )
267
- else :
268
- reference_dir = self .reference_dir
269
- else :
270
- if not reference_dir .startswith (('http://' , 'https://' )):
271
- reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir )
272
-
273
- baseline_remote = reference_dir .startswith ('http' )
274
-
275
- # Run test and get figure object
276
- import inspect
277
- if inspect .ismethod (original ): # method
278
- array = original (* args [1 :], ** kwargs )
279
- else : # function
280
- array = original (* args , ** kwargs )
281
-
282
- # Find test name to use as plot name
283
- filename = compare .kwargs .get ('filename' , None )
284
- if filename is None :
285
- if single_reference :
286
- filename = original .__name__ + '.' + extension
287
- else :
288
- filename = item .name + '.' + extension
289
- filename = filename .replace ('[' , '_' ).replace (']' , '_' )
290
- filename = filename .replace ('_.' + extension , '.' + extension )
291
-
292
- # What we do now depends on whether we are generating the reference
293
- # files or simply running the test.
294
- if self .generate_dir is None :
295
-
296
- # Save the figure
297
- result_dir = tempfile .mkdtemp ()
298
- test_array = os .path .abspath (os .path .join (result_dir , filename ))
299
-
300
- FORMATS [file_format ].write (test_array , array , ** write_kwargs )
301
-
302
- # Find path to baseline array
303
- if baseline_remote :
304
- baseline_file_ref = _download_file (reference_dir + filename )
305
- else :
306
- baseline_file_ref = os .path .abspath (os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir , filename ))
307
-
308
- if not os .path .exists (baseline_file_ref ):
309
- raise Exception ("""File not found for comparison test
310
- Generated file:
311
- \t {test}
312
- This is expected for new tests.""" .format (
313
- test = test_array ))
314
-
315
- # setuptools may put the baseline arrays in non-accessible places,
316
- # copy to our tmpdir to be sure to keep them in case of failure
317
- baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
318
- shutil .copyfile (baseline_file_ref , baseline_file )
319
-
320
- identical , msg = FORMATS [file_format ].compare (baseline_file , test_array , atol = atol , rtol = rtol )
321
-
322
- if identical :
323
- shutil .rmtree (result_dir )
324
- else :
325
- raise Exception (msg )
304
+ # Run test and get array object
305
+ wrap_array_interceptor (self , item )
306
+ yield
307
+ test_name = generate_test_name (item )
308
+ if test_name not in self .return_value :
309
+ # Test function did not complete successfully
310
+ return
311
+ array = self .return_value [test_name ]
312
+
313
+ # Find test name to use as plot name
314
+ filename = compare .kwargs .get ('filename' , None )
315
+ if filename is None :
316
+ filename = item .name + '.' + extension
317
+ if not single_reference :
318
+ filename = filename .replace ('[' , '_' ).replace (']' , '_' )
319
+ filename = filename .replace ('_.' + extension , '.' + extension )
320
+
321
+ # What we do now depends on whether we are generating the reference
322
+ # files or simply running the test.
323
+ if self .generate_dir is None :
324
+
325
+ # Save the figure
326
+ result_dir = tempfile .mkdtemp ()
327
+ test_array = os .path .abspath (os .path .join (result_dir , filename ))
326
328
329
+ FORMATS [file_format ].write (test_array , array , ** write_kwargs )
330
+
331
+ # Find path to baseline array
332
+ if baseline_remote :
333
+ baseline_file_ref = _download_file (reference_dir + filename )
327
334
else :
335
+ baseline_file_ref = os .path .abspath (os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir , filename ))
336
+
337
+ if not os .path .exists (baseline_file_ref ):
338
+ raise Exception ("""File not found for comparison test
339
+ Generated file:
340
+ \t {test}
341
+ This is expected for new tests.""" .format (
342
+ test = test_array ))
328
343
329
- if not os .path .exists (self .generate_dir ):
330
- os .makedirs (self .generate_dir )
344
+ # setuptools may put the baseline arrays in non-accessible places,
345
+ # copy to our tmpdir to be sure to keep them in case of failure
346
+ baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
347
+ shutil .copyfile (baseline_file_ref , baseline_file )
331
348
332
- FORMATS [file_format ].write ( os . path . abspath ( os . path . join ( self . generate_dir , filename )), array , ** write_kwargs )
349
+ identical , msg = FORMATS [file_format ].compare ( baseline_file , test_array , atol = atol , rtol = rtol )
333
350
334
- pytest .skip ("Skipping test, since generating data" )
351
+ if identical :
352
+ shutil .rmtree (result_dir )
353
+ else :
354
+ raise Exception (msg )
335
355
336
- if item .cls is not None :
337
- setattr (item .cls , item .function .__name__ , item_function_wrapper )
338
356
else :
339
- item .obj = item_function_wrapper
357
+
358
+ if not os .path .exists (self .generate_dir ):
359
+ os .makedirs (self .generate_dir )
360
+
361
+ FORMATS [file_format ].write (os .path .abspath (os .path .join (self .generate_dir , filename )), array , ** write_kwargs )
362
+
363
+ pytest .skip ("Skipping test, since generating data" )
364
+
365
+
366
+ class ArrayInterceptor :
367
+ """
368
+ This is used in place of ArrayComparison when the array comparison option is not used,
369
+ to make sure that we still intercept arrays returned by tests.
370
+ """
371
+
372
+ def __init__ (self , config ):
373
+ self .config = config
374
+ self .return_value = {}
375
+
376
+ @pytest .hookimpl (hookwrapper = True )
377
+ def pytest_runtest_call (self , item ):
378
+
379
+ if item .get_closest_marker ('array_compare' ) is not None :
380
+ wrap_array_interceptor (self , item )
381
+
382
+ yield
383
+ return
0 commit comments