Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Below shows the generation speed gain by using FastSeq.
| Model | W/O FastSeq (in samples/s) | W/ FastSeq (in samples/s) | Speedup |
|------------------|:--------------------------:|:-------------------------:|:-----:|
| [ProphetNet](examples/prophetnet/README.md) | 2.7 | 10.3 | 3.8x |
| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 12.5 | 4.6x |
| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 8.1 | 2.4x |
| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 8.5 | 2.1x |
| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 7.5 | 1.6x |
| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 122.0 | 1.5x |
| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 13.3 | 5x |
| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 9.9 | 2.9x |
| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 11.9 | 3x |
| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 11.0 | 2.3x |
| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 124.0 | 1.5x |

- All benchmarking experiments run on NVIDIA-V100-16GB with [docker](docker/Dockerfile). Highest speed recorded for each model by tuning batch size. For parameter setting details, click link of corresponding model.
- `fs` stands for [Fairseq](https://github.com/pytorch/fairseq) 0.9.0 version, `hf` stands for [Huggingface Transformers](https://github.com/huggingface/transformers) 3.0.2 version.
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/models/fs_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "bart.large.cnn cnn_dm.1k/len-1024.bin valid " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 10.4 10.6
# Speed on V100 16GB 250W
grep -E "fairseq_v0.9.0 bart.large.cnn cnn_dm.1k/len-1024.bin valid 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 2.3 2.8
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.1 100
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 10.9 100
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 12.5 100
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.3 100
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.4 100
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.3 100

## Accuracy
#grep "bart.large.cnn cnn_dm/len-1024.bin valid " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 17.9 18
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/models/fs_wmt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ source utils.sh
grep " wmt16.en.de.32k wmt16_en_de_bpe32k/bin test " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 0.019 0.021
# Speed on V100 16GB 250W
grep -E "fairseq_v0.9.0 wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 256 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 82 85
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 256 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 116.7 1000
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 512 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 120 1000
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 1024 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 121 1000
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 256 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 117 1000
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 512 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 123 1000
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 1024 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 123 1000
6 changes: 3 additions & 3 deletions benchmarks/models/hf_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.8 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.0 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.3 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.6 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.9 100

## Accuracy
#grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/models/hf_distibart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.5 100
# todo: bigger bs doesn't increase speed
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.9 100

## Accuracy
#grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1
Expand Down
5 changes: 2 additions & 3 deletions benchmarks/models/hf_t5.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ source utils.sh
grep "t5-base wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.42 27.44
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.6 5.2
grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.0 7.1
grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.5 7.8

grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.0 9.2
grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 10.9 11.1
47 changes: 47 additions & 0 deletions fastseq/clib/cuda/ngram_repeat_block_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/

#include <torch/extension.h>
#include <vector>

/*
CPP Binding for CUDA OP
*/

// CUDA forward declarations
torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens,
torch::Tensor lprobs, int bsz,
int step, int beam_size,
int no_repeat_ngram_size);

#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

// Input check and call to CUDA OP
// Backward method not required
torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens,
torch::Tensor lprobs, int bsz,
int step, int beam_size,
int no_repeat_ngram_size) {
CHECK_INPUT(tokens);
CHECK_INPUT(lprobs);
assert(bsz > 0);
assert(step >= 0);
assert(beam_size > 0);
assert(no_repeat_ngram_size > 0);

return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size,
no_repeat_ngram_size);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &ngram_repeat_block_forward,
"No Repeat Ngram Block forward (CUDA)");
}
76 changes: 76 additions & 0 deletions fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/

/*
Kernel implementation for blocking repeated n-grams.
*/

#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <torch/extension.h>
#include <vector>

// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__ void banRepeatedTokens(long* __restrict__ tokens,
float* __restrict__ lprobs,
int max_predict_len, int vocab_size,
int no_repeat_ngram_size) {
auto row = blockIdx.x;
auto col = threadIdx.x;
auto start = row * (max_predict_len) + col;
// Each thread compares ngram starting from
// thread index with final ngram starting from
// step - no_repeat_ngram_size +2
auto check_start_pos = blockDim.x;
auto lprob_start = row * vocab_size;
bool is_banned = true;
extern __shared__ long tokens_shm[];
tokens_shm[col] = tokens[start];
if (col == blockDim.x - 1) {
for (int i=1; i<no_repeat_ngram_size; i++){
if (col+i < max_predict_len){
tokens_shm[col + i] = tokens[start + i];
}
}
}
__syncthreads();

for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
is_banned = false;
}
}
if (is_banned == true) {
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
}
}

// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor tokens,
torch::Tensor lprobs, int bsz,
int step, int beam_size,
int no_repeat_ngram_size) {
int threads = step - no_repeat_ngram_size + 2;
if (threads <= 0) return lprobs;
int max_predict_len = tokens.size(1);
int vocab_size = lprobs.size(1);
auto token_ptr = tokens.data_ptr<long>();
auto lprob_ptr = lprobs.data_ptr<float>();
int blocks = bsz * beam_size;
int shared_mem_size = (step + 1) * sizeof(long);

// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
return lprobs;
}
58 changes: 58 additions & 0 deletions fastseq/ops/ngram_repeat_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

""" Wrapper for ngram_repeat_block cuda extension """
from torch import nn
from torch.autograd import Function
import ngram_repeat_block_cuda

