21
21
LinearScalingRotaryEmbeddingWithLora ,
22
22
LoRAMapping )
23
23
from vllm .lora .lora import LoRALayerWeights , PackedLoRALayerWeights
24
+ from vllm .lora .peft_helper import PEFTHelper
24
25
from vllm .lora .punica_wrapper import get_punica_wrapper
25
26
from vllm .lora .utils import (from_layer , from_layer_logits_processor ,
26
27
is_regex_target_modules ,
@@ -104,14 +105,12 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
104
105
def from_lora_tensors (
105
106
cls ,
106
107
lora_model_id : int ,
107
- rank : int ,
108
- lora_alpha : int ,
109
108
tensors : Dict [str , torch .Tensor ],
109
+ peft_helper : PEFTHelper ,
110
110
device : str = "cuda" ,
111
111
dtype : Optional [torch .dtype ] = None ,
112
112
embeddings : Optional [Dict [str , torch .Tensor ]] = None ,
113
113
target_embedding_padding : Optional [int ] = None ,
114
- scaling_factor : Optional [float ] = None ,
115
114
embedding_modules : Optional [Dict [str , str ]] = None ,
116
115
embedding_padding_modules : Optional [List [str ]] = None ,
117
116
) -> "LoRAModel" :
@@ -135,10 +134,9 @@ def from_lora_tensors(
135
134
if pin_memory :
136
135
lora_embeddings_tensor = (
137
136
lora_embeddings_tensor .pin_memory ())
138
- loras [module_name ] = LoRALayerWeights (module_name , rank ,
139
- lora_alpha , None , None ,
140
- None ,
141
- lora_embeddings_tensor )
137
+ loras [module_name ] = LoRALayerWeights .from_config (
138
+ module_name , peft_helper , lora_embeddings_tensor )
139
+
142
140
if is_bias :
143
141
loras [module_name ].bias = tensor .to (device = device ,
144
142
dtype = dtype ).t ()
@@ -170,7 +168,11 @@ def from_lora_tensors(
170
168
171
169
for lora in loras .values ():
172
170
lora .optimize ()
173
- return cls (lora_model_id , rank , loras , scaling_factor = scaling_factor )
171
+
172
+ return cls (lora_model_id ,
173
+ peft_helper .r ,
174
+ loras ,
175
+ scaling_factor = peft_helper .vllm_scaling_factor )
174
176
175
177
@classmethod
176
178
def from_local_checkpoint (
@@ -212,6 +214,9 @@ def from_local_checkpoint(
212
214
"new_embeddings.bin" )
213
215
with open (lora_config_path ) as f :
214
216
config = json .load (f )
217
+
218
+ config ["vllm_max_position_embeddings" ] = max_position_embeddings
219
+ peft_helper = PEFTHelper .from_dict (config )
215
220
if os .path .isfile (lora_tensor_path ):
216
221
tensors : Dict [str , torch .Tensor ] = {}
217
222
# Find unexpected modules.
@@ -242,7 +247,7 @@ def from_local_checkpoint(
242
247
# When a bin file is provided, we rely on config to find unexpected
243
248
# modules.
244
249
unexpected_modules = []
245
- target_modules = config [ " target_modules" ]
250
+ target_modules = peft_helper . target_modules
246
251
if not isinstance (target_modules , list ):
247
252
target_modules = [target_modules ]
248
253
for module in target_modules :
@@ -256,7 +261,7 @@ def from_local_checkpoint(
256
261
# https://github.com/vllm-project/vllm/pull/5909. But there's no
257
262
# other better mechanism.
258
263
if unexpected_modules and not is_regex_target_modules (
259
- config [ " target_modules" ] , expected_lora_modules ):
264
+ peft_helper . target_modules , expected_lora_modules ):
260
265
raise ValueError (
261
266
f"While loading { lora_dir } , expected"
262
267
f" target modules in { expected_lora_modules } "
@@ -274,30 +279,17 @@ def from_local_checkpoint(
274
279
embeddings = torch .load (new_embeddings_bin_file_path ,
275
280
map_location = device )
276
281
277
- rank = config ["r" ]
278
- lora_alpha = config ["lora_alpha" ]
279
- context_length = config .get ("context_length" , None )
280
- scaling_factor = None
281
- if context_length :
282
- if max_position_embeddings is None :
283
- max_position_embeddings = context_length
284
- scaling_factor = float (
285
- math .ceil (context_length / max_position_embeddings ))
286
-
287
282
return cls .from_lora_tensors (
288
283
lora_model_id = get_lora_id ()
289
284
if lora_model_id is None else lora_model_id ,
290
- rank = rank ,
291
- lora_alpha = lora_alpha ,
292
285
tensors = tensors ,
286
+ peft_helper = peft_helper ,
293
287
device = device ,
294
288
dtype = dtype ,
295
289
embeddings = embeddings ,
296
290
target_embedding_padding = target_embedding_padding ,
297
- scaling_factor = scaling_factor ,
298
291
embedding_modules = embedding_modules ,
299
- embedding_padding_modules = embedding_padding_modules ,
300
- )
292
+ embedding_padding_modules = embedding_padding_modules )
301
293
302
294
303
295
class LoRAModelManager (AdapterModelManager ):
0 commit comments