Skip to content

Commit 662c93f

Browse files
authored
feat: support overriding the amber serializer class (#683)
1 parent 4ca0716 commit 662c93f

File tree

6 files changed

+173
-58
lines changed

6 files changed

+173
-58
lines changed

src/syrupy/extensions/amber/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
Any,
66
Optional,
77
Set,
8+
Type,
89
)
910

1011
from syrupy.data import SnapshotCollection
1112
from syrupy.exceptions import TaintedSnapshotError
1213
from syrupy.extensions.base import AbstractSyrupyExtension
1314

15+
from .serializer import AmberDataSerializerSorted # noqa: F401 # re-exported
1416
from .serializer import AmberDataSerializer
1517

1618
if TYPE_CHECKING:
@@ -24,12 +26,14 @@ class AmberSnapshotExtension(AbstractSyrupyExtension):
2426

2527
_file_extension = "ambr"
2628

29+
serializer_class: Type["AmberDataSerializer"] = AmberDataSerializer
30+
2731
def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
2832
"""
2933
Returns the serialized form of 'data' to be compared
3034
with the snapshot data written to disk.
3135
"""
32-
return AmberDataSerializer.serialize(data, **kwargs)
36+
return self.serializer_class.serialize(data, **kwargs)
3337

3438
def delete_snapshots(
3539
self, snapshot_location: str, snapshot_names: Set[str]
@@ -39,19 +43,19 @@ def delete_snapshots(
3943
snapshot_collection_to_update.remove(snapshot_name)
4044

4145
if snapshot_collection_to_update.has_snapshots:
42-
AmberDataSerializer.write_file(snapshot_collection_to_update)
46+
self.serializer_class.write_file(snapshot_collection_to_update)
4347
else:
4448
Path(snapshot_location).unlink()
4549

4650
def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
47-
return AmberDataSerializer.read_file(snapshot_location)
51+
return self.serializer_class.read_file(snapshot_location)
4852

49-
@staticmethod
53+
@classmethod
5054
@lru_cache()
5155
def __cacheable_read_snapshot(
52-
snapshot_location: str, cache_key: str
56+
cls, snapshot_location: str, cache_key: str
5357
) -> "SnapshotCollection":
54-
return AmberDataSerializer.read_file(snapshot_location)
58+
return cls.serializer_class.read_file(snapshot_location)
5559

5660
def _read_snapshot_data_from_location(
5761
self, snapshot_location: str, snapshot_name: str, session_id: str
@@ -70,7 +74,7 @@ def _read_snapshot_data_from_location(
7074
def _write_snapshot_collection(
7175
cls, *, snapshot_collection: "SnapshotCollection"
7276
) -> None:
73-
AmberDataSerializer.write_file(snapshot_collection, merge=True)
77+
cls.serializer_class.write_file(snapshot_collection, merge=True)
7478

7579

7680
__all__ = ["AmberSnapshotExtension", "AmberDataSerializer"]

src/syrupy/extensions/amber/serializer.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
Dict,
1212
Generator,
1313
Iterable,
14-
List,
1514
NamedTuple,
1615
Optional,
1716
Set,
@@ -77,7 +76,13 @@ class MissingVersionError(Exception):
7776

7877

7978
class AmberDataSerializer:
80-
VERSION = 1
79+
"""
80+
If extending the serializer, change the VERSION property to some unique value
81+
for your iteration of the serializer so as to force invalidation of existing
82+
snapshots.
83+
"""
84+
85+
VERSION = "1"
8186

8287
_indent: str = " "
8388
_max_depth: int = 99
@@ -89,23 +94,8 @@ class Marker:
8994
Divider = "---"
9095

9196
@classmethod
92-
def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]:
93-
try:
94-
# cast to int only if the string is the exact representation of the int
95-
# for example, '012' != str(int('012'))
96-
i = int(part)
97-
if str(i) == part:
98-
return (1, i)
99-
return (0, part)
100-
except ValueError:
101-
# the nested tuple is to prevent comparing a str to an int
102-
return (0, part)
103-
104-
@classmethod
105-
def __snapshot_sort_key(
106-
cls, snapshot: "Snapshot"
107-
) -> List[Tuple[int, Union[str, int]]]:
108-
return [cls.__maybe_int(part) for part in snapshot.name.split(".")]
97+
def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any:
98+
return snapshot.name
10999

110100
@classmethod
111101
def write_file(
@@ -123,7 +113,7 @@ def write_file(
123113
with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
124114
f.write(f"{cls._marker_prefix}{cls.Marker.Version}: {cls.VERSION}\n")
125115
for snapshot in sorted(
126-
snapshot_collection, key=lambda s: cls.__snapshot_sort_key(s)
116+
snapshot_collection, key=lambda s: cls._snapshot_sort_key(s) # type: ignore # noqa: E501
127117
):
128118
snapshot_data = str(snapshot.data)
129119
if snapshot_data is not None:
@@ -152,14 +142,14 @@ def __read_file_with_markers(
152142
":", maxsplit=1
153143
)
154144
marker_key = marker_key.rstrip(" \r\n")
155-
marker_value = marker_rest[0] if marker_rest else None
145+
marker_value = marker_rest[0].strip() if marker_rest else None
156146

157147
if marker_key == cls.Marker.Version:
158148
if line_no:
159149
raise MalformedAmberFile(
160150
"Version must be specified at the top of the file."
161151
)
162-
if not marker_value or int(marker_value) != cls.VERSION:
152+
if not marker_value or marker_value != cls.VERSION:
163153
tainted = True
164154
continue
165155
missing_version = False
@@ -457,3 +447,28 @@ def __serialize_lines(
457447
formatted_open_tag = cls.with_indent(f"{maybe_obj_type}{open_tag}", depth)
458448
formatted_close_tag = cls.with_indent(close_tag, depth)
459449
return f"{formatted_open_tag}\n{lines}{lines_end}{formatted_close_tag}"
450+
451+
452+
class AmberDataSerializerSorted(AmberDataSerializer):
453+
"""
454+
This is an experimental serializer with known performance issues.
455+
"""
456+
457+
VERSION = f"{AmberDataSerializer.VERSION}-sorted"
458+
459+
@classmethod
460+
def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]:
461+
try:
462+
# cast to int only if the string is the exact representation of the int
463+
# for example, '012' != str(int('012'))
464+
i = int(part)
465+
if str(i) == part:
466+
return (1, i)
467+
return (0, part)
468+
except ValueError:
469+
# the nested tuple is to prevent comparing a str to an int
470+
return (0, part)
471+
472+
@classmethod
473+
def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any:
474+
return [cls.__maybe_int(part) for part in snapshot.name.split(".")]

tests/syrupy/__snapshots__/test_doctest.ambr

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
obj_attr='test class attr',
55
)
66
# ---
7+
# name: DocTestClass.1
8+
DocTestClass(
9+
obj_attr='test class attr',
10+
)
11+
# ---
712
# name: DocTestClass.NestedDocTestClass
813
NestedDocTestClass(
914
nested_obj_attr='nested doc test class attr',
@@ -15,11 +20,6 @@
1520
# name: DocTestClass.doctest_method
1621
'doc test method return value'
1722
# ---
18-
# name: DocTestClass.1
19-
DocTestClass(
20-
obj_attr='test class attr',
21-
)
22-
# ---
2323
# name: doctest_fn
2424
'doc test fn return value'
2525
# ---

tests/syrupy/extensions/amber/__snapshots__/test_amber_serializer.ambr

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,30 +253,6 @@
253253
# name: test_many_sorted.1
254254
1
255255
# ---
256-
# name: test_many_sorted.2
257-
2
258-
# ---
259-
# name: test_many_sorted.3
260-
3
261-
# ---
262-
# name: test_many_sorted.4
263-
4
264-
# ---
265-
# name: test_many_sorted.5
266-
5
267-
# ---
268-
# name: test_many_sorted.6
269-
6
270-
# ---
271-
# name: test_many_sorted.7
272-
7
273-
# ---
274-
# name: test_many_sorted.8
275-
8
276-
# ---
277-
# name: test_many_sorted.9
278-
9
279-
# ---
280256
# name: test_many_sorted.10
281257
10
282258
# ---
@@ -307,6 +283,9 @@
307283
# name: test_many_sorted.19
308284
19
309285
# ---
286+
# name: test_many_sorted.2
287+
2
288+
# ---
310289
# name: test_many_sorted.20
311290
20
312291
# ---
@@ -322,6 +301,27 @@
322301
# name: test_many_sorted.24
323302
24
324303
# ---
304+
# name: test_many_sorted.3
305+
3
306+
# ---
307+
# name: test_many_sorted.4
308+
4
309+
# ---
310+
# name: test_many_sorted.5
311+
5
312+
# ---
313+
# name: test_many_sorted.6
314+
6
315+
# ---
316+
# name: test_many_sorted.7
317+
7
318+
# ---
319+
# name: test_many_sorted.8
320+
8
321+
# ---
322+
# name: test_many_sorted.9
323+
9
324+
# ---
325325
# name: test_multiline_string_in_dict
326326
dict({
327327
'value': '''
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# serializer version: 1-sorted
2+
# name: test_many_sorted
3+
0
4+
# ---
5+
# name: test_many_sorted.1
6+
1
7+
# ---
8+
# name: test_many_sorted.2
9+
2
10+
# ---
11+
# name: test_many_sorted.3
12+
3
13+
# ---
14+
# name: test_many_sorted.4
15+
4
16+
# ---
17+
# name: test_many_sorted.5
18+
5
19+
# ---
20+
# name: test_many_sorted.6
21+
6
22+
# ---
23+
# name: test_many_sorted.7
24+
7
25+
# ---
26+
# name: test_many_sorted.8
27+
8
28+
# ---
29+
# name: test_many_sorted.9
30+
9
31+
# ---
32+
# name: test_many_sorted.10
33+
10
34+
# ---
35+
# name: test_many_sorted.11
36+
11
37+
# ---
38+
# name: test_many_sorted.12
39+
12
40+
# ---
41+
# name: test_many_sorted.13
42+
13
43+
# ---
44+
# name: test_many_sorted.14
45+
14
46+
# ---
47+
# name: test_many_sorted.15
48+
15
49+
# ---
50+
# name: test_many_sorted.16
51+
16
52+
# ---
53+
# name: test_many_sorted.17
54+
17
55+
# ---
56+
# name: test_many_sorted.18
57+
18
58+
# ---
59+
# name: test_many_sorted.19
60+
19
61+
# ---
62+
# name: test_many_sorted.20
63+
20
64+
# ---
65+
# name: test_many_sorted.21
66+
21
67+
# ---
68+
# name: test_many_sorted.22
69+
22
70+
# ---
71+
# name: test_many_sorted.23
72+
23
73+
# ---
74+
# name: test_many_sorted.24
75+
24
76+
# ---
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
from syrupy.extensions.amber import (
4+
AmberDataSerializerSorted,
5+
AmberSnapshotExtension,
6+
)
7+
8+
9+
class AmberSortedSnapshotExtension(AmberSnapshotExtension):
10+
serializer_class = AmberDataSerializerSorted
11+
12+
13+
@pytest.fixture
14+
def snapshot(snapshot):
15+
return snapshot.use_extension(AmberSortedSnapshotExtension)
16+
17+
18+
def test_many_sorted(snapshot):
19+
for i in range(25):
20+
assert i == snapshot

0 commit comments

Comments
 (0)