|
1 | 1 | import torch
|
2 |
| -from parameterized import parameterized |
3 | 2 | from torchtext.models import (
|
4 | 3 | XLMR_BASE_ENCODER,
|
5 | 4 | XLMR_LARGE_ENCODER,
|
|
8 | 7 | )
|
9 | 8 |
|
10 | 9 | from ..common.assets import get_asset_path
|
| 10 | +from ..common.parameterized_utils import nested_params |
11 | 11 | from ..common.torchtext_test_case import TorchtextTestCase
|
12 | 12 |
|
13 | 13 | TEST_MODELS_PARAMETERIZED_ARGS = [
|
|
27 | 27 |
|
28 | 28 |
|
29 | 29 | class TestModels(TorchtextTestCase):
|
30 |
| - @parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS) |
31 |
| - def test_model(self, expected_asset_name, test_text, model_bundler): |
| 30 | + @nested_params( |
| 31 | + [ |
| 32 | + ("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER), |
| 33 | + ("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER), |
| 34 | + ( |
| 35 | + "roberta.base.output.pt", |
| 36 | + "Roberta base Model Comparison", |
| 37 | + ROBERTA_BASE_ENCODER, |
| 38 | + ), |
| 39 | + ( |
| 40 | + "roberta.large.output.pt", |
| 41 | + "Roberta base Model Comparison", |
| 42 | + ROBERTA_LARGE_ENCODER, |
| 43 | + ), |
| 44 | + ], |
| 45 | + [True, False], |
| 46 | + ) |
| 47 | + def test_model(self, model_args, is_jit): |
| 48 | + """Verify pre-trained XLM-R and Roberta models in torchtext produce |
| 49 | + the same output as the reference implementation within fairseq |
| 50 | + """ |
| 51 | + expected_asset_name, test_text, model_bundler = model_args |
| 52 | + |
32 | 53 | expected_asset_path = get_asset_path(expected_asset_name)
|
33 | 54 |
|
34 | 55 | transform = model_bundler.transform()
|
35 | 56 | model = model_bundler.get_model()
|
36 | 57 | model = model.eval()
|
37 | 58 |
|
| 59 | + if is_jit: |
| 60 | + transform = torch.jit.script(transform) |
| 61 | + model = torch.jit.script(model) |
| 62 | + |
38 | 63 | model_input = torch.tensor(transform([test_text]))
|
39 | 64 | actual = model(model_input)
|
40 | 65 | expected = torch.load(expected_asset_path)
|
41 | 66 | torch.testing.assert_close(actual, expected)
|
42 |
| - |
43 |
| - @parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS) |
44 |
| - def test_model_jit(self, expected_asset_name, test_text, model_bundler): |
45 |
| - expected_asset_path = get_asset_path(expected_asset_name) |
46 |
| - |
47 |
| - transform = model_bundler.transform() |
48 |
| - transform_jit = torch.jit.script(transform) |
49 |
| - model = model_bundler.get_model() |
50 |
| - model = model.eval() |
51 |
| - model_jit = torch.jit.script(model) |
52 |
| - |
53 |
| - model_input = torch.tensor(transform_jit([test_text])) |
54 |
| - actual = model_jit(model_input) |
55 |
| - expected = torch.load(expected_asset_path) |
56 |
| - torch.testing.assert_close(actual, expected) |
|
0 commit comments