Skip to content

Commit d896135

Browse files
Nayef211nayef211
andauthored
Parameterize jit and non-jit model integration tests (#1502)
* Updated max seq length for truncate in xlmr base. Updated xlmr docs. Moved xlmr tests to integration tests * Removing changes to truncate transform * Remove documentation changes from PR * Parameterized model tests * Added nested_params helper method. Updated model integration test to parameterize a single method covering jit and non-jit tests * Added docstring for unit tests Co-authored-by: nayef211 <[email protected]>
1 parent 0bcab91 commit d896135

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

test/common/parameterized_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,53 @@
11
import json
2+
from itertools import product
23

3-
from parameterized import param
4+
from parameterized import param, parameterized
45

56
from .assets import get_asset_path
67

78

89
def load_params(*paths):
910
with open(get_asset_path(*paths), "r") as file:
1011
return [param(json.loads(line)) for line in file]
12+
13+
14+
def _name_func(func, _, params):
15+
strs = []
16+
for arg in params.args:
17+
if isinstance(arg, tuple):
18+
strs.append("_".join(str(a) for a in arg))
19+
else:
20+
strs.append(str(arg))
21+
# sanitize the test name
22+
name = "_".join(strs).replace(".", "_")
23+
return f"{func.__name__}_{name}"
24+
25+
26+
def nested_params(*params_set):
27+
"""Generate the cartesian product of the given list of parameters.
28+
Args:
29+
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
30+
all the parameters have to be specified with the class, only using kwargs.
31+
"""
32+
flatten = [p for params in params_set for p in params]
33+
34+
# Parameters to be nested are given as list of plain objects
35+
if all(not isinstance(p, param) for p in flatten):
36+
args = list(product(*params_set))
37+
return parameterized.expand(args, name_func=_name_func)
38+
39+
# Parameters to be nested are given as list of `parameterized.param`
40+
if not all(isinstance(p, param) for p in flatten):
41+
raise TypeError(
42+
"When using ``parameterized.param``, "
43+
"all the parameters have to be of the ``param`` type."
44+
)
45+
if any(p.args for p in flatten):
46+
raise ValueError(
47+
"When using ``parameterized.param``, "
48+
"all the parameters have to be provided as keyword argument."
49+
)
50+
args = [param()]
51+
for params in params_set:
52+
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
53+
return parameterized.expand(args)

test/integration_tests/test_models.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from parameterized import parameterized
32
from torchtext.models import (
43
XLMR_BASE_ENCODER,
54
XLMR_LARGE_ENCODER,
@@ -8,6 +7,7 @@
87
)
98

109
from ..common.assets import get_asset_path
10+
from ..common.parameterized_utils import nested_params
1111
from ..common.torchtext_test_case import TorchtextTestCase
1212

1313
TEST_MODELS_PARAMETERIZED_ARGS = [
@@ -27,30 +27,40 @@
2727

2828

2929
class TestModels(TorchtextTestCase):
30-
@parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS)
31-
def test_model(self, expected_asset_name, test_text, model_bundler):
30+
@nested_params(
31+
[
32+
("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER),
33+
("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER),
34+
(
35+
"roberta.base.output.pt",
36+
"Roberta base Model Comparison",
37+
ROBERTA_BASE_ENCODER,
38+
),
39+
(
40+
"roberta.large.output.pt",
41+
"Roberta base Model Comparison",
42+
ROBERTA_LARGE_ENCODER,
43+
),
44+
],
45+
[True, False],
46+
)
47+
def test_model(self, model_args, is_jit):
48+
"""Verify pre-trained XLM-R and Roberta models in torchtext produce
49+
the same output as the reference implementation within fairseq
50+
"""
51+
expected_asset_name, test_text, model_bundler = model_args
52+
3253
expected_asset_path = get_asset_path(expected_asset_name)
3354

3455
transform = model_bundler.transform()
3556
model = model_bundler.get_model()
3657
model = model.eval()
3758

59+
if is_jit:
60+
transform = torch.jit.script(transform)
61+
model = torch.jit.script(model)
62+
3863
model_input = torch.tensor(transform([test_text]))
3964
actual = model(model_input)
4065
expected = torch.load(expected_asset_path)
4166
torch.testing.assert_close(actual, expected)
42-
43-
@parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS)
44-
def test_model_jit(self, expected_asset_name, test_text, model_bundler):
45-
expected_asset_path = get_asset_path(expected_asset_name)
46-
47-
transform = model_bundler.transform()
48-
transform_jit = torch.jit.script(transform)
49-
model = model_bundler.get_model()
50-
model = model.eval()
51-
model_jit = torch.jit.script(model)
52-
53-
model_input = torch.tensor(transform_jit([test_text]))
54-
actual = model_jit(model_input)
55-
expected = torch.load(expected_asset_path)
56-
torch.testing.assert_close(actual, expected)

0 commit comments

Comments
 (0)