Skip to content

Commit 867e1c0

Browse files
petitalbychiat35
authored andcommitted
Add OAR scheduler support
1 parent 0e3d33c commit 867e1c0

File tree

3 files changed

+253
-0
lines changed

3 files changed

+253
-0
lines changed

pydra/engine/tests/test_submitter.py

+80
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .utils import (
99
need_sge,
1010
need_slurm,
11+
need_oar,
1112
gen_basic_wf,
1213
gen_basic_wf_with_threadcount,
1314
gen_basic_wf_with_threadcount_concurrent,
@@ -573,6 +574,85 @@ def test_sge_no_limit_maxthreads(tmpdir):
573574
assert job_1_endtime > job_2_starttime
574575

575576

577+
@need_oar
578+
def test_oar_wf(tmpdir):
579+
wf = gen_basic_wf()
580+
wf.cache_dir = tmpdir
581+
# submit workflow and every task as oar job
582+
with Submitter("oar") as sub:
583+
sub(wf)
584+
585+
res = wf.result()
586+
assert res.output.out == 9
587+
script_dir = tmpdir / "OarWorker_scripts"
588+
assert script_dir.exists()
589+
# ensure each task was executed with oar
590+
assert len([sd for sd in script_dir.listdir() if sd.isdir()]) == 2
591+
592+
593+
@need_oar
594+
def test_oar_wf_cf(tmpdir):
595+
# submit entire workflow as single job executing with cf worker
596+
wf = gen_basic_wf()
597+
wf.cache_dir = tmpdir
598+
wf.plugin = "cf"
599+
with Submitter("oar") as sub:
600+
sub(wf)
601+
res = wf.result()
602+
assert res.output.out == 9
603+
script_dir = tmpdir / "OarWorker_scripts"
604+
assert script_dir.exists()
605+
# ensure only workflow was executed with oar
606+
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
607+
assert len(sdirs) == 1
608+
# oar scripts should be in the dirs that are using uid in the name
609+
assert sdirs[0].basename == wf.uid
610+
611+
612+
@need_oar
613+
def test_oar_wf_state(tmpdir):
614+
wf = gen_basic_wf()
615+
wf.split("x", x=[5, 6])
616+
wf.cache_dir = tmpdir
617+
with Submitter("oar") as sub:
618+
sub(wf)
619+
res = wf.result()
620+
assert res[0].output.out == 9
621+
assert res[1].output.out == 10
622+
script_dir = tmpdir / "OarWorker_scripts"
623+
assert script_dir.exists()
624+
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
625+
assert len(sdirs) == 2 * len(wf.inputs.x)
626+
627+
628+
@need_oar
629+
def test_oar_args_1(tmpdir):
630+
"""testing sbatch_args provided to the submitter"""
631+
task = sleep_add_one(x=1)
632+
task.cache_dir = tmpdir
633+
# submit workflow and every task as oar job
634+
with Submitter("oar", oarsub_args="-l nodes=2") as sub:
635+
sub(task)
636+
637+
res = task.result()
638+
assert res.output.out == 2
639+
script_dir = tmpdir / "OarWorker_scripts"
640+
assert script_dir.exists()
641+
642+
643+
@need_oar
644+
def test_oar_args_2(tmpdir):
645+
"""testing oarsub_args provided to the submitter
646+
exception should be raised for invalid options
647+
"""
648+
task = sleep_add_one(x=1)
649+
task.cache_dir = tmpdir
650+
# submit workflow and every task as oar job
651+
with pytest.raises(RuntimeError, match="Error returned from oarsub:"):
652+
with Submitter("oar", oarsub_args="-l nodes=2 --invalid") as sub:
653+
sub(task)
654+
655+
576656
# @pytest.mark.xfail(reason="Not sure")
577657
def test_wf_with_blocked_tasks(tmpdir):
578658
wf = Workflow(name="wf_with_blocked_tasks", input_spec=["x"])

pydra/engine/tests/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
not (bool(shutil.which("qsub")) and bool(shutil.which("qacct"))),
3131
reason="sge not available",
3232
)
33+
need_oar = pytest.mark.skipif(
34+
not (bool(shutil.which("oarsub")) and bool(shutil.which("oarstat"))),
35+
reason="oar not available",
36+
)
3337

