From cdc5d30407a92abc0052c2d0095050ebd76c70e5 Mon Sep 17 00:00:00 2001 From: Jinheng Li Date: Mon, 8 Sep 2025 17:05:24 +0800 Subject: [PATCH 1/4] Remove LoRA additional vocabulary support This commit completely removes the deprecated lora_extra_vocab_size feature that allowed LoRA adapters to extend model vocabulary. The feature has been deprecated and is no longer needed. Changes: - Remove lora_extra_vocab_size field from LoRAConfig - Remove extra_vocab_size parameter from all punica wrapper methods - Update embeddings_indices calculation to use only base vocab size - Remove all lora_vocab = 0 assignments from model files - Fix IndentationError in models caused by empty if blocks - Update all tests to use new function signatures - Remove all references and comments about extra vocabulary support The removal simplifies the codebase and eliminates dead code paths that were no longer being used. Signed-off-by: Jinheng Li --- tests/lora/test_layers.py | 51 ++++++----------- tests/lora/test_lora_manager.py | 27 +++------ vllm/engine/arg_utils.py | 4 -- vllm/lora/lora.py | 5 -- vllm/lora/models.py | 57 +------------------ vllm/lora/punica_wrapper/punica_base.py | 7 +-- vllm/lora/punica_wrapper/punica_gpu.py | 4 +- vllm/lora/punica_wrapper/punica_tpu.py | 2 - vllm/lora/punica_wrapper/punica_xpu.py | 4 +- vllm/lora/punica_wrapper/utils.py | 11 ++-- vllm/lora/worker_manager.py | 10 ---- vllm/model_executor/models/apertus.py | 6 +- vllm/model_executor/models/bamba.py | 7 +-- vllm/model_executor/models/commandr.py | 7 +-- vllm/model_executor/models/exaone.py | 7 +-- vllm/model_executor/models/exaone4.py | 7 +-- vllm/model_executor/models/falcon_h1.py | 7 +-- vllm/model_executor/models/gpt_bigcode.py | 7 +-- vllm/model_executor/models/granite.py | 7 +-- vllm/model_executor/models/granitemoe.py | 7 +-- .../model_executor/models/granitemoehybrid.py | 7 +-- .../model_executor/models/granitemoeshared.py | 7 +-- vllm/model_executor/models/grok1.py | 7 +-- vllm/model_executor/models/hunyuan_v1.py | 5 +- vllm/model_executor/models/jamba.py | 7 +-- vllm/model_executor/models/lfm2.py | 7 +-- vllm/model_executor/models/llama.py | 7 +-- vllm/model_executor/models/mamba.py | 7 +-- vllm/model_executor/models/mamba2.py | 7 +-- vllm/model_executor/models/minicpm.py | 7 +-- vllm/model_executor/models/minicpm_eagle.py | 7 +-- vllm/model_executor/models/mixtral.py | 7 +-- vllm/model_executor/models/nemotron.py | 7 +-- vllm/model_executor/models/nemotron_h.py | 7 +-- vllm/model_executor/models/nemotron_nas.py | 7 +-- vllm/model_executor/models/phi4flash.py | 2 - vllm/model_executor/models/phi4mm.py | 2 - vllm/model_executor/models/phimoe.py | 7 +-- vllm/model_executor/models/solar.py | 7 +-- vllm/model_executor/models/step3_text.py | 2 - vllm/model_executor/models/zamba2.py | 7 +-- vllm/v1/worker/lora_model_runner_mixin.py | 1 - vllm/v1/worker/tpu_model_runner.py | 2 - vllm/worker/model_runner.py | 1 - 44 files changed, 89 insertions(+), 289 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6735b7cd9e43..4791236100e2 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -293,8 +293,7 @@ def create_random_embedding_layer(): prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + vocab_size) lora_result = lora_embedding(torch.cat(inputs)) @@ -331,8 +330,7 @@ def create_random_embedding_layer(): prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + vocab_size) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) @@ -371,7 +369,7 @@ def create_random_embedding_layer(): embedding.weight.data = embedding_data embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - vocab_size + lora_config.lora_extra_vocab_size * max_loras, + vocab_size, 256, org_num_embeddings=vocab_size) expanded_embedding.weight.data[:vocab_size, :] = embedding_data @@ -392,7 +390,7 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size)), + (256, vocab_size)), generate_embeddings_tensor=256, ) @@ -417,8 +415,7 @@ def create_random_embedding_layer(): prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + vocab_size) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range @@ -473,8 +470,7 @@ def create_random_embedding_layer(): prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + vocab_size) lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) @@ -505,14 +501,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, lora_dtype=torch.float16) def _pretest(): - linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + linear = ParallelLMHead(vocab_size, 1024, vocab_size, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size) + vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device, None) @@ -550,9 +546,7 @@ def _pretest(): lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) + vocab_size) input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( @@ -566,8 +560,7 @@ def _pretest(): org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = (vocab_size + - lora_config.lora_extra_vocab_size) + logits_processor.org_vocab_size = vocab_size expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] @@ -599,9 +592,7 @@ def _pretest(): lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) + vocab_size) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), @@ -684,9 +675,7 @@ def create_random_linear_replicated_layer(): lora_mapping, id_to_index, max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) + 512) lora_result = lora_linear(torch.cat(inputs))[0] @@ -721,7 +710,7 @@ def create_random_linear_replicated_layer(): is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + 512) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] @@ -807,9 +796,7 @@ def create_random_linear_parallel_layer(): lora_mapping, id_to_index, max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) + 512) lora_result = lora_linear(torch.cat(inputs))[0] @@ -844,7 +831,7 @@ def create_random_linear_parallel_layer(): is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + 512) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] @@ -954,9 +941,7 @@ class FakeConfig: lora_mapping, id_to_index, max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) + 512) lora_result = lora_linear(torch.cat(inputs))[0] @@ -995,9 +980,7 @@ class FakeConfig: lora_mapping, id_to_index, max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) + 512) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index a5802c108c6b..ff06a1a424d9 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -41,8 +41,6 @@ def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) - new_embeddings = load_file( - os.path.join(sql_lora_files, "new_embeddings.safetensors")) peft_helper = PEFTHelper.from_local_dir(sql_lora_files, max_position_embeddings=4096) @@ -50,10 +48,7 @@ def test_from_lora_tensors(sql_lora_files, device): 1, tensors, peft_helper=peft_helper, - device=device, - embeddings=new_embeddings, - embedding_modules=EMBEDDING_MODULES, - embedding_padding_modules=EMBEDDING_PADDING_MODULES) + device=device) for module_name, lora in lora_model.loras.items(): assert lora.module_name == module_name assert lora.rank == 8 @@ -65,15 +60,8 @@ def test_from_lora_tensors(sql_lora_files, device): assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" assert lora.lora_a.shape[1] == 8 - embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), None) - if embeddings_module: - assert torch.equal( - lora.embeddings_tensor, - new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( - device=lora.embeddings_tensor.device)) - else: - assert lora.embeddings_tensor is None + # No embeddings tensor since additional vocabulary is removed + assert lora.embeddings_tensor is None def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], @@ -437,8 +425,8 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, ) worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, - dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size, - lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + dummy_model.unpadded_vocab_size, + lora_config, device, EMBEDDING_MODULES) worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) @@ -518,9 +506,8 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, max_loras=4, lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( - 4, 2, dummy_model_gate_up.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + 4, 2, dummy_model_gate_up.unpadded_vocab_size, lora_config, device, + EMBEDDING_MODULES) worker_adapter_manager.create_lora_manager(dummy_model_gate_up) dummy_lora_files = f"{tmp_path}/lora_adapter" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4831cb5348c7..aa918e2effe6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -398,7 +398,6 @@ class EngineArgs: fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype - lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: Optional[ @@ -831,8 +830,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) - lora_group.add_argument("--lora-extra-vocab-size", - **lora_kwargs["lora_extra_vocab_size"]) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], @@ -1411,7 +1408,6 @@ def create_engine_config( max_loras=self.max_loras, default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 958364fca592..87f94c2f9591 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -58,11 +58,6 @@ def output_dim(self) -> int: def is_packed(self) -> bool: return False - @property - def extra_vocab_size(self) -> int: - return self.embeddings_tensor.shape[ - 0] if self.embeddings_tensor is not None else 0 - @classmethod def from_config( cls, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 771243805491..02151c5c65b0 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -91,11 +91,6 @@ def clone(self, lora_model_id: int) -> "LoRAModel": loras=self.loras.copy(), ) - @property - def extra_vocab_size(self) -> int: - return max(lora.extra_vocab_size - for lora in self.loras.values()) if self.loras else 0 - def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) @@ -112,10 +107,6 @@ def from_lora_tensors( peft_helper: PEFTHelper, device: str = "cuda", dtype: Optional[torch.dtype] = None, - embeddings: Optional[dict[str, torch.Tensor]] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[dict[str, str]] = None, - embedding_padding_modules: Optional[list[str]] = None, weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" @@ -125,21 +116,8 @@ def from_lora_tensors( module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( tensor_name, weights_mapper) if module_name not in loras: - lora_embeddings_tensor = None - if embeddings: - assert embedding_modules is not None - embeddings_module = next( - (k for k in embedding_modules if k in module_name), - None) - if embeddings_module: - lora_embeddings_tensor = embeddings[ - embedding_modules[embeddings_module]].to( - device=device, dtype=dtype) - if pin_memory: - lora_embeddings_tensor = ( - lora_embeddings_tensor.pin_memory()) loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper, lora_embeddings_tensor) + module_name, peft_helper, None) if is_bias: loras[module_name].bias = tensor.to(device=device, @@ -157,15 +135,6 @@ def from_lora_tensors( else: loras[module_name].lora_b = tensor.to(device=device, dtype=dtype).t() - assert embedding_padding_modules is not None - if any(name in module_name - for name in embedding_padding_modules - ) and target_embedding_padding is not None: - lora_b = loras[module_name].lora_b - assert target_embedding_padding >= lora_b.shape[1] - addition = target_embedding_padding - lora_b.shape[1] - loras[module_name].lora_b = torch.nn.functional.pad( - lora_b, (0, addition)) if pin_memory: loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() @@ -185,9 +154,6 @@ def from_local_checkpoint( lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[dict[str, str]] = None, - embedding_padding_modules: Optional[list[str]] = None, weights_mapper: Optional[WeightsMapper] = None, tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -208,10 +174,6 @@ def from_local_checkpoint( lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") - new_embeddings_tensor_path = os.path.join( - lora_dir, "new_embeddings.safetensors") - new_embeddings_bin_file_path = os.path.join(lora_dir, - "new_embeddings.bin") tensors: dict[str, torch.Tensor] = {} unexpected_modules: list[Union[list[str], str]] = [] @@ -290,15 +252,6 @@ def check_unexpected_modules(modules: dict): else: raise ValueError(f"{lora_dir} doesn't contain tensors") - embeddings = None - if os.path.isfile(new_embeddings_tensor_path): - embeddings = safetensors.torch.load_file( - new_embeddings_tensor_path) - elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load(new_embeddings_bin_file_path, - map_location=device, - weights_only=True) - return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, @@ -306,10 +259,6 @@ def check_unexpected_modules(modules: dict): peft_helper=peft_helper, device=device, dtype=dtype, - embeddings=embeddings, - target_embedding_padding=target_embedding_padding, - embedding_modules=embedding_modules, - embedding_padding_modules=embedding_padding_modules, weights_mapper=weights_mapper) @@ -447,7 +396,6 @@ def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size, ) def remove_all_adapters(self): @@ -539,8 +487,7 @@ def create_dummy_lora( if module_name not in self.packed_modules: assert embedding_modules is not None if parts[-1] in embedding_modules: - input_dim = (module.base_layer.org_vocab_size + - self.lora_config.lora_extra_vocab_size if + input_dim = (module.base_layer.org_vocab_size if hasattr(module.base_layer, "org_vocab_size") else module.base_layer.weight.shape[1]) output_dim = module.base_layer.embedding_dim if hasattr( diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b3413de1c816..6e6328324797 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -31,7 +31,6 @@ def update_metadata( lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ) -> None: """ @@ -170,7 +169,6 @@ def _update_base_metadata( lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, - extra_vocab_size: int, ): ( base_indices, @@ -183,7 +181,6 @@ def _update_base_metadata( lora_index_to_id, max_loras, vocab_size, - extra_vocab_size, self.device, ) self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) @@ -302,10 +299,10 @@ def embeddings_indices(self) -> torch.Tensor: def update_metadata(self, mapping: "LoRAMapping", lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): + vocab_size: int, **kwargs): self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + vocab_size) if mapping.is_prefill: # Update metadata required for prefill-related operators. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 2db0e9fee142..93164fa94f19 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -53,11 +53,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, def update_metadata(self, mapping: LoRAMapping, lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): + vocab_size: int, **kwargs): self.is_prefill = mapping.is_prefill self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + vocab_size) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 07dc337a1cc8..2900c2ecfcf4 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -320,7 +320,6 @@ def _update_base_metadata( lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, - extra_vocab_size: int, ): # Make sure we don't accidentally collect outside operations xm.mark_step() @@ -342,7 +341,6 @@ def _update_base_metadata( lora_index_to_id, max_loras, vocab_size, - extra_vocab_size, "cpu", ) self._token_lora_indices = self._pad_to_shape( diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 163bb412235c..ca97ced8cb55 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -35,11 +35,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, def update_metadata(self, mapping: LoRAMapping, lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): + vocab_size: int, **kwargs): self.is_prefill = mapping.is_prefill self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + vocab_size) def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index d22c29da1c61..2a79c556897b 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -46,7 +46,6 @@ def convert_mapping( lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, - extra_vocab_size: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]: """Converts LoRAMapping to index tensors. @@ -56,7 +55,6 @@ def convert_mapping( lora_index_to_id: List mapping LoRA ids to LoRA indices. max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. Returns: A tuple of tensors: @@ -71,9 +69,8 @@ def convert_mapping( Same as sampler_indices, but -1 is replaced with max_loras. embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. + requests to embedding indices. First row is always zeros, + second row is for the LoRA.lora_a embeddings. indices_len: List of lengths of the above tensors. It contains (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices). @@ -105,8 +102,8 @@ def convert_mapping( dtype=torch.long, device=device) embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), + torch.zeros_like(indices[2]), + indices[2] * vocab_size, ]) embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, embeddings_indices) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3a807b1e161d..b2949019f4d4 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -38,13 +38,11 @@ def __init__( lora_config: LoRAConfig, device: torch.device, embedding_modules: dict[str, str], - embedding_padding_modules: list[str], lora_model_cls: type[LoRAModel] = LoRAModel, max_position_embeddings: Optional[int] = None, ): self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules - self.embedding_padding_modules = embedding_padding_modules self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = max_num_batched_tokens @@ -120,10 +118,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + - self.lora_config.lora_extra_vocab_size, - embedding_modules=self.embedding_modules, - embedding_padding_modules=self.embedding_padding_modules, tensorizer_config_dict=lora_request.tensorizer_config_dict, weights_mapper=hf_to_vllm_mapper) @@ -140,10 +134,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For BadRequestError raise e - if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " - f"is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}.") return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index f6400b05e110..562c215170c9 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -332,9 +332,7 @@ def __init__(self, self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -505,8 +503,6 @@ def __init__(self, if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 397089f31cdf..2fbfa55967be 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -276,9 +276,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -504,8 +503,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = BambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 7f87e31abdcd..c444b18b1738 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -286,9 +286,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.config = config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) @@ -419,8 +418,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # enabled assert config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.quant_config = quant_config self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index f503fb0f9364..dca33c4efffe 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -317,9 +317,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.wte = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -491,8 +490,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 9f7d57d93814..2b8a58a4adbd 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -302,9 +302,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( @@ -474,8 +473,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 757051b3b144..3e39971ab78c 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -419,9 +419,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank: @@ -594,8 +593,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.tie_word_embeddings = config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size self.mamba_cache: Optional[MambaCacheManager] = None - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 745d0b775999..a51cad3418e3 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -212,9 +212,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not config.add_cross_attention self.embed_dim = config.hidden_size - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.wte = VocabParallelEmbedding(self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size) @@ -305,8 +304,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=self.config.vocab_size, prefix=maybe_prefix(prefix, "lm_head")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 4f9cc2532bd8..6a9527d26120 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -264,9 +264,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -423,8 +422,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index da16c72000c0..3890a4a6af44 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -262,9 +262,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config # Required by MixtralModel - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -476,8 +475,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = GraniteMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 79c6d8146ba9..4a8e2b6095ed 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -327,9 +327,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -601,8 +600,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 0b568a4b2268..b0e56feb09c8 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -158,9 +158,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config # Required by MixtralModel self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -277,8 +276,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index a59113438337..a05201d0b7a9 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -298,9 +298,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embedding_multiplier_scale = getattr( config, "embedding_multiplier_scale", @@ -487,8 +486,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index db054b5c537e..efe537badd46 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -566,9 +566,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5b8fbc722686..f35a21cc1de4 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -293,9 +293,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -492,8 +491,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = JambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 927f78c4e4b4..5eeaa3c0dbf9 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -318,9 +318,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -504,8 +503,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f8ea2111fed5..59bfaf783754 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -342,9 +342,8 @@ def __init__(self, self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -540,8 +539,6 @@ def __init__(self, if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9d1017dac8aa..b9aeb55fb812 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -101,9 +101,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): is_lora_enabled = bool(lora_config) self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embeddings = VocabParallelEmbedding( @@ -210,8 +209,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.backbone = MambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size if config.tie_word_embeddings: self.lm_head = self.backbone.embeddings else: diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index b1a4138cb8f6..5dfa42bf49ea 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -105,9 +105,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not is_lora_enabled self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embeddings = VocabParallelEmbedding( @@ -267,8 +266,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.backbone = Mamba2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c7be7f76dba1..42279abe4281 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -365,9 +365,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -536,8 +535,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 848a97b8bb2a..fa962c3e6512 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -150,9 +150,8 @@ def __init__(self, self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.config.hidden_size, @@ -327,8 +326,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): start_layer=target_layer_num) unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8b3474d80953..b6d910f06199 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -294,9 +294,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -496,8 +495,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = MixtralModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 21f785e4b91a..0ffbfa181768 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -304,9 +304,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -455,8 +454,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 1e1f0524bd06..44b733cacf23 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -331,9 +331,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -555,8 +554,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = NemotronHModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index f8e38dcd80b5..b4bea204679c 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -227,9 +227,8 @@ def __init__( self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -419,8 +418,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index c4548ee168bd..4efd8604eaaa 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -618,8 +618,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b3fc55dab6ec..3da83aa2200c 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -981,8 +981,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 01d16f1f2c38..43c3b1ddb9d4 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -455,9 +455,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.config = config self.quant_config = quant_config @@ -632,8 +631,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = PhiMoEModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 94c862258b7a..4c58a88b6dd4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -272,9 +272,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): @@ -458,8 +457,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index b8733fa5e612..62648f33dc41 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -379,8 +379,6 @@ def __init__( if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index e601bc3adb6e..3926d96fb4bf 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -679,9 +679,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: assert not is_lora_enabled self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size # Initialize token embeddings @@ -925,8 +924,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size # Initialize core model self.model = Zamba2Model(vllm_config=vllm_config, diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 01d5f0525c4e..9c15fd7888e8 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -55,7 +55,6 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig, lora_config, device, model.embedding_modules, - model.embedding_padding_modules, max_position_embeddings=text_config.max_position_embeddings, ) return self.lora_manager.create_lora_manager(model) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 43f12912707f..d185f037b798 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -186,8 +186,6 @@ def __init__( self.hidden_size = model_config.get_hidden_size() self.vocab_size = model_config.get_vocab_size() - if self.lora_config is not None: - self.vocab_size += self.lora_config.lora_extra_vocab_size # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 88f83c9dd7e6..3fae02c9570e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1072,7 +1072,6 @@ def load_model(self) -> None: self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules, max_position_embeddings=text_config. max_position_embeddings, ) From ccad11db8f61a7277f5c2fc48fd57dcb16472c1f Mon Sep 17 00:00:00 2001 From: Jinheng LI Date: Wed, 17 Sep 2025 23:26:52 +0800 Subject: [PATCH 2/4] pre-commit changes Signed-off-by: Jinheng LI --- tests/lora/test_layers.py | 52 +++++++++--------------------- tests/lora/test_lora_manager.py | 14 ++++---- vllm/lora/models.py | 21 ++++++------ vllm/v1/worker/tpu_model_runner.py | 1 - 4 files changed, 31 insertions(+), 57 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 4791236100e2..dfd01b6cf096 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -369,9 +369,7 @@ def create_random_embedding_layer(): embedding.weight.data = embedding_data embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - vocab_size, - 256, - org_num_embeddings=vocab_size) + vocab_size, 256, org_num_embeddings=vocab_size) expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place @@ -389,8 +387,7 @@ def create_random_embedding_layer(): lora_dict, _ = populate_loras( id_to_index, layer=lora_embedding, - layer_weights=torch.zeros( - (256, vocab_size)), + layer_weights=torch.zeros((256, vocab_size)), generate_embeddings_tensor=256, ) @@ -507,8 +504,7 @@ def _pretest(): params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 - logits_processor = LogitsProcessor( - vocab_size, vocab_size) + logits_processor = LogitsProcessor(vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device, None) @@ -542,11 +538,8 @@ def _pretest(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size) input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( @@ -588,11 +581,8 @@ def _pretest(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), @@ -671,11 +661,8 @@ def create_random_linear_replicated_layer(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512) lora_result = lora_linear(torch.cat(inputs))[0] @@ -792,11 +779,8 @@ def create_random_linear_parallel_layer(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512) lora_result = lora_linear(torch.cat(inputs))[0] @@ -937,11 +921,8 @@ class FakeConfig: prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512) lora_result = lora_linear(torch.cat(inputs))[0] @@ -976,11 +957,8 @@ class FakeConfig: prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index ff06a1a424d9..938fa31515f5 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -44,11 +44,10 @@ def test_from_lora_tensors(sql_lora_files, device): peft_helper = PEFTHelper.from_local_dir(sql_lora_files, max_position_embeddings=4096) - lora_model = LoRAModel.from_lora_tensors( - 1, - tensors, - peft_helper=peft_helper, - device=device) + lora_model = LoRAModel.from_lora_tensors(1, + tensors, + peft_helper=peft_helper, + device=device) for module_name, lora in lora_model.loras.items(): assert lora.module_name == module_name assert lora.rank == 8 @@ -424,9 +423,8 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, lora_dtype=DEFAULT_DTYPE, ) worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, - dummy_model.unpadded_vocab_size, - lora_config, device, EMBEDDING_MODULES) + 4, 2, dummy_model.unpadded_vocab_size, lora_config, device, + EMBEDDING_MODULES) worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 02151c5c65b0..75ad3e888021 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -252,14 +252,13 @@ def check_unexpected_modules(modules: dict): else: raise ValueError(f"{lora_dir} doesn't contain tensors") - return cls.from_lora_tensors( - lora_model_id=get_lora_id() - if lora_model_id is None else lora_model_id, - tensors=tensors, - peft_helper=peft_helper, - device=device, - dtype=dtype, - weights_mapper=weights_mapper) + return cls.from_lora_tensors(lora_model_id=get_lora_id() if + lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + weights_mapper=weights_mapper) class LoRAModelManager(AdapterModelManager): @@ -487,9 +486,9 @@ def create_dummy_lora( if module_name not in self.packed_modules: assert embedding_modules is not None if parts[-1] in embedding_modules: - input_dim = (module.base_layer.org_vocab_size if - hasattr(module.base_layer, "org_vocab_size") - else module.base_layer.weight.shape[1]) + input_dim = (module.base_layer.org_vocab_size if hasattr( + module.base_layer, "org_vocab_size") else + module.base_layer.weight.shape[1]) output_dim = module.base_layer.embedding_dim if hasattr( module.base_layer, "embedding_dim") else module.base_layer.weight.shape[0] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d185f037b798..c0521d834ffd 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -186,7 +186,6 @@ def __init__( self.hidden_size = model_config.get_hidden_size() self.vocab_size = model_config.get_vocab_size() - # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope From 8954e0619a27902bc97c45d3025caa81cb29635e Mon Sep 17 00:00:00 2001 From: Jinheng Li Date: Thu, 18 Sep 2025 00:11:53 +0800 Subject: [PATCH 3/4] Fix unused variable errors after LoRA vocabulary removal Remove unused lora_config assignments in model classes where the LoRA additional vocabulary support was removed. The lora_config variable was being assigned but never used after the vocabulary size logic was simplified. Signed-off-by: Jinheng Li --- vllm/model_executor/models/apertus.py | 1 - vllm/model_executor/models/bamba.py | 1 - vllm/model_executor/models/commandr.py | 2 -- vllm/model_executor/models/exaone.py | 1 - vllm/model_executor/models/exaone4.py | 1 - vllm/model_executor/models/falcon_h1.py | 1 - vllm/model_executor/models/gpt_bigcode.py | 1 - vllm/model_executor/models/granitemoe.py | 1 - vllm/model_executor/models/granitemoehybrid.py | 1 - vllm/model_executor/models/granitemoeshared.py | 1 - vllm/model_executor/models/grok1.py | 1 - vllm/model_executor/models/hunyuan_v1.py | 1 - vllm/model_executor/models/jamba.py | 1 - vllm/model_executor/models/lfm2.py | 1 - vllm/model_executor/models/llama.py | 1 - vllm/model_executor/models/minicpm.py | 1 - vllm/model_executor/models/minicpm_eagle.py | 1 - vllm/model_executor/models/mixtral.py | 1 - vllm/model_executor/models/nemotron.py | 1 - vllm/model_executor/models/nemotron_h.py | 1 - vllm/model_executor/models/nemotron_nas.py | 1 - vllm/model_executor/models/phimoe.py | 1 - vllm/model_executor/models/solar.py | 1 - 23 files changed, 24 deletions(-) diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 562c215170c9..10d3c90c685e 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -328,7 +328,6 @@ def __init__(self, config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 2fbfa55967be..79d9545bcee5 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -273,7 +273,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # No additional vocabulary support for LoRA diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index c444b18b1738..31fb635a48c8 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -282,7 +282,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.quant_config = quant_config self.config = config @@ -412,7 +411,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # currently all existing command R models have `tie_word_embeddings` # enabled diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index dca33c4efffe..343fd17746a5 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -313,7 +313,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 2b8a58a4adbd..051e3845811f 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -298,7 +298,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 3e39971ab78c..8623c0d8a121 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -416,7 +416,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # No additional vocabulary support for LoRA diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index a51cad3418e3..d1c30a2742a0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -206,7 +206,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config assert not config.add_cross_attention diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 3890a4a6af44..4a38ec86d491 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -258,7 +258,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 4a8e2b6095ed..186b51221597 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -324,7 +324,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # No additional vocabulary support for LoRA diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index b0e56feb09c8..abbeba8d3798 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -153,7 +153,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index a05201d0b7a9..c43f95eefaea 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -293,7 +293,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index efe537badd46..95625599440a 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -561,7 +561,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index f35a21cc1de4..b33aac1887f8 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -290,7 +290,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # No additional vocabulary support for LoRA diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 5eeaa3c0dbf9..017f63d41f10 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -315,7 +315,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # No additional vocabulary support for LoRA diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 59bfaf783754..4c38fdcd050c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -338,7 +338,6 @@ def __init__(self, config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 42279abe4281..652a5db6412a 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -360,7 +360,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.cache_config = cache_config diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index fa962c3e6512..9438fdfb3c9e 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -145,7 +145,6 @@ def __init__(self, config = vllm_config.speculative_config.draft_model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.cache_config = cache_config diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b6d910f06199..991baf88e497 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -289,7 +289,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config parallel_config = vllm_config.parallel_config self.config = config diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 0ffbfa181768..65b7a3798185 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -300,7 +300,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 44b733cacf23..d23040b52116 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -328,7 +328,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # No additional vocabulary support for LoRA diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index b4bea204679c..efd625940e82 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -222,7 +222,6 @@ def __init__( config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 43c3b1ddb9d4..e3ee4384a1c0 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -453,7 +453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config # No additional vocabulary support for LoRA self.vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 4c58a88b6dd4..803f83a1b520 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -268,7 +268,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config From 44f75d4838d1a8ed986c4b188cb20fb4c951ccdc Mon Sep 17 00:00:00 2001 From: Jinheng Li Date: Thu, 18 Sep 2025 01:05:08 +0800 Subject: [PATCH 4/4] Remove lora_extra_vocab_size from LoRA config directly Signed-off-by: Jinheng Li --- vllm/config/lora.py | 20 +------- vllm/lora/layers/logits_processor.py | 51 ++++--------------- vllm/lora/layers/vocal_parallel_embedding.py | 52 +++----------------- vllm/model_executor/models/qwen3_next.py | 7 +-- vllm/model_executor/models/qwen3_next_mtp.py | 5 +- 5 files changed, 20 insertions(+), 115 deletions(-) diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 3fe28f5dad4f..5c0e71677758 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -44,11 +44,6 @@ class LoRAConfig: `max_loras`.""" lora_dtype: Union[torch.dtype, LoRADType] = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: int = 256 - """(Deprecated) Maximum size of extra vocabulary that can be present in a - LoRA adapter. Will be removed in v0.12.0.""" - lora_vocab_padding_size: ClassVar[int] = current_platform\ - .get_lora_vocab_padding_size() default_mm_loras: Optional[dict[str, str]] = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -80,21 +75,13 @@ def compute_hash(self) -> str: factors.append(self.max_loras) factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) - factors.append(self.lora_extra_vocab_size) - factors.append(self.lora_vocab_padding_size) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): - # Deprecation warning for lora_extra_vocab_size - logger.warning( - "`lora_extra_vocab_size` is deprecated and will be removed " - "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out.") - - # Deprecation warning for enable_lora_bias +# Deprecation warning for enable_lora_bias if self.bias_enabled: logger.warning("`enable_lora_bias` is deprecated " "and will be removed in v0.12.0.") @@ -102,15 +89,10 @@ def __post_init__(self): # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) - possible_lora_extra_vocab_size = (256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " f"{possible_max_ranks}.") - if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError( - f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index a50dcfa748f2..a3a41f437db6 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -103,20 +103,13 @@ def create_lora_weights( max_loras, 1, # Pad for kernel compatibility - math.ceil(self.base_layer.vocab_size / - lora_config.lora_vocab_padding_size) * - lora_config.lora_vocab_padding_size, + math.ceil(self.base_layer.vocab_size / 256) * 256, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.device, ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) + # No additional vocabulary support for LoRA if self.sharded_to_full_mapping is not None: self.sharded_to_full_mapping_gpu = torch.tensor( self.sharded_to_full_mapping, @@ -128,7 +121,7 @@ def create_lora_weights( def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") + # No additional vocabulary support for LoRA def set_lora( self, @@ -145,12 +138,8 @@ def set_lora( self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ] = embeddings_tensor + # No additional vocabulary support for LoRA + # embeddings_tensor parameter is no longer used def _get_logits( self, @@ -188,37 +177,15 @@ def _get_logits( # token_id: [0, 1, 2, 3, 4, 5, -1, -1] logits = logits[:, self.sharded_to_full_mapping_gpu] - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT + # No additional vocabulary support for LoRA - skip lora_logits computation + # The original logits computation remains unchanged indices_padded = self.punica_wrapper.sampler_indices_padded if current_platform.is_tpu() or current_platform.is_xpu(): indices_padded = indices_padded[:logits.size(0)] - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) - - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits + # Continue with the base logits without additional vocabulary + # No additional vocabulary assignment needed lora_output: Optional[ torch.Tensor] = self.punica_wrapper.add_lora_logits( diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 4d6218d97097..54478e7f1514 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -30,37 +30,13 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: - if self.base_layer.num_added_embeddings_per_partition > 0: - # We can start adding lora weights - self.embeddings_weights = self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:self. - base_layer.num_org_embeddings_per_partition + - self.base_layer.num_added_embeddings_per_partition] - self.embeddings_slice = ( - self.base_layer.shard_indices.added_vocab_start_index - - self.base_layer.org_vocab_size, - self.base_layer.shard_indices.added_vocab_end_index - - self.base_layer.org_vocab_size) - self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:].fill_(0) - else: - self.embeddings_slice = None - self.embeddings_weights = None - - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) + # No additional vocabulary support for LoRA + self.embeddings_slice = None + self.embeddings_weights = None self.lora_a_stacked = torch.zeros( ( max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, + self.base_layer.org_vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -84,7 +60,7 @@ def create_lora_weights( def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 + # No additional vocabulary support for LoRA def set_lora( self, @@ -100,22 +76,8 @@ def set_lora( self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0]:self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + # No additional vocabulary support for LoRA + # embeddings_tensor parameter is no longer used def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index fe63e9303235..f58b3eacf9db 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -879,16 +879,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - lora_config = vllm_config.lora_config speculative_config = vllm_config.speculative_config enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -1075,8 +1072,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen3NextModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 190a1750e673..46dcdb277f1d 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -45,9 +45,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config: Qwen3NextConfig = model_config.hf_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + # No additional vocabulary support for LoRA + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.mtp_start_layer_idx = config.num_hidden_layers