|
| 1 | +""" |
| 2 | +Example usage: |
| 3 | +
|
| 4 | + `Map` and `Reduce` functions for Dask distributed client. |
| 5 | +
|
| 6 | + >>> from util._dask import dask_map, dask_reduce |
| 7 | +
|
| 8 | + # map and keep track of future <-> workitem mapping |
| 9 | + >>> futures, futurekey2item = dask_map( |
| 10 | + >>> Processor().process, |
| 11 | + >>> workitems, |
| 12 | + >>> client=client, |
| 13 | + >>> NanoEventsFactory_kwargs={ |
| 14 | + >>> "preload": lambda b: b.name in {'Jet_pt', 'Jet_eta'}, |
| 15 | + >>> "schemaclass": NtupleSchema, |
| 16 | + >>> } |
| 17 | + >>> ) |
| 18 | +
|
| 19 | + # perform reduction and track failures |
| 20 | + >>> final_future, failed_items = dask_reduce( |
| 21 | + >>> futures, |
| 22 | + >>> futurekey2item=futurekey2item, |
| 23 | + >>> client=client, |
| 24 | + >>> treereduction=16, |
| 25 | + >>> ) |
| 26 | +
|
| 27 | +See also: `if __name__ == "__main__":` block for a complete example. |
| 28 | +""" |
| 29 | + |
| 30 | +from __future__ import annotations |
| 31 | + |
| 32 | +import dataclasses |
| 33 | +import typing as tp |
| 34 | +from collections import Counter, defaultdict |
| 35 | +from functools import partial |
| 36 | + |
| 37 | +import awkward as ak |
| 38 | +import uproot |
| 39 | +from coffea.nanoevents import NanoEventsFactory |
| 40 | +from coffea.processor import Accumulatable, accumulate |
| 41 | +from coffea.processor.executor import WorkItem |
| 42 | +from coffea.util import coffea_console, rich_bar |
| 43 | +from dask.distributed import Client |
| 44 | +from dask.tokenize import tokenize |
| 45 | +from rich.console import Group |
| 46 | +from rich.live import Live |
| 47 | +from rich.progress import Progress |
| 48 | + |
| 49 | +from util._futures import DynamicAsCompleted, FutureLike |
| 50 | + |
| 51 | +_processing_sentinel = object() |
| 52 | +_final_merge_sentinel = object() |
| 53 | + |
| 54 | + |
| 55 | +# group of progress bars for dask/future executor |
| 56 | +def pbar_group(datasets: list[str]) -> tuple[Live, dict[tp.Any, Progress]]: |
| 57 | + pbars = {_processing_sentinel: rich_bar()} |
| 58 | + pbars.update({ds: rich_bar() for ds in datasets}) |
| 59 | + pbars[_final_merge_sentinel] = rich_bar() |
| 60 | + return Live(Group(*pbars.values()), console=coffea_console), pbars |
| 61 | + |
| 62 | + |
| 63 | +Result: tp.TypeAlias = Accumulatable | BaseException |
| 64 | + |
| 65 | + |
| 66 | +@dataclasses.dataclass(frozen=True, slots=True) |
| 67 | +class Failure: |
| 68 | + item: WorkItem |
| 69 | + reason: BaseException |
| 70 | + |
| 71 | + |
| 72 | +class ReduceSchedulingError(RuntimeError): ... |
| 73 | + |
| 74 | + |
| 75 | +acc = partial(accumulate, accum=None) |
| 76 | + |
| 77 | + |
| 78 | +def failed_future(future: FutureLike) -> bool: |
| 79 | + # if we return an exception as a value, consider it failed (see wrapped_process) |
| 80 | + # we catch any exception, but no RuntimeError. Maybe the user wants to raise that? |
| 81 | + return issubclass(future.type, BaseException) or future.status == "error" |
| 82 | + |
| 83 | + |
| 84 | +def wrapped_process( |
| 85 | + process_func: tp.Callable[[ak.Array], Result], |
| 86 | + workitem: WorkItem, |
| 87 | + /, |
| 88 | + *, |
| 89 | + NanoEventsFactory_kwargs: dict[str, tp.Any] | None = None, |
| 90 | +) -> Result: |
| 91 | + f = uproot.open(workitem.filename) |
| 92 | + if NanoEventsFactory_kwargs is None: |
| 93 | + NanoEventsFactory_kwargs = {} |
| 94 | + events = NanoEventsFactory.from_root( |
| 95 | + f, |
| 96 | + treepath=workitem.treename, |
| 97 | + mode="virtual", |
| 98 | + access_log=(access_log := []), |
| 99 | + entry_start=workitem.entrystart, |
| 100 | + entry_stop=workitem.entrystop, |
| 101 | + **NanoEventsFactory_kwargs, |
| 102 | + ).events() |
| 103 | + events.metadata.update(workitem.usermeta) |
| 104 | + try: |
| 105 | + out = process_func(events) |
| 106 | + except Exception as err: |
| 107 | + # return err as value, no metrics |
| 108 | + return err |
| 109 | + bytesread = f.file.source.num_requested_bytes |
| 110 | + report = { |
| 111 | + "bytesread": bytesread, |
| 112 | + "columns": access_log, |
| 113 | + "bytesread_per_chunk": { |
| 114 | + (workitem.filename, workitem.entrystart, workitem.entrystop): bytesread |
| 115 | + }, |
| 116 | + } |
| 117 | + return {"out": out, "report": report} |
| 118 | + |
| 119 | + |
| 120 | +def dask_map( |
| 121 | + process_func: tp.Callable[[ak.Array], Result], |
| 122 | + workitems: tp.Iterable[WorkItem], |
| 123 | + /, |
| 124 | + *, |
| 125 | + client: Client, |
| 126 | + NanoEventsFactory_kwargs: dict[str, tp.Any] | None = None, |
| 127 | +) -> tuple[list[FutureLike], dict[str, WorkItem]]: |
| 128 | + futures = client.map( |
| 129 | + partial( |
| 130 | + wrapped_process, |
| 131 | + process_func, |
| 132 | + NanoEventsFactory_kwargs=NanoEventsFactory_kwargs, |
| 133 | + ), |
| 134 | + workitems, |
| 135 | + pure=True, |
| 136 | + key="process", |
| 137 | + priority=0, |
| 138 | + ) |
| 139 | + return futures, {f.key: wi for f, wi in zip(futures, workitems)} |
| 140 | + |
| 141 | + |
| 142 | +def dask_reduce( |
| 143 | + futures: tp.Iterable[FutureLike], |
| 144 | + *, |
| 145 | + futurekey2item: tp.Mapping[str, WorkItem], |
| 146 | + client: Client, |
| 147 | + treereduction: int = 1 << 4, |
| 148 | +) -> tuple[FutureLike, defaultdict[list[Failure]]]: |
| 149 | + items = list(futurekey2item.values()) |
| 150 | + datasets = [it.dataset for it in items] |
| 151 | + unique_datasets = sorted(set(datasets)) |
| 152 | + |
| 153 | + live, pbars = pbar_group(unique_datasets) |
| 154 | + |
| 155 | + with live: |
| 156 | + # prepare some metadata for merging |
| 157 | + # dataset -> number of items to do |
| 158 | + ds2todo = Counter(datasets) |
| 159 | + # create a buffer for each dataset (what we merge) |
| 160 | + ds2buf = defaultdict(list) |
| 161 | + # future.key -> dataset item |
| 162 | + key2ds = {fk: wi.dataset for fk, wi in futurekey2item.items()} |
| 163 | + |
| 164 | + # initialize progress bars |
| 165 | + processing_task = pbars[_processing_sentinel].add_task( |
| 166 | + "Processing", total=len(futures), unit="chunk" |
| 167 | + ) |
| 168 | + dataset_merge_tasks = {} |
| 169 | + for ds in unique_datasets: |
| 170 | + total = ds2todo[ds] |
| 171 | + dataset_merge_tasks[ds] = pbars[ds].add_task( |
| 172 | + f"[cyan]Merging {total} [italic]{ds}[/italic] datasets into 1", |
| 173 | + total=total, |
| 174 | + unit="merge", |
| 175 | + ) |
| 176 | + |
| 177 | + failed_items: defaultdict[list[Failure]] = defaultdict(list) |
| 178 | + dynac = DynamicAsCompleted(futures) |
| 179 | + |
| 180 | + # in-dataset merging loop, we merge first within datasets to avoid large accumulators in memory |
| 181 | + # some reasonable value for the batch_size: |
| 182 | + # yield in batches of treereduction, and we want at least 1 item per batch |
| 183 | + batch_size = min( |
| 184 | + treereduction, max(int(len(futures) / 100), 1) |
| 185 | + ) # this is heuristic, can be tuned |
| 186 | + for batch in dynac.iter_batches(batch_size=batch_size): |
| 187 | + for future in batch: |
| 188 | + ds = key2ds[future.key] |
| 189 | + |
| 190 | + # subtract from todo |
| 191 | + if not future.key.startswith("accumulate-"): |
| 192 | + ds2todo[ds] -= 1 |
| 193 | + |
| 194 | + # get buffer |
| 195 | + buf = ds2buf[ds] |
| 196 | + |
| 197 | + if failed_future(future): |
| 198 | + # let merge failures raise right away |
| 199 | + if future.key.startswith("accumulate-"): |
| 200 | + raise future.exception() from None |
| 201 | + |
| 202 | + # just collect bad futures coming from the processing step, do not merge them |
| 203 | + reason = future.result() |
| 204 | + item = futurekey2item[future.key] |
| 205 | + failure = Failure(item=item, reason=reason) |
| 206 | + |
| 207 | + # append to failed items |
| 208 | + failed_items[ds].append(failure) |
| 209 | + |
| 210 | + # all failed for this dataset |
| 211 | + if len(buf) == 0 and ds2todo[ds] == 0: |
| 212 | + del ds2todo[ds] |
| 213 | + del ds2buf[ds] |
| 214 | + |
| 215 | + # nothing to merge, skip |
| 216 | + continue |
| 217 | + |
| 218 | + # update progress bars only for successful items |
| 219 | + if future.key.startswith("accumulate-"): |
| 220 | + # merging task |
| 221 | + pbars[ds].update(dataset_merge_tasks[ds], advance=1) |
| 222 | + else: |
| 223 | + pbars[_processing_sentinel].update(processing_task, advance=1) |
| 224 | + |
| 225 | + # add future to buffer for merging |
| 226 | + if future in buf: |
| 227 | + raise ReduceSchedulingError("Future already in buffer!") |
| 228 | + buf.append(future) |
| 229 | + |
| 230 | + # if this is the last item for this dataset, skip merging |
| 231 | + # as we schedule it for final cross-dataset merging |
| 232 | + if len(buf) == 1 and ds2todo[ds] == 0: |
| 233 | + continue |
| 234 | + |
| 235 | + # submit treereduction merge if we have enough items |
| 236 | + if len(buf) >= min(ds2todo[ds], treereduction) and len(buf) > 1: |
| 237 | + work = client.submit( |
| 238 | + acc, |
| 239 | + buf, |
| 240 | + key=f"accumulate-{tokenize(buf)}", |
| 241 | + priority=1, |
| 242 | + ) |
| 243 | + |
| 244 | + # release explicit retention |
| 245 | + for f in buf: |
| 246 | + f.release() |
| 247 | + |
| 248 | + # add merged item to key2item, just use the first one of the |
| 249 | + # buffer in order to access the dataset later again |
| 250 | + key2ds[work.key] = ds |
| 251 | + |
| 252 | + # reset buffer |
| 253 | + buf.clear() |
| 254 | + |
| 255 | + # add back to the ac, recursively merge |
| 256 | + dynac.add(work) |
| 257 | + |
| 258 | + del dynac |
| 259 | + |
| 260 | + # make sure there's only 1 future per dataset in the buffer for the final merge |
| 261 | + final_merge_futures = {} |
| 262 | + for ds, todo in ds2todo.items(): |
| 263 | + buf = ds2buf[ds] |
| 264 | + if todo != 0 or len(buf) != 1: |
| 265 | + msg = f"dataset {ds} has {len(buf)} items in merge-buffer (should only be 1); chunks left to merge: {todo}" |
| 266 | + raise ReduceSchedulingError(msg) |
| 267 | + pbars[ds].update(dataset_merge_tasks[ds], advance=1) |
| 268 | + final_merge_futures[ds] = buf[0] |
| 269 | + |
| 270 | + final_total = 0 |
| 271 | + final_merge_task = pbars[_final_merge_sentinel].add_task( |
| 272 | + f"[cyan]Merging {len(final_merge_futures)} merged datasets [italic](final)", |
| 273 | + total=total, |
| 274 | + unit="merge", |
| 275 | + ) |
| 276 | + |
| 277 | + # not needed anymore |
| 278 | + del ds2buf, ds2todo |
| 279 | + |
| 280 | + # final merge across datasets |
| 281 | + buf = [] |
| 282 | + |
| 283 | + dynac = DynamicAsCompleted(final_merge_futures.values()) |
| 284 | + for future in dynac: |
| 285 | + if failed_future(future): |
| 286 | + raise future.exception() |
| 287 | + |
| 288 | + if future not in final_merge_futures.values(): |
| 289 | + # final merge progress |
| 290 | + pbars[_final_merge_sentinel].update( |
| 291 | + final_merge_task, |
| 292 | + advance=1, |
| 293 | + total=final_total, |
| 294 | + ) |
| 295 | + |
| 296 | + buf.append(future) |
| 297 | + if len(buf) >= min(len(buf), treereduction) and len(buf) > 1: |
| 298 | + future = client.submit( |
| 299 | + acc, |
| 300 | + buf, |
| 301 | + key=f"accumulate-{tokenize(buf)}", |
| 302 | + priority=2, |
| 303 | + ) |
| 304 | + |
| 305 | + # release explicit retention |
| 306 | + for f in buf: |
| 307 | + f.release() |
| 308 | + |
| 309 | + # add merged item to key2item, just use the first one of the |
| 310 | + # buffer in order to access the dataset later again |
| 311 | + key2ds[future.key] = ds |
| 312 | + |
| 313 | + # reset buffer |
| 314 | + buf.clear() |
| 315 | + |
| 316 | + # add back to the ac, recursively merge |
| 317 | + dynac.add(future) |
| 318 | + |
| 319 | + # add one to the pbar |
| 320 | + final_total += 1 |
| 321 | + |
| 322 | + del dynac |
| 323 | + |
| 324 | + # final result |
| 325 | + assert len(buf) == 1 |
| 326 | + future = buf[0] |
| 327 | + return future, failed_items |
| 328 | + |
| 329 | + |
| 330 | +if __name__ == "__main__": |
| 331 | + # Run with: `python -m util._dask` |
| 332 | + from coffea.nanoevents import NanoAODSchema |
| 333 | + from dask.distributed import Client, LocalCluster |
| 334 | + |
| 335 | + workitems = [ |
| 336 | + WorkItem( |
| 337 | + filename="../coffea/tests/samples/nano_dy.root", |
| 338 | + treename="Events", |
| 339 | + entrystart=i * 10, |
| 340 | + entrystop=(i + 1) * 10, |
| 341 | + dataset="dy", |
| 342 | + usermeta={"dataset": "dy"}, |
| 343 | + fileuuid="1234abcd", |
| 344 | + ) |
| 345 | + for i in range(4) |
| 346 | + ] + [ |
| 347 | + WorkItem( |
| 348 | + filename="../coffea/tests/samples/nano_dimuon.root", |
| 349 | + treename="Events", |
| 350 | + entrystart=i * 10, |
| 351 | + entrystop=(i + 1) * 10, |
| 352 | + dataset="data", |
| 353 | + usermeta={"dataset": "data"}, |
| 354 | + fileuuid="5678efgh", |
| 355 | + ) |
| 356 | + for i in range(4) |
| 357 | + ] |
| 358 | + |
| 359 | + def process(events: ak.Array) -> ak.Array: |
| 360 | + import random |
| 361 | + |
| 362 | + if random.random() < 0.4: |
| 363 | + raise ValueError("Random failure during processing!") |
| 364 | + return ak.mean(events.Jet.pt) |
| 365 | + |
| 366 | + with ( |
| 367 | + LocalCluster(n_workers=4, threads_per_worker=1) as cluster, |
| 368 | + Client(cluster) as client, |
| 369 | + ): |
| 370 | + # map and keep track of future <-> workitem mapping |
| 371 | + futures, futurekey2item = dask_map( |
| 372 | + process, |
| 373 | + workitems, |
| 374 | + client=client, |
| 375 | + NanoEventsFactory_kwargs={ |
| 376 | + "preload": lambda b: b.name in {"Jet_pt", "Jet_eta"}, |
| 377 | + "schemaclass": NanoAODSchema, |
| 378 | + }, |
| 379 | + ) |
| 380 | + |
| 381 | + # perform reduction and track failures |
| 382 | + final_future, failed_items = dask_reduce( |
| 383 | + futures, |
| 384 | + futurekey2item=futurekey2item, |
| 385 | + client=client, |
| 386 | + treereduction=3, |
| 387 | + ) |
| 388 | + |
| 389 | + coffea_console.print("Failed items:", failed_items) |
| 390 | + result = final_future.result() |
| 391 | + coffea_console.print("Output:", result["out"]) |
| 392 | + coffea_console.print("Metrics:", result["report"]) |
0 commit comments