Skip to content

Commit 446eac6

Browse files
authored
Fix torchhub due to numerical changes in torch.sum (#2361)
1 parent bb14c2b commit 446eac6

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

test/test_hub.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def sum_of_model_parameters(model):
1313
return s
1414

1515

16-
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.99609375
16+
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625
1717

1818

1919
@unittest.skipIf('torchvision' in sys.modules,
@@ -31,8 +31,9 @@ def test_load_from_github(self):
3131
'resnet18',
3232
pretrained=True,
3333
progress=False)
34-
self.assertEqual(sum_of_model_parameters(hub_model).item(),
35-
SUM_OF_PRETRAINED_RESNET18_PARAMS)
34+
self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
35+
SUM_OF_PRETRAINED_RESNET18_PARAMS,
36+
places=2)
3637

3738
def test_set_dir(self):
3839
temp_dir = tempfile.gettempdir()
@@ -42,8 +43,9 @@ def test_set_dir(self):
4243
'resnet18',
4344
pretrained=True,
4445
progress=False)
45-
self.assertEqual(sum_of_model_parameters(hub_model).item(),
46-
SUM_OF_PRETRAINED_RESNET18_PARAMS)
46+
self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(),
47+
SUM_OF_PRETRAINED_RESNET18_PARAMS,
48+
places=2)
4749
self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master'))
4850
shutil.rmtree(temp_dir + '/pytorch_vision_master')
4951

0 commit comments

Comments
 (0)