Skip to content

Commit d1adc72

Browse files
committed
Move weight loading logic to FlexOlmoModel
Signed-off-by: Shane A <[email protected]>
1 parent f98dd62 commit d1adc72

File tree

1 file changed

+60
-57
lines changed

1 file changed

+60
-57
lines changed

vllm/model_executor/models/flex_olmo.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@
4141
from vllm.model_executor.layers.vocab_parallel_embedding import (
4242
ParallelLMHead, VocabParallelEmbedding)
4343
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44+
from vllm.model_executor.models.interfaces import SupportsPP
45+
from vllm.model_executor.models.utils import (
46+
AutoWeightsLoader, is_pp_missing_parameter,
47+
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
4448
from vllm.model_executor.sampling_metadata import SamplingMetadata
4549
from vllm.sequence import IntermediateTensors
4650
from vllm.transformers_utils.configs import FlexOlmoConfig
4751

48-
from .interfaces import SupportsPP
49-
from .utils import (is_pp_missing_parameter,
50-
make_empty_intermediate_tensors_factory, make_layers,
51-
maybe_prefix)
52-
5352
logger = init_logger(__name__)
5453

5554

@@ -307,6 +306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
307306

308307
config = vllm_config.model_config.hf_config
309308
assert isinstance(config, FlexOlmoConfig)
309+
self.config = config
310310

311311
self.vocab_size = config.vocab_size
312312

@@ -359,58 +359,6 @@ def forward(
359359
hidden_states = self.norm(hidden_states)
360360
return hidden_states
361361

362-
363-
class FlexOlmoForCausalLM(nn.Module, SupportsPP):
364-
365-
fall_back_to_pt_during_load = False
366-
367-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
368-
super().__init__()
369-
config = vllm_config.model_config.hf_config
370-
assert isinstance(config, FlexOlmoConfig)
371-
quant_config = vllm_config.quant_config
372-
self.config = config
373-
self.quant_config = quant_config
374-
self.model = FlexOlmoModel(vllm_config=vllm_config,
375-
prefix=maybe_prefix(prefix, "model"))
376-
self.lm_head = ParallelLMHead(config.vocab_size,
377-
config.hidden_size,
378-
quant_config=quant_config,
379-
prefix=maybe_prefix(prefix, "lm_head"))
380-
self.logits_processor = LogitsProcessor(config.vocab_size)
381-
self.sampler = get_sampler()
382-
383-
self.make_empty_intermediate_tensors = (
384-
self.model.make_empty_intermediate_tensors)
385-
386-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
387-
return self.model.get_input_embeddings(input_ids)
388-
389-
def forward(
390-
self,
391-
input_ids: torch.Tensor,
392-
positions: torch.Tensor,
393-
intermediate_tensors: Optional[IntermediateTensors] = None,
394-
inputs_embeds: Optional[torch.Tensor] = None,
395-
) -> Union[torch.Tensor, IntermediateTensors]:
396-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
397-
inputs_embeds)
398-
return hidden_states
399-
400-
def compute_logits(self, hidden_states: torch.Tensor,
401-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
402-
logits = self.logits_processor(self.lm_head, hidden_states,
403-
sampling_metadata)
404-
return logits
405-
406-
def sample(
407-
self,
408-
logits: Optional[torch.Tensor],
409-
sampling_metadata: SamplingMetadata,
410-
) -> Optional[SamplerOutput]:
411-
next_tokens = self.sampler(logits, sampling_metadata)
412-
return next_tokens
413-
414362
def load_weights(self, weights: Iterable[tuple[str,
415363
torch.Tensor]]) -> set[str]:
416364
stacked_params_mapping = [
@@ -508,3 +456,58 @@ def load_weights(self, weights: Iterable[tuple[str,
508456
weight_loader(param, loaded_weight)
509457
loaded_params.add(name)
510458
return loaded_params
459+
460+
461+
class FlexOlmoForCausalLM(nn.Module, SupportsPP):
462+
463+
fall_back_to_pt_during_load = False
464+
465+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
466+
super().__init__()
467+
config = vllm_config.model_config.hf_config
468+
assert isinstance(config, FlexOlmoConfig)
469+
quant_config = vllm_config.quant_config
470+
self.quant_config = quant_config
471+
self.model = FlexOlmoModel(vllm_config=vllm_config,
472+
prefix=maybe_prefix(prefix, "model"))
473+
self.lm_head = ParallelLMHead(config.vocab_size,
474+
config.hidden_size,
475+
quant_config=quant_config,
476+
prefix=maybe_prefix(prefix, "lm_head"))
477+
self.logits_processor = LogitsProcessor(config.vocab_size)
478+
self.sampler = get_sampler()
479+
480+
self.make_empty_intermediate_tensors = (
481+
self.model.make_empty_intermediate_tensors)
482+
483+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
484+
return self.model.get_input_embeddings(input_ids)
485+
486+
def forward(
487+
self,
488+
input_ids: torch.Tensor,
489+
positions: torch.Tensor,
490+
intermediate_tensors: Optional[IntermediateTensors] = None,
491+
inputs_embeds: Optional[torch.Tensor] = None,
492+
) -> Union[torch.Tensor, IntermediateTensors]:
493+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
494+
inputs_embeds)
495+
return hidden_states
496+
497+
def compute_logits(self, hidden_states: torch.Tensor,
498+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
499+
logits = self.logits_processor(self.lm_head, hidden_states,
500+
sampling_metadata)
501+
return logits
502+
503+
def sample(
504+
self,
505+
logits: Optional[torch.Tensor],
506+
sampling_metadata: SamplingMetadata,
507+
) -> Optional[SamplerOutput]:
508+
next_tokens = self.sampler(logits, sampling_metadata)
509+
return next_tokens
510+
511+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
512+
loader = AutoWeightsLoader(self)
513+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)