3438

3539
def result_no_submitter(shell_task, plugin=None):

pydra/engine/workers.py

+169
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
import sys
44
import json
5+
import os
56
import re
67
from tempfile import gettempdir
78
from pathlib import Path
@@ -186,6 +187,173 @@ def close(self):
186187
self.pool.shutdown()
187188

188189

190+
class OarWorker(DistributedWorker):
191+
"""A worker to execute tasks on OAR systems."""
192+
193+
_cmd = "oarsub"
194+
195+
def __init__(self, loop=None, max_jobs=None, poll_delay=1, oarsub_args=None):
196+
"""
197+
Initialize OAR Worker.
198+
199+
Parameters
200+
----------
201+
poll_delay : seconds
202+
Delay between polls to oar
203+
oarsub_args : str
204+
Additional oarsub arguments
205+
max_jobs : int
206+
Maximum number of submitted jobs
207+
208+
"""
209+
super().__init__(loop=loop, max_jobs=max_jobs)
210+
if not poll_delay or poll_delay < 0:
211+
poll_delay = 0
212+
self.poll_delay = poll_delay
213+
self.oarsub_args = oarsub_args or ""
214+
self.error = {}
215+
216+
def run_el(self, runnable, rerun=False):
217+
"""Worker submission API."""
218+
script_dir, batch_script = self._prepare_runscripts(runnable, rerun=rerun)
219+
if (script_dir / script_dir.parts[1]) == gettempdir():
220+
logger.warning("Temporary directories may not be shared across computers")
221+
if isinstance(runnable, TaskBase):
222+
cache_dir = runnable.cache_dir
223+
name = runnable.name
224+
uid = runnable.uid
225+
else: # runnable is a tuple (ind, pkl file, task)
226+
cache_dir = runnable[-1].cache_dir
227+
name = runnable[-1].name
228+
uid = f"{runnable[-1].uid}_{runnable[0]}"
229+
230+
return self._submit_job(batch_script, name=name, uid=uid, cache_dir=cache_dir)
231+
232+
def _prepare_runscripts(self, task, interpreter="/bin/sh", rerun=False):
233+
if isinstance(task, TaskBase):
234+
cache_dir = task.cache_dir
235+
ind = None
236+
uid = task.uid
237+
else:
238+
ind = task[0]
239+
cache_dir = task[-1].cache_dir
240+
uid = f"{task[-1].uid}_{ind}"
241+
242+
script_dir = cache_dir / f"{self.__class__.__name__}_scripts" / uid
243+
script_dir.mkdir(parents=True, exist_ok=True)
244+
if ind is None:
245+
if not (script_dir / "_task.pkl").exists():
246+
save(script_dir, task=task)
247+
else:
248+
copyfile(task[1], script_dir / "_task.pklz")
249+
250+
task_pkl = script_dir / "_task.pklz"
251+
if not task_pkl.exists() or not task_pkl.stat().st_size:
252+
raise Exception("Missing or empty task!")
253+
254+
batchscript = script_dir / f"batchscript_{uid}.sh"
255+
python_string = (
256+
f"""'from pydra.engine.helpers import load_and_run; """
257+
f"""load_and_run(task_pkl="{task_pkl}", ind={ind}, rerun={rerun}) '"""
258+
)
259+
bcmd = "\n".join(
260+
(
261+
f"#!{interpreter}",
262+
f"{sys.executable} -c " + python_string,
263+
)
264+
)
265+
with batchscript.open("wt") as fp:
266+
fp.writelines(bcmd)
267+
os.chmod(batchscript, 0o544)
268+
return script_dir, batchscript
269+
270+
async def _submit_job(self, batchscript, name, uid, cache_dir):
271+
"""Coroutine that submits task runscript and polls job until completion or error."""
272+
script_dir = cache_dir / f"{self.__class__.__name__}_scripts" / uid
273+
sargs = self.oarsub_args.split()
274+
jobname = re.search(r"(?<=-n )\S+|(?<=--name=)\S+", self.oarsub_args)
275+
if not jobname:
276+
jobname = ".".join((name, uid))
277+
sargs.append(f"--name={jobname}")
278+
output = re.search(r"(?<=-O )\S+|(?<=--stdout=)\S+", self.oarsub_args)
279+
if not output:
280+
output_file = str(script_dir / "oar-%jobid%.out")
281+
sargs.append(f"--stdout={output_file}")
282+
error = re.search(r"(?<=-E )\S+|(?<=--stderr=)\S+", self.oarsub_args)
283+
if not error:
284+
error_file = str(script_dir / "oar-%jobid%.err")
285+
sargs.append(f"--stderr={error_file}")
286+
else:
287+
error_file = None
288+
sargs.append(str(batchscript))
289+
# TO CONSIDER: add random sleep to avoid overloading calls
290+
logger.debug(f"Submitting job {' '.join(sargs)}")
291+
rc, stdout, stderr = await read_and_display_async(
292+
self._cmd, *sargs, hide_display=True
293+
)
294+
jobid = re.search(r"OAR_JOB_ID=(\d+)", stdout)
295+
if rc:
296+
raise RuntimeError(f"Error returned from oarsub: {stderr}")
297+
elif not jobid:
298+
raise RuntimeError("Could not extract job ID")
299+
jobid = jobid.group(1)
300+
if error_file:
301+
error_file = error_file.replace("%jobid%", jobid)
302+
self.error[jobid] = error_file.replace("%jobid%", jobid)
303+
# intermittent polling
304+
while True:
305+
# 4 possibilities
306+
# False: job is still pending/working
307+
# Terminated: job is complete
308+
# Error + idempotent: job has been stopped and resubmited with another jobid
309+
# Error: Job failure
310+
done = await self._poll_job(jobid)
311+
if not done:
312+
await asyncio.sleep(self.poll_delay)
313+
elif done == "Terminated":
314+
return True
315+
elif done == "Error" and "idempotent" in self.oarsub_args:
316+
logger.debug(
317+
f"Job {jobid} has been stopped. Looking for its resubmission..."
318+
)
319+
# loading info about task with a specific uid
320+
info_file = cache_dir / f"{uid}_info.json"
321+
if info_file.exists():
322+
checksum = json.loads(info_file.read_text())["checksum"]
323+
if (cache_dir / f"{checksum}.lock").exists():
324+
# for pyt3.8 we could you missing_ok=True
325+
(cache_dir / f"{checksum}.lock").unlink()
326+
cmd_re = ("oarstat", "-J", "--sql", f"resubmit_job_id='{jobid}'")
327+
_, stdout, _ = await read_and_display_async(*cmd_re, hide_display=True)
328+
if not stdout:
329+
raise RuntimeError(
330+
"Job information about resubmission of job {jobid} not found"
331+
)
332+
jobid = next(iter(json.loads(stdout).keys()), None)
333+
else:
334+
error_file = self.error[jobid]
335+
error_line = Path(error_file).read_text().split("\n")[-2]
336+
if "Exception" in error_line:
337+
error_message = error_line.replace("Exception: ", "")
338+
elif "Error" in error_line:
339+
error_message = error_line.replace("Error: ", "")
340+
else:
341+
error_message = "Job failed (unknown reason - TODO)"
342+
raise Exception(error_message)
343+
return True
344+
345+
async def _poll_job(self, jobid):
346+
cmd = ("oarstat", "-J", "-s", "-j", jobid)
347+
logger.debug(f"Polling job {jobid}")
348+
_, stdout, _ = await read_and_display_async(*cmd, hide_display=True)
349+
if not stdout:
350+
raise RuntimeError("Job information not found")
351+
status = json.loads(stdout)[jobid]
352+
if status in ["Waiting", "Launching", "Running", "Finishing"]:
353+
return False
354+
return status
355+
356+
189357
class SlurmWorker(DistributedWorker):
190358
"""A worker to execute tasks on SLURM systems."""
191359

@@ -1042,6 +1210,7 @@ def close(self):
10421210
"slurm": SlurmWorker,
10431211
"dask": DaskWorker,
10441212
"sge": SGEWorker,
1213+
"oar": OarWorker,
10451214
**{
10461215
"psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype)
10471216
for subtype in ["local", "slurm"]

0 commit comments

Comments
 (0)