|
| 1 | +import asyncio |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import json |
| 5 | +import re |
| 6 | +import typing as ty |
| 7 | +from tempfile import gettempdir |
| 8 | +from pathlib import Path |
| 9 | +from shutil import copyfile |
| 10 | +import logging |
| 11 | +import attrs |
| 12 | +from pydra.engine.job import Job, save |
| 13 | +from pydra.workers import base |
| 14 | + |
| 15 | + |
| 16 | +logger = logging.getLogger("pydra.worker") |
| 17 | + |
| 18 | +if ty.TYPE_CHECKING: |
| 19 | + from pydra.engine.result import Result |
| 20 | + |
| 21 | + |
| 22 | +@attrs.define |
| 23 | +class OarWorker(base.Worker): |
| 24 | + """A worker to execute tasks on OAR systems.""" |
| 25 | + |
| 26 | + _cmd = "oarsub" |
| 27 | + |
| 28 | + poll_delay: int = attrs.field(default=1, converter=base.ensure_non_negative) |
| 29 | + oarsub_args: str = "" |
| 30 | + error: dict[str, ty.Any] = attrs.field(factory=dict) |
| 31 | + |
| 32 | + def __getstate__(self) -> dict[str, ty.Any]: |
| 33 | + """Return state for pickling.""" |
| 34 | + state = super().__getstate__() |
| 35 | + del state["error"] |
| 36 | + return state |
| 37 | + |
| 38 | + def __setstate__(self, state: dict[str, ty.Any]): |
| 39 | + """Set state for unpickling.""" |
| 40 | + state["error"] = {} |
| 41 | + super().__setstate__(state) |
| 42 | + |
| 43 | + def _prepare_runscripts(self, job, interpreter="/bin/sh", rerun=False): |
| 44 | + if isinstance(job, Job): |
| 45 | + cache_root = job.cache_root |
| 46 | + ind = None |
| 47 | + uid = job.uid |
| 48 | + else: |
| 49 | + assert isinstance(job, tuple), f"Expecting a job or a tuple, not {job!r}" |
| 50 | + assert len(job) == 2, f"Expecting a tuple of length 2, not {job!r}" |
| 51 | + ind = job[0] |
| 52 | + cache_root = job[-1].cache_root |
| 53 | + uid = f"{job[-1].uid}_{ind}" |
| 54 | + |
| 55 | + script_dir = cache_root / f"{self.plugin_name()}_scripts" / uid |
| 56 | + script_dir.mkdir(parents=True, exist_ok=True) |
| 57 | + if ind is None: |
| 58 | + if not (script_dir / "_job.pklz").exists(): |
| 59 | + save(script_dir, job=job) |
| 60 | + else: |
| 61 | + copyfile(job[1], script_dir / "_job.pklz") |
| 62 | + |
| 63 | + job_pkl = script_dir / "_job.pklz" |
| 64 | + if not job_pkl.exists() or not job_pkl.stat().st_size: |
| 65 | + raise Exception("Missing or empty job!") |
| 66 | + |
| 67 | + batchscript = script_dir / f"batchscript_{uid}.sh" |
| 68 | + python_string = ( |
| 69 | + f"""'from pydra.engine.job import load_and_run; """ |
| 70 | + f"""load_and_run("{job_pkl}", rerun={rerun}) '""" |
| 71 | + ) |
| 72 | + bcmd = "\n".join( |
| 73 | + ( |
| 74 | + f"#!{interpreter}", |
| 75 | + f"{sys.executable} -c " + python_string, |
| 76 | + ) |
| 77 | + ) |
| 78 | + with batchscript.open("wt") as fp: |
| 79 | + fp.writelines(bcmd) |
| 80 | + os.chmod(batchscript, 0o544) |
| 81 | + return script_dir, batchscript |
| 82 | + |
| 83 | + async def run(self, job: "Job[base.TaskType]", rerun: bool = False) -> "Result": |
| 84 | + """Worker submission API.""" |
| 85 | + script_dir, batch_script = self._prepare_runscripts(job, rerun=rerun) |
| 86 | + if (script_dir / script_dir.parts[1]) == gettempdir(): |
| 87 | + logger.warning("Temporary directories may not be shared across computers") |
| 88 | + script_dir = job.cache_root / f"{self.plugin_name()}_scripts" / job.uid |
| 89 | + sargs = self.oarsub_args.split() |
| 90 | + jobname = re.search(r"(?<=-n )\S+|(?<=--name=)\S+", self.oarsub_args) |
| 91 | + if not jobname: |
| 92 | + jobname = ".".join((job.name, job.uid)) |
| 93 | + sargs.append(f"--name={jobname}") |
| 94 | + output = re.search(r"(?<=-O )\S+|(?<=--stdout=)\S+", self.oarsub_args) |
| 95 | + if not output: |
| 96 | + output_file = str(script_dir / "oar-%jobid%.out") |
| 97 | + sargs.append(f"--stdout={output_file}") |
| 98 | + error = re.search(r"(?<=-e )\S+|(?<=--error=)\S+", self.oarsub_args) |
| 99 | + if not error: |
| 100 | + error_file = str(script_dir / "oar-%jobid%.err") |
| 101 | + sargs.append(f"--stderr={error_file}") |
| 102 | + else: |
| 103 | + error_file = None |
| 104 | + sargs.append(str(batch_script)) |
| 105 | + # TO CONSIDER: add random sleep to avoid overloading calls |
| 106 | + rc, stdout, stderr = await base.read_and_display_async( |
| 107 | + self._cmd, *sargs, hide_display=True |
| 108 | + ) |
| 109 | + jobid = re.search(r"OAR_JOB_ID=(\d+)", stdout) |
| 110 | + if rc: |
| 111 | + raise RuntimeError(f"Error returned from oarsub: {stderr}") |
| 112 | + elif not jobid: |
| 113 | + raise RuntimeError("Could not extract job ID") |
| 114 | + jobid = jobid.group(1) |
| 115 | + if error_file: |
| 116 | + error_file = error_file.replace("%jobid%", jobid) |
| 117 | + self.error[jobid] = error_file.replace("%jobid%", jobid) |
| 118 | + # intermittent polling |
| 119 | + while True: |
| 120 | + # 4 possibilities |
| 121 | + # False: job is still pending/working |
| 122 | + # Terminated: job is complete |
| 123 | + # Error + idempotent: job has been stopped and resubmited with another jobid |
| 124 | + # Error: Job failure |
| 125 | + done = await self._poll_job(jobid) |
| 126 | + if not done: |
| 127 | + await asyncio.sleep(self.poll_delay) |
| 128 | + elif done == "Terminated": |
| 129 | + return True |
| 130 | + elif done == "Error" and "idempotent" in self.oarsub_args: |
| 131 | + logger.debug( |
| 132 | + f"Job {jobid} has been stopped. Looking for its resubmission..." |
| 133 | + ) |
| 134 | + # loading info about task with a specific uid |
| 135 | + info_file = job.cache_root / f"{job.uid}_info.json" |
| 136 | + if info_file.exists(): |
| 137 | + checksum = json.loads(info_file.read_text())["checksum"] |
| 138 | + if (job.cache_root / f"{checksum}.lock").exists(): |
| 139 | + # for pyt3.8 we could you missing_ok=True |
| 140 | + (job.cache_root / f"{checksum}.lock").unlink() |
| 141 | + cmd_re = ("oarstat", "-J", "--sql", f"resubmit_job_id='{jobid}'") |
| 142 | + _, stdout, _ = await base.read_and_display_async(*cmd_re, hide_display=True) |
| 143 | + if not stdout: |
| 144 | + raise RuntimeError( |
| 145 | + "Job information about resubmission of job {jobid} not found" |
| 146 | + ) |
| 147 | + jobid = next(iter(json.loads(stdout).keys()), None) |
| 148 | + else: |
| 149 | + error_file = self.error[jobid] |
| 150 | + error_line = Path(error_file).read_text().split("\n")[-2] |
| 151 | + if "Exception" in error_line: |
| 152 | + error_message = error_line.replace("Exception: ", "") |
| 153 | + elif "Error" in error_line: |
| 154 | + error_message = error_line.replace("Error: ", "") |
| 155 | + else: |
| 156 | + error_message = "Job failed (unknown reason - TODO)" |
| 157 | + raise Exception(error_message) |
| 158 | + return True |
| 159 | + |
| 160 | + async def _poll_job(self, jobid): |
| 161 | + cmd = ("oarstat", "-J", "-s", "-j", jobid) |
| 162 | + logger.debug(f"Polling job {jobid}") |
| 163 | + _, stdout, _ = await base.read_and_display_async(*cmd, hide_display=True) |
| 164 | + if not stdout: |
| 165 | + raise RuntimeError("Job information not found") |
| 166 | + status = json.loads(stdout)[jobid] |
| 167 | + if status in ["Waiting", "Launching", "Running", "Finishing"]: |
| 168 | + return False |
| 169 | + return status |
| 170 | + |
| 171 | + |
| 172 | +# Alias so it can be referred to as oar.Worker |
| 173 | +Worker = OarWorker |
0 commit comments