Skip to content

Commit 075ff91

Browse files
feat: extend metrics and workload (#19)
* save Dask reports * switch to mjj as observable * add worker tracking and extend metrics
1 parent ca1d6e8 commit 075ff91

File tree

4 files changed

+142
-49
lines changed

4 files changed

+142
-49
lines changed

atlas/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,7 @@ __marimo__/
212212
ntuple_production/production_status.json
213213
# preprocess json
214214
preprocess_output.json
215+
# Dask reports
216+
*html
217+
# figures
218+
*png

atlas/analysis.ipynb

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77
"metadata": {},
88
"outputs": [],
99
"source": [
10-
"# ! pip install --upgrade atlas_schema\n",
11-
"# ! pip install --upgrade git+https://github.com/scikit-hep/mplhep.git\n",
12-
"\n",
13-
"# import importlib\n",
14-
"# importlib.reload(utils)"
10+
"! pip install --upgrade --quiet atlas_schema\n",
11+
"! pip install --upgrade --quiet --pre mplhep"
1512
]
1613
},
1714
{
@@ -27,18 +24,19 @@
2724
"import time\n",
2825
"\n",
2926
"import awkward as ak\n",
27+
"import cloudpickle\n",
3028
"import dask\n",
31-
"import vector\n",
3229
"import hist\n",
3330
"import matplotlib.pyplot as plt\n",
3431
"import mplhep\n",
3532
"import numpy as np\n",
3633
"import uproot\n",
34+
"import vector\n",
3735
"\n",
3836
"from atlas_schema.schema import NtupleSchema\n",
3937
"from coffea import processor\n",
4038
"from coffea.nanoevents import NanoEventsFactory\n",
41-
"from dask.distributed import Client, PipInstall\n",
39+
"from dask.distributed import Client, PipInstall, performance_report\n",
4240
"\n",
4341
"\n",
4442
"import utils\n",
@@ -49,7 +47,9 @@
4947
"client = Client(\"tls://localhost:8786\")\n",
5048
"\n",
5149
"plugin = PipInstall(packages=[\"atlas_schema\"], pip_options=[\"--upgrade\"])\n",
52-
"client.register_plugin(plugin)"
50+
"client.register_plugin(plugin)\n",
51+
"\n",
52+
"cloudpickle.register_pickle_by_value(utils)"
5353
]
5454
},
5555
{
@@ -74,6 +74,7 @@
7474
"\n",
7575
"# construct fileset\n",
7676
"fileset = {}\n",
77+
"input_size_GB = 0\n",
7778
"for containers_for_category in dataset_info.values():\n",
7879
" for container, metadata in containers_for_category.items():\n",
7980
" if metadata[\"files_output\"] is None:\n",
@@ -82,13 +83,16 @@
8283
"\n",
8384
" dsid, _, campaign = utils.dsid_rtag_campaign(container)\n",
8485
"\n",
85-
" # debugging shortcuts\n",
86-
" # if campaign not in [\"mc20a\", \"data15\", \"data16\"]: continue\n",
87-
" # if \"601352\" not in dsid: continue\n",
86+
" # debugging shortcuts, use one or both of the following to reduce workload\n",
87+
" if campaign not in [\"mc23a\", \"data22\"]: continue\n",
88+
" # if \"601229\" not in dsid: continue\n",
8889
"\n",
8990
" weight_xs = utils.sample_xs(campaign, dsid)\n",
9091
" lumi = utils.integrated_luminosity(campaign)\n",
9192
" fileset[container] = {\"files\": dict((path, \"reco\") for path in metadata[\"files_output\"]), \"metadata\": {\"dsid\": dsid, \"campaign\": campaign, \"weight_xs\": weight_xs, \"lumi\": lumi}}\n",
93+
" input_size_GB += metadata[\"size_output_GB\"]\n",
94+
"\n",
95+
"print(f\"fileset has {len(fileset)} categories with {sum([len(f[\"files\"]) for f in fileset.values()])} files total, size is {input_size_GB:.2f} GB\")\n",
9296
"\n",
9397
"# minimal fileset for debugging\n",
9498
"# fileset = {\"mc20_13TeV.601352.PhPy8EG_tW_dyn_DR_incl_antitop.deriv.DAOD_PHYSLITE.e8547_s4231_r13144_p6697\": fileset[\"mc20_13TeV.601352.PhPy8EG_tW_dyn_DR_incl_antitop.deriv.DAOD_PHYSLITE.e8547_s4231_r13144_p6697\"]}\n",
@@ -148,16 +152,18 @@
148152
"outputs": [],
149153
"source": [
150154
"run = processor.Runner(\n",
151-
" executor = processor.DaskExecutor(client=client),\n",
152-
" # executor = processor.IterativeExecutor(),\n",
155+
" executor = processor.DaskExecutor(client=client, treereduction=4),\n",
156+
" # executor = processor.IterativeExecutor(), # to run locally\n",
153157
" schema=NtupleSchema,\n",
154158
" savemetrics=True,\n",
155-
" chunksize=100_000,\n",
159+
" chunksize=50_000,\n",
156160
" skipbadfiles=True,\n",
157-
" # maxchunks=1\n",
161+
" align_clusters=False,\n",
162+
" # maxchunks=1 # for debugging only\n",
158163
")\n",
159164
"\n",
160-
"preprocess_output = run.preprocess(fileset)\n",
165+
"with performance_report(filename=\"preprocess.html\"):\n",
166+
" preprocess_output = run.preprocess(fileset)\n",
161167
"\n",
162168
"# write to disk\n",
163169
"with open(\"preprocess_output.json\", \"w\") as f:\n",
@@ -187,7 +193,7 @@
187193
"source": [
188194
"class Analysis(processor.ProcessorABC):\n",
189195
" def __init__(self):\n",
190-
" self.h = hist.new.Regular(30, 0, 300, label=\"leading electron $p_T$\").\\\n",
196+
" self.h = hist.new.Regular(20, 0, 1_000, label=\"$m_{jj}$ [GeV]\").\\\n",
191197
" StrCat([], name=\"dsid_and_campaign\", growth=True).\\\n",
192198
" StrCat([], name=\"variation\", growth=True).\\\n",
193199
" Weight()\n",
@@ -213,12 +219,14 @@
213219
" sumw = None # no normalization for data\n",
214220
"\n",
215221
" for variation in events.systematic_names:\n",
216-
" if variation != \"NOSYS\" and \"EG_SCALE_ALL\" not in variation:\n",
222+
" if variation not in [\"NOSYS\"] + [name for name in events.systematic_names if \"JET_JER_Effective\" in name]:\n",
217223
" continue\n",
218224
"\n",
219225
" cut = events[variation][\"pass\"][\"ejets\"] == 1\n",
226+
" # TODO: remaining weights\n",
220227
" weight = (events[variation][cut==1].weight.mc if events.metadata[\"dsid\"] != \"data\" else 1.0) * events.metadata[\"weight_xs\"] * events.metadata[\"lumi\"]\n",
221-
" self.h.fill(events[variation][cut==1].el.pt[:, 0] / 1_000, dsid_and_campaign=dsid_and_campaign, variation=variation, weight=weight)\n",
228+
" mjj = (events[variation][cut==1].jet[:, 0] + events[variation][cut==1].jet[:, 1]).mass\n",
229+
" self.h.fill(mjj / 1_000, dsid_and_campaign=dsid_and_campaign, variation=variation, weight=weight)\n",
222230
"\n",
223231
" return {\n",
224232
" \"hist\": self.h,\n",
@@ -239,18 +247,31 @@
239247
" accumulator[\"hist\"][:, dsid_and_campaign, :] = np.stack([count_normalized, variance_normalized], axis=-1)\n",
240248
"\n",
241249
"\n",
242-
"t0 = time.perf_counter()\n",
243-
"out, report = run(preprocess_output, processor_instance=Analysis())\n",
250+
"client.run_on_scheduler(utils.start_tracking) # track worker count on scheduler\n",
251+
"t0 = time.perf_counter() # track walltime\n",
252+
"\n",
253+
"with performance_report(filename=\"process.html\"):\n",
254+
" out, report = run(preprocess_output, processor_instance=Analysis())\n",
255+
"\n",
244256
"t1 = time.perf_counter()\n",
245-
"report"
257+
"worker_count_dict = client.run_on_scheduler(utils.stop_tracking) # stop tracking, read out data, get average\n",
258+
"nworker_avg = utils.get_avg_num_workers(worker_count_dict)\n",
259+
"\n",
260+
"print(f\"histogram size: {out[\"hist\"].view(True).nbytes / 1_000 / 1_000:.2f} GB\\n\")\n",
261+
"\n",
262+
"# shortened version of report, dropping extra columns\n",
263+
"dict((k, v) for k, v in report.items() if k != \"columns\") | ({\"columns\": report[\"columns\"][0:10] + [\"...\"]})"
246264
]
247265
},
248266
{
249267
"cell_type": "markdown",
250268
"id": "8663e9ff-f8bb-43a0-8978-f2d430d2bbbd",
251269
"metadata": {},
252270
"source": [
253-
"track XCache egress: [link](https://grafana.mwt2.org/d/EKefjM-Sz/af-network-200gbps-challenge?var-cnode=c111_af_uchicago_edu&var-cnode=c112_af_uchicago_edu&var-cnode=c113_af_uchicago_edu&var-cnode=c114_af_uchicago_edu&var-cnode=c115_af_uchicago_edu&viewPanel=195&kiosk=true&orgId=1&from=now-1h&to=now&timezone=browser&refresh=5s)"
271+
"track XCache egress: [link](https://grafana.mwt2.org/d/EKefjM-Sz/af-network-200gbps-challenge?var-cnode=c111_af_uchicago_edu&var-cnode=c112_af_uchicago_edu&var-cnode=c113_af_uchicago_edu&var-cnode=c114_af_uchicago_edu&var-cnode=c115_af_uchicago_edu&viewPanel=195&kiosk=true&orgId=1&from=now-1h&to=now&timezone=browser&refresh=5s)\n",
272+
"\n",
273+
"**to-do for metrics:**\n",
274+
"- data rate by tracking `bytesread` per chunk"
254275
]
255276
},
256277
{
@@ -260,13 +281,25 @@
260281
"metadata": {},
261282
"outputs": [],
262283
"source": [
263-
"print(f\"data read: {report[\"bytesread\"] / 1000**3:.2f} GB in {report[\"chunks\"]} chunks\")\n",
284+
"print(f\"walltime: {t1 - t0:.2f} sec ({(t1 - t0) / 60:.2f} min)\")\n",
285+
"print(f\"average worker count: {nworker_avg:.1f}\")\n",
286+
"print(f\"number of events processed: {report[\"entries\"]:,}\\n\")\n",
287+
"\n",
288+
"print(f\"data read: {report[\"bytesread\"] / 1000**3:.2f} GB in {report[\"chunks\"]} chunks (average {report[\"bytesread\"] / 1000**3 / report[\"chunks\"]:.2f} GB per chunk)\")\n",
289+
"print(f\"average total data rate: {report[\"bytesread\"] / 1000**3 * 8 / (t1 - t0):.2f} Gbps\")\n",
290+
"print(f\"fraction of input files read: {report[\"bytesread\"] / 1000**3 / input_size_GB:.1%}\")\n",
291+
"print(f\"number of branches read: {len(report[\"columns\"])}\\n\")\n",
292+
"\n",
293+
"print(f\"worker-average event rate using \\'processtime\\': {report[\"entries\"] / 1000 / report[\"processtime\"]:.2f} kHz\")\n",
294+
"print(f\"worker-average data rate using \\'processtime\\': {report[\"bytesread\"] / 1000**3 * 8 / report[\"processtime\"]:.2f} Gbps\\n\")\n",
264295
"\n",
265-
"print(f\"core-average event rate using \\'processtime\\': {report[\"entries\"] / 1000 / report[\"processtime\"]:.2f} kHz\")\n",
266-
"print(f\"core-average data rate using \\'processtime\\': {report[\"bytesread\"] / 1000**3 * 8 / report[\"processtime\"]:.2f} Gbps\")\n",
296+
"print(f\"average event rate using walltime and time-averaged worker count: {report[\"entries\"] / 1000 / (t1 - t0) / nworker_avg:.2f} kHz\")\n",
297+
"print(f\"average data rate using walltime and time-averaged worker count: {report[\"bytesread\"] / 1000**3 * 8 / (t1 - t0) / nworker_avg:.2f} Gbps\\n\")\n",
267298
"\n",
268-
"print(f\"average event rate using walltime: {report[\"entries\"] / 1000 / (t1 - t0):.2f} kHz\")\n",
269-
"print(f\"average data rate using walltime: {report[\"bytesread\"] / 1000**3 * 8 / (t1 - t0):.2f} Gbps\")"
299+
"print(f\"fraction of time spent in processing: {report[\"processtime\"] / ((t1 - t0) * nworker_avg):.1%}\")\n",
300+
"print(f\"average process task length: {report[\"processtime\"] / report[\"chunks\"]:.1f} sec\")\n",
301+
"\n",
302+
"_ = utils.plot_worker_count(worker_count_dict)"
270303
]
271304
},
272305
{
@@ -285,16 +318,22 @@
285318
"\n",
286319
" dsids = sorted(set(dsids))\n",
287320
" dsids_in_hist = [dc for dc in out[\"hist\"].axes[1] if dc.split(\"_\")[0] in dsids]\n",
288-
" print(f\"{key}:\\n - expect {dsids}\\n - have {dsids_in_hist}\")\n",
321+
" # print(f\"{key}:\\n - expect {dsids}\\n - have {dsids_in_hist}\")\n",
289322
"\n",
290323
" if key in [\"data\", \"ttbar_H7\", \"ttbar_hdamp\", \"ttbar_pthard\", \"Wt_DS\", \"Wt_H7\", \"Wt_pthard\"] or len(dsids_in_hist) == 0:\n",
291324
" continue # data drawn separately, skip MC modeling variations and skip empty categories\n",
292325
"\n",
293326
" mc_stack.append(out[\"hist\"][:, :, \"NOSYS\"].integrate(\"dsid_and_campaign\", dsids_in_hist))\n",
294327
" labels.append(key)\n",
295328
"\n",
296-
"fig, ax1, ax2 = mplhep.data_model(\n",
297-
" data_hist=out[\"hist\"].integrate(\"dsid_and_campaign\", [dc for dc in out[\"hist\"].axes[1] if \"data\" in dc])[:, \"NOSYS\"],\n",
329+
"try:\n",
330+
" data_hist = out[\"hist\"].integrate(\"dsid_and_campaign\", [dc for dc in out[\"hist\"].axes[1] if \"data\" in dc])[:, \"NOSYS\"]\n",
331+
"except ValueError:\n",
332+
" print(\"falling back to plotting first entry of categorical axes as \\\"data\\\"\")\n",
333+
" data_hist = out[\"hist\"][:, 0, 0]\n",
334+
"\n",
335+
"fig, ax1, ax2 = mplhep.comp.data_model(\n",
336+
" data_hist=data_hist,\n",
298337
" stacked_components=mc_stack,\n",
299338
" stacked_labels=labels,\n",
300339
" # https://scikit-hep.org/mplhep/gallery/model_with_stacked_and_unstacked_histograms_components/\n",
@@ -309,7 +348,7 @@
309348
"ax2.set_ylim([0.5, 1.5])\n",
310349
"\n",
311350
"# compare to e.g. https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PAPERS/HDBS-2020-11/fig_02a.png\n",
312-
"fig.savefig(\"el_pt.png\")"
351+
"fig.savefig(\"mjj.png\")"
313352
]
314353
},
315354
{
@@ -326,9 +365,7 @@
326365
" f.write(json.dumps(out[\"hist\"], default=uhi.io.json.default).encode(\"utf-8\"))\n",
327366
"\n",
328367
"with gzip.open(\"hist.json.gz\") as f:\n",
329-
" h = hist.Hist(json.loads(f.read(), object_hook=uhi.io.json.object_hook))\n",
330-
"\n",
331-
"h[:, \"data_data15\", \"NOSYS\"]"
368+
" h = hist.Hist(json.loads(f.read(), object_hook=uhi.io.json.object_hook))"
332369
]
333370
}
334371
],

