Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def _get_weight_dtype(self) -> torch.dtype:
"""
Get the dtype for initializing LoRA parameters.
Handles both regular and quantized layers.
"""
base_layer = self.get_base_layer()

# For quantized layers (BitsAndBytes) - check compute_dtype first
if hasattr(base_layer, 'compute_dtype'):
return base_layer.compute_dtype

# For regular layers, use weight dtype
if hasattr(base_layer, 'weight'):
return base_layer.weight.dtype

# Fallback to float32
return torch.float32

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Expand Down Expand Up @@ -193,8 +211,9 @@ def update_layer(
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))

# Actual trainable parameters
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias)
weight_dtype = self._get_weight_dtype()
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False, dtype=weight_dtype)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias, dtype=weight_dtype)
self.lora_bias[adapter_name] = lora_bias

if use_rslora:
Expand Down