|
6 | 6 | import pprint |
7 | 7 | import time |
8 | 8 | from collections.abc import Sequence |
9 | | -from contextlib import ExitStack |
10 | 9 | from typing import Any, Callable, Optional |
11 | | -from unittest.mock import patch |
12 | 10 |
|
13 | 11 | import torch |
14 | 12 | import torch.fx as fx |
15 | 13 |
|
16 | 14 | import vllm.envs as envs |
17 | 15 | from vllm.config import CompilationConfig, VllmConfig |
18 | 16 | from vllm.logger import init_logger |
19 | | -from vllm.utils import weak_ref_tensors |
| 17 | +from vllm.platforms import current_platform |
| 18 | +from vllm.utils import resolve_obj_by_qualname |
20 | 19 |
|
21 | 20 | from .compiler_interface import (CompilerInterface, EagerAdaptor, |
22 | 21 | InductorAdaptor, InductorStandaloneAdaptor) |
23 | 22 | from .counter import compilation_counter |
24 | 23 | from .inductor_pass import InductorPass |
25 | | -from .monitor import end_monitoring_torch_compile |
26 | 24 | from .pass_manager import PostGradPassManager |
27 | 25 |
|
28 | 26 | logger = init_logger(__name__) |
@@ -297,7 +295,9 @@ def call_module(self, target: torch.fx.node.Target, |
297 | 295 | num_graphs=len(self.compile_submod_names), |
298 | 296 | runtime_shape=None) |
299 | 297 |
|
300 | | - self.module.__dict__[target] = PiecewiseBackend( |
| 298 | + piecewise_backend = resolve_obj_by_qualname( |
| 299 | + current_platform.get_piecewise_backend_cls()) |
| 300 | + self.module.__dict__[target] = piecewise_backend( |
301 | 301 | submod, self.vllm_config, self.graph_pool, index, |
302 | 302 | len(self.compile_submod_names), sym_shape_indices, |
303 | 303 | compiled_graph_for_general_shape, self.vllm_backend) |
@@ -341,7 +341,7 @@ def __init__( |
341 | 341 | ): |
342 | 342 | global global_graph_pool |
343 | 343 | if global_graph_pool is None: |
344 | | - global_graph_pool = torch.cuda.graph_pool_handle() |
| 344 | + global_graph_pool = current_platform.graph_pool_handle() |
345 | 345 |
|
346 | 346 | # TODO: in the future, if we want to use multiple |
347 | 347 | # streams, it might not be safe to share a global pool. |
@@ -558,197 +558,3 @@ def copy_and_call(*args): |
558 | 558 | return self.split_gm(*list_args) |
559 | 559 |
|
560 | 560 | return copy_and_call |
561 | | - |
562 | | - |
563 | | -@dataclasses.dataclass |
564 | | -class ConcreteSizeEntry: |
565 | | - runtime_shape: int |
566 | | - need_to_compile: bool # the size is in compile_sizes |
567 | | - use_cudagraph: bool # the size is in cudagraph_capture_sizes |
568 | | - |
569 | | - compiled: bool = False |
570 | | - runnable: Callable = None # type: ignore |
571 | | - num_finished_warmup: int = 0 |
572 | | - cudagraph: Optional[torch.cuda.CUDAGraph] = None |
573 | | - output: Optional[Any] = None |
574 | | - |
575 | | - # for cudagraph debugging, track the input addresses |
576 | | - # during capture, and check if they are the same during replay |
577 | | - input_addresses: Optional[list[int]] = None |
578 | | - |
579 | | - |
580 | | -class PiecewiseBackend: |
581 | | - |
582 | | - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, |
583 | | - graph_pool: Any, piecewise_compile_index: int, |
584 | | - total_piecewise_compiles: int, sym_shape_indices: list[int], |
585 | | - compiled_graph_for_general_shape: Callable, |
586 | | - vllm_backend: VllmBackend): |
587 | | - """ |
588 | | - The backend for piecewise compilation. |
589 | | - It mainly handles the compilation and cudagraph capturing. |
590 | | -
|
591 | | - We will compile `self.graph` once for the general shape, |
592 | | - and then compile for different shapes specified in |
593 | | - `compilation_config.compile_sizes`. |
594 | | -
|
595 | | - Independently, we will capture cudagraph for different shapes. |
596 | | -
|
597 | | - If a shape needs both compilation and cudagraph, we will |
598 | | - compile it first, and then capture cudagraph. |
599 | | - """ |
600 | | - self.graph = graph |
601 | | - self.vllm_config = vllm_config |
602 | | - self.compilation_config = vllm_config.compilation_config |
603 | | - self.graph_pool = graph_pool |
604 | | - self.piecewise_compile_index = piecewise_compile_index |
605 | | - self.total_piecewise_compiles = total_piecewise_compiles |
606 | | - self.vllm_backend = vllm_backend |
607 | | - |
608 | | - self.is_first_graph = piecewise_compile_index == 0 |
609 | | - self.is_last_graph = ( |
610 | | - piecewise_compile_index == total_piecewise_compiles - 1) |
611 | | - |
612 | | - self.compile_sizes: set[int] = set( |
613 | | - self.compilation_config.compile_sizes) |
614 | | - self.cudagraph_capture_sizes: set[int] = set( |
615 | | - self.compilation_config.cudagraph_capture_sizes |
616 | | - ) if self.compilation_config.use_cudagraph else set() |
617 | | - |
618 | | - self.first_run_finished = False |
619 | | - |
620 | | - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa |
621 | | - |
622 | | - self.sym_shape_indices = sym_shape_indices |
623 | | - |
624 | | - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" |
625 | | - |
626 | | - # the entries for different shapes that we need to either |
627 | | - # compile or capture cudagraph |
628 | | - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} |
629 | | - |
630 | | - # to_be_compiled_sizes tracks the remaining sizes to compile, |
631 | | - # and updates during the compilation process, so we need to copy it |
632 | | - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() |
633 | | - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): |
634 | | - self.concrete_size_entries[shape] = ConcreteSizeEntry( |
635 | | - runtime_shape=shape, |
636 | | - need_to_compile=shape in self.compile_sizes, |
637 | | - use_cudagraph=shape in self.cudagraph_capture_sizes, |
638 | | - ) |
639 | | - |
640 | | - def check_for_ending_compilation(self): |
641 | | - if self.is_last_graph and not self.to_be_compiled_sizes: |
642 | | - # no specific sizes to compile |
643 | | - # save the hash of the inductor graph for the next run |
644 | | - self.vllm_backend.compiler_manager.save_to_file() |
645 | | - end_monitoring_torch_compile(self.vllm_config) |
646 | | - |
647 | | - def __call__(self, *args) -> Any: |
648 | | - if not self.first_run_finished: |
649 | | - self.first_run_finished = True |
650 | | - self.check_for_ending_compilation() |
651 | | - return self.compiled_graph_for_general_shape(*args) |
652 | | - |
653 | | - runtime_shape = args[self.sym_shape_indices[0]] |
654 | | - if runtime_shape not in self.concrete_size_entries: |
655 | | - # we don't need to do anything for this shape |
656 | | - return self.compiled_graph_for_general_shape(*args) |
657 | | - |
658 | | - entry = self.concrete_size_entries[runtime_shape] |
659 | | - |
660 | | - if entry.runnable is None: |
661 | | - entry.runnable = self.compiled_graph_for_general_shape |
662 | | - |
663 | | - if entry.need_to_compile and not entry.compiled: |
664 | | - entry.compiled = True |
665 | | - self.to_be_compiled_sizes.remove(runtime_shape) |
666 | | - # args are real arguments |
667 | | - entry.runnable = self.vllm_backend.compiler_manager.compile( |
668 | | - self.graph, |
669 | | - args, |
670 | | - self.compilation_config.inductor_compile_config, |
671 | | - self.compilation_config, |
672 | | - graph_index=self.piecewise_compile_index, |
673 | | - num_graphs=self.total_piecewise_compiles, |
674 | | - runtime_shape=runtime_shape) |
675 | | - |
676 | | - # finished compilations for all required shapes |
677 | | - if self.is_last_graph and not self.to_be_compiled_sizes: |
678 | | - self.check_for_ending_compilation() |
679 | | - |
680 | | - if not entry.use_cudagraph: |
681 | | - return entry.runnable(*args) |
682 | | - |
683 | | - if entry.cudagraph is None: |
684 | | - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa |
685 | | - entry.num_finished_warmup += 1 |
686 | | - if self.is_first_graph: |
687 | | - logger.debug( |
688 | | - "Warming up %s/%s for shape %s", |
689 | | - entry.num_finished_warmup, |
690 | | - self.compilation_config.cudagraph_num_of_warmups, |
691 | | - runtime_shape) |
692 | | - return entry.runnable(*args) |
693 | | - |
694 | | - if self.is_first_graph: |
695 | | - # Since we capture cudagraph for many different shapes and |
696 | | - # capturing is fast, we don't need to log it for every shape. |
697 | | - # We only log it in the debug mode. |
698 | | - logger.debug("Capturing a cudagraph for shape %s", |
699 | | - runtime_shape) |
700 | | - |
701 | | - input_addresses = [ |
702 | | - x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
703 | | - ] |
704 | | - entry.input_addresses = input_addresses |
705 | | - cudagraph = torch.cuda.CUDAGraph() |
706 | | - |
707 | | - with ExitStack() as stack: |
708 | | - if not self.is_first_graph: |
709 | | - # during every model forward, we will capture |
710 | | - # many pieces of cudagraphs (roughly one per layer). |
711 | | - # running gc again and again across layers will |
712 | | - # make the cudagraph capture very slow. |
713 | | - # therefore, we only run gc for the first graph, |
714 | | - # and disable gc for the rest of the graphs. |
715 | | - stack.enter_context(patch("gc.collect", lambda: None)) |
716 | | - stack.enter_context( |
717 | | - patch("torch.cuda.empty_cache", lambda: None)) |
718 | | - |
719 | | - # mind-exploding: carefully manage the reference and memory. |
720 | | - with torch.cuda.graph(cudagraph, pool=self.graph_pool): |
721 | | - # `output` is managed by pytorch's cudagraph pool |
722 | | - output = entry.runnable(*args) |
723 | | - if self.is_last_graph: |
724 | | - # by converting it to weak ref, |
725 | | - # the original `output` will immediately be released |
726 | | - # to save memory. It is only safe to do this for |
727 | | - # the last graph, because the output of the last graph |
728 | | - # will not be used by any other cuda graph. |
729 | | - output = weak_ref_tensors(output) |
730 | | - |
731 | | - # here we always use weak ref for the output |
732 | | - # to save memory |
733 | | - entry.output = weak_ref_tensors(output) |
734 | | - entry.cudagraph = cudagraph |
735 | | - |
736 | | - compilation_counter.num_cudagraph_caputured += 1 |
737 | | - |
738 | | - # important: we need to return the output, rather than |
739 | | - # the weak ref of the output, so that pytorch can correctly |
740 | | - # manage the memory during cuda graph capture |
741 | | - return output |
742 | | - |
743 | | - if self.is_debugging_mode: |
744 | | - # check if the input addresses are the same |
745 | | - new_input_addresses = [ |
746 | | - x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
747 | | - ] |
748 | | - assert new_input_addresses == entry.input_addresses, ( |
749 | | - "Input addresses for cudagraphs are different during replay." |
750 | | - f" Expected {entry.input_addresses}, got {new_input_addresses}" |
751 | | - ) |
752 | | - |
753 | | - entry.cudagraph.replay() |
754 | | - return entry.output |
0 commit comments