Skip to content

Refactor & enable JIT tests in all models and add warnings if skipped #3033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .circleci/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ set -e
eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
3 changes: 2 additions & 1 deletion .circleci/unittest/windows/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ set -e
eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
conda activate ./env

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
15 changes: 11 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import io
import torch
import errno
import warnings
import __main__

from numbers import Number
Expand Down Expand Up @@ -265,14 +265,21 @@ def assertTensorsEqual(a, b):
else:
super(TestCase, self).assertEqual(x, y, message)

def checkModule(self, nn_module, args, unwrapper=None, skip=False):
def check_jit_scriptable(self, nn_module, args, unwrapper=None, skip=False):
"""
Check that a nn.Module's results in TorchScript match eager and that it
can be exported
"""
if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests
return
msg = "The check_jit_scriptable test for {} was skipped. " \
"This test checks if the module's results in TorchScript " \
"match eager and that it can be exported. To run these " \
"tests make sure you set the environment variable " \
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
"manually skipped.".format(nn_module.__class__.__name__)
warnings.warn(msg, RuntimeWarning)
return None

sm = torch.jit.script(nn_module)

Expand All @@ -284,7 +291,7 @@ def checkModule(self, nn_module, args, unwrapper=None, skip=False):
if unwrapper:
script_out = unwrapper(script_out)

self.assertEqual(eager_out, script_out)
self.assertEqual(eager_out, script_out, prec=1e-4)
self.assertExportImportModule(sm, args)

return sm
Expand Down
79 changes: 19 additions & 60 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,44 +38,16 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


# models that are in torch hub, as well as r3d_18. we tried testing all models
# but the test was too slow. not included are detection models, because
# they are not yet supported in JIT.
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
script_test_models = {
'deeplabv3_resnet50': {},
'deeplabv3_resnet101': {},
'mobilenet_v2': {},
'resnext50_32x4d': {},
'fcn_resnet50': {},
'fcn_resnet101': {},
'googlenet': {
'unwrapper': lambda x: x.logits
},
'densenet121': {},
'resnet18': {},
'alexnet': {},
'shufflenet_v2_x1_0': {},
'squeezenet1_0': {},
'vgg11': {},
'inception_v3': {
'unwrapper': lambda x: x.logits
},
'r3d_18': {},
"fasterrcnn_resnet50_fpn": {
'unwrapper': lambda x: x[1]
},
"maskrcnn_resnet50_fpn": {
'unwrapper': lambda x: x[1]
},
"keypointrcnn_resnet50_fpn": {
'unwrapper': lambda x: x[1]
},
"retinanet_resnet50_fpn": {
'unwrapper': lambda x: x[1]
}
script_model_unwrapper = {
'googlenet': lambda x: x.logits,
'inception_v3': lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
}


Expand All @@ -97,12 +69,6 @@ def get_available_video_models():


class ModelTester(TestCase):
def checkModule(self, model, name, args):
if name not in script_test_models:
return
unwrapper = script_test_models[name].get('unwrapper', None)
return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False)

def _test_classification_model(self, name, input_shape, dev):
set_rng_seed(0)
# passing num_class equal to a number other than 1000 helps in making the test
Expand All @@ -114,7 +80,7 @@ def _test_classification_model(self, name, input_shape, dev):
out = model(x)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertEqual(out.shape[-1], 50)
self.checkModule(model, name, (x,))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

if dev == "cuda":
with torch.cuda.amp.autocast():
Expand All @@ -134,7 +100,7 @@ def _test_segmentation_model(self, name, dev):
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.checkModule(model, name, (x,))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

if dev == "cuda":
with torch.cuda.amp.autocast():
Expand Down Expand Up @@ -209,18 +175,7 @@ def compute_mean_std(tensor):
return True # Full validation performed

full_validation = check_out(out)

scripted_model = torch.jit.script(model)
scripted_model.eval()
scripted_out = scripted_model(model_input)[1]
self.assertEqual(scripted_out[0]["boxes"], out[0]["boxes"])
self.assertEqual(scripted_out[0]["scores"], out[0]["scores"])
# labels currently float in script: need to investigate (though same result)
self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"])
# don't check script because we are compiling it here:
# TODO: refactor tests
# self.check_script(model, name)
self.checkModule(model, name, ([x],))
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))

if dev == "cuda":
with torch.cuda.amp.autocast():
Expand Down Expand Up @@ -270,7 +225,7 @@ def _test_video_model(self, name, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.checkModule(model, name, (x,))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
self.assertEqual(out.shape[-1], 50)

if dev == "cuda":
Expand Down Expand Up @@ -345,11 +300,13 @@ def test_inceptionv3_eval(self):
kwargs['transform_input'] = True
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
name = "inception_v3"
model = models.Inception3(**kwargs)
model.aux_logits = False
model.AuxLogits = None
m = torch.jit.script(model.eval())
self.checkModule(m, "inception_v3", torch.rand(1, 3, 299, 299))
model = model.eval()
x = torch.rand(1, 3, 299, 299)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
Expand All @@ -371,12 +328,14 @@ def test_googlenet_eval(self):
kwargs['transform_input'] = True
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
name = "googlenet"
model = models.GoogLeNet(**kwargs)
model.aux_logits = False
model.aux1 = None
model.aux2 = None
m = torch.jit.script(model.eval())
self.checkModule(m, "googlenet", torch.rand(1, 3, 224, 224))
model = model.eval()
x = torch.rand(1, 3, 224, 224)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))

@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
Expand Down