diff --git a/test/common_utils.py b/test/common_utils.py index 5368018e971..e5713dc0832 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -88,24 +88,7 @@ def is_iterable(obj): class TestCase(unittest.TestCase): precision = 1e-5 - def assertExpected(self, output, subname=None, prec=None, strip_suffix=None): - r""" - Test that a python value matches the recorded contents of a file - derived from the name of this test and subname. The value must be - pickable with `torch.save`. This file - is placed in the 'expect' directory in the same directory - as the test script. You can automatically update the recorded test - output using --accept. - - If you call this multiple times in a single function, you must - give a unique subname each time. - - strip_suffix allows different tests that expect similar numerics, e.g. - "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data. - test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass - strip_suffix="_cpu", and they would both use a data file name based on - "test_xyz". - """ + def _get_expected_file(self, subname=None, strip_suffix=None): def remove_prefix_suffix(text, prefix, suffix): if text.startswith(prefix): text = text[len(prefix):] @@ -128,33 +111,41 @@ def remove_prefix_suffix(text, prefix, suffix): subname_output = " ({})".format(subname) expected_file += "_expect.pkl" - def accept_output(update_type): - print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output)) + if not ACCEPT and not os.path.exists(expected_file): + raise RuntimeError( + ("No expect file exists for {}{}; to accept the current output, run:\n" + "python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id)) + + return expected_file + + def assertExpected(self, output, subname=None, prec=None, strip_suffix=None): + r""" + Test that a python value matches the recorded contents of a file + derived from the name of this test and subname. The value must be + pickable with `torch.save`. This file + is placed in the 'expect' directory in the same directory + as the test script. You can automatically update the recorded test + output using --accept. + + If you call this multiple times in a single function, you must + give a unique subname each time. + + strip_suffix allows different tests that expect similar numerics, e.g. + "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data. + test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass + strip_suffix="_cpu", and they would both use a data file name based on + "test_xyz". + """ + expected_file = self._get_expected_file(subname, strip_suffix) + + if ACCEPT: + print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output)) torch.save(output, expected_file) MAX_PICKLE_SIZE = 50 * 1000 # 50 KB binary_size = os.path.getsize(expected_file) self.assertTrue(binary_size <= MAX_PICKLE_SIZE) - - try: - expected = torch.load(expected_file) - except IOError as e: - if e.errno != errno.ENOENT: - raise - elif ACCEPT: - accept_output("output") - return - else: - raise RuntimeError( - ("I got this output for {}{}:\n\n{}\n\n" - "No expect file exists; to accept the current output, run:\n" - "python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id)) - - if ACCEPT: - try: - self.assertEqual(output, expected, prec=prec) - except Exception: - accept_output("updated output") else: + expected = torch.load(expected_file) self.assertEqual(output, expected, prec=prec) def assertEqual(self, x, y, prec=None, message='', allow_inf=False):