File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change
1
+ --- torch/nn/functional.py 2021-10-01 16:53:42.827338664 -0700
2
+ +++ functional.py 2021-10-01 16:53:34.639338618 -0700
3
+ @@ -4975,7 +4975,7 @@
4
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
5
+ if isinstance(embed_dim, torch.Tensor):
6
+ # embed_dim can be a tensor when JIT tracing
7
+ - head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
8
+ + head_dim = int(embed_dim.div(num_heads, rounding_mode='trunc'))
9
+ else:
10
+ head_dim = embed_dim // num_heads
11
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
12
+ @@ -5044,6 +5044,7 @@
13
+ #
14
+ # reshape q, k, v for multihead attention and make em batch first
15
+ #
16
+ + bsz = int(bsz)
17
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
18
+ if static_k is None:
19
+ k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
You can’t perform that action at this time.
0 commit comments