Skip to content

Commit 8fbd58b

Browse files
committed
Add support for OAR Scheduler (merged with the new structure 1.0a1)
1 parent 55c30c1 commit 8fbd58b

File tree

3 files changed

+251
-0
lines changed

3 files changed

+251
-0
lines changed

pydra/engine/tests/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
not (bool(shutil.which("qsub")) and bool(shutil.which("qacct"))),
3838
reason="sge not available",
3939
)
40+
need_oar = pytest.mark.skipif(
41+
not (bool(shutil.which("oarsub")) and bool(shutil.which("oarstat"))),
42+
reason="oar not available",
43+
)
4044

4145

4246
def num_python_cache_roots(cache_path: Path) -> int:

pydra/workers/oar.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import asyncio
2+
import os
3+
import sys
4+
import json
5+
import re
6+
import typing as ty
7+
from tempfile import gettempdir
8+
from pathlib import Path
9+
from shutil import copyfile
10+
import logging
11+
import attrs
12+
from pydra.engine.job import Job, save
13+
from pydra.workers import base
14+
15+
16+
logger = logging.getLogger("pydra.worker")
17+
18+
if ty.TYPE_CHECKING:
19+
from pydra.engine.result import Result
20+
21+
22+
@attrs.define
23+
class OarWorker(base.Worker):
24+
"""A worker to execute tasks on OAR systems."""
25+
26+
_cmd = "oarsub"
27+
28+
poll_delay: int = attrs.field(default=1, converter=base.ensure_non_negative)
29+
oarsub_args: str = ""
30+
error: dict[str, ty.Any] = attrs.field(factory=dict)
31+
32+
def __getstate__(self) -> dict[str, ty.Any]:
33+
"""Return state for pickling."""
34+
state = super().__getstate__()
35+
del state["error"]
36+
return state
37+
38+
def __setstate__(self, state: dict[str, ty.Any]):
39+
"""Set state for unpickling."""
40+
state["error"] = {}
41+
super().__setstate__(state)
42+
43+
def _prepare_runscripts(self, job, interpreter="/bin/sh", rerun=False):
44+
if isinstance(job, Job):
45+
cache_root = job.cache_root
46+
ind = None
47+
uid = job.uid
48+
else:
49+
assert isinstance(job, tuple), f"Expecting a job or a tuple, not {job!r}"
50+
assert len(job) == 2, f"Expecting a tuple of length 2, not {job!r}"
51+
ind = job[0]
52+
cache_root = job[-1].cache_root
53+
uid = f"{job[-1].uid}_{ind}"
54+
55+
script_dir = cache_root / f"{self.plugin_name()}_scripts" / uid
56+
script_dir.mkdir(parents=True, exist_ok=True)
57+
if ind is None:
58+
if not (script_dir / "_job.pklz").exists():
59+
save(script_dir, job=job)
60+
else:
61+
copyfile(job[1], script_dir / "_job.pklz")
62+
63+
job_pkl = script_dir / "_job.pklz"
64+
if not job_pkl.exists() or not job_pkl.stat().st_size:
65+
raise Exception("Missing or empty job!")
66+
67+
batchscript = script_dir / f"batchscript_{uid}.sh"
68+
python_string = (
69+
f"""'from pydra.engine.job import load_and_run; """
70+
f"""load_and_run("{job_pkl}", rerun={rerun}) '"""
71+
)
72+
bcmd = "\n".join(
73+
(
74+
f"#!{interpreter}",
75+
f"{sys.executable} -c " + python_string,
76+
)
77+
)
78+
with batchscript.open("wt") as fp:
79+
fp.writelines(bcmd)
80+
os.chmod(batchscript, 0o544)
81+
return script_dir, batchscript
82+
83+
async def run(self, job: "Job[base.TaskType]", rerun: bool = False) -> "Result":
84+
"""Worker submission API."""
85+
script_dir, batch_script = self._prepare_runscripts(job, rerun=rerun)
86+
if (script_dir / script_dir.parts[1]) == gettempdir():
87+
logger.warning("Temporary directories may not be shared across computers")
88+
script_dir = job.cache_root / f"{self.plugin_name()}_scripts" / job.uid
89+
sargs = self.oarsub_args.split()
90+
jobname = re.search(r"(?<=-n )\S+|(?<=--name=)\S+", self.oarsub_args)
91+
if not jobname:
92+
jobname = ".".join((job.name, job.uid))
93+
sargs.append(f"--name={jobname}")
94+
output = re.search(r"(?<=-O )\S+|(?<=--stdout=)\S+", self.oarsub_args)
95+
if not output:
96+
output_file = str(script_dir / "oar-%jobid%.out")
97+
sargs.append(f"--stdout={output_file}")
98+
error = re.search(r"(?<=-e )\S+|(?<=--error=)\S+", self.oarsub_args)
99+
if not error:
100+
error_file = str(script_dir / "oar-%jobid%.err")
101+
sargs.append(f"--stderr={error_file}")
102+
else:
103+
error_file = None
104+
sargs.append(str(batch_script))
105+
# TO CONSIDER: add random sleep to avoid overloading calls
106+
rc, stdout, stderr = await base.read_and_display_async(
107+
self._cmd, *sargs, hide_display=True
108+
)
109+
jobid = re.search(r"OAR_JOB_ID=(\d+)", stdout)
110+
if rc:
111+
raise RuntimeError(f"Error returned from oarsub: {stderr}")
112+
elif not jobid:
113+
raise RuntimeError("Could not extract job ID")
114+
jobid = jobid.group(1)
115+
if error_file:
116+
error_file = error_file.replace("%jobid%", jobid)
117+
self.error[jobid] = error_file.replace("%jobid%", jobid)
118+
# intermittent polling
119+
while True:
120+
# 4 possibilities
121+
# False: job is still pending/working
122+
# Terminated: job is complete
123+
# Error + idempotent: job has been stopped and resubmited with another jobid
124+
# Error: Job failure
125+
done = await self._poll_job(jobid)
126+
if not done:
127+
await asyncio.sleep(self.poll_delay)
128+
elif done == "Terminated":
129+
return True
130+
elif done == "Error" and "idempotent" in self.oarsub_args:
131+
logger.debug(
132+
f"Job {jobid} has been stopped. Looking for its resubmission..."
133+
)
134+
# loading info about task with a specific uid
135+
info_file = job.cache_root / f"{job.uid}_info.json"
136+
if info_file.exists():
137+
checksum = json.loads(info_file.read_text())["checksum"]
138+
if (job.cache_root / f"{checksum}.lock").exists():
139+
# for pyt3.8 we could you missing_ok=True
140+
(job.cache_root / f"{checksum}.lock").unlink()
141+
cmd_re = ("oarstat", "-J", "--sql", f"resubmit_job_id='{jobid}'")
142+
_, stdout, _ = await base.read_and_display_async(*cmd_re, hide_display=True)
143+
if not stdout:
144+
raise RuntimeError(
145+
"Job information about resubmission of job {jobid} not found"
146+
)
147+
jobid = next(iter(json.loads(stdout).keys()), None)
148+
else:
149+
error_file = self.error[jobid]
150+
error_line = Path(error_file).read_text().split("\n")[-2]
151+
if "Exception" in error_line:
152+
error_message = error_line.replace("Exception: ", "")
153+
elif "Error" in error_line:
154+
error_message = error_line.replace("Error: ", "")
155+
else:
156+
error_message = "Job failed (unknown reason - TODO)"
157+
raise Exception(error_message)
158+
return True
159+
160+
async def _poll_job(self, jobid):
161+
cmd = ("oarstat", "-J", "-s", "-j", jobid)
162+
logger.debug(f"Polling job {jobid}")
163+
_, stdout, _ = await base.read_and_display_async(*cmd, hide_display=True)
164+
if not stdout:
165+
raise RuntimeError("Job information not found")
166+
status = json.loads(stdout)[jobid]
167+
if status in ["Waiting", "Launching", "Running", "Finishing"]:
168+
return False
169+
return status
170+
171+
172+
# Alias so it can be referred to as oar.Worker
173+
Worker = OarWorker

