Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 3300e3b

Browse files
authored
Specialized mask kernel for to_mask(dim=2) (#466)
1 parent 6161ad1 commit 3300e3b

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

nestedtensor/csrc/cuda/mha.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ at::Tensor bt_min_mha(
4848
// auto start = std::chrono::system_clock::now();
4949
auto options =
5050
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
51-
at::Tensor input_mask = to_mask(query, 2);
52-
input_mask = input_mask.to(options);
53-
int64_t batch_size = input_mask.size(0);
54-
int64_t seq_len = input_mask.size(1);
5551
int64_t embedding_dim = head_dim * num_heads; //*(opt_sizes[2]);
5652
int64_t head_num = num_heads;
5753
int64_t size_per_head = embedding_dim / head_num;
@@ -65,6 +61,8 @@ at::Tensor bt_min_mha(
6561
at::Tensor query_buf = packed_padded_chunks[0];
6662
at::Tensor key_buf = packed_padded_chunks[1];
6763
at::Tensor val_buf = packed_padded_chunks[2];
64+
int64_t batch_size = query_buf.size(0);
65+
int64_t seq_len = query_buf.size(1);
6866

6967
query_buf = query_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
7068
key_buf = key_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
@@ -75,6 +73,8 @@ at::Tensor bt_min_mha(
7573

7674
auto mask_options =
7775
torch::TensorOptions().dtype(query.dtype()).device(torch::kCUDA);
76+
at::Tensor input_mask = to_mask(query, 2);
77+
input_mask = input_mask.to(options);
7878
at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(mask_options);
7979
attr_mask = attr_mask * attr_mask.transpose(2, 3);
8080

nestedtensor/csrc/masking.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,22 @@ Tensor to_mask(
433433
for (int64_t i = 1; i < *mask_dim; i++) {
434434
max_size.push_back(tmp_max_size[i - 1]);
435435
}
436+
if (*mask_dim == 2 && get_dim(nt) == 3) {
437+
auto nt_size = get_efficient_nested_size(nt);
438+
auto esizes = nt_size.sizes();
439+
auto options = torch::TensorOptions().dtype(torch::kByte);
440+
auto result = torch::zeros({*opt_sizes[0], tmp_max_size[0]},
441+
options);
442+
uint8_t* result_data = result.data_ptr<uint8_t>();
443+
int64_t* esizes_ptr = esizes.data_ptr<int64_t>();
444+
for (int64_t i = 0; i < esizes.size(0); i++) {
445+
int64_t length = esizes_ptr[i * esizes.size(1)];
446+
for (int64_t j = 0; j < length; j++) {
447+
result_data[i * result.size(1) + j] = 1;
448+
}
449+
}
450+
return result;
451+
}
436452
return _create_nt_mask(get_efficient_nested_size(nt), max_size);
437453
}
438454
max_size = get_max_size(nt);
@@ -525,13 +541,13 @@ Tensor _collapse_two_dims_3(Tensor input, int64_t dim1, int64_t dim2) {
525541
auto input_esizes = get_efficient_nested_size(input);
526542
Tensor nt_sizes = input_esizes.sizes();
527543

528-
Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1).contiguous();
529-
Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1).contiguous();
544+
Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1);
545+
Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1);
530546

531547
Tensor new_nt_sizes;
532548
if (dim1 == 1) {
533549
Tensor collapsed_sizes = sizes_dim1 * sizes_dim2;
534-
new_nt_sizes = collapsed_sizes;
550+
new_nt_sizes = collapsed_sizes.contiguous();
535551
}
536552
auto new_esizes = torch::nested_tensor::EfficientSizeNode(input_esizes.structure(), new_nt_sizes);
537553
Tensor result = wrap_buffer(get_buffer(input), new_esizes);

nestedtensor/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+986cfd5'
2-
git_version = '986cfd55e2d0c8139a5e19cfca6efc740ea7ad23'
1+
__version__ = '0.1.4+5b45731'
2+
git_version = '5b457313bfb6578b43d76282b321657bf85ee1b3'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

0 commit comments

Comments
 (0)