Skip to content

Conversation

@fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Feb 23, 2024

Adding self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) in @ArthurZucker's rewrite of llama & gemma adds a 500 MB overhead when serializing to ONNX/TorchScript IR/PyTorch ExportedProgram (from https://pytorch.org/docs/stable/export.html), for max_position_embeddings=8182.

Essentially, these IRs do not support non-persistent buffers. One quick fix is to use torch.bool instead of torch.int64, but bool is still 8-bits in pytorch (pytorch/pytorch#41571) & the overhead is still ~70 MB.

The lowered overhead is acceptable to me, but this won't scale to 10M context length.

@fxmarty fxmarty requested a review from ArthurZucker February 23, 2024 11:13
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@xenova
Copy link
Contributor

xenova commented Feb 23, 2024

Can confirm this shrunk my tiny-random-GemmeForCausalLM ONNX export from ~500MB to ~70MB (PR). Ideally, there would be no overhead, but I think this helps a ton for now!

@fxmarty fxmarty requested a review from amyeroberts February 26, 2024 10:28
Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for digging into this and fixing!

Happy to merge once slow model tests for gemma and llama are confirmed to be passing.

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 26, 2024

@amyeroberts Running on A100, I can confirm that no additional tests are failing with RUN_SLOW=1 CUDA_VISIBLE_DEVICES=2 pytest tests/ -k "llama or gemma" -s -vvvvv compared to running on main.

@fxmarty fxmarty merged commit 24d59c7 into huggingface:main Feb 26, 2024
ArthurZucker pushed a commit that referenced this pull request Feb 28, 2024
…ask buffer (#29241)

use torch.bool instead of torch.int64
ArthurZucker pushed a commit that referenced this pull request Mar 1, 2024
…ask buffer (#29241)

use torch.bool instead of torch.int64
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants