Skip to content

The hidden states in LlamaFlashAttention2 are cast in fp16 unexpectedly #26451

@hiyouga

Description

@hiyouga

System Info

  • transformers version: 4.33.1
  • Platform: Linux-5.4.0-147-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.17.1
  • Safetensors version: 0.3.3
  • Accelerate version: 0.23.0
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: A100 40GB
  • Using distributed or parallel set-up in script?: No

Who can help?

@younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

As we discussed in this thread: #25598 (comment)

The hidden states may be cast in float16 even if we are using bf16 mixed precision training.

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)

It may be difficult to figure out the correct data type if the model is loaded in 4/8-bit mode.

Expected behavior

The hidden states should be cast in Bfloat16 in bf16 training.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions