diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index d901baac82fc..5390806a9122 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -17,8 +17,10 @@ from cacheflow.worker.controller import DeviceID from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds app = FastAPI() + class FastAPIFrontend: def __init__( self, @@ -30,7 +32,7 @@ def __init__( dtype: str, seed: int, swap_space: int, - max_batch_size: int, + max_num_batched_tokens: int, num_nodes: int, num_devices_per_node: int, distributed_init_method: str, @@ -51,7 +53,7 @@ def __init__( dtype=dtype, seed=seed, swap_space=swap_space, - max_batch_size=max_batch_size, + max_num_batched_tokens=max_num_batched_tokens, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, @@ -68,12 +70,14 @@ async def server_step(self): self.is_server_running = True updated_seq_groups = await self.server.step.remote() self.is_server_running = False + # Notify the waiting coroutines that there new outputs ready. for seq_group in updated_seq_groups: group_id = seq_group.group_id self.running_seq_groups[group_id] = seq_group self.sequence_group_events[group_id].set() async def generate(self, request_dict: Dict): + # Preprocess the request. prompt = request_dict["prompt"] sampling_params = SamplingParams.from_dict(request_dict) sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id) @@ -87,15 +91,27 @@ async def generate(self, request_dict: Dict): arrival_time = time.time() group_id = next(self.seq_group_counter) seq_group = SequenceGroup(group_id, seqs, arrival_time) + # Create an event to notify us that there is new output from the + # cacheflow server. group_event = asyncio.Event() + self.running_seq_groups[group_id] = seq_group self.sequence_group_events[group_id] = group_event + # Add the request into the cacheflow server's waiting queue. await self.server.add_sequence_groups.remote([(seq_group, sampling_params)]) + # The cacheflow server does not have a background loop that keeps + # processing incoming requests. Therefore, we need to keep kicking + # the server to process the requests. while True: + # Kick the server if the server is not running. if not self.is_server_running: await self.server_step() - # Wait for new output. Add a 1s timeout to prevent dead lock. - await asyncio.wait_for(group_event.wait(), timeout=1) + # Wait for new output. The group_event will be set in server_step + # when there is new output available for the sequence group. + # Added a timeout to prevent deadlock. + await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK) + # Reset the event to wait for the next output. group_event.clear() + # Decode and return new outputs seq_group = self.running_seq_groups[group_id] all_outputs = [] for seq in seq_group.seqs: @@ -107,7 +123,16 @@ async def generate(self, request_dict: Dict): "error": 0, } yield (json.dumps(ret) + "\0").encode("utf-8") + + # Once finished, release the resources of the sequence group. if seq_group.is_finished(): + del self.running_seq_groups[group_id] + del self.sequence_group_events[group_id] + # Kick the server if the server is not running. This is to + # prevent that there are still requests in server's waiting + # queue to be executed. + if not self.is_server_running: + await self.server_step() break @@ -143,7 +168,7 @@ async def generate_stream(request: Request): dtype=args.dtype, seed=args.seed, swap_space=args.swap_space, - max_batch_size=args.max_batch_size, + max_num_batched_tokens=args.max_num_batched_tokens, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method,