22
33import  torch 
44
5+ import  vllm .envs  as  envs 
56from  vllm .logger  import  init_logger 
67
78from  .interface  import  Platform , PlatformEnum , _Backend 
@@ -30,10 +31,16 @@ class TpuPlatform(Platform):
3031    def  get_attn_backend_cls (cls , selected_backend : _Backend , head_size : int ,
3132                             dtype : torch .dtype , kv_cache_dtype : Optional [str ],
3233                             block_size : int , use_v1 : bool ) ->  str :
33-         if  selected_backend  !=  _Backend .PALLAS :
34+         if  (selected_backend  !=  _Backend .PALLAS 
35+                 and  selected_backend  !=  _Backend .PALLAS_VLLM_V1 ):
3436            logger .info ("Cannot use %s backend on TPU." , selected_backend )
35-         logger .info ("Using Pallas backend." )
36-         return  "vllm.attention.backends.pallas.PallasAttentionBackend" 
37+ 
38+         if  use_v1 :
39+             logger .info ("Using Pallas V1 backend." )
40+             return  "vllm.v1.attention.backends.pallas.PallasAttentionBackend" 
41+         else :
42+             logger .info ("Using Pallas backend." )
43+             return  "vllm.attention.backends.pallas.PallasAttentionBackend" 
3744
3845    @classmethod  
3946    def  get_device_name (cls , device_id : int  =  0 ) ->  str :
@@ -45,7 +52,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
4552
4653    @classmethod  
4754    def  is_async_output_supported (cls , enforce_eager : Optional [bool ]) ->  bool :
48-         return  True 
55+         return  not   envs . VLLM_USE_V1 
4956
5057    @classmethod  
5158    def  inference_mode (cls ):
@@ -60,22 +67,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6067            cache_config .block_size  =  16 
6168
6269        compilation_config  =  vllm_config .compilation_config 
63-         if  compilation_config .level  ==  CompilationLevel .NO_COMPILATION :
64-             # TPU does not support NO_COMPILATION 
70+ 
71+         # TPU only supports DYNAMO_ONCE compilation level 
72+         if  compilation_config .level  !=  CompilationLevel .DYNAMO_ONCE :
73+             logger .info ("[TPU] Forcing DYNAMO_ONCE compilation level" )
6574            compilation_config .level  =  CompilationLevel .DYNAMO_ONCE 
66-         assert  compilation_config .level  <  CompilationLevel .PIECEWISE ,\
67-             "TPU does not support Inductor." 
6875
6976        if  compilation_config .backend  ==  "" :
7077            compilation_config .backend  =  "openxla" 
7178
7279        assert  vllm_config .speculative_config  is  None , \
7380            "TPU does not support speculative decoding" 
7481
75-         assert  not  vllm_config .scheduler_config .chunked_prefill_enabled , (
76-             "Chunked prefill is not yet supported for TPU backend" )
77-         assert  not  vllm_config .speculative_config , (
78-             "Speculative decoding is not yet supported for TPU backend" )
7982        if  vllm_config .model_config .dtype  in  (torch .float16 , torch .float32 ):
8083            logger .warning (
8184                "The TPU backend currently does not support %s. " 
@@ -85,8 +88,34 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8588        parallel_config  =  vllm_config .parallel_config 
8689        scheduler_config  =  vllm_config .scheduler_config 
8790        if  parallel_config .worker_cls  ==  "auto" :
88-             if  scheduler_config . is_multi_step :
91+             if  envs . VLLM_USE_V1 :
8992                parallel_config .worker_cls  =  \
90-                     "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker " 
93+                     "vllm.v1. worker.tpu_worker.TPUWorker " 
9194            else :
92-                 parallel_config .worker_cls  =  "vllm.worker.tpu_worker.TPUWorker" 
95+                 if  scheduler_config .is_multi_step :
96+                     parallel_config .worker_cls  =  \
97+                         "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" 
98+                 else :
99+                     parallel_config .worker_cls  =  \
100+                         "vllm.worker.tpu_worker.TPUWorker" 
101+ 
102+         # Adjust scheduler config for V1 
103+         # TODO: Add support for these 
104+         if  envs .VLLM_USE_V1 :
105+             if  vllm_config .cache_config .enable_prefix_caching :
106+                 logger .warning ("[V1][TPU] Disable prefix caching" )
107+                 vllm_config .cache_config .enable_prefix_caching  =  False 
108+ 
109+             if  vllm_config .scheduler_config .chunked_prefill_enabled :
110+                 logger .warning ("[V1][TPU] Disable chunked prefill" )
111+                 vllm_config .scheduler_config .chunked_prefill_enabled  =  False 
112+ 
113+         assert  not  vllm_config .scheduler_config .chunked_prefill_enabled , (
114+             "Chunked prefill is not yet supported for TPU backend" )
115+         assert  not  vllm_config .speculative_config , (
116+             "Speculative decoding is not yet supported for TPU backend" )
117+ 
118+     @classmethod  
119+     def  is_pin_memory_available (cls ):
120+         logger .warning ("Pin memory is not supported on TPU." )
121+         return  False 
0 commit comments