From ae6e5ae295e1e9f4b2f1f1246384a6b9a03e9798 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 Jan 2025 20:32:47 +0800 Subject: [PATCH 1/9] add monitor Signed-off-by: youkaichao --- .../vllm_test_utils/__init__.py | 3 +- .../vllm_test_utils/monitor.py | 63 +++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 tests/vllm_test_utils/vllm_test_utils/monitor.py diff --git a/tests/vllm_test_utils/vllm_test_utils/__init__.py b/tests/vllm_test_utils/vllm_test_utils/__init__.py index bf0b62a5b75e..6505c81546bb 100644 --- a/tests/vllm_test_utils/vllm_test_utils/__init__.py +++ b/tests/vllm_test_utils/vllm_test_utils/__init__.py @@ -4,5 +4,6 @@ """ from .blame import BlameResult, blame +from .monitor import MonitoredValues, monitor -__all__ = ["blame", "BlameResult"] +__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"] diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py new file mode 100644 index 000000000000..2981e9b552ba --- /dev/null +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -0,0 +1,63 @@ +import contextlib +import dataclasses +import sys +import traceback +from typing import Any, Callable, Generator, List + + +@dataclasses.dataclass +class MonitoredValues: + values: List[Any] = dataclasses.field(default_factory=list) + trace_stacks: List[str] = dataclasses.field(default_factory=list) + + +@contextlib.contextmanager +def monitor( + measure_func: Callable[[], + Any]) -> Generator[MonitoredValues, None, None]: + """ + Trace the function calls to continuously monitor the change of + a value. + + Usage: + + ```python + + def measure_func(): + ... # measure the current value + return current_value + + with monitor(measure_func) as monitored_values: + # do something + + monitored_values.values # all changes of the values + monitored_values.trace_stacks # trace stacks of every change + """ + monitored_values = MonitoredValues() + + def _trace_calls(frame, event, arg=None): + nonlocal monitored_values + if event in ['call', 'return']: + # for every function call or return + try: + # Temporarily disable the trace function + sys.settrace(None) + # do a measurement + current_value = measure_func() + if len(monitored_values.values + ) == 0 or current_value != monitored_values.values[-1]: + monitored_values.values.append(current_value) + monitored_values.trace_stacks.append("".join( + traceback.format_stack())) + # Re-enable the trace function + sys.settrace(_trace_calls) + except NameError: + # modules are deleted during shutdown + pass + return _trace_calls + + try: + sys.settrace(_trace_calls) + yield monitored_values + finally: + sys.settrace(None) From 9a3c9648c3b0d70957abb18e6a6c15f65e24e260 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 Jan 2025 20:34:56 +0800 Subject: [PATCH 2/9] use reserved Signed-off-by: youkaichao --- vllm/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 63057153f851..7050f87dff45 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1742,10 +1742,8 @@ class MemorySnapshot: timestamp: float = 0.0 def measure(self): - self.torch_peak_in_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.peak"] - self.torch_memory_in_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] + self.torch_peak_in_bytes = torch.cuda.max_memory_reserved() + self.torch_memory_in_bytes = torch.cuda.memory_reserved() self.timestamp = time.time() def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": @@ -1827,6 +1825,8 @@ def memory_profiling( (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. """ # noqa + gc.collect() + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() result = MemoryProfilingResult() From fccb65fcceb6c1dcb687e12eb0350508d4da1e3e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 Jan 2025 20:40:27 +0800 Subject: [PATCH 3/9] add printing Signed-off-by: youkaichao --- tests/test_utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 32a6b0aed66a..0ea529cc4157 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,6 +11,7 @@ supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test +from vllm_test_utils import monitor @pytest.mark.asyncio @@ -289,8 +290,16 @@ def test_memory_profiling(): weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB + def measure_current_non_torch(): + free, total = torch.cuda.mem_get_info() + current_used = total - free + current_torch = torch.cuda.memory_reserved() + current_non_torch = current_used - current_torch + return current_non_torch + with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, - weights_memory_in_bytes=weights_memory_in_bytes) as result: + weights_memory_in_bytes=weights_memory_in_bytes) as result, \ + monitor(measure_current_non_torch) as monitored_values: # make a memory spike, 1 GiB spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) del spike @@ -298,6 +307,11 @@ def test_memory_profiling(): # Add some extra non-torch memory 256 MiB (simulate NCCL) handle2 = lib.cudaMalloc(256 * 1024 * 1024) + for value, stack in zip(monitored_values.values, \ + monitored_values.trace_stacks): + print(f"non_torch memory changed to {value / 1024 / 1024} MiB in") + print(stack) + # Check that the memory usage is within 5% of the expected values non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa From c83b3b9f567c44538b22068b1ab708bca85e0261 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 Jan 2025 21:31:15 +0800 Subject: [PATCH 4/9] finish Signed-off-by: youkaichao --- tests/test_utils.py | 13 ++++++++----- tests/vllm_test_utils/vllm_test_utils/monitor.py | 6 ++++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0ea529cc4157..5155c764cbf1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,13 +5,13 @@ import pytest import torch +from vllm_test_utils import monitor from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test -from vllm_test_utils import monitor @pytest.mark.asyncio @@ -307,12 +307,15 @@ def measure_current_non_torch(): # Add some extra non-torch memory 256 MiB (simulate NCCL) handle2 = lib.cudaMalloc(256 * 1024 * 1024) - for value, stack in zip(monitored_values.values, \ - monitored_values.trace_stacks): - print(f"non_torch memory changed to {value / 1024 / 1024} MiB in") - print(stack) + # this is an analytic value, it is exact, + # we only have 256 MiB non-torch memory increase + measured_diff =monitored_values.values[-1] - monitored_values.values[0] + assert measured_diff == 256 * 1024 * 1024 # Check that the memory usage is within 5% of the expected values + # 5% tolerance is caused by PyTorch caching allocator, + # we cannot control PyTorch's behavior of its internal buffers, + # which causes a small error (<10 MiB in practice) non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa assert abs(non_torch_ratio - 1) <= 0.05 diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index 2981e9b552ba..887dc394e450 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -37,8 +37,10 @@ def measure_func(): def _trace_calls(frame, event, arg=None): nonlocal monitored_values - if event in ['call', 'return']: - # for every function call or return + if event in ['line']: + # triggered by every line of Python code. + # only Python functions will trigger it, + # c/cpp functions will not trigger it. try: # Temporarily disable the trace function sys.settrace(None) From 9fc3db7b0ba4cb04a40108e2548cbe7200cedd56 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 Jan 2025 21:41:38 +0800 Subject: [PATCH 5/9] add comments Signed-off-by: youkaichao --- vllm/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 7050f87dff45..df6564087cb9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1743,6 +1743,8 @@ class MemorySnapshot: def measure(self): self.torch_peak_in_bytes = torch.cuda.max_memory_reserved() + # torch.cuda.memory_reserved() is how many bytes + # PyTorch gets from cuda (by calling cudaMalloc, etc.) self.torch_memory_in_bytes = torch.cuda.memory_reserved() self.timestamp = time.time() @@ -1820,10 +1822,10 @@ def memory_profiling( The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. - The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), - subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. + subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`. """ # noqa gc.collect() torch.cuda.empty_cache() From f07c38c7b9ddf2388be62881b5a55e4a1715c306 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 Jan 2025 22:01:27 +0800 Subject: [PATCH 6/9] restore Signed-off-by: youkaichao --- vllm/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index df6564087cb9..2660b53d7bfb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1827,8 +1827,6 @@ def memory_profiling( (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`. """ # noqa - gc.collect() - torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() result = MemoryProfilingResult() From 5f611957153a650d8ba20c926e4126c9b5642902 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 Jan 2025 10:30:38 +0800 Subject: [PATCH 7/9] Update tests/vllm_test_utils/vllm_test_utils/monitor.py Co-authored-by: Cyrus Leung --- .../vllm_test_utils/vllm_test_utils/monitor.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index 887dc394e450..0ef4c0c1ff93 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -2,19 +2,20 @@ import dataclasses import sys import traceback -from typing import Any, Callable, Generator, List +from typing import Any, Callable, Generator, Generic, TypeVar +_T = TypeVar("_T") @dataclasses.dataclass -class MonitoredValues: - values: List[Any] = dataclasses.field(default_factory=list) - trace_stacks: List[str] = dataclasses.field(default_factory=list) +class MonitoredValues(Generic[_T]): + values: list[_T] = dataclasses.field(default_factory=list) + trace_stacks: list[str] = dataclasses.field(default_factory=list) @contextlib.contextmanager def monitor( measure_func: Callable[[], - Any]) -> Generator[MonitoredValues, None, None]: + _T]) -> Generator[MonitoredValues[_T], None, None]: """ Trace the function calls to continuously monitor the change of a value. @@ -30,10 +31,11 @@ def measure_func(): with monitor(measure_func) as monitored_values: # do something - monitored_values.values # all changes of the values - monitored_values.trace_stacks # trace stacks of every change + monitored_values.values # all changes of the values + monitored_values.trace_stacks # trace stacks of every change + ``` """ - monitored_values = MonitoredValues() + monitored_values = MonitoredValues[_T]() def _trace_calls(frame, event, arg=None): nonlocal monitored_values From c86efad5aeabeb6fb75f3ac27e17e5dfd208d8e5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 Jan 2025 10:30:51 +0800 Subject: [PATCH 8/9] Update tests/test_utils.py Co-authored-by: Cyrus Leung --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5155c764cbf1..0285b00d73be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -309,7 +309,7 @@ def measure_current_non_torch(): # this is an analytic value, it is exact, # we only have 256 MiB non-torch memory increase - measured_diff =monitored_values.values[-1] - monitored_values.values[0] + measured_diff = monitored_values.values[-1] - monitored_values.values[0] assert measured_diff == 256 * 1024 * 1024 # Check that the memory usage is within 5% of the expected values From f6a6222ad04abf0e32a41fa7fd9b19156cee47f2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 Jan 2025 10:36:22 +0800 Subject: [PATCH 9/9] fix lint Signed-off-by: youkaichao --- tests/vllm_test_utils/vllm_test_utils/monitor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index 0ef4c0c1ff93..a237f53a75d1 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -2,10 +2,11 @@ import dataclasses import sys import traceback -from typing import Any, Callable, Generator, Generic, TypeVar +from typing import Callable, Generator, Generic, TypeVar _T = TypeVar("_T") + @dataclasses.dataclass class MonitoredValues(Generic[_T]): values: list[_T] = dataclasses.field(default_factory=list) @@ -14,8 +15,8 @@ class MonitoredValues(Generic[_T]): @contextlib.contextmanager def monitor( - measure_func: Callable[[], - _T]) -> Generator[MonitoredValues[_T], None, None]: + measure_func: Callable[[], + _T]) -> Generator[MonitoredValues[_T], None, None]: """ Trace the function calls to continuously monitor the change of a value.