Skip to content

Commit 24b41b8

Browse files
feat: progress bar for custom dask.distributed setups (#32)
* add progress bar for xrdcp notebook * add progress bar for main notebook
1 parent 16b91ad commit 24b41b8

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

atlas/analysis.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,11 @@
267267
"\n",
268268
"if USE_CUSTOM_PROCESSING:\n",
269269
" # configure here whether to preload branches\n",
270-
" columns_to_preload = json.load(pathlib.Path(\"columns_to_preload.json\").open())[\"JET_JER_Effective\"]\n",
271-
" columns_to_preload = []\n",
270+
" columns_to_preload = json.load(pathlib.Path(\"columns_to_preload.json\").open())[\"JET_JER_Effective\"] # or []\n",
272271
"\n",
273272
" with performance_report(filename=\"process_custom.html\"):\n",
274273
" out, report = utils.custom_process(preprocess_output, processor_class=Analysis, schema=run.schema, client=client, preload=columns_to_preload)\n",
275-
" print(f\"preloaded columns: {len(columns_to_preload)}, {columns_to_preload} {\"etc.\" if len(columns_to_preload) > 4 else \"\"}\")\n",
274+
" print(f\"preloaded columns: {len(columns_to_preload)}, {columns_to_preload[:4]} {\"etc.\" if len(columns_to_preload) > 4 else \"\"}\")\n",
276275
" print(f\"preloaded but unused columns: {len([c for c in columns_to_preload if c not in report[\"columns\"]])}\")\n",
277276
" print(f\"used but not preloaded columns: {len([c for c in report[\"columns\"] if c not in columns_to_preload])}\")\n",
278277
"\n",

atlas/ntuple_production/distributed_xrdcp.ipynb

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"import numpy as np\n",
1919
"import matplotlib.dates as mdates\n",
2020
"import matplotlib.pyplot as plt\n",
21+
"import tqdm.notebook\n",
2122
"\n",
2223
"client = Client(\"tls://localhost:8786\")"
2324
]
@@ -65,7 +66,13 @@
6566
"\n",
6667
"t0 = time.time()\n",
6768
"tasks = [dask.delayed(run_xrdcp)(fname, size) for fname, size in zip(all_files, all_sizes_GB)]\n",
68-
"res = dask.compute(*tasks)\n",
69+
"futures = client.compute(tasks)\n",
70+
"\n",
71+
"with tqdm.notebook.tqdm(total=len(futures)) as pbar:\n",
72+
" for future in dask.distributed.as_completed(futures):\n",
73+
" pbar.update(1)\n",
74+
"\n",
75+
"res = [f.result() for f in futures]\n",
6976
"t1 = time.time()"
7077
]
7178
},

atlas/utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import matplotlib.dates as mdates
2020
import matplotlib.pyplot as plt
2121
import numpy as np
22+
import tqdm.notebook
2223
import uproot
2324

2425

@@ -311,13 +312,16 @@ def extract_metadata(fname_and_treename: str, custom_func) -> dict:
311312
meta.update({"custom_meta": custom_func(f)})
312313
return {fname: meta}
313314

314-
preprocess_input = dask.bag.from_sequence(files_to_preprocess, partition_size=1)
315315
print(f"pre-processing {len(files_to_preprocess)} file(s)")
316-
futures = preprocess_input.map(functools.partial(extract_metadata, custom_func=custom_func))
317-
result = client.compute(futures).result()
316+
tasks = client.map(functools.partial(extract_metadata, custom_func=custom_func), files_to_preprocess)
317+
futures = client.compute(tasks)
318+
319+
with tqdm.notebook.tqdm(total=len(futures)) as pbar:
320+
for _ in dask.distributed.as_completed(futures):
321+
pbar.update(1)
318322

319323
# turn into dict for easier use
320-
result_dict = {k: v for res in result for k, v in res.items()}
324+
result_dict = {k: v for res in [f.result() for f in futures] for k, v in res.items()}
321325

322326
# join back together per-file information with fileset-level information and turn into WorkItem list for coffea
323327
workitems = []
@@ -410,5 +414,13 @@ def sum_output(a, b):
410414
)
411415

412416
workitems_bag = dask.bag.from_sequence(workitems, partition_size=1)
413-
futures = workitems_bag.map(run_analysis).fold(sum_output)
414-
return client.compute(futures).result()
417+
tasks = workitems_bag.map(run_analysis).to_delayed()
418+
futures = client.compute(tasks)
419+
workitems_bag = dask.bag.from_delayed(futures)
420+
res = client.compute(workitems_bag.fold(sum_output))
421+
422+
with tqdm.notebook.tqdm(total=len(futures)) as pbar:
423+
for _ in dask.distributed.as_completed(futures):
424+
pbar.update(1)
425+
426+
return res.result()

0 commit comments

Comments
 (0)