Skip to content

Commit ab2dda2

Browse files
authored
🔨 Adding a job counter to address Semaphore issues (#408)
* 🔨 Adding a job counter to address Semaphore issues * 🧪 Test function for semaphore blocker
1 parent 9109c2e commit ab2dda2

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

arq/worker.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,11 @@ def __init__(
236236
self.on_job_start = on_job_start
237237
self.on_job_end = on_job_end
238238
self.after_job_end = after_job_end
239-
self.sem = asyncio.BoundedSemaphore(max_jobs)
239+
240+
self.max_jobs = max_jobs
241+
self.sem = asyncio.BoundedSemaphore(max_jobs + 1)
242+
self.job_counter: int = 0
243+
240244
self.job_timeout_s = to_seconds(job_timeout)
241245
self.keep_result_s = to_seconds(keep_result)
242246
self.keep_result_forever = keep_result_forever
@@ -374,13 +378,13 @@ async def _poll_iteration(self) -> None:
374378
return
375379
count = min(burst_jobs_remaining, count)
376380
if self.allow_pick_jobs:
377-
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
381+
if self.job_counter < self.max_jobs:
378382
now = timestamp_ms()
379383
job_ids = await self.pool.zrangebyscore(
380384
self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now
381385
)
382386

383-
await self.start_jobs(job_ids)
387+
await self.start_jobs(job_ids)
384388

385389
if self.allow_abort_jobs:
386390
await self._cancel_aborted_jobs()
@@ -419,12 +423,23 @@ async def _cancel_aborted_jobs(self) -> None:
419423
self.aborting_tasks.update(aborted)
420424
await self.pool.zrem(abort_jobs_ss, *aborted)
421425

426+
def _release_sem_dec_counter_on_complete(self) -> None:
427+
self.job_counter = self.job_counter - 1
428+
self.sem.release()
429+
422430
async def start_jobs(self, job_ids: List[bytes]) -> None:
423431
"""
424432
For each job id, get the job definition, check it's not running and start it in a task
425433
"""
426434
for job_id_b in job_ids:
427435
await self.sem.acquire()
436+
437+
if self.job_counter >= self.max_jobs:
438+
self.sem.release()
439+
return None
440+
441+
self.job_counter = self.job_counter + 1
442+
428443
job_id = job_id_b.decode()
429444
in_progress_key = in_progress_key_prefix + job_id
430445
async with self.pool.pipeline(transaction=True) as pipe:
@@ -433,6 +448,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
433448
score = await pipe.zscore(self.queue_name, job_id)
434449
if ongoing_exists or not score:
435450
# job already started elsewhere, or already finished and removed from queue
451+
self.job_counter = self.job_counter - 1
436452
self.sem.release()
437453
logger.debug('job %s already running elsewhere', job_id)
438454
continue
@@ -445,11 +461,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
445461
await pipe.execute()
446462
except (ResponseError, WatchError):
447463
# job already started elsewhere since we got 'existing'
464+
self.job_counter = self.job_counter - 1
448465
self.sem.release()
449466
logger.debug('multi-exec error, job %s already started elsewhere', job_id)
450467
else:
451468
t = self.loop.create_task(self.run_job(job_id, int(score)))
452-
t.add_done_callback(lambda _: self.sem.release())
469+
t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete())
453470
self.tasks[job_id] = t
454471

455472
async def run_job(self, job_id: str, score: int) -> None: # noqa: C901

tests/test_worker.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,36 @@ async def test(ctx):
984984
assert result['called'] == 4
985985

986986

987+
async def test_job_cancel_on_max_jobs(arq_redis: ArqRedis, worker, caplog):
988+
async def longfunc(ctx):
989+
await asyncio.sleep(3600)
990+
991+
async def wait_and_abort(job, delay=0.1):
992+
await asyncio.sleep(delay)
993+
assert await job.abort() is True
994+
995+
caplog.set_level(logging.INFO)
996+
await arq_redis.zadd(abort_jobs_ss, {b'foobar': int(1e9)})
997+
job = await arq_redis.enqueue_job('longfunc', _job_id='testing')
998+
999+
worker: Worker = worker(
1000+
functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1, max_jobs=1
1001+
)
1002+
assert worker.jobs_complete == 0
1003+
assert worker.jobs_failed == 0
1004+
assert worker.jobs_retried == 0
1005+
await asyncio.gather(wait_and_abort(job), worker.main())
1006+
await worker.main()
1007+
assert worker.jobs_complete == 0
1008+
assert worker.jobs_failed == 1
1009+
assert worker.jobs_retried == 0
1010+
log = re.sub(r'\d+.\d\ds', 'X.XXs', '\n'.join(r.message for r in caplog.records))
1011+
assert 'X.XXs → testing:longfunc()\n X.XXs ⊘ testing:longfunc aborted' in log
1012+
assert worker.aborting_tasks == set()
1013+
assert worker.tasks == {}
1014+
assert worker.job_tasks == {}
1015+
1016+
9871017
async def test_worker_timezone_defaults_to_system_timezone(worker):
9881018
worker = worker(functions=[func(foobar)])
9891019
assert worker.timezone is not None

0 commit comments

Comments
 (0)