Skip to content

Commit b3101c6

Browse files
author
Anurag Dixit
committed
(//core): Added the patch file for DETR
Signed-off-by: Anurag Dixit <[email protected]>
1 parent dd7979a commit b3101c6

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

docker/mha.patch

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
--- torch/nn/functional.py 2021-10-01 16:53:42.827338664 -0700
2+
+++ functional.py 2021-10-01 16:53:34.639338618 -0700
3+
@@ -4975,7 +4975,7 @@
4+
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
5+
if isinstance(embed_dim, torch.Tensor):
6+
# embed_dim can be a tensor when JIT tracing
7+
- head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
8+
+ head_dim = int(embed_dim.div(num_heads, rounding_mode='trunc'))
9+
else:
10+
head_dim = embed_dim // num_heads
11+
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
12+
@@ -5044,6 +5044,7 @@
13+
#
14+
# reshape q, k, v for multihead attention and make em batch first
15+
#
16+
+ bsz = int(bsz)
17+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
18+
if static_k is None:
19+
k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

0 commit comments

Comments
 (0)