Skip to content

Commit b8fb6ea

Browse files
mingfeimafacebook-github-bot
authored andcommitted
Improve bmm() performance on CPU when input tensor is non-contiguous (pytorch#19338)
Summary: This PR aims to improve Transformer performance on CPU, `bmm()` is one of the major bottlenecks now. Current logic of `bmm()` on CPU only uses MKL batch gemm when the inputs `A` and `B` are contiguous or transposed. So when `A` or `B` is a slice of a larger tensor, it falls to a slower path. `A` and `B` are both 3D tensors. MKL is able to handle the batch matrix multiplication on occasion that `A.stride(1) == 1 || A.stride(2) == 1` and `B.stride(1) == || B.stride(2) == 1`. From [fairseq](https://github.com/pytorch/fairseq) implementation of Transformer, multi-head attention has two places to call bmm(), [here](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py#L167) and [here](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py#L197), `q`, `k`, `v` are all slices from larger tensor. So the `bmm()` falls to slow path at the moment. Results on Xeon 6148 (20*2 cores 2.5GHz) indicate this PR improves Transformer training performance by **48%** (seconds per iteration reduced from **5.48** to **3.70**), the inference performance should also be boosted. Before: ``` | epoch 001: 0%| | 27/25337 [02:27<38:31:26, 5.48s/it, loss=16.871, nll_loss=16.862, ppl=119099.70, wps=865, ups=0, wpb=4715.778, bsz=129.481, num_updates=27, lr=4.05e-06, gnorm=9.133, ``` After: ``` | epoch 001: 0%| | 97/25337 [05:58<25:55:49, 3.70s/it, loss=14.736, nll_loss=14.571, ppl=24339.38, wps=1280, ups=0, wpb=4735.299, bsz=131.134, num_updates=97, lr=1.455e-05, gnorm=3.908, ``` Pull Request resolved: pytorch#19338 Differential Revision: D14986346 Pulled By: soumith fbshipit-source-id: 827106245af908b8a4fda69ed0288d322b028f08
1 parent 12d6f79 commit b8fb6ea

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&
297297
}
298298

299299
auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
300-
return (t.stride(2) == 1 && t.stride(1) == t.size(2))
301-
|| (t.stride(1) == 1 && t.stride(2) == t.size(1));
300+
return (t.stride(2) == 1 && t.stride(1) >= t.size(2))
301+
|| (t.stride(1) == 1 && t.stride(2) >= t.size(1));
302302
};
303303

304304
if (contraction_size * res_rows * res_cols < 400) {

aten/src/ATen/native/mkl/LinearAlgebra.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,17 @@ namespace at { namespace native {
3434

3535
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
3636
const int batch_size, const int M, const int N, const int K, const float alpha,
37-
const float** A, const float** B, const float beta, float** C) {
38-
const int lda = (trans_A == CblasNoTrans) ? K : M;
39-
const int ldb = (trans_B == CblasNoTrans) ? N : K;
40-
const int ldc = N;
37+
const float** A, const int lda, const float** B, const int ldb, const float beta,
38+
float** C, const int ldc) {
4139

4240
cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
4341
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
4442
}
4543

4644
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
4745
const int batch_size, const int M, const int N, const int K, const double alpha,
48-
const double** A, const double** B, const double beta, double** C) {
49-
const int lda = (trans_A == CblasNoTrans) ? K : M;
50-
const int ldb = (trans_B == CblasNoTrans) ? N : K;
51-
const int ldc = N;
46+
const double** A, const int lda, const double** B, const int ldb, const double beta,
47+
double** C, const int ldc) {
5248

5349
cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
5450
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
@@ -57,7 +53,7 @@ static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANS
5753
template <typename scalar_t>
5854
static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
5955
auto is_transposed = [&](const Tensor& t) {
60-
return t.stride(0) == 1 && t.stride(1) == t.size(0);
56+
return t.stride(0) == 1 && t.stride(1) >= t.size(0);
6157
};
6258
const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
6359
const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;
@@ -69,6 +65,10 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c
6965
scalar_t alpha = alpha_.to<scalar_t>();
7066
scalar_t beta = beta_.to<scalar_t>();
7167

68+
const int lda = is_transposed(mat1[0]) ? mat1[0].stride(1) : mat1[0].stride(0);
69+
const int ldb = is_transposed(mat2[0]) ? mat2[0].stride(1) : mat2[0].stride(0);
70+
const int ldc = res[0].stride(0);
71+
7272
std::vector<const scalar_t*> A(batch_size);
7373
std::vector<const scalar_t*> B(batch_size);
7474
std::vector<scalar_t*> C(batch_size);
@@ -78,7 +78,7 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c
7878
C[batch] = res[batch].data<scalar_t>();
7979
}
8080

81-
gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), B.data(), beta, C.data());
81+
gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc);
8282
}
8383

8484
Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {

0 commit comments

Comments
 (0)