diff --git a/README.md b/README.md index 9afa04d3..ce753d86 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ Below shows the generation speed gain by using FastSeq. | Model | W/O FastSeq (in samples/s) | W/ FastSeq (in samples/s) | Speedup | |------------------|:--------------------------:|:-------------------------:|:-----:| | [ProphetNet](examples/prophetnet/README.md) | 2.7 | 10.3 | 3.8x | -| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 14.5 | 5.4x | -| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 6.4 | 1.9x | -| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 6.5 | 1.6x | -| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 7.5 | 1.6x | -| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 135.0 | 1.6x | +| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 13.3 | 5x | +| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 11.0 | 3.2x | +| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 13.5 | 3.4x | +| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 17.0 | 3.5x | +| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 124.0 | 1.5x | - All the following benchmarking experiments run on NVIDIA-V100-16GB with [docker](docker/Dockerfile). Highest speed recorded for each model by tuning batch size. For parameter setting details, click link of corresponding model. - `fs` stands for [Fairseq](https://github.com/pytorch/fairseq) 0.9.0 version, `hf` stands for [Huggingface Transformers](https://github.com/huggingface/transformers) 3.0.2 version. diff --git a/benchmarks/models/hf_bart.sh b/benchmarks/models/hf_bart.sh index 033ce696..893bf6a6 100755 --- a/benchmarks/models/hf_bart.sh +++ b/benchmarks/models/hf_bart.sh @@ -20,9 +20,9 @@ source utils.sh grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.4 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.9 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 10.7 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.0 100 ## Accuracy #grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82 diff --git a/benchmarks/models/hf_distibart.sh b/benchmarks/models/hf_distibart.sh index 1ff95925..4a1b95c7 100755 --- a/benchmarks/models/hf_distibart.sh +++ b/benchmarks/models/hf_distibart.sh @@ -20,9 +20,9 @@ source utils.sh grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2 -grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100 +grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.3 100 # todo: bigger bs doesn't increase speed -grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100 +grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.5 100 ## Accuracy #grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1 diff --git a/benchmarks/models/hf_mbart.sh b/benchmarks/models/hf_mbart.sh index 6b5393ab..58891834 100755 --- a/benchmarks/models/hf_mbart.sh +++ b/benchmarks/models/hf_mbart.sh @@ -14,5 +14,5 @@ source utils.sh # Accuracy grep "facebook/mbart-large-en-ro wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.79 27.95 # Speed on V100 16GB 250W -grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.8 6.2 +grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.6 7.7 grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.0 100 diff --git a/benchmarks/models/hf_t5.sh b/benchmarks/models/hf_t5.sh index 796f5903..a7817b7c 100755 --- a/benchmarks/models/hf_t5.sh +++ b/benchmarks/models/hf_t5.sh @@ -14,6 +14,5 @@ source utils.sh grep "t5-base wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.42 27.44 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.6 5.2 -grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.8 7.3 -grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.5 8.0 - +grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 14.4 14.8 +grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 16.8 17.0 diff --git a/fastseq_cli/transformers_generate.py b/fastseq_cli/transformers_generate.py index dda5eb56..bf30c28a 100644 --- a/fastseq_cli/transformers_generate.py +++ b/fastseq_cli/transformers_generate.py @@ -2,21 +2,115 @@ import argparse import json from pathlib import Path - -import torch +from multiprocessing import Process, Queue from tqdm import tqdm - -from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score +import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +GENERATE_FINISHED = 'done' +POSTPROCESS_FINISHED = None + +class TokenizeDataset(torch.utils.data.Dataset): + """Characterizes a dataset for PyTorch""" + def __init__(self, examples, tokenizer, model_name, prefix): + """Multiprocess Dataloader. + Args: + examples (List(str)): a list of input sentences. + tokenizer (AutoTokenizer): instance of AutoTokenizer. + model_name (string): model name. + prefix (string): input example prefix if any. + """ + self.examples = examples + self.tokenizer= tokenizer + self.model_name = model_name + self.prefix = prefix + self.return_tensors="pt" + self.truncation=True + self.padding="max_length" + + def __len__(self): + return len(self.examples) + + def __getitem__(self, index): + batch = self.examples[index] + if "t5" in self.model_name: + batch = self.prefix + batch + batch = self.tokenizer(batch, + return_tensors=self.return_tensors, + truncation=self.truncation, + padding=self.padding) + return batch['input_ids'], batch['attention_mask'] + +class IOProcess (Process): + """ Write detokenized output to file in order.""" + def __init__(self, msg_queue, fout): + super(IOProcess, self).__init__() + self.msg_queue = msg_queue + self.fout = fout + self.waiting_for=0 + self.dec_buf = {} + + def process_dec(self, dec): + for hypothesis in dec: + self.fout.write(hypothesis + "\n") + self.fout.flush() + + def process_buffer(self): + while self.waiting_for in self.dec_buf: + self.process_dec(self.dec_buf[self.waiting_for]) + del self.dec_buf[self.waiting_for] + self.waiting_for+=1 -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i:i + n] + def run(self): + while True: + ind, dec = self.msg_queue.get() + if dec == GENERATE_FINISHED: + break + elif ind != self.waiting_for: + self.dec_buf[ind] = dec + else: + self.process_dec(dec) + self.waiting_for+=1 + self.process_buffer() + self.process_buffer() + assert not self.dec_buf, "IO Buffer not empty" + self.msg_queue.close() + self.msg_queue.join_thread() +class PostProcess(Process): + """ Parallel detokenization """ + def __init__(self, tokenizer, data_queue, msg_queue, + skip_special_tokens, clean_up_tokenization_spaces): + super(PostProcess, self).__init__() + self.data_queue = data_queue + self.msg_queue = msg_queue + self.tokenizer = tokenizer + self.clean_up_tokenization_spaces = clean_up_tokenization_spaces + self.skip_special_tokens = skip_special_tokens + + def run(self): + while True: + ind, summaries = self.data_queue.get() + if summaries == GENERATE_FINISHED: + self.data_queue.put((-1, POSTPROCESS_FINISHED)) + break + elif summaries == POSTPROCESS_FINISHED: + self.data_queue.put((-1, POSTPROCESS_FINISHED)) + break + else: + dec = self.tokenizer.batch_decode(summaries, + skip_special_tokens = self.skip_special_tokens, + clean_up_tokenization_spaces = + self.clean_up_tokenization_spaces) + self.msg_queue.put((ind, dec)) + + self.data_queue.close() + self.data_queue.join_thread() + self.msg_queue.close() + self.msg_queue.join_thread() def generate_summaries_or_translations( examples: list, @@ -29,6 +123,10 @@ def generate_summaries_or_translations( decoder_start_token_id=None, fastseq_opt=True, no_repeat_ngram_size=None, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + preprocess_cpu_num=2, + postprocess_cpu_num=2, **gen_kwargs, ) -> None: """Run generation""" @@ -46,16 +144,29 @@ def generate_summaries_or_translations( # update config with summarization specific params use_task_specific_params(model, task) + data_queue = Queue() + msg_queue = Queue() + p_list = [] + + for i in range(postprocess_cpu_num): + p = PostProcess(tokenizer, data_queue, msg_queue, + skip_special_tokens, clean_up_tokenization_spaces) + p_list.append(p) + p.start() - for batch in tqdm(list(chunks(examples, batch_size))): - if "t5" in model_name: - batch = [model.config.prefix + text for text in batch] - batch = tokenizer(batch, - return_tensors="pt", - truncation=True, - padding="max_length").to(device) + io_process = IOProcess( msg_queue, fout) + io_process.start() + dataset = TokenizeDataset(examples, tokenizer, model_name, + model.config.prefix) + training_generator = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, num_workers = preprocess_cpu_num, + drop_last=True) + for ind, batch in tqdm(enumerate(training_generator)): + input_ids, attention_mask = batch + input_ids = input_ids.view(batch_size, -1).to(device) + attention_mask = attention_mask.view(batch_size, -1).to(device) input_ids, attention_mask = trim_batch( - **batch, pad_token_id=tokenizer.pad_token_id) + input_ids, tokenizer.pad_token_id, attention_mask) summaries = model.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -63,13 +174,14 @@ def generate_summaries_or_translations( no_repeat_ngram_size=no_repeat_ngram_size, **gen_kwargs, ) - dec = tokenizer.batch_decode(summaries, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) - for hypothesis in dec: - fout.write(hypothesis + "\n") - fout.flush() - + summaries_cpu = summaries.cpu() + data_queue.put((ind, summaries_cpu)) + data_queue.put((-1, GENERATE_FINISHED)) + for p in p_list: + p.join() + msg_queue.put((-1, GENERATE_FINISHED)) + io_process.join() + fout.close() def run_generate(): """Entrance is here.""" @@ -118,6 +230,19 @@ def run_generate(): parser.add_argument("--without_fastseq_opt", action="store_true") parser.add_argument("--no_repeat_ngram_size", type=int, default=None, required=False, help="size of no repeat ngram") + parser.add_argument("--include_special_tokens", action="store_true") + parser.add_argument("--clean_up_tokenization_spaces", action="store_true") + parser.add_argument("--preprocess_cpu_num", + type=int, + default=2, + required=False, + help="pre-processing worker threads") + parser.add_argument("--postprocess_cpu_num", + type=int, + default=2, + required=False, + help="post-processing worker threads") + args = parser.parse_args() examples = [ " " + x.rstrip() if "t5" in args.model_name else x.rstrip() @@ -137,7 +262,11 @@ def run_generate(): decoder_start_token_id=args.decoder_start_token_id, fastseq_opt=not args.without_fastseq_opt, no_repeat_ngram_size=args.no_repeat_ngram_size, - ) + skip_special_tokens=not args.include_special_tokens, + clean_up_tokenization_spaces=args.clean_up_tokenization_spaces, + preprocess_cpu_num=args.preprocess_cpu_num, + postprocess_cpu_num=args.postprocess_cpu_num, + ) if args.reference_path is None: return # Compute scores