Skip to content

Commit a09c358

Browse files
noahnuNoah Negin-Ulster
authored andcommitted
fix: defer snapshot writes until end of session (#606)
1 parent 4986aad commit a09c358

File tree

6 files changed

+122
-42
lines changed

6 files changed

+122
-42
lines changed

src/syrupy/assertion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def _assert(self, data: "SerializableData") -> bool:
262262
)
263263
assertion_success = matches
264264
if not matches and self.update_snapshots:
265-
self.extension.write_snapshot(
265+
self.session.queue_snapshot_write(
266+
extension=self.extension,
266267
data=serialized_data,
267268
index=self.index,
268269
)
@@ -297,6 +298,8 @@ def _post_assert(self) -> None:
297298

298299
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
299300
try:
300-
return self.extension.read_snapshot(index=index)
301+
return self.extension.read_snapshot(
302+
index=index, session_id=str(id(self.session))
303+
)
301304
except SnapshotDoesNotExist:
302305
return None

src/syrupy/extensions/amber/__init__.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import lru_cache
12
from pathlib import Path
23
from typing import (
34
TYPE_CHECKING,
@@ -46,16 +47,23 @@ def _file_extension(self) -> str:
4647
def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil":
4748
return DataSerializer.read_file(snapshot_location)
4849

50+
@lru_cache()
51+
def __cacheable_read_snapshot(
52+
self, snapshot_location: str, cache_key: str
53+
) -> "SnapshotFossil":
54+
return DataSerializer.read_file(snapshot_location)
55+
4956
def _read_snapshot_data_from_location(
50-
self, snapshot_location: str, snapshot_name: str
57+
self, snapshot_location: str, snapshot_name: str, session_id: str
5158
) -> Optional["SerializableData"]:
52-
snapshot = self._read_snapshot_fossil(snapshot_location).get(snapshot_name)
59+
snapshots = self.__cacheable_read_snapshot(
60+
snapshot_location=snapshot_location, cache_key=session_id
61+
)
62+
snapshot = snapshots.get(snapshot_name)
5363
return snapshot.data if snapshot else None
5464

5565
def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None:
56-
snapshot_fossil_to_update = DataSerializer.read_file(snapshot_fossil.location)
57-
snapshot_fossil_to_update.merge(snapshot_fossil)
58-
DataSerializer.write_file(snapshot_fossil_to_update)
66+
DataSerializer.write_file(snapshot_fossil, merge=True)
5967

6068

6169
__all__ = ["AmberSnapshotExtension", "DataSerializer"]

src/syrupy/extensions/amber/serializer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import functools
21
import os
32
from types import (
43
GeneratorType,
@@ -71,11 +70,16 @@ class DataSerializer:
7170
_marker_crn: str = "\r\n"
7271

7372
@classmethod
74-
def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None:
73+
def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None:
7574
"""
76-
Writes the snapshot data into the snapshot file that be read later.
75+
Writes the snapshot data into the snapshot file that can be read later.
7776
"""
7877
filepath = snapshot_fossil.location
78+
if merge:
79+
base_snapshot = cls.read_file(filepath)
80+
base_snapshot.merge(snapshot_fossil)
81+
snapshot_fossil = base_snapshot
82+
7983
with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
8084
for snapshot in sorted(snapshot_fossil, key=lambda s: s.name):
8185
snapshot_data = str(snapshot.data)
@@ -86,7 +90,6 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None:
8690
f.write(f"\n{cls._marker_divider}\n")
8791

8892
@classmethod
89-
@functools.lru_cache()
9093
def read_file(cls, filepath: str) -> "SnapshotFossil":
9194
"""
9295
Read the raw snapshot data (str) from the snapshot file into a dict

src/syrupy/extensions/base.py

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@
33
ABC,
44
abstractmethod,
55
)
6+
from collections import defaultdict
67
from difflib import ndiff
78
from gettext import gettext
89
from itertools import zip_longest
910
from pathlib import Path
1011
from typing import (
1112
TYPE_CHECKING,
1213
Callable,
14+
DefaultDict,
1315
Dict,
1416
Iterator,
1517
List,
1618
Optional,
1719
Set,
20+
Tuple,
1821
)
1922

2023
from syrupy.constants import (
@@ -115,7 +118,9 @@ def discover_snapshots(self) -> "SnapshotFossils":
115118

116119
return discovered
117120

118-
def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
121+
def read_snapshot(
122+
self, *, index: "SnapshotIndex", session_id: str
123+
) -> "SerializedData":
119124
"""
120125
Utility method for reading the contents of a snapshot assertion.
121126
Will call `_pre_read`, then perform `read` and finally `post_read`,
@@ -129,7 +134,9 @@ def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
129134
snapshot_location = self.get_location(index=index)
130135
snapshot_name = self.get_snapshot_name(index=index)
131136
snapshot_data = self._read_snapshot_data_from_location(
132-
snapshot_location=snapshot_location, snapshot_name=snapshot_name
137+
snapshot_location=snapshot_location,
138+
snapshot_name=snapshot_name,
139+
session_id=session_id,
133140
)
134141
if snapshot_data is None:
135142
raise SnapshotDoesNotExist()
@@ -145,33 +152,66 @@ def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> N
145152
This method is _final_, do not override. You can override
146153
`_write_snapshot_fossil` in a subclass to change behaviour.
147154
"""
148-
self._pre_write(data=data, index=index)
149-
snapshot_location = self.get_location(index=index)
150-
if not self.test_location.matches_snapshot_location(snapshot_location):
151-
warning_msg = gettext(
152-
"{line_end}Can not relate snapshot location '{}' to the test location."
153-
"{line_end}Consider adding '{}' to the generated location."
154-
).format(
155-
snapshot_location,
156-
self.test_location.filename,
157-
line_end="\n",
158-
)
159-
warnings.warn(warning_msg)
160-
snapshot_name = self.get_snapshot_name(index=index)
161-
if not self.test_location.matches_snapshot_name(snapshot_name):
162-
warning_msg = gettext(
163-
"{line_end}Can not relate snapshot name '{}' to the test location."
164-
"{line_end}Consider adding '{}' to the generated name."
165-
).format(
166-
snapshot_name,
167-
self.test_location.testname,
168-
line_end="\n",
169-
)
170-
warnings.warn(warning_msg)
171-
snapshot_fossil = SnapshotFossil(location=snapshot_location)
172-
snapshot_fossil.add(Snapshot(name=snapshot_name, data=data))
173-
self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil)
174-
self._post_write(data=data, index=index)
155+
self.write_snapshot_batch(snapshots=[(data, index)])
156+
157+
def write_snapshot_batch(
158+
self, *, snapshots: List[Tuple["SerializedData", "SnapshotIndex"]]
159+
) -> None:
160+
"""
161+
Utility method for writing the contents of multiple snapshot assertions.
162+
Will call `_pre_write` per snapshot, then perform `write` per snapshot
163+
and finally `_post_write`.
164+
165+
This method is _final_, do not override. You can override
166+
`_write_snapshot_fossil` in a subclass to change behaviour.
167+
"""
168+
# First we group by location since it'll let us batch by file on disk.
169+
# Not as useful for single file snapshots, but useful for the standard
170+
# Amber extension.
171+
locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list)
172+
for data, index in snapshots:
173+
location = self.get_location(index=index)
174+
snapshot_name = self.get_snapshot_name(index=index)
175+
locations[location].append(Snapshot(name=snapshot_name, data=data))
176+
177+
# Is there a better place to do the pre-writes?
178+
# Or can we remove the pre-write concept altogether?
179+
self._pre_write(data=data, index=index)
180+
181+
for location, location_snapshots in locations.items():
182+
snapshot_fossil = SnapshotFossil(location=location)
183+
184+
if not self.test_location.matches_snapshot_location(location):
185+
warning_msg = gettext(
186+
"{line_end}Can not relate snapshot location '{}' "
187+
"to the test location.{line_end}"
188+
"Consider adding '{}' to the generated location."
189+
).format(
190+
location,
191+
self.test_location.filename,
192+
line_end="\n",
193+
)
194+
warnings.warn(warning_msg)
195+
196+
for snapshot in location_snapshots:
197+
snapshot_fossil.add(snapshot)
198+
199+
if not self.test_location.matches_snapshot_name(snapshot.name):
200+
warning_msg = gettext(
201+
"{line_end}Can not relate snapshot name '{}' "
202+
"to the test location.{line_end}"
203+
"Consider adding '{}' to the generated name."
204+
).format(
205+
snapshot.name,
206+
self.test_location.testname,
207+
line_end="\n",
208+
)
209+
warnings.warn(warning_msg)
210+
211+
self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil)
212+
213+
for data, index in snapshots:
214+
self._post_write(data=data, index=index)
175215

176216
@abstractmethod
177217
def delete_snapshots(
@@ -206,7 +246,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil":
206246

207247
@abstractmethod
208248
def _read_snapshot_data_from_location(
209-
self, *, snapshot_location: str, snapshot_name: str
249+
self, *, snapshot_location: str, snapshot_name: str, session_id: str
210250
) -> Optional["SerializedData"]:
211251
"""
212252
Get only the snapshot data from location for assertion

src/syrupy/extensions/single_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil":
7777
return snapshot_fossil
7878

7979
def _read_snapshot_data_from_location(
80-
self, *, snapshot_location: str, snapshot_name: str
80+
self, *, snapshot_location: str, snapshot_name: str, session_id: str
8181
) -> Optional["SerializableData"]:
8282
try:
8383
with open(

src/syrupy/session.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
List,
1414
Optional,
1515
Set,
16+
Tuple,
1617
)
1718

1819
import pytest
1920

2021
from .constants import EXIT_STATUS_FAIL_UNUSED
2122
from .data import SnapshotFossils
2223
from .report import SnapshotReport
24+
from .types import (
25+
SerializedData,
26+
SnapshotIndex,
27+
)
2328

2429
if TYPE_CHECKING:
2530
from .assertion import SnapshotAssertion
@@ -43,6 +48,26 @@ class SnapshotSession:
4348
default_factory=lambda: defaultdict(set)
4449
)
4550

51+
_queued_snapshot_writes: Dict[
52+
"AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]]
53+
] = field(default_factory=dict)
54+
55+
def queue_snapshot_write(
56+
self,
57+
extension: "AbstractSyrupyExtension",
58+
data: "SerializedData",
59+
index: "SnapshotIndex",
60+
) -> None:
61+
queue = self._queued_snapshot_writes.get(extension, [])
62+
queue.append((data, index))
63+
self._queued_snapshot_writes[extension] = queue
64+
65+
def flush_snapshot_write_queue(self) -> None:
66+
for extension, queued_write in self._queued_snapshot_writes.items():
67+
if queued_write:
68+
extension.write_snapshot_batch(snapshots=queued_write)
69+
self._queued_snapshot_writes = {}
70+
4671
@property
4772
def update_snapshots(self) -> bool:
4873
return bool(self.pytest_session.config.option.update_snapshots)
@@ -72,6 +97,7 @@ def ran_item(self, nodeid: str) -> None:
7297

7398
def finish(self) -> int:
7499
exitstatus = 0
100+
self.flush_snapshot_write_queue()
75101
self.report = SnapshotReport(
76102
base_dir=self.pytest_session.config.rootdir,
77103
collected_items=self._collected_items,

0 commit comments

Comments
 (0)