3
3
ABC ,
4
4
abstractmethod ,
5
5
)
6
+ from collections import defaultdict
6
7
from difflib import ndiff
7
8
from gettext import gettext
8
9
from itertools import zip_longest
9
10
from pathlib import Path
10
11
from typing import (
11
12
TYPE_CHECKING ,
12
13
Callable ,
14
+ DefaultDict ,
13
15
Dict ,
14
16
Iterator ,
15
17
List ,
16
18
Optional ,
17
19
Set ,
20
+ Tuple ,
18
21
)
19
22
20
23
from syrupy .constants import (
@@ -115,7 +118,9 @@ def discover_snapshots(self) -> "SnapshotFossils":
115
118
116
119
return discovered
117
120
118
- def read_snapshot (self , * , index : "SnapshotIndex" ) -> "SerializedData" :
121
+ def read_snapshot (
122
+ self , * , index : "SnapshotIndex" , session_id : str
123
+ ) -> "SerializedData" :
119
124
"""
120
125
Utility method for reading the contents of a snapshot assertion.
121
126
Will call `_pre_read`, then perform `read` and finally `post_read`,
@@ -129,7 +134,9 @@ def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
129
134
snapshot_location = self .get_location (index = index )
130
135
snapshot_name = self .get_snapshot_name (index = index )
131
136
snapshot_data = self ._read_snapshot_data_from_location (
132
- snapshot_location = snapshot_location , snapshot_name = snapshot_name
137
+ snapshot_location = snapshot_location ,
138
+ snapshot_name = snapshot_name ,
139
+ session_id = session_id ,
133
140
)
134
141
if snapshot_data is None :
135
142
raise SnapshotDoesNotExist ()
@@ -145,33 +152,66 @@ def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> N
145
152
This method is _final_, do not override. You can override
146
153
`_write_snapshot_fossil` in a subclass to change behaviour.
147
154
"""
148
- self ._pre_write (data = data , index = index )
149
- snapshot_location = self .get_location (index = index )
150
- if not self .test_location .matches_snapshot_location (snapshot_location ):
151
- warning_msg = gettext (
152
- "{line_end}Can not relate snapshot location '{}' to the test location."
153
- "{line_end}Consider adding '{}' to the generated location."
154
- ).format (
155
- snapshot_location ,
156
- self .test_location .filename ,
157
- line_end = "\n " ,
158
- )
159
- warnings .warn (warning_msg )
160
- snapshot_name = self .get_snapshot_name (index = index )
161
- if not self .test_location .matches_snapshot_name (snapshot_name ):
162
- warning_msg = gettext (
163
- "{line_end}Can not relate snapshot name '{}' to the test location."
164
- "{line_end}Consider adding '{}' to the generated name."
165
- ).format (
166
- snapshot_name ,
167
- self .test_location .testname ,
168
- line_end = "\n " ,
169
- )
170
- warnings .warn (warning_msg )
171
- snapshot_fossil = SnapshotFossil (location = snapshot_location )
172
- snapshot_fossil .add (Snapshot (name = snapshot_name , data = data ))
173
- self ._write_snapshot_fossil (snapshot_fossil = snapshot_fossil )
174
- self ._post_write (data = data , index = index )
155
+ self .write_snapshot_batch (snapshots = [(data , index )])
156
+
157
+ def write_snapshot_batch (
158
+ self , * , snapshots : List [Tuple ["SerializedData" , "SnapshotIndex" ]]
159
+ ) -> None :
160
+ """
161
+ Utility method for writing the contents of multiple snapshot assertions.
162
+ Will call `_pre_write` per snapshot, then perform `write` per snapshot
163
+ and finally `_post_write`.
164
+
165
+ This method is _final_, do not override. You can override
166
+ `_write_snapshot_fossil` in a subclass to change behaviour.
167
+ """
168
+ # First we group by location since it'll let us batch by file on disk.
169
+ # Not as useful for single file snapshots, but useful for the standard
170
+ # Amber extension.
171
+ locations : DefaultDict [str , List ["Snapshot" ]] = defaultdict (list )
172
+ for data , index in snapshots :
173
+ location = self .get_location (index = index )
174
+ snapshot_name = self .get_snapshot_name (index = index )
175
+ locations [location ].append (Snapshot (name = snapshot_name , data = data ))
176
+
177
+ # Is there a better place to do the pre-writes?
178
+ # Or can we remove the pre-write concept altogether?
179
+ self ._pre_write (data = data , index = index )
180
+
181
+ for location , location_snapshots in locations .items ():
182
+ snapshot_fossil = SnapshotFossil (location = location )
183
+
184
+ if not self .test_location .matches_snapshot_location (location ):
185
+ warning_msg = gettext (
186
+ "{line_end}Can not relate snapshot location '{}' "
187
+ "to the test location.{line_end}"
188
+ "Consider adding '{}' to the generated location."
189
+ ).format (
190
+ location ,
191
+ self .test_location .filename ,
192
+ line_end = "\n " ,
193
+ )
194
+ warnings .warn (warning_msg )
195
+
196
+ for snapshot in location_snapshots :
197
+ snapshot_fossil .add (snapshot )
198
+
199
+ if not self .test_location .matches_snapshot_name (snapshot .name ):
200
+ warning_msg = gettext (
201
+ "{line_end}Can not relate snapshot name '{}' "
202
+ "to the test location.{line_end}"
203
+ "Consider adding '{}' to the generated name."
204
+ ).format (
205
+ snapshot .name ,
206
+ self .test_location .testname ,
207
+ line_end = "\n " ,
208
+ )
209
+ warnings .warn (warning_msg )
210
+
211
+ self ._write_snapshot_fossil (snapshot_fossil = snapshot_fossil )
212
+
213
+ for data , index in snapshots :
214
+ self ._post_write (data = data , index = index )
175
215
176
216
@abstractmethod
177
217
def delete_snapshots (
@@ -206,7 +246,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil":
206
246
207
247
@abstractmethod
208
248
def _read_snapshot_data_from_location (
209
- self , * , snapshot_location : str , snapshot_name : str
249
+ self , * , snapshot_location : str , snapshot_name : str , session_id : str
210
250
) -> Optional ["SerializedData" ]:
211
251
"""
212
252
Get only the snapshot data from location for assertion
0 commit comments