|
26 | 26 | import torch |
27 | 27 | from torch import nn |
28 | 28 | from tqdm import tqdm |
| 29 | +from tqdm.contrib.logging import logging_redirect_tqdm |
29 | 30 |
|
30 | 31 | from ...configuration_utils import PretrainedConfig |
31 | 32 | from ...generation.configuration_utils import GenerationConfig |
@@ -809,6 +810,7 @@ def is_running(self) -> bool: |
809 | 810 | """Check if the background generation thread is running.""" |
810 | 811 | return self._generation_thread is not None and self._generation_thread.is_alive() |
811 | 812 |
|
| 813 | + # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition |
812 | 814 | def stop(self, block: bool = True, timeout: float | None = None) -> None: |
813 | 815 | """Signal the background thread to stop. |
814 | 816 |
|
@@ -1063,14 +1065,35 @@ class ContinuousMixin: |
1063 | 1065 | """Mixin class for models to add continuous batching capabilities.""" |
1064 | 1066 |
|
1065 | 1067 | @contextmanager |
1066 | | - def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]: |
1067 | | - manager = self.init_continuous_batching(**kwargs) |
| 1068 | + def continuous_batching_context_manager( |
| 1069 | + self, |
| 1070 | + generation_config: GenerationConfig | None = None, |
| 1071 | + manual_eviction: bool = False, |
| 1072 | + max_queue_size: int = 0, |
| 1073 | + num_q_cuda_graphs: int = 0, |
| 1074 | + num_kv_cuda_graphs: int = 0, |
| 1075 | + allow_prefix_sharing: bool = True, |
| 1076 | + block: bool = True, |
| 1077 | + timeout: float | None = None, |
| 1078 | + ) -> Generator[ContinuousBatchingManager]: |
| 1079 | + manager = self.init_continuous_batching( |
| 1080 | + generation_config, |
| 1081 | + manual_eviction, |
| 1082 | + max_queue_size, |
| 1083 | + num_q_cuda_graphs, |
| 1084 | + num_kv_cuda_graphs, |
| 1085 | + allow_prefix_sharing, |
| 1086 | + ) |
1068 | 1087 | manager.start() |
1069 | 1088 | try: |
1070 | 1089 | yield manager |
1071 | 1090 | finally: |
1072 | | - manager.stop(block=True) |
| 1091 | + logger.debug( |
| 1092 | + "Continuous batching loop finished" |
| 1093 | + ) # a dummy log needed for the logs of stop to show. Won't show |
| 1094 | + manager.stop(block=block, timeout=timeout) |
1073 | 1095 |
|
| 1096 | + # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition |
1074 | 1097 | def init_continuous_batching( |
1075 | 1098 | self, |
1076 | 1099 | generation_config: GenerationConfig | None = None, |
@@ -1149,45 +1172,41 @@ def generate_batch( |
1149 | 1172 | progress_bar = False |
1150 | 1173 |
|
1151 | 1174 | # Initialize manager with the batch inputs |
1152 | | - manager = self.init_continuous_batching( |
1153 | | - generation_config=generation_config, |
1154 | | - num_q_cuda_graphs=num_q_cuda_graphs, |
1155 | | - num_kv_cuda_graphs=num_kv_cuda_graphs, |
1156 | | - allow_prefix_sharing=allow_prefix_sharing, |
1157 | | - ) |
1158 | | - manager.start() |
1159 | 1175 | results = {} |
1160 | 1176 | num_requests = len(inputs) |
1161 | | - try: |
1162 | | - from tqdm.contrib.logging import logging_redirect_tqdm |
1163 | | - |
1164 | | - with logging_redirect_tqdm([logger]): |
1165 | | - with tqdm( |
1166 | | - total=num_requests, |
1167 | | - disable=(not progress_bar), |
1168 | | - desc=f"Solving {num_requests} requests", |
1169 | | - unit="request", |
1170 | | - ) as pbar: |
1171 | | - manager.add_requests( |
1172 | | - inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps |
1173 | | - ) |
1174 | | - finished_count = 0 |
1175 | | - while finished_count < num_requests: |
1176 | | - result = manager.get_result(timeout=1) |
1177 | | - if result: |
1178 | | - req_id = result.request_id |
1179 | | - if result.is_finished(): |
1180 | | - results[req_id] = result |
1181 | | - finished_count += 1 |
1182 | | - pbar.update(1) |
1183 | | - else: |
1184 | | - if not manager.is_running(): |
1185 | | - logger.error("Generation thread terminated unexpectedly.") |
1186 | | - break |
| 1177 | + with ( |
| 1178 | + self.continuous_batching_context_manager( |
| 1179 | + generation_config=generation_config, |
| 1180 | + num_q_cuda_graphs=num_q_cuda_graphs, |
| 1181 | + num_kv_cuda_graphs=num_kv_cuda_graphs, |
| 1182 | + allow_prefix_sharing=allow_prefix_sharing, |
| 1183 | + block=True, |
| 1184 | + timeout=5, |
| 1185 | + ) as manager, |
| 1186 | + logging_redirect_tqdm([logger]), |
| 1187 | + tqdm( |
| 1188 | + total=num_requests, |
| 1189 | + disable=(not progress_bar), |
| 1190 | + desc=f"Solving {num_requests} requests", |
| 1191 | + unit="request", |
| 1192 | + ) as pbar, |
| 1193 | + ): |
| 1194 | + try: |
| 1195 | + manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens")) |
| 1196 | + finished_count = 0 |
| 1197 | + while finished_count < num_requests: |
| 1198 | + result = manager.get_result(timeout=1) |
| 1199 | + if result: |
| 1200 | + req_id = result.request_id |
| 1201 | + if result.is_finished(): |
| 1202 | + results[req_id] = result |
| 1203 | + finished_count += 1 |
| 1204 | + pbar.update(1) |
| 1205 | + else: |
| 1206 | + if not manager.is_running(): |
| 1207 | + logger.error("Generation thread terminated unexpectedly.") |
| 1208 | + break |
1187 | 1209 |
|
1188 | | - except Exception as e: |
1189 | | - logger.error(f"Error during batch generation: {e}", exc_info=True) |
1190 | | - finally: |
1191 | | - logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show. |
1192 | | - manager.stop(block=True, timeout=5.0) |
| 1210 | + except Exception as e: |
| 1211 | + logger.error(f"Error during batch generation: {e}", exc_info=True) |
1193 | 1212 | return results |
0 commit comments