Skip to content

Commit 5eb3b08

Browse files
pfackeldeyalexander-held
authored andcommitted
feat: custom dask workflow
1 parent 24b41b8 commit 5eb3b08

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed

util/__init__.py

Whitespace-only changes.

util/_dask.py

Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
"""
2+
Example usage:
3+
4+
`Map` and `Reduce` functions for Dask distributed client.
5+
6+
>>> from util._dask import dask_map, dask_reduce
7+
8+
# map and keep track of future <-> workitem mapping
9+
>>> futures, futurekey2item = dask_map(
10+
>>> Processor().process,
11+
>>> workitems,
12+
>>> client=client,
13+
>>> NanoEventsFactory_kwargs={
14+
>>> "preload": lambda b: b.name in {'Jet_pt', 'Jet_eta'},
15+
>>> "schemaclass": NtupleSchema,
16+
>>> }
17+
>>> )
18+
19+
# perform reduction and track failures
20+
>>> final_future, failed_items = dask_reduce(
21+
>>> futures,
22+
>>> futurekey2item=futurekey2item,
23+
>>> client=client,
24+
>>> treereduction=16,
25+
>>> )
26+
27+
See also: `if __name__ == "__main__":` block for a complete example.
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import dataclasses
33+
import typing as tp
34+
from collections import Counter, defaultdict
35+
from functools import partial
36+
37+
import awkward as ak
38+
import uproot
39+
from coffea.nanoevents import NanoEventsFactory
40+
from coffea.processor import Accumulatable, accumulate
41+
from coffea.processor.executor import WorkItem
42+
from coffea.util import coffea_console, rich_bar
43+
from dask.distributed import Client
44+
from dask.tokenize import tokenize
45+
from rich.console import Group
46+
from rich.live import Live
47+
from rich.progress import Progress
48+
49+
from util._futures import DynamicAsCompleted, FutureLike
50+
51+
_processing_sentinel = object()
52+
_final_merge_sentinel = object()
53+
54+
55+
# group of progress bars for dask/future executor
56+
def pbar_group(datasets: list[str]) -> tuple[Live, dict[tp.Any, Progress]]:
57+
pbars = {_processing_sentinel: rich_bar()}
58+
pbars.update({ds: rich_bar() for ds in datasets})
59+
pbars[_final_merge_sentinel] = rich_bar()
60+
return Live(Group(*pbars.values()), console=coffea_console), pbars
61+
62+
63+
Result: tp.TypeAlias = Accumulatable | BaseException
64+
65+
66+
@dataclasses.dataclass(frozen=True, slots=True)
67+
class Failure:
68+
item: WorkItem
69+
reason: BaseException
70+
71+
72+
class ReduceSchedulingError(RuntimeError): ...
73+
74+
75+
acc = partial(accumulate, accum=None)
76+
77+
78+
def failed_future(future: FutureLike) -> bool:
79+
# if we return an exception as a value, consider it failed (see wrapped_process)
80+
# we catch any exception, but no RuntimeError. Maybe the user wants to raise that?
81+
return issubclass(future.type, BaseException) or future.status == "error"
82+
83+
84+
def wrapped_process(
85+
process_func: tp.Callable[[ak.Array], Result],
86+
workitem: WorkItem,
87+
/,
88+
*,
89+
NanoEventsFactory_kwargs: dict[str, tp.Any] | None = None,
90+
) -> Result:
91+
f = uproot.open(workitem.filename)
92+
if NanoEventsFactory_kwargs is None:
93+
NanoEventsFactory_kwargs = {}
94+
events = NanoEventsFactory.from_root(
95+
f,
96+
treepath=workitem.treename,
97+
mode="virtual",
98+
access_log=(access_log := []),
99+
entry_start=workitem.entrystart,
100+
entry_stop=workitem.entrystop,
101+
**NanoEventsFactory_kwargs,
102+
).events()
103+
events.metadata.update(workitem.usermeta)
104+
try:
105+
out = process_func(events)
106+
except Exception as err:
107+
# return err as value, no metrics
108+
return err
109+
bytesread = f.file.source.num_requested_bytes
110+
report = {
111+
"bytesread": bytesread,
112+
"columns": access_log,
113+
"bytesread_per_chunk": {
114+
(workitem.filename, workitem.entrystart, workitem.entrystop): bytesread
115+
},
116+
}
117+
return {"out": out, "report": report}
118+
119+
120+
def dask_map(
121+
process_func: tp.Callable[[ak.Array], Result],
122+
workitems: tp.Iterable[WorkItem],
123+
/,
124+
*,
125+
client: Client,
126+
NanoEventsFactory_kwargs: dict[str, tp.Any] | None = None,
127+
) -> tuple[list[FutureLike], dict[str, WorkItem]]:
128+
futures = client.map(
129+
partial(
130+
wrapped_process,
131+
process_func,
132+
NanoEventsFactory_kwargs=NanoEventsFactory_kwargs,
133+
),
134+
workitems,
135+
pure=True,
136+
key="process",
137+
priority=0,
138+
)
139+
return futures, {f.key: wi for f, wi in zip(futures, workitems)}
140+
141+
142+
def dask_reduce(
143+
futures: tp.Iterable[FutureLike],
144+
*,
145+
futurekey2item: tp.Mapping[str, WorkItem],
146+
client: Client,
147+
treereduction: int = 1 << 4,
148+
) -> tuple[FutureLike, defaultdict[list[Failure]]]:
149+
items = list(futurekey2item.values())
150+
datasets = [it.dataset for it in items]
151+
unique_datasets = sorted(set(datasets))
152+
153+
live, pbars = pbar_group(unique_datasets)
154+
155+
with live:
156+
# prepare some metadata for merging
157+
# dataset -> number of items to do
158+
ds2todo = Counter(datasets)
159+
# create a buffer for each dataset (what we merge)
160+
ds2buf = defaultdict(list)
161+
# future.key -> dataset item
162+
key2ds = {fk: wi.dataset for fk, wi in futurekey2item.items()}
163+
164+
# initialize progress bars
165+
processing_task = pbars[_processing_sentinel].add_task(
166+
"Processing", total=len(futures), unit="chunk"
167+
)
168+
dataset_merge_tasks = {}
169+
for ds in unique_datasets:
170+
total = ds2todo[ds]
171+
dataset_merge_tasks[ds] = pbars[ds].add_task(
172+
f"[cyan]Merging {total} [italic]{ds}[/italic] datasets into 1",
173+
total=total,
174+
unit="merge",
175+
)
176+
177+
failed_items: defaultdict[list[Failure]] = defaultdict(list)
178+
dynac = DynamicAsCompleted(futures)
179+
180+
# in-dataset merging loop, we merge first within datasets to avoid large accumulators in memory
181+
# some reasonable value for the batch_size:
182+
# yield in batches of treereduction, and we want at least 1 item per batch
183+
batch_size = min(
184+
treereduction, max(int(len(futures) / 100), 1)
185+
) # this is heuristic, can be tuned
186+
for batch in dynac.iter_batches(batch_size=batch_size):
187+
for future in batch:
188+
ds = key2ds[future.key]
189+
190+
# subtract from todo
191+
if not future.key.startswith("accumulate-"):
192+
ds2todo[ds] -= 1
193+
194+
# get buffer
195+
buf = ds2buf[ds]
196+
197+
if failed_future(future):
198+
# let merge failures raise right away
199+
if future.key.startswith("accumulate-"):
200+
raise future.exception() from None
201+
202+
# just collect bad futures coming from the processing step, do not merge them
203+
reason = future.result()
204+
item = futurekey2item[future.key]
205+
failure = Failure(item=item, reason=reason)
206+
207+
# append to failed items
208+
failed_items[ds].append(failure)
209+
210+
# all failed for this dataset
211+
if len(buf) == 0 and ds2todo[ds] == 0:
212+
del ds2todo[ds]
213+
del ds2buf[ds]
214+
215+
# nothing to merge, skip
216+
continue
217+
218+
# update progress bars only for successful items
219+
if future.key.startswith("accumulate-"):
220+
# merging task
221+
pbars[ds].update(dataset_merge_tasks[ds], advance=1)
222+
else:
223+
pbars[_processing_sentinel].update(processing_task, advance=1)
224+
225+
# add future to buffer for merging
226+
if future in buf:
227+
raise ReduceSchedulingError("Future already in buffer!")
228+
buf.append(future)
229+
230+
# if this is the last item for this dataset, skip merging
231+
# as we schedule it for final cross-dataset merging
232+
if len(buf) == 1 and ds2todo[ds] == 0:
233+
continue
234+
235+
# submit treereduction merge if we have enough items
236+
if len(buf) >= min(ds2todo[ds], treereduction) and len(buf) > 1:
237+
work = client.submit(
238+
acc,
239+
buf,
240+
key=f"accumulate-{tokenize(buf)}",
241+
priority=1,
242+
)
243+
244+
# release explicit retention
245+
for f in buf:
246+
f.release()
247+
248+
# add merged item to key2item, just use the first one of the
249+
# buffer in order to access the dataset later again
250+
key2ds[work.key] = ds
251+
252+
# reset buffer
253+
buf.clear()
254+
255+
# add back to the ac, recursively merge
256+
dynac.add(work)
257+
258+
del dynac
259+
260+
# make sure there's only 1 future per dataset in the buffer for the final merge
261+
final_merge_futures = {}
262+
for ds, todo in ds2todo.items():
263+
buf = ds2buf[ds]
264+
if todo != 0 or len(buf) != 1:
265+
msg = f"dataset {ds} has {len(buf)} items in merge-buffer (should only be 1); chunks left to merge: {todo}"
266+
raise ReduceSchedulingError(msg)
267+
pbars[ds].update(dataset_merge_tasks[ds], advance=1)
268+
final_merge_futures[ds] = buf[0]
269+
270+
final_total = 0
271+
final_merge_task = pbars[_final_merge_sentinel].add_task(
272+
f"[cyan]Merging {len(final_merge_futures)} merged datasets [italic](final)",
273+
total=total,
274+
unit="merge",
275+
)
276+
277+
# not needed anymore
278+
del ds2buf, ds2todo
279+
280+
# final merge across datasets
281+
buf = []
282+
283+
dynac = DynamicAsCompleted(final_merge_futures.values())
284+
for future in dynac:
285+
if failed_future(future):
286+
raise future.exception()
287+
288+
if future not in final_merge_futures.values():
289+
# final merge progress
290+
pbars[_final_merge_sentinel].update(
291+
final_merge_task,
292+
advance=1,
293+
total=final_total,
294+
)
295+
296+
buf.append(future)
297+
if len(buf) >= min(len(buf), treereduction) and len(buf) > 1:
298+
future = client.submit(
299+
acc,
300+
buf,
301+
key=f"accumulate-{tokenize(buf)}",
302+
priority=2,
303+
)
304+
305+
# release explicit retention
306+
for f in buf:
307+
f.release()
308+
309+
# add merged item to key2item, just use the first one of the
310+
# buffer in order to access the dataset later again
311+
key2ds[future.key] = ds
312+
313+
# reset buffer
314+
buf.clear()
315+
316+
# add back to the ac, recursively merge
317+
dynac.add(future)
318+
319+
# add one to the pbar
320+
final_total += 1
321+
322+
del dynac
323+
324+
# final result
325+
assert len(buf) == 1
326+
future = buf[0]
327+
return future, failed_items
328+
329+
330+
if __name__ == "__main__":
331+
# Run with: `python -m util._dask`
332+
from coffea.nanoevents import NanoAODSchema
333+
from dask.distributed import Client, LocalCluster
334+
335+
workitems = [
336+
WorkItem(
337+
filename="../coffea/tests/samples/nano_dy.root",
338+
treename="Events",
339+
entrystart=i * 10,
340+
entrystop=(i + 1) * 10,
341+
dataset="dy",
342+
usermeta={"dataset": "dy"},
343+
fileuuid="1234abcd",
344+
)
345+
for i in range(4)
346+
] + [
347+
WorkItem(
348+
filename="../coffea/tests/samples/nano_dimuon.root",
349+
treename="Events",
350+
entrystart=i * 10,
351+
entrystop=(i + 1) * 10,
352+
dataset="data",
353+
usermeta={"dataset": "data"},
354+
fileuuid="5678efgh",
355+
)
356+
for i in range(4)
357+
]
358+
359+
def process(events: ak.Array) -> ak.Array:
360+
import random
361+
362+
if random.random() < 0.4:
363+
raise ValueError("Random failure during processing!")
364+
return ak.mean(events.Jet.pt)
365+
366+
with (
367+
LocalCluster(n_workers=4, threads_per_worker=1) as cluster,
368+
Client(cluster) as client,
369+
):
370+
# map and keep track of future <-> workitem mapping
371+
futures, futurekey2item = dask_map(
372+
process,
373+
workitems,
374+
client=client,
375+
NanoEventsFactory_kwargs={
376+
"preload": lambda b: b.name in {"Jet_pt", "Jet_eta"},
377+
"schemaclass": NanoAODSchema,
378+
},
379+
)
380+
381+
# perform reduction and track failures
382+
final_future, failed_items = dask_reduce(
383+
futures,
384+
futurekey2item=futurekey2item,
385+
client=client,
386+
treereduction=3,
387+
)
388+
389+
coffea_console.print("Failed items:", failed_items)
390+
result = final_future.result()
391+
coffea_console.print("Output:", result["out"])
392+
coffea_console.print("Metrics:", result["report"])

0 commit comments

Comments
 (0)