Skip to content

Commit 34e75ed

Browse files
committed
sqlite_utils.utils.flatten() function, closes #500
1 parent d792dad commit 34e75ed

File tree

5 files changed

+40
-23
lines changed

5 files changed

+40
-23
lines changed

docs/reference.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,10 @@ sqlite_utils.utils.chunks
101101
-------------------------
102102

103103
.. autofunction:: sqlite_utils.utils.chunks
104+
105+
.. _reference_utils_flatten:
106+
107+
sqlite_utils.utils.flatten
108+
--------------------------
109+
110+
.. autofunction:: sqlite_utils.utils.flatten

sqlite_utils/cli.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
chunks,
2525
file_progress,
2626
find_spatialite,
27+
flatten as _flatten,
2728
sqlite3,
2829
decode_base64_values,
2930
progressbar,
@@ -997,7 +998,7 @@ def insert_upsert_implementation(
997998
"Invalid JSON - use --csv for CSV or --tsv for TSV files"
998999
)
9991000
if flatten:
1000-
docs = (dict(_flatten(doc)) for doc in docs)
1001+
docs = (_flatten(doc) for doc in docs)
10011002

10021003
if convert:
10031004
variable = "row"
@@ -1079,15 +1080,6 @@ def insert_upsert_implementation(
10791080
db[table].transform(types=tracker.types)
10801081

10811082

1082-
def _flatten(d):
1083-
for key, value in d.items():
1084-
if isinstance(value, dict):
1085-
for key2, value2 in _flatten(value):
1086-
yield key + "_" + key2, value2
1087-
else:
1088-
yield key, value
1089-
1090-
10911083
def _find_variables(tb, vars):
10921084
to_find = list(vars)
10931085
found = {}
@@ -1845,7 +1837,7 @@ def memory(
18451837
tracker = TypeTracker()
18461838
rows = tracker.wrap(rows)
18471839
if flatten:
1848-
rows = (dict(_flatten(row)) for row in rows)
1840+
rows = (_flatten(row) for row in rows)
18491841
db[csv_table].insert_all(rows, alter=True)
18501842
if tracker is not None:
18511843
db[csv_table].transform(types=tracker.types)

sqlite_utils/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,21 @@ def hash_record(record: Dict, keys: Optional[Iterable[str]] = None):
513513
"utf8"
514514
)
515515
).hexdigest()
516+
517+
518+
def _flatten(d):
519+
for key, value in d.items():
520+
if isinstance(value, dict):
521+
for key2, value2 in _flatten(value):
522+
yield key + "_" + key2, value2
523+
else:
524+
yield key, value
525+
526+
527+
def flatten(row: dict) -> dict:
528+
"""
529+
Turn a nested dict e.g. ``{"a": {"b": 1}}`` into a flat dict: ``{"a_b": 1}``
530+
531+
:param row: A Python dictionary, optionally with nested dictionaries
532+
"""
533+
return dict(_flatten(row))

tests/test_cli.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,18 +2176,6 @@ def test_upsert_detect_types(tmpdir, option):
21762176
]
21772177

21782178

2179-
@pytest.mark.parametrize(
2180-
"input,expected",
2181-
(
2182-
({"foo": {"bar": 1}}, {"foo_bar": 1}),
2183-
({"foo": {"bar": [1, 2, {"baz": 3}]}}, {"foo_bar": [1, 2, {"baz": 3}]}),
2184-
({"foo": {"bar": 1, "baz": {"three": 3}}}, {"foo_bar": 1, "foo_baz_three": 3}),
2185-
),
2186-
)
2187-
def test_flatten_helper(input, expected):
2188-
assert dict(cli._flatten(input)) == expected
2189-
2190-
21912179
def test_integer_overflow_error(tmpdir):
21922180
db_path = str(tmpdir / "test.db")
21932181
result = CliRunner().invoke(

tests/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,15 @@ def test_maximize_csv_field_size_limit():
7171
assert len(rows_list2) == 1
7272
assert rows_list2[0]["id"] == "1"
7373
assert rows_list2[0]["text"] == long_value
74+
75+
76+
@pytest.mark.parametrize(
77+
"input,expected",
78+
(
79+
({"foo": {"bar": 1}}, {"foo_bar": 1}),
80+
({"foo": {"bar": [1, 2, {"baz": 3}]}}, {"foo_bar": [1, 2, {"baz": 3}]}),
81+
({"foo": {"bar": 1, "baz": {"three": 3}}}, {"foo_bar": 1, "foo_baz_three": 3}),
82+
),
83+
)
84+
def test_flatten(input, expected):
85+
assert utils.flatten(input) == expected

0 commit comments

Comments
 (0)