Skip to content

Commit eb399a9

Browse files
McPatateremi-or
andauthored
feat(cb): use context manager in generate_batch (#42190)
* feat(cb): use context manager in `generate_batch` * refactor(cb): group `with` stmts * refactor(cb): move log line before `stop` call Co-authored-by: Rémi Ouazan <[email protected]> * fix: lint --------- Co-authored-by: Rémi Ouazan <[email protected]>
1 parent 3ea7ecd commit eb399a9

File tree

1 file changed

+60
-41
lines changed

1 file changed

+60
-41
lines changed

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
from torch import nn
2828
from tqdm import tqdm
29+
from tqdm.contrib.logging import logging_redirect_tqdm
2930

3031
from ...configuration_utils import PretrainedConfig
3132
from ...generation.configuration_utils import GenerationConfig
@@ -809,6 +810,7 @@ def is_running(self) -> bool:
809810
"""Check if the background generation thread is running."""
810811
return self._generation_thread is not None and self._generation_thread.is_alive()
811812

813+
# NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition
812814
def stop(self, block: bool = True, timeout: float | None = None) -> None:
813815
"""Signal the background thread to stop.
814816
@@ -1063,14 +1065,35 @@ class ContinuousMixin:
10631065
"""Mixin class for models to add continuous batching capabilities."""
10641066

10651067
@contextmanager
1066-
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
1067-
manager = self.init_continuous_batching(**kwargs)
1068+
def continuous_batching_context_manager(
1069+
self,
1070+
generation_config: GenerationConfig | None = None,
1071+
manual_eviction: bool = False,
1072+
max_queue_size: int = 0,
1073+
num_q_cuda_graphs: int = 0,
1074+
num_kv_cuda_graphs: int = 0,
1075+
allow_prefix_sharing: bool = True,
1076+
block: bool = True,
1077+
timeout: float | None = None,
1078+
) -> Generator[ContinuousBatchingManager]:
1079+
manager = self.init_continuous_batching(
1080+
generation_config,
1081+
manual_eviction,
1082+
max_queue_size,
1083+
num_q_cuda_graphs,
1084+
num_kv_cuda_graphs,
1085+
allow_prefix_sharing,
1086+
)
10681087
manager.start()
10691088
try:
10701089
yield manager
10711090
finally:
1072-
manager.stop(block=True)
1091+
logger.debug(
1092+
"Continuous batching loop finished"
1093+
) # a dummy log needed for the logs of stop to show. Won't show
1094+
manager.stop(block=block, timeout=timeout)
10731095

1096+
# NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition
10741097
def init_continuous_batching(
10751098
self,
10761099
generation_config: GenerationConfig | None = None,
@@ -1149,45 +1172,41 @@ def generate_batch(
11491172
progress_bar = False
11501173

11511174
# Initialize manager with the batch inputs
1152-
manager = self.init_continuous_batching(
1153-
generation_config=generation_config,
1154-
num_q_cuda_graphs=num_q_cuda_graphs,
1155-
num_kv_cuda_graphs=num_kv_cuda_graphs,
1156-
allow_prefix_sharing=allow_prefix_sharing,
1157-
)
1158-
manager.start()
11591175
results = {}
11601176
num_requests = len(inputs)
1161-
try:
1162-
from tqdm.contrib.logging import logging_redirect_tqdm
1163-
1164-
with logging_redirect_tqdm([logger]):
1165-
with tqdm(
1166-
total=num_requests,
1167-
disable=(not progress_bar),
1168-
desc=f"Solving {num_requests} requests",
1169-
unit="request",
1170-
) as pbar:
1171-
manager.add_requests(
1172-
inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps
1173-
)
1174-
finished_count = 0
1175-
while finished_count < num_requests:
1176-
result = manager.get_result(timeout=1)
1177-
if result:
1178-
req_id = result.request_id
1179-
if result.is_finished():
1180-
results[req_id] = result
1181-
finished_count += 1
1182-
pbar.update(1)
1183-
else:
1184-
if not manager.is_running():
1185-
logger.error("Generation thread terminated unexpectedly.")
1186-
break
1177+
with (
1178+
self.continuous_batching_context_manager(
1179+
generation_config=generation_config,
1180+
num_q_cuda_graphs=num_q_cuda_graphs,
1181+
num_kv_cuda_graphs=num_kv_cuda_graphs,
1182+
allow_prefix_sharing=allow_prefix_sharing,
1183+
block=True,
1184+
timeout=5,
1185+
) as manager,
1186+
logging_redirect_tqdm([logger]),
1187+
tqdm(
1188+
total=num_requests,
1189+
disable=(not progress_bar),
1190+
desc=f"Solving {num_requests} requests",
1191+
unit="request",
1192+
) as pbar,
1193+
):
1194+
try:
1195+
manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"))
1196+
finished_count = 0
1197+
while finished_count < num_requests:
1198+
result = manager.get_result(timeout=1)
1199+
if result:
1200+
req_id = result.request_id
1201+
if result.is_finished():
1202+
results[req_id] = result
1203+
finished_count += 1
1204+
pbar.update(1)
1205+
else:
1206+
if not manager.is_running():
1207+
logger.error("Generation thread terminated unexpectedly.")
1208+
break
11871209

1188-
except Exception as e:
1189-
logger.error(f"Error during batch generation: {e}", exc_info=True)
1190-
finally:
1191-
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
1192-
manager.stop(block=True, timeout=5.0)
1210+
except Exception as e:
1211+
logger.error(f"Error during batch generation: {e}", exc_info=True)
11931212
return results

0 commit comments

Comments
 (0)