atlas/ntuple_production/ntuple_summary.ipynb

Lines changed: 20 additions & 14 deletions
Large diffs are not rendered by default.

atlas/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
11
# saving preprocessing output
2+
import asyncio
23
import base64
34
import dataclasses
5+
import datetime
46
import re
57
import urllib.request
68

79
import coffea
10+
import matplotlib.pyplot as plt
11+
12+
13+
##################################################
14+
### Dask task tracking
15+
##################################################
16+
17+
def start_tracking(dask_scheduler) -> None:
18+
""""run on scheduler to track worker count"""
19+
dask_scheduler.worker_counts = {}
20+
dask_scheduler.track_count = True
21+
22+
async def track_count() -> None:
23+
while dask_scheduler.track_count:
24+
dask_scheduler.worker_counts[datetime.datetime.now()] = len(dask_scheduler.workers)
25+
await asyncio.sleep(1)
26+
27+
asyncio.create_task(track_count())
28+
29+
30+
def stop_tracking(dask_scheduler) -> dict:
31+
"""obtain worker count and stop tracking"""
32+
dask_scheduler.track_count = False
33+
return dask_scheduler.worker_counts
34+
35+
36+
def get_avg_num_workers(worker_count_dict: dict) -> float:
37+
"""get time-averaged worker count"""
38+
worker_info = list(worker_count_dict.items())
39+
nworker_dt = 0
40+
for (t0, nw0), (t1, nw1) in zip(worker_info[:-1], worker_info[1:]):
41+
nworker_dt += (nw1 + nw0) / 2 * (t1 - t0).total_seconds()
42+
return nworker_dt / (worker_info[-1][0] - worker_info[0][0]).total_seconds()
43+
44+
45+
def plot_worker_count(worker_count_dict: dict):
46+
"""plot worker count over time"""
47+
fig, ax = plt.subplots()
48+
ax.plot(worker_count_dict.keys(), worker_count_dict.values())
49+
ax.tick_params(axis="x", labelrotation=45)
50+
ax.set_ylim([0, ax.get_ylim()[1]])
51+
ax.set_xlabel("time")
52+
ax.set_ylabel("number of workers")
53+
return fig, ax
854

955

1056
##################################################

0 commit comments

Comments
 (0)