diff --git a/README.md b/README.md index c1bee7ce..6f639a8b 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,9 @@ Below shows the generation speed gain by using FastSeq. |------------------|:--------------------------:|:-------------------------:|:-----:| | [ProphetNet](examples/prophetnet/README.md) | 2.8 | 11.3 | 4.0x | | [Bart (`fs`)](examples/bart/README.md) | 2.4 | 19.7 | 8.2x | -| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.5 | 11.4 | 3.3x | -| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.3 | 13.8 | 3.2x | -| [T5 (`hf`)](examples/t5/README.md) | 5.0 | 11.5 | 2.3x | +| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.5 | 12.4 | 3.5x | +| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.3 | 18.3 | 4.3x | +| [T5 (`hf`)](examples/t5/README.md) | 5.0 | 23.4 | 4.7x | | [WMT16 En-De (`fs`)](examples/wmt/README.md) | 96.0 | 417.0 | 4.3x | - All benchmarking experiments run on NVIDIA-V100-16GB with [docker](docker/Dockerfile). Highest speed recorded for each model by tuning batch size. For parameter setting details, click link of corresponding model. diff --git a/benchmarks/models/fs_wmt.sh b/benchmarks/models/fs_wmt.sh index e628f6f5..e51d662b 100755 --- a/benchmarks/models/fs_wmt.sh +++ b/benchmarks/models/fs_wmt.sh @@ -10,7 +10,7 @@ source utils.sh # MODEL - wmt16 ./benchmark.sh fairseq wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256 -./benchmark.sh fairseq+fastseq wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256/512/1024 --post-process-workers 5 +./benchmark.sh fairseq+fastseq wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256/512/1024 --postprocess-workers 5 # Accuracy grep " wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 0.05 0.07 # Speed on V100 16GB 250W diff --git a/benchmarks/models/hf_bart.sh b/benchmarks/models/hf_bart.sh index 9554ac73..54847a3f 100755 --- a/benchmarks/models/hf_bart.sh +++ b/benchmarks/models/hf_bart.sh @@ -10,25 +10,25 @@ source utils.sh # MODEL - bart large cnn from transformer # TASK - cnn dm val 1k set -./benchmark.sh transformers facebook/bart-large-cnn cnn_dm.1k/raw val 32 --task summarization # each loop 5 minutes -./benchmark.sh transformers+fastseq facebook/bart-large-cnn cnn_dm.1k/raw val 32/64/128 --task summarization # each loop 8 minutes +#./benchmark.sh transformers facebook/bart-large-cnn cnn_dm.1k/raw val 32 --task summarization # each loop 5 minutes +#./benchmark.sh transformers+fastseq facebook/bart-large-cnn cnn_dm.1k/raw val 32/64/128 --task summarization # each loop 8 minutes ## TASK - cnn dm val full set -#./benchmark.sh transformers facebook/bart-large-cnn cnn_dm/raw val 32 --task summarization # each loop 2 hours -#./benchmark.sh transformers+fastseq facebook/bart-large-cnn cnn_dm/raw val 32/64/128 --task summarization # each loop 2 hours +./benchmark.sh transformers facebook/bart-large-cnn cnn_dm/raw val 32 --task summarization # each loop 2 hours +./benchmark.sh transformers+fastseq facebook/bart-large-cnn cnn_dm/raw val 32/64/128 --task summarization # each loop 2 hours # Accuracy -grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35 -# Speed on V100 16GB 250W -grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.3 3.7 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.3 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.6 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.9 100 - -## Accuracy -#grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82 +#grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35 ## Speed on V100 16GB 250W -#grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 2.2 2.4 -#grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 100 -#grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.5 100 -#grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.9 100 +#grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.3 3.7 +#grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.6 100 +#grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.3 100 +#grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 12.4 100 + +# Accuracy +grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82 +# Speed on V100 16GB 250W +grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 2.2 2.4 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.6 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.3 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 12.4 100 diff --git a/benchmarks/models/hf_distibart.sh b/benchmarks/models/hf_distibart.sh index ecefc310..c33ccd53 100755 --- a/benchmarks/models/hf_distibart.sh +++ b/benchmarks/models/hf_distibart.sh @@ -10,24 +10,24 @@ source utils.sh # MODEL - distibart cnn # TASK - cnn dm val 1k set -./benchmark.sh transformers hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64 --task summarization # each loop takes 7 minutes -./benchmark.sh transformers+fastseq hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64/128 --task summarization # each loop takes 7 minutes +#./benchmark.sh transformers hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64 --task summarization # each loop takes 7 minutes +#./benchmark.sh transformers+fastseq hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64/128 --task summarization # each loop takes 7 minutes ## TASK - cnn dm val full set -#./benchmark.sh transformers hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64 --task summarization # each loop takes 2.5 hours -#./benchmark.sh transformers+fastseq hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64/128 --task summarization # each loop takes 2.5 hours +./benchmark.sh transformers hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64 --task summarization # each loop takes 2.5 hours +./benchmark.sh transformers+fastseq hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64/128 --task summarization # each loop takes 2.5 hours # Accuracy -grep "hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3 -# Speed on V100 16GB 250W -grep -E "transformers_v3.0.2 hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.0 6.0 -grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13 100 -# todo: bigger bs doesn't increase speed -grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.5 100 - -## Accuracy -#grep "hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1 +#grep "hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3 ## Speed on V100 16GB 250W -#grep -E "transformers_v3.0.2 hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 2.95 3.05 -#grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100 +#grep -E "transformers_v3.0.2 hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.0 6.0 +#grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 16.4 100 ## todo: bigger bs doesn't increase speed -#grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100 +#grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 18.4 100 + +# Accuracy +grep "hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1 +# Speed on V100 16GB 250W +grep -E "transformers_v3.0.2 hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 2.95 3.05 +grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 16.5 100 +# todo: bigger bs doesn't increase speed +grep -E "transformers_v3.0.2\+fastseq_v.* hf.sshleifer.distilbart-cnn-12-6.tar.gz cnn_dm/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 18.3 100 diff --git a/benchmarks/models/hf_mbart.sh b/benchmarks/models/hf_mbart.sh index 128879cc..0899a938 100755 --- a/benchmarks/models/hf_mbart.sh +++ b/benchmarks/models/hf_mbart.sh @@ -15,4 +15,4 @@ source utils.sh grep "facebook/mbart-large-en-ro wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.79 27.95 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.0 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.2 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.3 100 diff --git a/benchmarks/models/hf_t5.sh b/benchmarks/models/hf_t5.sh index c4f00d01..d81dc8b2 100755 --- a/benchmarks/models/hf_t5.sh +++ b/benchmarks/models/hf_t5.sh @@ -14,5 +14,5 @@ source utils.sh grep "t5-base wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.42 27.44 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.6 5.2 -grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.3 100 -grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11 100 +grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 19.3 100 +grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 23.4 100 diff --git a/examples/bart/README.md b/examples/bart/README.md index 4c84bdac..b8fc16a1 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -63,7 +63,7 @@ Refer to [file](../../tests/optimizer/fairseq/test_fairseq_optimizer.py). | BatchSize | 32 | 64 | 128 | |:-------------------:|:-------------:|:--------------:|:--------------:| | transformers-3.0.2 | 3.5 samples/s | OOM | OOM | - | above + fastseq | 7.9 samples/s | 10.7 samples/s | 11.4 samples/s | + | above + fastseq | 7.6 samples/s | 11.3 samples/s | 12.4 samples/s | ### Model diff --git a/examples/distilbart/README.md b/examples/distilbart/README.md index aa659c53..9913d715 100644 --- a/examples/distilbart/README.md +++ b/examples/distilbart/README.md @@ -11,7 +11,7 @@ More info can be found [here](https://github.com/huggingface/transformers/blob/m | BatchSize | 64 | 128 | |:-------------------:|:--------------:|:--------------:| | transformers-3.0.2 | 4.3 samples/s | OOM | - | above + fastseq | 13.3 samples/s | 13.8 samples/s | + | above + fastseq | 16.5 samples/s | 18.3 samples/s | ### Model diff --git a/examples/t5/README.md b/examples/t5/README.md index aa6e17a7..5114a1ed 100644 --- a/examples/t5/README.md +++ b/examples/t5/README.md @@ -10,7 +10,7 @@ The T5 model was presented in [Exploring the Limits of Transfer Learning with a | BatchSize | 64 | 128 | |:--------------------:|:---------------:|:--------------:| | ransformers_v3.0.2 | 5.0 samples/s | OOM | - | above + fastseq | 9.6 samples/s | 11.5 samples/s | + | above + fastseq | 19.3 samples/s | 23.4 samples/s | ### Model diff --git a/examples/wmt/README.md b/examples/wmt/README.md index 8582ba4e..55bd3e71 100644 --- a/examples/wmt/README.md +++ b/examples/wmt/README.md @@ -89,6 +89,6 @@ $ fastseq-generate-for-fairseq \ --lenpen 0.6 \ --remove-bpe \ --gen-subset test \ - --post-process-workers 5 + --postprocess-workers 5 ``` -To get baseline speed number which doesn't use FastSeq optimizations, replace `fastseq-generate-for-fairseq` by `fairseq-generate` and remove argument `--post-process-workers 5` since it is only provided by fastseq. +To get baseline speed number which doesn't use FastSeq optimizations, replace `fastseq-generate-for-fairseq` by `fairseq-generate` and remove argument `--postprocess-workers 5` since it is only provided by fastseq. diff --git a/fastseq/optimizer/fairseq/generate.py b/fastseq/optimizer/fairseq/generate.py index 2e14175c..ffaefa9a 100755 --- a/fastseq/optimizer/fairseq/generate.py +++ b/fastseq/optimizer/fairseq/generate.py @@ -272,7 +272,7 @@ def add_generation_args_v1(parser): group = original_add_generation_args(parser) # fmt: off group.add_argument( - '--post-process-workers', + '--postprocess-workers', default=1, type=int, choices=range(1, 128, 1), @@ -354,7 +354,7 @@ def main_v1(args): message_queue = JoinableQueue() p_list = [] - for i in range(args.post_process_workers): + for i in range(args.postprocess_workers): p = PostProcess(args, task, data_queue, message_queue) p_list.append(p) p.start() diff --git a/fastseq_cli/transformers_generate.py b/fastseq_cli/transformers_generate.py index dda5eb56..8d9e6695 100644 --- a/fastseq_cli/transformers_generate.py +++ b/fastseq_cli/transformers_generate.py @@ -2,21 +2,129 @@ import argparse import json from pathlib import Path - -import torch +from multiprocessing import Process, Queue from tqdm import tqdm - -from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score +import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +GENERATE_FINISHED = 'done' +POSTPROCESS_FINISHED = None + +class TokenizeDataset(torch.utils.data.Dataset): + """Characterizes a dataset for PyTorch""" + def __init__(self, examples, tokenizer, model_name, prefix, + return_tensors, truncation, padding): + """Multiprocess Dataloader. + Args: + examples (List(str)): a list of input sentences. + tokenizer (AutoTokenizer): instance of AutoTokenizer. + model_name (string): model name. + prefix (string): input example prefix if any. + """ + self.examples = examples + self.tokenizer= tokenizer + self.model_name = model_name + self.prefix = prefix + self.return_tensors=return_tensors + self.truncation=truncation + self.padding=padding + + def __len__(self): + return len(self.examples) + + def __getitem__(self, index): + batch = self.examples[index] + if "t5" in self.model_name: + batch = self.prefix + batch + batch = self.tokenizer(batch, + return_tensors=self.return_tensors, + truncation=self.truncation, + padding=self.padding) + return batch['input_ids'], batch['attention_mask'] + +class IOProcess (Process): + """ Write detokenized output to file in order.""" + def __init__(self, msg_queue, fout): + """Async output writer. + Args: + msg_queue : Multiprocess message Queue + fout : output file pointer. + """ + super(IOProcess, self).__init__() + self.msg_queue = msg_queue + self.fout = fout + self.waiting_for=0 + self.dec_buf = {} + + def process_dec(self, dec): + for hypothesis in dec: + self.fout.write(hypothesis + "\n") + self.fout.flush() + + def process_buffer(self): + while self.waiting_for in self.dec_buf: + self.process_dec(self.dec_buf[self.waiting_for]) + del self.dec_buf[self.waiting_for] + self.waiting_for+=1 -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i:i + n] + def run(self): + while True: + ind, dec = self.msg_queue.get() + if dec == GENERATE_FINISHED: + break + elif ind != self.waiting_for: + self.dec_buf[ind] = dec + else: + self.process_dec(dec) + self.waiting_for+=1 + self.process_buffer() + self.process_buffer() + assert not self.dec_buf, "IO Buffer not empty" + self.msg_queue.close() + self.msg_queue.join_thread() +class PostProcess(Process): + """ Parallel detokenization """ + def __init__(self, tokenizer, data_queue, msg_queue, + skip_special_tokens, clean_up_tokenization_spaces): + """Async Postprocess. + Args: + data_queue : Multiprocess data Queue + msg_queue : Multiprocess message queue + tokenizer : tokenizer + clean_up_tokenization_spaces : clean_up_tokenization_spaces? + skip_special_tokens = skip_special_tokens? + """ + super(PostProcess, self).__init__() + self.data_queue = data_queue + self.msg_queue = msg_queue + self.tokenizer = tokenizer + self.clean_up_tokenization_spaces = clean_up_tokenization_spaces + self.skip_special_tokens = skip_special_tokens + + def run(self): + while True: + ind, summaries = self.data_queue.get() + if summaries == GENERATE_FINISHED: + self.data_queue.put((-1, POSTPROCESS_FINISHED)) + break + elif summaries == POSTPROCESS_FINISHED: + self.data_queue.put((-1, POSTPROCESS_FINISHED)) + break + else: + dec = self.tokenizer.batch_decode(summaries, + skip_special_tokens = self.skip_special_tokens, + clean_up_tokenization_spaces = + self.clean_up_tokenization_spaces) + self.msg_queue.put((ind, dec)) + + self.data_queue.close() + self.data_queue.join_thread() + self.msg_queue.close() + self.msg_queue.join_thread() def generate_summaries_or_translations( examples: list, @@ -29,6 +137,13 @@ def generate_summaries_or_translations( decoder_start_token_id=None, fastseq_opt=True, no_repeat_ngram_size=None, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + preprocess_workers=2, + postprocess_workers=2, + return_tensors="pt", + truncation=True, + padding="max_length", **gen_kwargs, ) -> None: """Run generation""" @@ -46,16 +161,29 @@ def generate_summaries_or_translations( # update config with summarization specific params use_task_specific_params(model, task) + data_queue = Queue() + msg_queue = Queue() + p_list = [] + + for i in range(postprocess_workers): + p = PostProcess(tokenizer, data_queue, msg_queue, + skip_special_tokens, clean_up_tokenization_spaces) + p_list.append(p) + p.start() - for batch in tqdm(list(chunks(examples, batch_size))): - if "t5" in model_name: - batch = [model.config.prefix + text for text in batch] - batch = tokenizer(batch, - return_tensors="pt", - truncation=True, - padding="max_length").to(device) + io_process = IOProcess( msg_queue, fout) + io_process.start() + dataset = TokenizeDataset(examples, tokenizer, model_name, + model.config.prefix, return_tensors, truncation, padding) + training_generator = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, num_workers = preprocess_workers, + drop_last=False) + for ind, batch in tqdm(enumerate(training_generator)): + input_ids, attention_mask = batch + input_ids = input_ids.view(input_ids.size(0), -1).to(device) + attention_mask = attention_mask.view(input_ids.size(0), -1).to(device) input_ids, attention_mask = trim_batch( - **batch, pad_token_id=tokenizer.pad_token_id) + input_ids, tokenizer.pad_token_id, attention_mask) summaries = model.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -63,13 +191,14 @@ def generate_summaries_or_translations( no_repeat_ngram_size=no_repeat_ngram_size, **gen_kwargs, ) - dec = tokenizer.batch_decode(summaries, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) - for hypothesis in dec: - fout.write(hypothesis + "\n") - fout.flush() - + summaries_cpu = summaries.cpu() + data_queue.put((ind, summaries_cpu)) + data_queue.put((-1, GENERATE_FINISHED)) + for p in p_list: + p.join() + msg_queue.put((-1, GENERATE_FINISHED)) + io_process.join() + fout.close() def run_generate(): """Entrance is here.""" @@ -118,6 +247,24 @@ def run_generate(): parser.add_argument("--without_fastseq_opt", action="store_true") parser.add_argument("--no_repeat_ngram_size", type=int, default=None, required=False, help="size of no repeat ngram") + parser.add_argument("--include_special_tokens", action="store_true") + parser.add_argument("--clean_up_tokenization_spaces", action="store_true") + parser.add_argument("--preprocess_workers", + type=int, + default=2, + required=False, + help="pre-processing worker threads") + parser.add_argument("--postprocess_workers", + type=int, + default=1, + required=False, + help="post-processing worker threads") + parser.add_argument("--no_truncation", action="store_true") + parser.add_argument("--return_tensors", type=str, help="specify return tensors", + default="pt", required=False) + parser.add_argument("--padding", type=str, help="specify padding", + default="max_length", required=False) + args = parser.parse_args() examples = [ " " + x.rstrip() if "t5" in args.model_name else x.rstrip() @@ -137,7 +284,14 @@ def run_generate(): decoder_start_token_id=args.decoder_start_token_id, fastseq_opt=not args.without_fastseq_opt, no_repeat_ngram_size=args.no_repeat_ngram_size, - ) + skip_special_tokens=not args.include_special_tokens, + clean_up_tokenization_spaces=args.clean_up_tokenization_spaces, + preprocess_workers=args.preprocess_workers, + postprocess_workers=args.postprocess_workers, + return_tensors=args.return_tensors, + truncation=not args.no_truncation, + padding=args.padding + ) if args.reference_path is None: return # Compute scores