Skip to content

Commit 84f20db

Browse files
committed
handle non-serializable values in merge_json_on_disk()
1 parent 31bd925 commit 84f20db

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

aviary/wrenformer/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import time
33
from contextlib import contextmanager
4-
from typing import Generator
4+
from typing import Generator, Literal
55

66

77
def _int_keys(dct: dict) -> dict:
@@ -21,12 +21,20 @@ def recursive_dict_merge(d1: dict, d2: dict) -> dict:
2121
return d1
2222

2323

24-
def merge_json_on_disk(dct: dict, file_path: str) -> None:
24+
def merge_json_on_disk(
25+
dct: dict,
26+
file_path: str,
27+
on_non_serializable: Literal["annotate", "error"] = "annotate",
28+
) -> None:
2529
"""Merge a dict into a (possibly) existing JSON file.
2630
2731
Args:
2832
file_path (str): Path to JSON file. File will be created if not exist.
2933
dct (dict): Dictionary to merge into JSON file.
34+
on_non_serializable ('annotate' | 'error'): What to do with non-serializable values
35+
encountered in dct. 'annotate' will replace the offending object with a string
36+
indicating the type, e.g. '<not serializable: function>'. 'error' will raise
37+
'TypeError: Object of type function is not JSON serializable'. Defaults to 'annotate'.
3038
"""
3139
try:
3240
with open(file_path) as json_file:
@@ -36,8 +44,15 @@ def merge_json_on_disk(dct: dict, file_path: str) -> None:
3644
except (FileNotFoundError, json.decoder.JSONDecodeError): # file missing or empty
3745
pass
3846

47+
def non_serializable_handler(obj: object) -> str:
48+
# replace functions and classes in dct with string indicating a non-serializable type
49+
return f"<not serializable: {type(obj).__qualname__}>"
50+
3951
with open(file_path, "w") as file:
40-
json.dump(dct, file)
52+
default = (
53+
non_serializable_handler if on_non_serializable == "annotate" else None
54+
)
55+
json.dump(dct, file, default=default)
4156

4257

4358
@contextmanager

examples/wrenformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def run_wrenformer(
236236
"warmup_steps": warmup_steps,
237237
"robust": robust,
238238
"embedding_len": embedding_len,
239-
"losses": str(loss_dict),
239+
"losses": loss_dict,
240240
"training_samples": len(train_df),
241241
"test_samples": len(test_df),
242242
"trainable_params": model.num_params,

0 commit comments

Comments
 (0)