pydra/workers/tests/test_worker.py

+74
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
need_sge,
2323
need_slurm,
2424
need_singularity,
25+
need_oar,
2526
BasicWorkflow,
2627
BasicWorkflowWithThreadCount,
2728
BasicWorkflowWithThreadCountConcurrent,
@@ -602,6 +603,79 @@ def test_sge_no_limit_maxthreads(tmpdir):
602603
assert job_1_endtime > job_2_starttime
603604

604605

606+
@need_oar
607+
def test_oar_wf(tmpdir):
608+
wf = BasicWorkflow(x=1)
609+
wf.cache_dir = tmpdir
610+
# submit workflow and every task as oar job
611+
with Submitter(worker="oar", cache_root=tmpdir) as sub:
612+
res = sub(wf)
613+
614+
outputs = res.outputs
615+
assert outputs.out == 9
616+
script_dir = tmpdir / "oar_scripts"
617+
assert script_dir.exists()
618+
# ensure each task was executed with oar
619+
assert len([sd for sd in script_dir.listdir() if sd.isdir()]) == 2
620+
621+
622+
@need_oar
623+
def test_oar_wf_cf(tmpdir):
624+
# submit entire workflow as single job executing with cf worker
625+
wf = BasicWorkflow(x=1)
626+
wf.plugin = "cf"
627+
with Submitter(worker="oar", cache_root=tmpdir) as sub:
628+
res = sub(wf)
629+
630+
outputs = res.outputs
631+
assert outputs.out == 9
632+
script_dir = tmpdir / "oar_scripts"
633+
assert script_dir.exists()
634+
# ensure only workflow was executed with oar
635+
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
636+
assert len(sdirs) == 1
637+
# oar scripts should be in the dirs that are using uid in the name
638+
assert sdirs[0].basename == wf.uid
639+
640+
641+
@need_oar
642+
def test_oar_wf_state(tmpdir):
643+
wf = BasicWorkflow().split(x=[5, 6])
644+
with Submitter(worker="oar", cache_root=tmpdir) as sub:
645+
res = sub(wf)
646+
647+
assert res.outputs.out == [9, 10]
648+
script_dir = tmpdir / "OarWorker_scripts"
649+
assert script_dir.exists()
650+
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
651+
assert len(sdirs) == 2 * len(wf.x)
652+
653+
654+
@need_oar
655+
def test_oar_args_1(tmpdir):
656+
"""testing sbatch_args provided to the submitter"""
657+
task = SleepAddOne(x=1)
658+
# submit workflow and every task as oar job
659+
with Submitter(worker="oar", cache_root=tmpdir, oarsub_args="-l nodes=2") as sub:
660+
res = sub(task)
661+
662+
assert res.outputs.out == 2
663+
script_dir = tmpdir / "oar_scripts"
664+
assert script_dir.exists()
665+
666+
667+
@need_oar
668+
def test_oar_args_2(tmpdir):
669+
"""testing oarsub_args provided to the submitter
670+
exception should be raised for invalid options
671+
"""
672+
task = SleepAddOne(x=1)
673+
# submit workflow and every task as oar job
674+
with pytest.raises(RuntimeError, match="Error returned from oarsub:"):
675+
with Submitter(worker="oar", cache_root=tmpdir, oarsub_args="-l nodes=2 --invalid") as sub:
676+
sub(task)
677+
678+
605679
def test_hash_changes_in_task_inputs_file(tmp_path):
606680
@python.define
607681
def cache_dir_as_input(out_dir: Directory) -> Directory:

0 commit comments

Comments
 (0)