class NGramRepeatBlockFunction(Function):
"""
forward inputs to ngram_repeat_block cuda extension
backward method not needed.

"""
def forward(self, tokens, lprobs, bsz,
step, beam_size, no_repeat_ngram_size):
"""
Args:
tokens(Tensor): Input tokens(Bsz*beam, seq_len)
lprobs(Tensor): likelihood probability
Expected to be updated in place.(Bsz*beam, vocab_size)
bsz(int): batch size
step(int): current step
beam_size(int): beam size
no_repeat_ngram_size(int): Ngram size
"""
outputs = ngram_repeat_block_cuda.forward(tokens,
lprobs, bsz, step, beam_size, no_repeat_ngram_size)
return outputs

def backward (*args):
raise NotImplementedError

class NGramRepeatBlock(nn.Module):
""" Wrapper class for calling ngram_repeat_block cuda extension """
def __init__(self):
super(NGramRepeatBlock, self).__init__()

def reset_parameters(self):
pass

def forward(self, tokens, lprobs, bsz,
step, beam_size, no_repeat_ngram_size):
"""
Args:
tokens(Tensor): Input tokens(Bsz*beam, seq_len)
lprobs(Tensor): likelihood probability,
Expected to be updated in place.(Bsz*beam, vocab_size)
bsz(int): batch size
step(int): current step
beam_size(int): beam size
no_repeat_ngram_size(int): Ngram size
"""
assert tokens.size(0)== bsz*beam_size
assert lprobs.size(0)== bsz*beam_size

return NGramRepeatBlockFunction.apply(tokens, lprobs,
bsz, step, beam_size, no_repeat_ngram_size)
42 changes: 5 additions & 37 deletions fastseq/optimizer/fairseq/beam_search_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from fairseq.models.transformer import TransformerEncoder, TransformerModel
from fairseq.modules.multihead_attention import MultiheadAttention
from fairseq.sequence_generator import SequenceGenerator

from fastseq.ops.ngram_repeat_block import NGramRepeatBlock
from fastseq.utils.api_decorator import replace

@replace(TransformerEncoder)
Expand Down Expand Up @@ -429,6 +429,7 @@ def _generate(self,
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
self.no_repeat_ngram_op = NGramRepeatBlock()

if self.match_source_len:
max_len = src_lengths.max().item()
Expand Down Expand Up @@ -640,24 +641,6 @@ def replicate_first_beam(tensor, mask):
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf

if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
banned_list = [[] for bbsz_idx in range(bsz * beam_size)]
cpu_tokens = tokens.cpu()[:, :step + 1].numpy()
check_start_pos = step + 2 - self.no_repeat_ngram_size
for bbsz_idx in range(bsz * beam_size):
for i in range(check_start_pos):
is_banned = True
for k in range(self.no_repeat_ngram_size - 1):
if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[
bbsz_idx, check_start_pos + k]:
is_banned = False
break
if is_banned:
banned_list[bbsz_idx].append(
cpu_tokens[bbsz_idx,
i + self.no_repeat_ngram_size - 1])

# Record attention scores
if avg_attn_scores is not None:
if attn is None:
Expand All @@ -674,24 +657,9 @@ def replicate_first_beam(tensor, mask):
self.search.set_src_lengths(src_lengths)

if self.no_repeat_ngram_size > 0:

def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
banned_tokens_per_sample = [
(bbsz_idx, t) for t in banned_list[bbsz_idx]
]
return banned_tokens_per_sample

banned_tokens = []
if step + 2 - self.no_repeat_ngram_size >= 0:
for bbsz_idx in range(bsz * beam_size):
banned_tokens.extend(calculate_banned_tokens(bbsz_idx))

if banned_tokens:
banned_tokens = torch.LongTensor(banned_tokens)
lprobs.index_put_(
tuple(banned_tokens.t()),
lprobs.new_tensor([-math.inf] * len(banned_tokens)))
#Applying Cuda Op for NGram repeat Blocking
lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step,
beam_size, self.no_repeat_ngram_size)

cand_scores, cand_indices, cand_beams = self.search.step(
step,
Expand Down
18 changes: 7 additions & 11 deletions fastseq/optimizer/transformers/beam_search_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from transformers.modeling_bart import BartForConditionalGeneration
from transformers.modeling_t5 import T5ForConditionalGeneration
from fastseq.ops.ngram_repeat_block import NGramRepeatBlock

from fastseq.logging import get_logger
from fastseq.utils.api_decorator import replace
Expand Down Expand Up @@ -648,17 +649,9 @@ def _update_scores(banned_tokens):

cpu_input_ids = input_ids.cpu()
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively
# generating the same ngrams
num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_ngram_tokens = calc_banned_ngram_tokens_v2(
cpu_input_ids,
num_batch_hypotheses,
no_repeat_ngram_size,
cur_len,
self.config.pad_token_id)
_update_scores(banned_ngram_tokens)
#custom op for Ngram repeat blocking
scores = self.no_repeat_ngram_op(input_ids,scores.float(),
batch_size, cur_len-1, num_beams, no_repeat_ngram_size)

if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
Expand Down Expand Up @@ -721,6 +714,9 @@ def _generate_beam_search(
# done sentences
done = [False for _ in range(batch_size)]

#NGram Repeat block Op
self.no_repeat_ngram_op = NGramRepeatBlock()#.to('cuda', torch.float32)

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask,
Expand Down
Loading