-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Description
System Info
transformersversion: 4.33.1- Platform: Linux-5.4.0-147-generic-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.17.1
- Safetensors version: 0.3.3
- Accelerate version: 0.23.0
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: A100 40GB
- Using distributed or parallel set-up in script?: No
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
As we discussed in this thread: #25598 (comment)
The hidden states may be cast in float16 even if we are using bf16 mixed precision training.
transformers/src/transformers/models/llama/modeling_llama.py
Lines 485 to 487 in 78dd120
| query_states = query_states.to(torch.float16) | |
| key_states = key_states.to(torch.float16) | |
| value_states = value_states.to(torch.float16) |
It may be difficult to figure out the correct data type if the model is loaded in 4/8-bit mode.
Expected behavior
The hidden states should be cast in Bfloat16 in bf16 training.
Metadata
Metadata
Assignees
Labels
No labels