4
4
from dataclasses import MISSING , dataclass , field , fields
5
5
from typing import Literal , Optional , Union
6
6
7
+ from vllm .utils import print_info_once
8
+
7
9
8
10
@dataclass
9
11
class PEFTHelper :
@@ -14,21 +16,22 @@ class PEFTHelper:
14
16
15
17
bias : Literal ["none" , "all" , "lora_only" ] = field (default = "none" )
16
18
modules_to_save : Optional [list [str ]] = field (default = None )
19
+ # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
17
20
use_rslora : bool = field (default = False )
21
+ # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
18
22
use_dora : bool = field (default = False )
19
- # long lora field
23
+ # long context lora field
20
24
context_length : int = field (default = 0 )
21
25
# Extra vllm field, start with 'vllm_' to avoid conflict
26
+ vllm_lora_scaling_factor : float = field (default = 1.0 )
22
27
vllm_max_position_embeddings : Optional [int ] = field (default = False )
23
- vllm_scaling_factor : Optional [float ] = field (default = None )
28
+ vllm_long_context_scaling_factor : Optional [float ] = field (default = None )
24
29
25
30
def _validate_features (self ):
26
31
error_msg = []
27
32
28
33
if self .modules_to_save :
29
34
error_msg .append ("vLLM only supports modules_to_save being None." )
30
- if self .use_rslora :
31
- error_msg .append ("vLLM does not yet support RSLoRA." )
32
35
33
36
if self .use_dora :
34
37
error_msg .append ("vLLM does not yet support DoRA." )
@@ -38,10 +41,15 @@ def _validate_features(self):
38
41
39
42
def __post_init__ (self ):
40
43
self ._validate_features ()
44
+ if self .use_rslora :
45
+ print_info_once ("Loading LoRA weights trained with rsLoRA." )
46
+ self .vllm_lora_scaling_factor = self .lora_alpha / math .sqrt (self .r )
47
+ else :
48
+ self .vllm_lora_scaling_factor = self .lora_alpha / self .r
41
49
if self .context_length :
42
50
if self .vllm_max_position_embeddings is None :
43
51
self .vllm_max_position_embeddings = self .context_length
44
- self .vllm_scaling_factor = float (
52
+ self .vllm_long_context_scaling_factor = float (
45
53
math .ceil (self .context_length /
46
54
self .vllm_max_position_embeddings ))
47
55
0 commit comments