-
Couldn't load subscription status.
- Fork 2.3k
Description
Current handling of response_masks inside batch_forward_pass function does not take padding into consideration which results with shape unmatch during masking. I think response tokens should not be concatenated with a torch.zeros(query_length) and masking operation should be done without slicing.
An example with batch size of 2:
- First sample in the batch has a query size of 10 and response size of 9 (response mask has also size of 9).
- Second sample in the batch has a query size of 10 and response size of 5(response mask has also size of 5).
- With the concatenation,
response_mask_batch[1]has the size of 15. startwill be 14 for second sample(due to the padding) andendwill be 19.
Hence,
The operation
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
will yield a RuntimeError: The size of tensor a (5) must match the size of tensor b (1) at non-singleton dimension 0
as response_mask_batch[1][14:19] is same as response_mask_batch[1][14:15] which has length of 1.
Removing the concatenation of the response mask and removing the slicing from the response mask since response mask already has the length of end - start + 1, which is equal to length of masks[j, start:end] should be a better approach.