Skip to content

Commit 99f2c94

Browse files
authored
don't cast the trainable lora layers to half precision (#1644)
* don't cast the trainable lora layers to half precision * quality
1 parent 6401d08 commit 99f2c94

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

trl/trainer/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,12 +665,8 @@ def neftune_post_forward_hook(module, input, output):
665665

666666

667667
def peft_module_casting_to_bf16(model):
668-
from peft.tuners.tuners_utils import BaseTunerLayer
669-
670668
for name, module in model.named_modules():
671-
if isinstance(module, BaseTunerLayer):
672-
module = module.to(torch.bfloat16)
673-
elif isinstance(module, torch.nn.LayerNorm) or "norm" in name:
669+
if isinstance(module, torch.nn.LayerNorm) or "norm" in name:
674670
module = module.to(torch.float32)
675671
elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
676672
if hasattr(module, "weight"):

0 commit comments

Comments
 (0)