Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@ jobs:
echo "******* Running fairseq unittests *******"
bash tests/run_fairseq_tests.sh
echo "******* Running transformers unittests *******"
bash tests/run_transformers_tests.sh
#bash tests/run_transformers_tests.sh
echo "******* Running fastseq unittests *******"
pip install pytorch-transformers==1.0.0
python -m unittest discover -s tests/ -p 'test_*.py' -v
#pip install pytorch-transformers==1.0.0
#python -m unittest discover -s tests/ -p 'test_*.py' -v
#cd benchmarks/
#CUDA_VISIBLE_DEVICES=3 run_all_benchmarks.sh
displayName: 'run fastseq unit tests'
- task: PublishTestResults@2
condition: succeededOrFailed()
inputs:
testRunTitle: 'Publish test results for Python $(python.version)'
testResultsFiles: '/tmp/fastseq_tests/*.xml'
failTaskOnFailedTests: true

5 changes: 2 additions & 3 deletions fastseq/optimizer/fairseq/beam_search_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
from typing import Optional

import unittest
import torch
import torch.nn.functional as F
from torch import Tensor
Expand Down Expand Up @@ -50,7 +50,6 @@ class MultiheadAttentionV2(MultiheadAttention):

See "Attention Is All You Need" for more details.
"""

def __init__(self,
embed_dim,
num_heads,
Expand Down Expand Up @@ -159,7 +158,7 @@ def forward(
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
k v = None
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
Expand Down
14 changes: 13 additions & 1 deletion fastseq/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,30 @@

"""Utilities to make it easy to add unit tests"""

from inspect import getframeinfo, stack
import os
from statistics import mean, stdev
import time

from absl.testing import parameterized
from absl import flags
from absl.testing import absltest, parameterized

from fastseq.config import FASTSEQ_CACHE_DIR
from fastseq.logging import get_logger
from fastseq.utils.api_decorator import get_class

logger = get_logger(__name__)

FLAGS = flags.FLAGS

def fastseq_test_main():
caller = getframeinfo(stack()[1][0])
xml_log_file = caller.filename.replace(os.sep, '_').replace('.py', '.xml')
xml_log_file = os.path.join(os.sep, 'tmp', 'fastseq_tests', xml_log_file)
FLAGS.xml_output_file = xml_log_file
logger.info(f"Fastseq unit test log output filepath: {xml_log_file}")
absltest.main()

class TestCaseBase(parameterized.TestCase):
"""Base class used for unittest."""

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_prophetnet_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from fastseq.utils.file_utils import decompress_file, make_dirs, wget
from fastseq.utils.test_utils import (PROPHETNET_MODEL_URLS,
CACHED_PROPHETNET_MODEL_PATHS,
TestCaseBase)
fastseq_test_main, TestCaseBase)

logger = get_logger(__name__)

Expand Down Expand Up @@ -136,4 +136,4 @@ def test_beam_search_optimizer(self, beam_size, batch_size, need_attn,
self.assertEqual(output, self.expected_outputs[i])

if __name__ == "__main__":
absltest.main()
fastseq_test_main()
4 changes: 2 additions & 2 deletions tests/optimizer/fairseq/benchmark_fairseq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fastseq.utils.file_utils import decompress_file, make_dirs, wget
from fastseq.utils.test_utils import (BART_MODEL_URLS, CACHED_BART_MODEL_DIR,
CACHED_BART_MODEL_PATHS, BenchmarkBase,
benchmark)
benchmark, fastseq_test_main)

logger = get_logger(__name__)

Expand Down Expand Up @@ -128,4 +128,4 @@ def test_beam_search_optimizer(self, beam_size, batch_size, need_attn,


if __name__ == "__main__":
absltest.main()
fastseq_test_main()
5 changes: 3 additions & 2 deletions tests/optimizer/fairseq/test_fairseq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from fastseq.logging import get_logger
from fastseq.utils.file_utils import decompress_file, make_dirs, wget
from fastseq.utils.test_utils import (BART_MODEL_URLS, CACHED_BART_MODEL_DIR,
CACHED_BART_MODEL_PATHS, TestCaseBase)
CACHED_BART_MODEL_PATHS,
fastseq_test_main, TestCaseBase)

logger = get_logger(__name__)

Expand Down Expand Up @@ -117,4 +118,4 @@ def test_beam_search_optimizer(self, beam_size, batch_size, need_attn,


if __name__ == "__main__":
absltest.main()
fastseq_test_main()
4 changes: 2 additions & 2 deletions tests/optimizer/transformers/test_bart_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from absl.testing import absltest, parameterized
from transformers import BartForConditionalGeneration, BartTokenizer

from fastseq.utils.test_utils import TestCaseBase
from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase


class BARTOptimizerTest(TestCaseBase):
Expand Down Expand Up @@ -183,4 +183,4 @@ def test_beam_search_optimizer(self,


if __name__ == "__main__":
absltest.main()
fastseq_test_main()
4 changes: 2 additions & 2 deletions tests/optimizer/transformers/test_t5_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import fastseq
from fastseq.logging import get_logger
from fastseq.utils.test_utils import TestCaseBase
from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase
from transformers import (T5ForConditionalGeneration, T5Tokenizer)


Expand Down Expand Up @@ -184,4 +184,4 @@ def test_beam_search_optimizer(self,


if __name__ == "__main__":
absltest.main()
fastseq_test_main()
4 changes: 3 additions & 1 deletion tests/run_fairseq_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import shutil
import unittest
import xmlrunner
from git import Repo
from absl.testing import absltest, parameterized
from pip._internal import main as pipmain
Expand Down Expand Up @@ -77,4 +78,5 @@ def test_suites(self, without_fastseq_opt, fairseq_version, blocked_tests):
assert len(test_result.errors) == 0

if __name__ == "__main__":
absltest.main()
unittest.main(
testRunner=xmlrunner.XMLTestRunner(output='/tmp/fastseq_tests/'))
1 change: 1 addition & 0 deletions tests/run_fairseq_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pip install packaging
cd ${FASTSEQ_TEST_PATH}/../
pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --editable .
pip install unittest-xml-reporting
cd tests
python run_fairseq_tests.py
deactivate
4 changes: 2 additions & 2 deletions tests/utils/test_api_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from absl.testing import absltest, parameterized
from fastseq.utils.api_decorator import get_class, override_method, add_method, export_api, replace
from fastseq.utils.test_utils import TestCaseBase
from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase


class A:
Expand Down Expand Up @@ -152,4 +152,4 @@ def name(self):


if __name__ == "__main__":
absltest.main()
fastseq_test_main()
4 changes: 2 additions & 2 deletions tests/utils/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from absl.testing import absltest, parameterized

from fastseq.utils.file_utils import decompress_file, get_temp_dir, make_dirs, wget
from fastseq.utils.test_utils import TestCaseBase
from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase


class FileUtilsTest(TestCaseBase):
Expand Down Expand Up @@ -90,4 +90,4 @@ def test_wget_and_decompress_file(self, tar_file_url, tar_file_name,


if __name__ == "__main__":
absltest.main()
fastseq_test_main()