Use torch.bool instead of torch.int64 for non-persistant causal mask buffer
#29241
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adding
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)in @ArthurZucker's rewrite of llama & gemma adds a 500 MB overhead when serializing to ONNX/TorchScript IR/PyTorch ExportedProgram (from https://pytorch.org/docs/stable/export.html), formax_position_embeddings=8182.Essentially, these IRs do not support non-persistent buffers. One quick fix is to use torch.bool instead of torch.int64, but bool is still 8-bits in pytorch (pytorch/pytorch#41571) & the overhead is still ~70 MB.
The lowered overhead is acceptable to me, but this won't scale to 10M context length.