Skip to content

Commit 8f86c1d

Browse files
authored
FEAT: support distributed inference for sglang (#2877)
1 parent ce8991a commit 8f86c1d

File tree

5 files changed

+371
-32
lines changed

5 files changed

+371
-32
lines changed

xinference/core/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def __init__(
226226
model_description: Optional["ModelDescription"] = None,
227227
request_limits: Optional[int] = None,
228228
xavier_config: Optional[Dict] = None,
229+
n_worker: Optional[int] = 1,
230+
shard: Optional[int] = 0,
231+
driver_info: Optional[dict] = None, # for model across workers
229232
):
230233
super().__init__()
231234
from ..model.llm.lmdeploy.core import LMDeployModel
@@ -263,6 +266,10 @@ def __init__(
263266
"quantization": self._model_description.get("quantization", "none"),
264267
}
265268
self._loop: Optional[asyncio.AbstractEventLoop] = None
269+
# model across workers
270+
self._n_worker = n_worker
271+
self._shard = shard
272+
self._driver_info = driver_info
266273

267274
self._scheduler_ref = None
268275
self._text_to_image_scheduler_ref = None
@@ -455,6 +462,8 @@ async def load(self):
455462
i += 1
456463
try:
457464
self._model.load()
465+
if hasattr(self._model, "driver_info"):
466+
self._driver_info = self._model.driver_info
458467
break
459468
except Exception as e:
460469
if (
@@ -477,6 +486,10 @@ async def load(self):
477486
)
478487
logger.info(f"{self} loaded")
479488

489+
async def wait_for_load(self):
490+
if hasattr(self._model, "wait_for_load"):
491+
self._model.wait_for_load()
492+
480493
def model_uid(self):
481494
return (
482495
self._model.model_uid
@@ -488,6 +501,12 @@ def model_uid(self):
488501
)
489502
)
490503

504+
def get_driver_info(self):
505+
# driver info is used for model across workers,
506+
# the driver model actor(always the first worker)
507+
# will hold driver information includes dist store etc.
508+
return self._driver_info
509+
491510
async def _handle_oom_error(self, ex):
492511
error_message = (
493512
f"Model actor is out of memory, model id: {self.model_uid()}, error: {ex}"

xinference/core/status_guard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class InstanceInfo(BaseModel):
3939
replica: int
4040
status: str
4141
instance_created_ts: int
42+
n_worker: Optional[int] = 1
4243

4344
def update(self, **kwargs):
4445
for field, value in kwargs.items():

0 commit comments

Comments
 (0)