Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
e0bb716
feat:trace v1
Jul 2, 2025
a7414f7
Merge pull request #1 from RichardoMrMu/feat-trace-v1-aftermerge
RichardoMrMu Jul 2, 2025
440ca59
fix: ttft calculation
hcyezhang Jul 2, 2025
a30adc7
Merge pull request #2 from RichardoMrMu/main-ttft-fix
hcyezhang Jul 2, 2025
8afb03e
fix: merge error by accident
hcyezhang Aug 4, 2025
e0af39b
Merge pull request #3 from hcyezhang/main
RichardoMrMu Aug 4, 2025
a5462a1
Merge branch 'main' into fix_conflict
RichardoMrMu Aug 5, 2025
b5c27ed
Update vllm/v1/engine/async_llm.py
RichardoMrMu Aug 7, 2025
7b1de1c
Update vllm/v1/engine/output_processor.py
RichardoMrMu Aug 7, 2025
cdf0d9f
fix: gen meta directly from enginecorequest.sampling_params
hcyezhang Aug 7, 2025
4661667
Merge pull request #4 from hcyezhang/main
hcyezhang Aug 7, 2025
8e3887c
Update vllm/v1/engine/processor.py
RichardoMrMu Aug 15, 2025
3d65643
Merge branch 'main' into fix_conflict
RichardoMrMu Aug 15, 2025
6bea3fa
fix:pre-commit
Aug 15, 2025
1a5af39
Merge pull request #5 from RichardoMrMu/fix_conflict_2
RichardoMrMu Aug 15, 2025
4e623e3
Merge branch 'main' into fix_conflict
RichardoMrMu Aug 20, 2025
dd8c2a0
fix:pre-commit
Aug 20, 2025
47bea22
Merge branch 'main' into fix_conflict
RichardoMrMu Aug 20, 2025
33d736e
Merge branch 'main' into fix_conflict
RichardoMrMu Aug 20, 2025
6699296
remove v0 guard for tests
simon-mo Aug 24, 2025
c182529
Merge branch 'main' into fix_conflict
simon-mo Aug 24, 2025
86e4321
Merge branch 'main' into fix_conflict
RichardoMrMu Aug 25, 2025
516b954
change: test_tracing.py gpu_memory_utilization=0.3 to avoid oom
Aug 25, 2025
71012c0
test: timeout to 10
Aug 25, 2025
baa6b85
change: set env VLLM_USE_V1 1
Aug 26, 2025
05b2e69
test: set env VLLM_USE_V1 0
Sep 1, 2025
5f51aa1
Merge branch 'main' into fix_conflict
RichardoMrMu Sep 1, 2025
e1113e9
test: set env VLLM_USE_V1 1
Sep 1, 2025
81decbd
fix: tracing ut - tracer not initialized
hcyezhang Sep 2, 2025
b0f85e6
Merge branch 'fix_conflict' into main
hcyezhang Sep 2, 2025
38434cd
Merge pull request #6 from hcyezhang/main
RichardoMrMu Sep 2, 2025
eedf207
test:
Sep 2, 2025
28c0de7
Merge remote-tracking branch 'origin/fix_conflict' into fix_conflict
Sep 2, 2025
c255374
Merge branch 'main' into fix_conflict
RichardoMrMu Sep 2, 2025
73daf4d
test:disable_log_stats=False
Sep 2, 2025
0623cd7
test:format
Sep 2, 2025
57dbf9f
test:no model name
Sep 2, 2025
cdb9c48
add tracing ut for v1
ChrisYangAI Sep 9, 2025
2afc5bd
Merge pull request #7 from RichardoMrMu/chris_traceut_fix
RichardoMrMu Sep 9, 2025
daa13c8
reformat
ChrisYangAI Sep 9, 2025
bdb8847
reformat
ChrisYangAI Sep 9, 2025
8cf7c88
fix precommit error
ChrisYangAI Sep 9, 2025
a2b5346
fix precommit error
ChrisYangAI Sep 9, 2025
57c0df6
Merge pull request #8 from RichardoMrMu/chris_traceut_fix
RichardoMrMu Sep 9, 2025
27d6c69
[CI][Fix] deterministic seed for flaky CI runs on structured outputs …
aarnphm Sep 7, 2025
1d100d0
[CI/Build] Disable flaky test_structured_output tests (#24404)
22quinn Sep 8, 2025
a5c7f83
Merge pull request #9 from RichardoMrMu/fix_guidedcodeing_ut_failure
RichardoMrMu Sep 10, 2025
6370955
Merge branch 'main' into fix_conflict
ChrisYangAI Sep 10, 2025
bce28cc
fix trace test pipeline config
ChrisYangAI Sep 10, 2025
204a6b7
Merge pull request #10 from RichardoMrMu/fix_trace_test_config
RichardoMrMu Sep 10, 2025
e03076c
Merge branch 'main' into fix_conflict
ChrisYangAI Sep 10, 2025
23e74d3
Merge branch 'main' into fix_conflict
ChrisYangAI Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,12 +1440,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# No OTLP observability so far.
if (self.otlp_traces_endpoint or self.collect_detailed_traces):
_raise_or_fallback(feature_name="--otlp-traces-endpoint",
recommend_to_remove=False)
return False

# V1 supports N-gram, Medusa, and Eagle speculative decoding.
if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"):
Expand Down
5 changes: 5 additions & 0 deletions vllm/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class SpanAttributes:
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = (
"gen_ai.latency.time_in_model_execute")
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \
"gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \
"gen_ai.latency.time_in_model_inference"


def contains_trace_headers(headers: Mapping[str, str]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,9 +863,9 @@ def update_from_output(
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
))

else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import enum
import time
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union

import msgspec
Expand Down Expand Up @@ -69,6 +69,8 @@ class EngineCoreRequest(
current_wave: int = 0
priority: int = 0

trace_headers: Optional[Mapping[str, str]] = None


class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
Expand Down Expand Up @@ -114,6 +116,7 @@ class EngineCoreOutput(
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None

trace_headers: Optional[Mapping[str, str]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0

Expand Down Expand Up @@ -147,7 +150,7 @@ class EngineCoreOutputs(
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

#NOTE(Nick): We could consider ways to make this more compact,
# NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout

engine_index: int = 0
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(

self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats

Expand All @@ -118,6 +120,11 @@ def __init__(
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer

# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
Expand Down Expand Up @@ -569,7 +576,7 @@ async def get_tokenizer(
return self.tokenizer.get_lora_tokenizer(lora_request)

async def is_tracing_enabled(self) -> bool:
return False
return self.observability_config.otlp_traces_endpoint is not None

async def do_log_stats(
self,
Expand Down
70 changes: 63 additions & 7 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
Expand Down Expand Up @@ -71,7 +73,6 @@ def get_nowait(

@dataclass
class OutputProcessorOutput:

request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
reqs_to_abort: list[str]

Expand Down Expand Up @@ -274,16 +275,13 @@ def _new_pooling_output(
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""

def __init__(
self,
tokenizer: TokenizerGroup,
log_stats: bool,
):
def __init__(self, tokenizer: TokenizerGroup, log_stats: bool):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
self.tracer: Optional[Tracer] = None

def get_num_unfinished_requests(self):
return len(self.request_states)
Expand Down Expand Up @@ -441,14 +439,72 @@ def process_outputs(
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,
iteration_stats)

if self.tracer:
self.do_tracing(engine_core_output, req_state,
iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)

return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)

def do_tracing(self, engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: Optional[IterationStats]) -> None:
assert req_state.stats is not None
assert iteration_stats is not None
assert self.tracer is not None

arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
with (self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds) as span):
metrics = req_state.stats
e2e_time = iteration_stats.iteration_timestamp - \
metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
decode_time = metrics.last_token_ts - metrics.first_token_ts
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics.first_token_latency)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(req_state.prompt_token_ids))
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
prefill_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
decode_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
inference_time)

# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
req_state.request_id)
if req_state.parent_req and req_state.parent_req.sampling_params:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pooling model or params.n is 1, then the following attributes are missing ?

if is_pooling or params.n == 1:

Is it possible to add these attributes regardless of model type and params.n ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I added fields in RequestState to hold these sampling_params from EngineCoreRequest. or plz let me know if there's better way to do it :)

span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
req_state.parent_req.sampling_params.top_p)
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
req_state.parent_req.sampling_params.max_tokens)
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
req_state.parent_req.sampling_params.temperature)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
req_state.parent_req.sampling_params.n)

def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ def process_inputs(
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")

data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
Expand Down Expand Up @@ -345,6 +343,7 @@ def process_inputs(
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
)

def _validate_model_inputs(self,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class RequestStateStats:
first_token_ts: float = 0.0
last_token_ts: float = 0.0

# first token latency
first_token_latency: float = 0.0


@dataclass
class FinishedRequestStats:
Expand Down Expand Up @@ -116,6 +119,7 @@ def update_from_output(self, output: "EngineCoreOutput",

first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.first_token_latency = first_token_latency

req_stats.num_generation_tokens += num_new_generation_tokens

Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import enum
import time
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
trace_headers: Optional[Mapping[str, str]] = None,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
Expand Down Expand Up @@ -105,7 +107,8 @@ def __init__(
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)

# trace_headers
self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
Expand Down Expand Up @@ -150,6 +153,7 @@ def from_engine_core_request(
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
trace_headers=request.trace_headers,
block_hasher=block_hasher,
)

Expand Down