From e46bcd2ce82c159c49b276ac7946f4e2cca3c397 Mon Sep 17 00:00:00 2001 From: NickNickGo Date: Tue, 15 Dec 2020 21:11:32 +0000 Subject: [PATCH] exception guarding cuda extension import --- .../fairseq/beam_search_optimizer.py | 101 +++++++++--------- .../transformers/beam_search_optimizer.py | 12 ++- 2 files changed, 62 insertions(+), 51 deletions(-) diff --git a/fastseq/optimizer/fairseq/beam_search_optimizer.py b/fastseq/optimizer/fairseq/beam_search_optimizer.py index f984f175..96ee993f 100644 --- a/fastseq/optimizer/fairseq/beam_search_optimizer.py +++ b/fastseq/optimizer/fairseq/beam_search_optimizer.py @@ -15,7 +15,6 @@ from fairseq.modules.multihead_attention import MultiheadAttention from fairseq.search import BeamSearch from fairseq.sequence_generator import SequenceGenerator -from fastseq.ops.ngram_repeat_block import NGramRepeatBlock from fastseq.utils.api_decorator import replace @replace(BeamSearch) @@ -432,6 +431,49 @@ class SequenceGeneratorV2(SequenceGenerator): Sequence Generator is optimized by reducing the cached memory usage during the encoding period for beam search. """ + @torch.no_grad() + def apply_no_repeat_ngram_cpu(self, tokens,lprobs, bsz,step, + beam_size, no_repeat_ngram_size): + """ Fairseq implementation of blocking + repeated ngrams + """ + banned_list = [[] for bbsz_idx in range(bsz * beam_size)] + cpu_tokens = tokens.cpu()[:, :step + 1].numpy() + check_start_pos = step + 2 - no_repeat_ngram_size + for bbsz_idx in range(bsz * beam_size): + for i in range(check_start_pos): + is_banned = True + for k in range(no_repeat_ngram_size - 1): + if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[ + bbsz_idx, check_start_pos + k]: + is_banned = False + break + if is_banned: + banned_list[bbsz_idx].append( + cpu_tokens[bbsz_idx, + i + no_repeat_ngram_size - 1]) + + def calculate_banned_tokens(bbsz_idx): + """before decoding the next token, prevent decoding + of ngrams that have already appeared + """ + banned_tokens_per_sample = [ + (bbsz_idx, t) for t in banned_list[bbsz_idx] + ] + return banned_tokens_per_sample + + banned_tokens = [] + if step + 2 - no_repeat_ngram_size >= 0: + for bbsz_idx in range(bsz * beam_size): + banned_tokens.extend(calculate_banned_tokens(bbsz_idx)) + + if banned_tokens: + banned_tokens = torch.LongTensor(banned_tokens) + lprobs.index_put_( + tuple(banned_tokens.t()), + lprobs.new_tensor([-math.inf] * len(banned_tokens))) + + return lprobs @torch.no_grad() def _generate(self, @@ -459,7 +501,13 @@ def _generate(self, bsz = input_size[0] src_len = input_size[1] beam_size = self.beam_size - self.no_repeat_ngram_op = NGramRepeatBlock() + cuda_ngram_op_import = True + try: + #pylint: disable=import-outside-toplevel + from fastseq.ops.ngram_repeat_block import NGramRepeatBlock + self.no_repeat_ngram_op = NGramRepeatBlock() + except: + cuda_ngram_op_import = False if self.match_source_len: max_len = src_lengths.max().item() @@ -524,49 +572,6 @@ def is_finished(sent, step, unfin_idx): return True return False - def apply_no_repeat_ngram_cpu(self, tokens,lprobs, bsz,step, - beam_size, no_repeat_ngram_size): - """ Fairseq implementation of blocking - repeated ngrams - """ - banned_list = [[] for bbsz_idx in range(bsz * beam_size)] - cpu_tokens = tokens.cpu()[:, :step + 1].numpy() - check_start_pos = step + 2 - no_repeat_ngram_size - for bbsz_idx in range(bsz * beam_size): - for i in range(check_start_pos): - is_banned = True - for k in range(no_repeat_ngram_size - 1): - if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[ - bbsz_idx, check_start_pos + k]: - is_banned = False - break - if is_banned: - banned_list[bbsz_idx].append( - cpu_tokens[bbsz_idx, - i + no_repeat_ngram_size - 1]) - - def calculate_banned_tokens(bbsz_idx): - """before decoding the next token, prevent decoding - of ngrams that have already appeared - """ - banned_tokens_per_sample = [ - (bbsz_idx, t) for t in banned_list[bbsz_idx] - ] - return banned_tokens_per_sample - - banned_tokens = [] - if step + 2 - no_repeat_ngram_size >= 0: - for bbsz_idx in range(bsz * beam_size): - banned_tokens.extend(calculate_banned_tokens(bbsz_idx)) - - if banned_tokens: - banned_tokens = torch.LongTensor(banned_tokens) - lprobs.index_put_( - tuple(banned_tokens.t()), - lprobs.new_tensor([-math.inf] * len(banned_tokens))) - - return lprobs - def finalize_hypos(step, bbsz_idx, eos_scores): """ Finalize the given hypotheses at this step, while keeping the total @@ -731,12 +736,12 @@ def replicate_first_beam(tensor, mask): if self.no_repeat_ngram_size > 0: #Applying Cuda Op for NGram repeat Blocking - if (tokens.is_cuda and lprobs.is_cuda): + if (tokens.is_cuda and lprobs.is_cuda and cuda_ngram_op_import): lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step, beam_size, self.no_repeat_ngram_size) else: - lprobs = apply_no_repeat_ngram_cpu(tokens, lprobs, bsz, - step, beam_size, self.ngram_repeat_block_size) + lprobs = self.apply_no_repeat_ngram_cpu(tokens, lprobs, bsz, + step, beam_size, self.no_repeat_ngram_size) cand_scores, cand_indices, cand_beams = self.search.step( step, diff --git a/fastseq/optimizer/transformers/beam_search_optimizer.py b/fastseq/optimizer/transformers/beam_search_optimizer.py index df8061bc..4d4ed0b7 100644 --- a/fastseq/optimizer/transformers/beam_search_optimizer.py +++ b/fastseq/optimizer/transformers/beam_search_optimizer.py @@ -17,7 +17,6 @@ from transformers.modeling_bart import BartForConditionalGeneration from transformers.modeling_t5 import T5ForConditionalGeneration -from fastseq.ops.ngram_repeat_block import NGramRepeatBlock from fastseq.logging import get_logger from fastseq.utils.api_decorator import replace @@ -650,7 +649,8 @@ def _update_scores(banned_tokens): cpu_input_ids = input_ids.cpu() if no_repeat_ngram_size > 0: #custom op for Ngram repeat blocking - if (input_ids.is_cuda and scores.is_cuda): + if (input_ids.is_cuda and scores.is_cuda and + self.cuda_ngram_op_import): scores = self.no_repeat_ngram_op(input_ids,scores.float(), batch_size, cur_len-1, num_beams, no_repeat_ngram_size) else: @@ -725,7 +725,13 @@ def _generate_beam_search( done = [False for _ in range(batch_size)] #NGram Repeat block Op - self.no_repeat_ngram_op = NGramRepeatBlock()#.to('cuda', torch.float32) + self.cuda_ngram_op_import = True + try: + #pylint: disable=import-outside-toplevel + from fastseq.ops.ngram_repeat_block import NGramRepeatBlock + self.no_repeat_ngram_op = NGramRepeatBlock() + except: + self.cuda_ngram_op_import = False while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(