@@ -226,6 +226,9 @@ def __init__(
226226 model_description : Optional ["ModelDescription" ] = None ,
227227 request_limits : Optional [int ] = None ,
228228 xavier_config : Optional [Dict ] = None ,
229+ n_worker : Optional [int ] = 1 ,
230+ shard : Optional [int ] = 0 ,
231+ driver_info : Optional [dict ] = None , # for model across workers
229232 ):
230233 super ().__init__ ()
231234 from ..model .llm .lmdeploy .core import LMDeployModel
@@ -263,6 +266,10 @@ def __init__(
263266 "quantization" : self ._model_description .get ("quantization" , "none" ),
264267 }
265268 self ._loop : Optional [asyncio .AbstractEventLoop ] = None
269+ # model across workers
270+ self ._n_worker = n_worker
271+ self ._shard = shard
272+ self ._driver_info = driver_info
266273
267274 self ._scheduler_ref = None
268275 self ._text_to_image_scheduler_ref = None
@@ -455,6 +462,8 @@ async def load(self):
455462 i += 1
456463 try :
457464 self ._model .load ()
465+ if hasattr (self ._model , "driver_info" ):
466+ self ._driver_info = self ._model .driver_info
458467 break
459468 except Exception as e :
460469 if (
@@ -477,6 +486,10 @@ async def load(self):
477486 )
478487 logger .info (f"{ self } loaded" )
479488
489+ async def wait_for_load (self ):
490+ if hasattr (self ._model , "wait_for_load" ):
491+ self ._model .wait_for_load ()
492+
480493 def model_uid (self ):
481494 return (
482495 self ._model .model_uid
@@ -488,6 +501,12 @@ def model_uid(self):
488501 )
489502 )
490503
504+ def get_driver_info (self ):
505+ # driver info is used for model across workers,
506+ # the driver model actor(always the first worker)
507+ # will hold driver information includes dist store etc.
508+ return self ._driver_info
509+
491510 async def _handle_oom_error (self , ex ):
492511 error_message = (
493512 f"Model actor is out of memory, model id: { self .model_uid ()} , error: { ex } "
0 commit comments