From b56003d638f51974bbca556d77e8c3798cc0eeaf Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 21 May 2025 19:59:33 +0000 Subject: [PATCH 1/2] Make Olmo2Model weight loading return loaded weights Signed-off-by: Shane A --- vllm/model_executor/models/olmo2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 0a1fb10c186e..a8aeeeee76b3 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -314,7 +314,7 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -325,6 +325,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue @@ -347,6 +348,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Olmo2ForCausalLM(nn.Module, SupportsPP): From ad225169af2b9cbcc25632216c10384a5dce166e Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 21 May 2025 20:32:46 +0000 Subject: [PATCH 2/2] Fix formatting using pre-commit Signed-off-by: Shane A --- vllm/model_executor/models/olmo2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index a8aeeeee76b3..33adacdae5f5 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -314,7 +314,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"),