Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 37 additions & 73 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
Expand Down Expand Up @@ -608,11 +609,10 @@ def execute_model(
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
selected_token_ids = self.sample_from_hidden(hidden_states,
tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]

Expand Down Expand Up @@ -713,17 +713,15 @@ def load_model(self) -> None:
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
xm.wait_device_ops()
model = ModelWrapperV1(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
self.model = model
self.sampler = TPUSampler()

@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
def _dummy_run(self, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
Expand Down Expand Up @@ -774,7 +772,6 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None:
with set_forward_context(attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype

Expand All @@ -786,7 +783,7 @@ def capture_model(self) -> None:
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
self._dummy_run(num_tokens)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
Expand Down Expand Up @@ -815,8 +812,7 @@ def capture_model(self) -> None:
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
out = out.cpu()
if num_reqs_to_sample >= self.max_num_reqs:
break
Expand Down Expand Up @@ -874,79 +870,47 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)


class ModelWrapperV1(nn.Module):

def __init__(self, model: nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mind the wrapper too much I think it was grouping a few related functions nicely. Still, if it benefits performance I am ok with that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrapper was added in V0 for torch.compile. Now we have the @support_torch_compile decorator that wraps the model with torch.compile already, therefore we don't need the wrapper anymore.

super().__init__()
self.model = model
self.sampler = TPUSampler()

def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model.

Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""

hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)

return hidden_states
def reset_dynamo_cache(self):
if self.is_multimodal_model:
compiled_model = self.model.language_model.model
else:
compiled_model = self.model.model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK "language_model" is not a stable attribute to reference, it is based on the HF model definition. Maybe @ywang96 @DarkLight1337 would know a stable interface to access the language model backbone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a potential solution is to add a query function get_backbone_lm for multimodal models, similar to the existing get_input_embeddings

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can add get_language_model to SupportsMultiModal interface

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like #16007?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave an assert here for now as we will address later @lsy323

if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
logger.info("Clear dynamo cache and cached dynamo bytecode.")
torch._dynamo.eval_frame.remove_from_cache(
compiled_model.original_code_object)
compiled_model.compiled_codes.clear()

def sample_from_hidden(
self,
hidden_states: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
Sample with xla-friendly function. This function is to be traced
separately for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
logits = self.model.compute_logits(sample_hidden_states, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we re-add the pruning comment here? Just in case it slips though in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, may I ask what is 'pruning comment'?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# SamplingMetadata here for pruning output in LogitsProcessor, disabled
or smt along this line to indicate why the 2nd argument is None. That's because it could enable logits pruning which is bad for xla

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated


def sample(
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata
) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out

# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
out_tokens = torch.where(
sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
sample(logits, sampling_metadata).sampled_token_ids)
return out_tokens

def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits

def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)

def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)
Comment on lines -980 to -984
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested that multimodal inference still works and these are called correctly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only ran TPU CI to test the change, multimodal is not tested. Can you provide a script to test the multimodal? I can test it on this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM_USE_V1=1 vllm serve llava-hf/llava-1.5-7b-hf --max-model-len 4096 --max-num-seqs 8 --max-num-batched-tokens 512 --chat-template examples/template_llava.jinja

then python examples/online_serving/openai_chat_completion_client_for_multimodal.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am adding tests in another pr btw

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg I'll rebase to your PR after :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NickLucche Thank you for providing the testing cmd for multimodal, I made this PR working for llava-hf/llava-1.5-7b-hf. The server cmd runs fine.

However, the client script python examples/online_serving/openai_chat_completion_client_for_multimodal.py would fail at HEAD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main is working for me, let me try your PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still really slow but it works the same on this PR on my side, thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche Thank you so much for trying! I set up a new conda env on my end and tried as well. I also found it's really slow, was about to ask you lol.



def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,19 @@ def determine_available_memory(self) -> int:
runner_kv_caches)

self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
self.scheduler_config.max_num_batched_tokens)

# Synchronize before measuring the memory usage.
xm.wait_device_ops()

# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()

# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
Expand Down