Commit b8fb6ea
Improve bmm() performance on CPU when input tensor is non-contiguous (pytorch#19338)
Summary:
This PR aims to improve Transformer performance on CPU, `bmm()` is one of the major bottlenecks now.
Current logic of `bmm()` on CPU only uses MKL batch gemm when the inputs `A` and `B` are contiguous or transposed. So when `A` or `B` is a slice of a larger tensor, it falls to a slower path.
`A` and `B` are both 3D tensors. MKL is able to handle the batch matrix multiplication on occasion that `A.stride(1) == 1 || A.stride(2) == 1` and `B.stride(1) == || B.stride(2) == 1`.
From [fairseq](https://github.com/pytorch/fairseq) implementation of Transformer, multi-head attention has two places to call bmm(), [here](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py#L167) and [here](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py#L197), `q`, `k`, `v` are all slices from larger tensor. So the `bmm()` falls to slow path at the moment.
Results on Xeon 6148 (20*2 cores 2.5GHz) indicate this PR improves Transformer training performance by **48%** (seconds per iteration reduced from **5.48** to **3.70**), the inference performance should also be boosted.
Before:
```
| epoch 001: 0%| | 27/25337 [02:27<38:31:26, 5.48s/it, loss=16.871, nll_loss=16.862, ppl=119099.70, wps=865, ups=0, wpb=4715.778, bsz=129.481, num_updates=27, lr=4.05e-06, gnorm=9.133,
```
After:
```
| epoch 001: 0%| | 97/25337 [05:58<25:55:49, 3.70s/it, loss=14.736, nll_loss=14.571, ppl=24339.38, wps=1280, ups=0, wpb=4735.299, bsz=131.134, num_updates=97, lr=1.455e-05, gnorm=3.908,
```
Pull Request resolved: pytorch#19338
Differential Revision: D14986346
Pulled By: soumith
fbshipit-source-id: 827106245af908b8a4fda69ed0288d322b028f081 parent 12d6f79 commit b8fb6ea
File tree
2 files changed
+12
-12
lines changed- aten/src/ATen/native
- mkl
2 files changed
+12
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
297 | 297 | | |
298 | 298 | | |
299 | 299 | | |
300 | | - | |
301 | | - | |
| 300 | + | |
| 301 | + | |
302 | 302 | | |
303 | 303 | | |
304 | 304 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
| 37 | + | |
| 38 | + | |
41 | 39 | | |
42 | 40 | | |
43 | 41 | | |
44 | 42 | | |
45 | 43 | | |
46 | 44 | | |
47 | 45 | | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
| 46 | + | |
| 47 | + | |
52 | 48 | | |
53 | 49 | | |
54 | 50 | | |
| |||
57 | 53 | | |
58 | 54 | | |
59 | 55 | | |
60 | | - | |
| 56 | + | |
61 | 57 | | |
62 | 58 | | |
63 | 59 | | |
| |||
69 | 65 | | |
70 | 66 | | |
71 | 67 | | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| |||
78 | 78 | | |
79 | 79 | | |
80 | 80 | | |
81 | | - | |
| 81 | + | |
82 | 82 | | |
83 | 83 | | |
84 | 84 | | |
| |||
0 commit comments