-
Notifications
You must be signed in to change notification settings - Fork 37
Cuda op for ngram repeat blocking #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
54efa6f
Cuda op for ngram repeat blocking
062a9c7
clean up
99ecab6
Unit test for cuda op
c8b451d
unit test updated, minor updates in cpp/cu code
25d06f9
updating with code clean PR
b36a314
Rebased on new codebase , updated all benchmarks
c9fc88c
Merge branch 'main' into cuda_op_ngram_block
NickNickGo ea3b370
Update README.md
NickNickGo bddb09b
Update README.md
NickNickGo 532d12b
Update README.md
NickNickGo ada0682
minor change in kernel
ee29e85
Merge branch 'cuda_op_ngram_block' of https://github.com/NickNickGo/f…
a215fbf
Merge branch 'main' into cuda_op_ngram_block
NickNickGo 6a6877e
changing install order
5b65505
Merge branch 'cuda_op_ngram_block' of https://github.com/NickNickGo/f…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| /* | ||
| Copyright (c) Microsoft Corporation. | ||
| Licensed under the MIT License. | ||
| */ | ||
|
|
||
| #include <torch/extension.h> | ||
| #include <vector> | ||
|
|
||
| /* | ||
| CPP Binding for CUDA OP | ||
| */ | ||
|
|
||
| // CUDA forward declarations | ||
| torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, | ||
| torch::Tensor lprobs, int bsz, | ||
| int step, int beam_size, | ||
| int no_repeat_ngram_size); | ||
|
|
||
| #define CHECK_CUDA(x) \ | ||
| TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") | ||
| #define CHECK_CONTIGUOUS(x) \ | ||
| TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
| #define CHECK_INPUT(x) \ | ||
| CHECK_CUDA(x); \ | ||
| CHECK_CONTIGUOUS(x) | ||
|
|
||
| // Input check and call to CUDA OP | ||
| // Backward method not required | ||
| torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, | ||
| torch::Tensor lprobs, int bsz, | ||
| int step, int beam_size, | ||
| int no_repeat_ngram_size) { | ||
| CHECK_INPUT(tokens); | ||
| CHECK_INPUT(lprobs); | ||
| assert(bsz > 0); | ||
| assert(step >= 0); | ||
| assert(beam_size > 0); | ||
| assert(no_repeat_ngram_size > 0); | ||
|
|
||
| return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, | ||
| no_repeat_ngram_size); | ||
| } | ||
|
|
||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
| m.def("forward", &ngram_repeat_block_forward, | ||
| "No Repeat Ngram Block forward (CUDA)"); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| /* | ||
| Copyright (c) Microsoft Corporation. | ||
| Licensed under the MIT License. | ||
| */ | ||
|
|
||
| /* | ||
| Kernel implementation for blocking repeated n-grams. | ||
| */ | ||
|
|
||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
| #include <math.h> | ||
| #include <torch/extension.h> | ||
| #include <vector> | ||
|
|
||
| // Ban repeated ngrams of length = 'no_repeat_ngram_size' | ||
| __global__ void banRepeatedTokens(long* __restrict__ tokens, | ||
| float* __restrict__ lprobs, | ||
| int max_predict_len, int vocab_size, | ||
| int no_repeat_ngram_size) { | ||
| auto row = blockIdx.x; | ||
| auto col = threadIdx.x; | ||
| auto start = row * (max_predict_len) + col; | ||
| // Each thread compares ngram starting from | ||
| // thread index with final ngram starting from | ||
| // step - no_repeat_ngram_size +2 | ||
| auto check_start_pos = blockDim.x; | ||
| auto lprob_start = row * vocab_size; | ||
| bool is_banned = true; | ||
| extern __shared__ long tokens_shm[]; | ||
| tokens_shm[col] = tokens[start]; | ||
| if (col == blockDim.x - 1) { | ||
| for (int i=1; i<no_repeat_ngram_size; i++){ | ||
| if (col+i < max_predict_len){ | ||
| tokens_shm[col + i] = tokens[start + i]; | ||
| } | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| for (int k = 0; k < no_repeat_ngram_size - 1; k++) { | ||
| if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) { | ||
| is_banned = false; | ||
| } | ||
| } | ||
| if (is_banned == true) { | ||
| auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; | ||
| lprobs[lprob_start + token_to_be_banned] = -INFINITY; | ||
| } | ||
| } | ||
|
|
||
| // Allocate blocks and threads based on | ||
| // batch size and sequence length and launch | ||
| // kernel | ||
| torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor tokens, | ||
| torch::Tensor lprobs, int bsz, | ||
| int step, int beam_size, | ||
| int no_repeat_ngram_size) { | ||
| int threads = step - no_repeat_ngram_size + 2; | ||
| if (threads <= 0) return lprobs; | ||
| int max_predict_len = tokens.size(1); | ||
| int vocab_size = lprobs.size(1); | ||
| auto token_ptr = tokens.data_ptr<long>(); | ||
| auto lprob_ptr = lprobs.data_ptr<float>(); | ||
| int blocks = bsz * beam_size; | ||
| int shared_mem_size = (step + 1) * sizeof(long); | ||
|
|
||
| // Launching N blocks where N is number of samples in a batch (beams*bsz) | ||
| // Launching T threads where T is number of previous ngrams in a sample | ||
| // Allocating shared mem per block for fastser access of input tokens since | ||
| // each token will be accessed N times to compare with current Ngram where | ||
| // N is Ngram size. | ||
| banRepeatedTokens<<<blocks, threads, shared_mem_size>>>( | ||
| token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); | ||
| return lprobs; | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| """ Wrapper for ngram_repeat_block cuda extension """ | ||
| from torch import nn | ||
| from torch.autograd import Function | ||
| import ngram_repeat_block_cuda | ||
|
|
||
| class NGramRepeatBlockFunction(Function): | ||
| """ | ||
| forward inputs to ngram_repeat_block cuda extension | ||
| backward method not needed. | ||
|
|
||
| """ | ||
| def forward(self, tokens, lprobs, bsz, | ||
| step, beam_size, no_repeat_ngram_size): | ||
feihugis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Args: | ||
| tokens(Tensor): Input tokens(Bsz*beam, seq_len) | ||
| lprobs(Tensor): likelihood probability | ||
| Expected to be updated in place.(Bsz*beam, vocab_size) | ||
| bsz(int): batch size | ||
| step(int): current step | ||
| beam_size(int): beam size | ||
| no_repeat_ngram_size(int): Ngram size | ||
| """ | ||
| outputs = ngram_repeat_block_cuda.forward(tokens, | ||
| lprobs, bsz, step, beam_size, no_repeat_ngram_size) | ||
| return outputs | ||
|
|
||
| def backward (*args): | ||
| raise NotImplementedError | ||
|
|
||
| class NGramRepeatBlock(nn.Module): | ||
| """ Wrapper class for calling ngram_repeat_block cuda extension """ | ||
| def __init__(self): | ||
| super(NGramRepeatBlock, self).__init__() | ||
|
|
||
| def reset_parameters(self): | ||
| pass | ||
|
|
||
| def forward(self, tokens, lprobs, bsz, | ||
| step, beam_size, no_repeat_ngram_size): | ||
| """ | ||
| Args: | ||
| tokens(Tensor): Input tokens(Bsz*beam, seq_len) | ||
| lprobs(Tensor): likelihood probability, | ||
| Expected to be updated in place.(Bsz*beam, vocab_size) | ||
| bsz(int): batch size | ||
| step(int): current step | ||
| beam_size(int): beam size | ||
| no_repeat_ngram_size(int): Ngram size | ||
| """ | ||
| assert tokens.size(0)== bsz*beam_size | ||
| assert lprobs.size(0)== bsz*beam_size | ||
|
|
||
| return NGramRepeatBlockFunction.apply(tokens, lprobs, | ||
| bsz, step, beam_size, no_repeat_ngram_size) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.