@@ -236,7 +236,11 @@ def __init__(
236236 self .on_job_start = on_job_start
237237 self .on_job_end = on_job_end
238238 self .after_job_end = after_job_end
239- self .sem = asyncio .BoundedSemaphore (max_jobs )
239+
240+ self .max_jobs = max_jobs
241+ self .sem = asyncio .BoundedSemaphore (max_jobs + 1 )
242+ self .job_counter : int = 0
243+
240244 self .job_timeout_s = to_seconds (job_timeout )
241245 self .keep_result_s = to_seconds (keep_result )
242246 self .keep_result_forever = keep_result_forever
@@ -374,13 +378,13 @@ async def _poll_iteration(self) -> None:
374378 return
375379 count = min (burst_jobs_remaining , count )
376380 if self .allow_pick_jobs :
377- async with self .sem : # don't bother with zrangebyscore until we have "space" to run the jobs
381+ if self .job_counter < self . max_jobs :
378382 now = timestamp_ms ()
379383 job_ids = await self .pool .zrangebyscore (
380384 self .queue_name , min = float ('-inf' ), start = self ._queue_read_offset , num = count , max = now
381385 )
382386
383- await self .start_jobs (job_ids )
387+ await self .start_jobs (job_ids )
384388
385389 if self .allow_abort_jobs :
386390 await self ._cancel_aborted_jobs ()
@@ -419,12 +423,23 @@ async def _cancel_aborted_jobs(self) -> None:
419423 self .aborting_tasks .update (aborted )
420424 await self .pool .zrem (abort_jobs_ss , * aborted )
421425
426+ def _release_sem_dec_counter_on_complete (self ) -> None :
427+ self .job_counter = self .job_counter - 1
428+ self .sem .release ()
429+
422430 async def start_jobs (self , job_ids : List [bytes ]) -> None :
423431 """
424432 For each job id, get the job definition, check it's not running and start it in a task
425433 """
426434 for job_id_b in job_ids :
427435 await self .sem .acquire ()
436+
437+ if self .job_counter >= self .max_jobs :
438+ self .sem .release ()
439+ return None
440+
441+ self .job_counter = self .job_counter + 1
442+
428443 job_id = job_id_b .decode ()
429444 in_progress_key = in_progress_key_prefix + job_id
430445 async with self .pool .pipeline (transaction = True ) as pipe :
@@ -433,6 +448,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
433448 score = await pipe .zscore (self .queue_name , job_id )
434449 if ongoing_exists or not score :
435450 # job already started elsewhere, or already finished and removed from queue
451+ self .job_counter = self .job_counter - 1
436452 self .sem .release ()
437453 logger .debug ('job %s already running elsewhere' , job_id )
438454 continue
@@ -445,11 +461,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
445461 await pipe .execute ()
446462 except (ResponseError , WatchError ):
447463 # job already started elsewhere since we got 'existing'
464+ self .job_counter = self .job_counter - 1
448465 self .sem .release ()
449466 logger .debug ('multi-exec error, job %s already started elsewhere' , job_id )
450467 else :
451468 t = self .loop .create_task (self .run_job (job_id , int (score )))
452- t .add_done_callback (lambda _ : self .sem . release ())
469+ t .add_done_callback (lambda _ : self ._release_sem_dec_counter_on_complete ())
453470 self .tasks [job_id ] = t
454471
455472 async def run_job (self , job_id : str , score : int ) -> None : # noqa: C901
0 commit comments