@@ -107,9 +107,11 @@ def __init__(self,
107
107
layer_idx : int ,
108
108
cache_config : Optional [CacheConfig ] = None ,
109
109
quant_config : Optional [QuantizationConfig ] = None ,
110
- prefix : str = "" ) -> None :
110
+ is_lora_enabled : Optional [bool ] = False ,
111
+ ** kwargs ) -> None :
111
112
super ().__init__ ()
112
113
self .config = config
114
+ self .is_lora_enabled = is_lora_enabled
113
115
self .mamba = MambaMixer (hidden_size = config .hidden_size ,
114
116
ssm_state_size = config .mamba_d_state ,
115
117
conv_kernel_size = config .mamba_d_conv ,
@@ -120,7 +122,9 @@ def __init__(self,
120
122
use_bias = config .mamba_proj_bias ,
121
123
use_rms_norm = True ,
122
124
rms_norm_eps = config .rms_norm_eps ,
123
- activation = config .hidden_act )
125
+ activation = config .hidden_act ,
126
+ is_lora_enabled = self .is_lora_enabled
127
+ )
124
128
125
129
num_experts = config .layers_num_experts [layer_idx ]
126
130
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@@ -156,14 +160,13 @@ def forward(
156
160
157
161
class JambaAttentionDecoderLayer (nn .Module ):
158
162
159
- def __init__ (
160
- self ,
161
- config : JambaConfig ,
162
- layer_idx : int ,
163
- cache_config : Optional [CacheConfig ] = None ,
164
- quant_config : Optional [QuantizationConfig ] = None ,
165
- prefix : str = "" ,
166
- ) -> None :
163
+ def __init__ (self ,
164
+ config : JambaConfig ,
165
+ layer_idx : int ,
166
+ cache_config : Optional [CacheConfig ] = None ,
167
+ quant_config : Optional [QuantizationConfig ] = None ,
168
+ prefix : str = "" ,
169
+ ** kwargs ) -> None :
167
170
super ().__init__ ()
168
171
self .hidden_size = config .hidden_size
169
172
tp_size = get_tensor_model_parallel_world_size ()
@@ -287,17 +290,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
287
290
org_num_embeddings = config .vocab_size ,
288
291
)
289
292
293
+ extra_kwargs = {"is_lora_enabled" : bool (vllm_config .lora_config )}
294
+
290
295
def get_layer (prefix : str ):
291
296
layer_idx = int (prefix .rsplit ("." , 1 )[1 ])
292
297
layer_class = ALL_DECODER_LAYER_TYPES [
293
298
config .layers_block_type [layer_idx ]]
294
- return layer_class (
295
- config ,
296
- layer_idx ,
297
- cache_config ,
298
- quant_config = quant_config ,
299
- prefix = prefix ,
300
- )
299
+ return layer_class (config ,
300
+ layer_idx ,
301
+ cache_config ,
302
+ quant_config = quant_config ,
303
+ prefix = prefix ,
304
+ ** extra_kwargs )
301
305
302
306
self .start_layer , self .end_layer , self .layers = make_layers (
303
307
config .num_hidden_layers , get_layer , prefix = f"{ prefix } .layers" )
@@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
371
375
"k_proj" ,
372
376
"v_proj" ,
373
377
],
378
+ "in_proj" : ["in_proj" ],
374
379
}
375
380
376
381
# LoRA specific attributes
377
382
supported_lora_modules = [
378
- "qkv_proj" ,
379
- "o_proj" ,
380
- "embed_tokens" ,
381
- "lm_head" ,
383
+ "qkv_proj" , "o_proj" , "embed_tokens" , "lm_head" , "up_proj" ,
384
+ "down_proj" , "gate_proj" , "out_proj" , "in_proj" , "x_proj"
382
385
]
383
386
embedding_modules = {
384
387
"embed_tokens" : "input_embeddings" ,
@@ -423,9 +426,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
423
426
self .make_empty_intermediate_tensors = (
424
427
self .model .make_empty_intermediate_tensors )
425
428
if self .scheduler_config is not None and \
426
- not self .model_config .enforce_eager :
429
+ not self .model_config .enforce_eager :
427
430
if self .scheduler_config .max_num_seqs > \
428
- vllm_config .compilation_config .max_capture_size :
431
+ vllm_config .compilation_config .max_capture_size :
429
432
self .max_batch_size = \
430
433
vllm_config .compilation_config .max_capture_size
431
434
else :
@@ -446,7 +449,6 @@ def forward(self,
446
449
inputs_embeds : Optional [torch .Tensor ] = None ,
447
450
** kwargs ):
448
451
if self .mamba_cache is None :
449
-
450
452
num_mamba_layers = self .model_config .get_num_layers_by_block_type (
451
453
self .vllm_config .parallel_config , LayerBlockType .mamba )
452
454
self .mamba_cache = MambaCacheManager (
0 commit comments