diff --git a/fastseq/optimizer/fairseq/beam_search_optimizer_v1.py b/fastseq/optimizer/fairseq/beam_search_optimizer_v1.py index 83adba4e..76223959 100644 --- a/fastseq/optimizer/fairseq/beam_search_optimizer_v1.py +++ b/fastseq/optimizer/fairseq/beam_search_optimizer_v1.py @@ -195,8 +195,6 @@ def forward( ], dim=1) - q = q.contiguous().view(tgt_len, bsz * self.num_heads, - self.head_dim).transpose(0, 1) if k is not None: kv_bsz = k.size(1) k = k.contiguous().view(-1, kv_bsz * self.num_heads, @@ -283,14 +281,16 @@ def forward( dim=1) if self.encoder_decoder_attention and bsz != kv_bsz: - attn_weights = torch.einsum( - 'bxhtd,bhsd->bxhts', - q.view(kv_bsz, -1, self.num_heads, - *q.size()[1:]), - k.view(kv_bsz, self.num_heads, - *k.size()[1:])) - attn_weights = attn_weights.reshape(-1, *attn_weights.size()[-2:]) + #query size (1, B*b*h, c_embed) => (B*h, b, c) + q = q.view(tgt_len,-1, self.beam_size, self.num_heads, + self.head_dim).permute(1,3,2,0,4).contiguous( + ).view(kv_bsz*self.num_heads, self.beam_size, self.head_dim) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(-1, tgt_len, + *attn_weights.size()[-1:]) else: + q = q.contiguous().view(tgt_len, bsz * self.num_heads, + self.head_dim).transpose(0, 1) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask( attn_weights, tgt_len, src_len, bsz) @@ -306,15 +306,14 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, - src_len) + #attn_weights size (B*b*h/B*h*b, 1, S) => (B,h*b, S) + attn_weights = attn_weights.view(kv_bsz, -1, src_len) if not self.tpu: - attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, - tgt_len, src_len) attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to( + key_padding_mask.unsqueeze(1).to( torch.bool), float("-inf")) else: + #Not supported attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float('-inf')) @@ -323,6 +322,11 @@ def forward( src_len) if before_softmax: + #attn_weights size (B*h*b, 1, S) => (B*b*h, 1, S) + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = attn_weights.view(kv_bsz,self.num_heads, + self.beam_size, tgt_len, src_len).permute(0,2,1,3,4 + ).contiguous().view(-1, tgt_len, src_len) return attn_weights, v attn_weights_float = utils.softmax(attn_weights, @@ -335,18 +339,26 @@ def forward( assert v is not None if self.encoder_decoder_attention and bsz != kv_bsz: - attn = torch.einsum( - 'bxhts,bhsd->bxhtd', - attn_probs.view(kv_bsz, -1, self.num_heads, - *attn_probs.size()[1:]), - v.view(kv_bsz, self.num_heads, - *v.size()[1:])) - attn = attn.reshape(-1, *attn.size()[-2:]) + #attn_probs size (B*h*b, 1, S) => (B*h, b, S) + attn_probs = attn_probs.view(-1, self.beam_size, src_len) + attn = torch.bmm(attn_probs, v) + + if self.encoder_decoder_attention and bsz != kv_bsz: + assert list( + attn.size()) == [kv_bsz * self.num_heads, + self.beam_size, self.head_dim] else: - attn = torch.bmm(attn_probs, v) - assert list( - attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] - if (self.onnx_trace and attn.size(1) == 1): + assert list( + attn.size()) == [bsz * self.num_heads, + tgt_len, self.head_dim] + + if self.encoder_decoder_attention and bsz != kv_bsz: + #attn size (B*h, b, c) => (1, B*b, c_embed) + attn = attn.view(kv_bsz, self.num_heads, + self.beam_size,self.head_dim).permute(0, 2, 1, 3 + ).contiguous().view(tgt_len, bsz, embed_dim) + #.view(tgt_len, -1, self.head_dim + elif (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) @@ -356,9 +368,15 @@ def forward( attn = self.out_proj(attn) if need_weights: - attn_weights = attn_weights_float.view(bsz, self.num_heads, - tgt_len, - src_len).transpose(1, 0) + #attn_weights size (B*h*b,1, S) => (h,B*b,1,S) + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = attn_weights_float.view(kv_bsz, self.num_heads, + self.beam_size, tgt_len, src_len).permute(1,0,2,3,4).contiguous( + ).view(self.num_heads, bsz, tgt_len, src_len) + else: + attn_weights = attn_weights_float.view(bsz, self.num_heads, + tgt_len, + src_len).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0)