|
41 | 41 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
42 | 42 | ParallelLMHead, VocabParallelEmbedding)
|
43 | 43 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 44 | +from vllm.model_executor.models.interfaces import SupportsPP |
| 45 | +from vllm.model_executor.models.utils import ( |
| 46 | + AutoWeightsLoader, is_pp_missing_parameter, |
| 47 | + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) |
44 | 48 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
45 | 49 | from vllm.sequence import IntermediateTensors
|
46 | 50 | from vllm.transformers_utils.configs import FlexOlmoConfig
|
47 | 51 |
|
48 |
| -from .interfaces import SupportsPP |
49 |
| -from .utils import (is_pp_missing_parameter, |
50 |
| - make_empty_intermediate_tensors_factory, make_layers, |
51 |
| - maybe_prefix) |
52 |
| - |
53 | 52 | logger = init_logger(__name__)
|
54 | 53 |
|
55 | 54 |
|
@@ -307,6 +306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
307 | 306 |
|
308 | 307 | config = vllm_config.model_config.hf_config
|
309 | 308 | assert isinstance(config, FlexOlmoConfig)
|
| 309 | + self.config = config |
310 | 310 |
|
311 | 311 | self.vocab_size = config.vocab_size
|
312 | 312 |
|
@@ -359,58 +359,6 @@ def forward(
|
359 | 359 | hidden_states = self.norm(hidden_states)
|
360 | 360 | return hidden_states
|
361 | 361 |
|
362 |
| - |
363 |
| -class FlexOlmoForCausalLM(nn.Module, SupportsPP): |
364 |
| - |
365 |
| - fall_back_to_pt_during_load = False |
366 |
| - |
367 |
| - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
368 |
| - super().__init__() |
369 |
| - config = vllm_config.model_config.hf_config |
370 |
| - assert isinstance(config, FlexOlmoConfig) |
371 |
| - quant_config = vllm_config.quant_config |
372 |
| - self.config = config |
373 |
| - self.quant_config = quant_config |
374 |
| - self.model = FlexOlmoModel(vllm_config=vllm_config, |
375 |
| - prefix=maybe_prefix(prefix, "model")) |
376 |
| - self.lm_head = ParallelLMHead(config.vocab_size, |
377 |
| - config.hidden_size, |
378 |
| - quant_config=quant_config, |
379 |
| - prefix=maybe_prefix(prefix, "lm_head")) |
380 |
| - self.logits_processor = LogitsProcessor(config.vocab_size) |
381 |
| - self.sampler = get_sampler() |
382 |
| - |
383 |
| - self.make_empty_intermediate_tensors = ( |
384 |
| - self.model.make_empty_intermediate_tensors) |
385 |
| - |
386 |
| - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
387 |
| - return self.model.get_input_embeddings(input_ids) |
388 |
| - |
389 |
| - def forward( |
390 |
| - self, |
391 |
| - input_ids: torch.Tensor, |
392 |
| - positions: torch.Tensor, |
393 |
| - intermediate_tensors: Optional[IntermediateTensors] = None, |
394 |
| - inputs_embeds: Optional[torch.Tensor] = None, |
395 |
| - ) -> Union[torch.Tensor, IntermediateTensors]: |
396 |
| - hidden_states = self.model(input_ids, positions, intermediate_tensors, |
397 |
| - inputs_embeds) |
398 |
| - return hidden_states |
399 |
| - |
400 |
| - def compute_logits(self, hidden_states: torch.Tensor, |
401 |
| - sampling_metadata: SamplingMetadata) -> torch.Tensor: |
402 |
| - logits = self.logits_processor(self.lm_head, hidden_states, |
403 |
| - sampling_metadata) |
404 |
| - return logits |
405 |
| - |
406 |
| - def sample( |
407 |
| - self, |
408 |
| - logits: Optional[torch.Tensor], |
409 |
| - sampling_metadata: SamplingMetadata, |
410 |
| - ) -> Optional[SamplerOutput]: |
411 |
| - next_tokens = self.sampler(logits, sampling_metadata) |
412 |
| - return next_tokens |
413 |
| - |
414 | 362 | def load_weights(self, weights: Iterable[tuple[str,
|
415 | 363 | torch.Tensor]]) -> set[str]:
|
416 | 364 | stacked_params_mapping = [
|
@@ -508,3 +456,58 @@ def load_weights(self, weights: Iterable[tuple[str,
|
508 | 456 | weight_loader(param, loaded_weight)
|
509 | 457 | loaded_params.add(name)
|
510 | 458 | return loaded_params
|
| 459 | + |
| 460 | + |
| 461 | +class FlexOlmoForCausalLM(nn.Module, SupportsPP): |
| 462 | + |
| 463 | + fall_back_to_pt_during_load = False |
| 464 | + |
| 465 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 466 | + super().__init__() |
| 467 | + config = vllm_config.model_config.hf_config |
| 468 | + assert isinstance(config, FlexOlmoConfig) |
| 469 | + quant_config = vllm_config.quant_config |
| 470 | + self.quant_config = quant_config |
| 471 | + self.model = FlexOlmoModel(vllm_config=vllm_config, |
| 472 | + prefix=maybe_prefix(prefix, "model")) |
| 473 | + self.lm_head = ParallelLMHead(config.vocab_size, |
| 474 | + config.hidden_size, |
| 475 | + quant_config=quant_config, |
| 476 | + prefix=maybe_prefix(prefix, "lm_head")) |
| 477 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
| 478 | + self.sampler = get_sampler() |
| 479 | + |
| 480 | + self.make_empty_intermediate_tensors = ( |
| 481 | + self.model.make_empty_intermediate_tensors) |
| 482 | + |
| 483 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 484 | + return self.model.get_input_embeddings(input_ids) |
| 485 | + |
| 486 | + def forward( |
| 487 | + self, |
| 488 | + input_ids: torch.Tensor, |
| 489 | + positions: torch.Tensor, |
| 490 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 491 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 492 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 493 | + hidden_states = self.model(input_ids, positions, intermediate_tensors, |
| 494 | + inputs_embeds) |
| 495 | + return hidden_states |
| 496 | + |
| 497 | + def compute_logits(self, hidden_states: torch.Tensor, |
| 498 | + sampling_metadata: SamplingMetadata) -> torch.Tensor: |
| 499 | + logits = self.logits_processor(self.lm_head, hidden_states, |
| 500 | + sampling_metadata) |
| 501 | + return logits |
| 502 | + |
| 503 | + def sample( |
| 504 | + self, |
| 505 | + logits: Optional[torch.Tensor], |
| 506 | + sampling_metadata: SamplingMetadata, |
| 507 | + ) -> Optional[SamplerOutput]: |
| 508 | + next_tokens = self.sampler(logits, sampling_metadata) |
| 509 | + return next_tokens |
| 510 | + |
| 511 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 512 | + loader = AutoWeightsLoader(self) |
| 513 | + return loader.load_weights(weights) |
0 commit comments