diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e7ce8a661fc8..956b7bfaf522 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -81,14 +81,11 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states + return (self.weight * hidden_states).to(input_dtype) class LlamaRotaryEmbedding(torch.nn.Module): diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index a88ba62056d4..819d0d1dacab 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -91,14 +91,11 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states + return (self.weight * hidden_states).to(input_dtype) # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama