33
44import ast
55import dataclasses
6+ import hashlib
67import os
78import pprint
89import time
2021from vllm .platforms import current_platform
2122from vllm .utils import is_torch_equal_or_newer , resolve_obj_by_qualname
2223
24+ from .caching import VllmSerializableFunction
2325from .compiler_interface import (CompilerInterface , EagerAdaptor ,
2426 InductorAdaptor , InductorStandaloneAdaptor )
2527from .counter import compilation_counter
@@ -160,6 +162,7 @@ def compile(self,
160162 # there can be multiple graphs due to piecewise compilation.
161163 now = time .time ()
162164 elapsed = now - compilation_start_time
165+ compilation_config .compilation_time += elapsed
163166 if runtime_shape is None :
164167 logger .info (
165168 "Directly load the compiled graph(s) for dynamic shape "
@@ -398,35 +401,6 @@ def set_model_tag(tag: str):
398401 model_tag = old_tag
399402
400403
401- try :
402- from torch ._dynamo .aot_compile import SerializableCallable
403- except ImportError :
404- SerializableCallable = object
405-
406- assert isinstance (SerializableCallable , type )
407-
408-
409- class VllmCompiledFunction (SerializableCallable ):
410-
411- def __init__ (self , graph_module , example_inputs , vllm_config ,
412- optimized_call ):
413- self .graph_module = graph_module
414- self .example_inputs = example_inputs
415- self .vllm_config = vllm_config
416- self .optimized_call = optimized_call
417-
418- def __call__ (self , * args , ** kwargs ):
419- return self .optimized_call (* args , ** kwargs )
420-
421- @classmethod
422- def serialize_compile_artifacts (cls , compiled_fn ):
423- raise NotImplementedError ("serialization not implemented" )
424-
425- @classmethod
426- def deserialize_compile_artifacts (cls , data ):
427- raise NotImplementedError ("deserialization not implemented" )
428-
429-
430404class VllmBackend :
431405 """The compilation backend for `torch.compile` with vLLM.
432406 It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -502,7 +476,11 @@ def configure_post_pass(self):
502476 self .post_grad_pass_manager .add (inductor_config [PASS_KEY ])
503477 inductor_config [PASS_KEY ] = self .post_grad_pass_manager
504478
505- def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
479+ def __call__ (self , graph : fx .GraphModule ,
480+ example_inputs ) -> VllmSerializableFunction :
481+
482+ from .caching import (_compute_code_hash ,
483+ compilation_config_hash_factors )
506484
507485 vllm_config = self .vllm_config
508486 if not self .compilation_config .cache_dir :
@@ -511,37 +489,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
511489 # the cache dir will be the same so that we can reuse the compiled
512490 # graph.
513491
514- factors = []
515- # 0. factors come from the env, for example, The values of
516- # VLLM_PP_LAYER_PARTITION will affect the computation graph.
517- env_hash = envs .compute_hash ()
518- factors .append (env_hash )
519-
520- # 1. factors come from the vllm_config (it mainly summarizes how the
521- # model is created)
522- config_hash = vllm_config .compute_hash ()
523- factors .append (config_hash )
524-
492+ factors = compilation_config_hash_factors (vllm_config )
525493 # 2. factors come from the code files that are traced by Dynamo (
526494 # it mainly summarizes how the model is used in forward pass)
527- forward_code_files = list (
528- sorted ( self .compilation_config .traced_files ) )
495+ code_hash = _compute_code_hash (
496+ self .compilation_config .traced_files )
529497 self .compilation_config .traced_files .clear ()
530- logger .debug (
531- "Traced files (to be considered for compilation cache):\n %s" ,
532- "\n " .join (forward_code_files ))
533- hash_content = []
534- for filepath in forward_code_files :
535- hash_content .append (filepath )
536- if filepath == "<string>" :
537- # This means the function was dynamically generated, with
538- # e.g. exec(). We can't actually check these.
539- continue
540- with open (filepath ) as f :
541- hash_content .append (f .read ())
542- import hashlib
543- code_hash = hashlib .md5 ("\n " .join (hash_content ).encode (),
544- usedforsecurity = False ).hexdigest ()
498+
545499 factors .append (code_hash )
546500
547501 # 3. compiler hash
@@ -634,8 +588,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
634588
635589 if self .compilation_config .cudagraph_mode == CUDAGraphMode .NONE or \
636590 not self .compilation_config .cudagraph_copy_inputs :
637- return VllmCompiledFunction (graph , example_inputs , vllm_config ,
638- self .split_gm )
591+ return VllmSerializableFunction (graph , example_inputs , self . prefix ,
592+ self .split_gm )
639593
640594 # if we need to copy input buffers for cudagraph
641595 from torch ._guards import detect_fake_mode
@@ -677,5 +631,5 @@ def copy_and_call(*args):
677631 list_args [index ] = static_tensor
678632 return self .split_gm (* list_args )
679633
680- return VllmCompiledFunction (graph , example_inputs , vllm_config ,
681- copy_and_call )
634+ return VllmSerializableFunction (graph , example_inputs , self . prefix ,
635+ copy_and_call )
0 commit comments