12
12
Dict ,
13
13
List ,
14
14
Optional ,
15
+ Tuple ,
15
16
Type ,
16
17
)
17
18
18
- from .exceptions import SnapshotDoesNotExist
19
+ from .exceptions import (
20
+ SnapshotDoesNotExist ,
21
+ TaintedSnapshotError ,
22
+ )
19
23
from .extensions .amber .serializer import Repr
20
24
21
25
if TYPE_CHECKING :
@@ -125,13 +129,15 @@ def __repr(self) -> "SerializableData":
125
129
SnapshotAssertionRepr = namedtuple ( # type: ignore
126
130
"SnapshotAssertion" , ["name" , "num_executions" ]
127
131
)
128
- assertion_result = self . executions . get (
129
- ( self ._custom_index and self ._execution_name_index .get (self ._custom_index ) )
130
- or self .num_executions - 1
131
- )
132
+ execution_index = (
133
+ self ._custom_index and self ._execution_name_index .get (self ._custom_index )
134
+ ) or self .num_executions - 1
135
+ assertion_result = self . executions . get ( execution_index )
132
136
return (
133
137
Repr (str (assertion_result .final_data ))
134
- if assertion_result
138
+ if execution_index in self .executions
139
+ and assertion_result
140
+ and assertion_result .final_data is not None
135
141
else SnapshotAssertionRepr (
136
142
name = self .name ,
137
143
num_executions = self .num_executions ,
@@ -179,15 +185,23 @@ def _serialize(self, data: "SerializableData") -> "SerializedData":
179
185
def get_assert_diff (self ) -> List [str ]:
180
186
assertion_result = self ._execution_results [self .num_executions - 1 ]
181
187
if assertion_result .exception :
182
- lines = [
183
- line
184
- for lines in traceback .format_exception (
185
- assertion_result .exception .__class__ ,
186
- assertion_result .exception ,
187
- assertion_result .exception .__traceback__ ,
188
- )
189
- for line in lines .splitlines ()
190
- ]
188
+ if isinstance (assertion_result .exception , (TaintedSnapshotError ,)):
189
+ lines = [
190
+ gettext (
191
+ "This snapshot needs to be regenerated. "
192
+ "This is typically due to a major Syrupy update."
193
+ )
194
+ ]
195
+ else :
196
+ lines = [
197
+ line
198
+ for lines in traceback .format_exception (
199
+ assertion_result .exception .__class__ ,
200
+ assertion_result .exception ,
201
+ assertion_result .exception .__traceback__ ,
202
+ )
203
+ for line in lines .splitlines ()
204
+ ]
191
205
# Rotate to place exception with message at first line
192
206
return lines [- 1 :] + lines [:- 1 ]
193
207
snapshot_data = assertion_result .recalled_data
@@ -232,7 +246,7 @@ def __call__(
232
246
return self
233
247
234
248
def __repr__ (self ) -> str :
235
- return str (self ._serialize ( self . __repr ) )
249
+ return str (self .__repr )
236
250
237
251
def __eq__ (self , other : "SerializableData" ) -> bool :
238
252
return self ._assert (other )
@@ -250,29 +264,36 @@ def _assert(self, data: "SerializableData") -> bool:
250
264
assertion_success = False
251
265
assertion_exception = None
252
266
try :
253
- snapshot_data = self ._recall_data (index = self .index )
267
+ snapshot_data , tainted = self ._recall_data (index = self .index )
254
268
serialized_data = self ._serialize (data )
255
269
snapshot_diff = getattr (self , "_snapshot_diff" , None )
256
270
if snapshot_diff is not None :
257
- snapshot_data_diff = self ._recall_data (index = snapshot_diff )
271
+ snapshot_data_diff , _ = self ._recall_data (index = snapshot_diff )
258
272
if snapshot_data_diff is None :
259
273
raise SnapshotDoesNotExist ()
260
274
serialized_data = self .extension .diff_snapshots (
261
275
serialized_data = serialized_data ,
262
276
snapshot_data = snapshot_data_diff ,
263
277
)
264
- matches = snapshot_data is not None and self .extension .matches (
265
- serialized_data = serialized_data , snapshot_data = snapshot_data
278
+ matches = (
279
+ not tainted
280
+ and snapshot_data is not None
281
+ and self .extension .matches (
282
+ serialized_data = serialized_data , snapshot_data = snapshot_data
283
+ )
266
284
)
267
285
assertion_success = matches
268
- if not matches and self .update_snapshots :
269
- self .session .queue_snapshot_write (
270
- extension = self .extension ,
271
- test_location = self .test_location ,
272
- data = serialized_data ,
273
- index = self .index ,
274
- )
275
- assertion_success = True
286
+ if not matches :
287
+ if self .update_snapshots :
288
+ self .session .queue_snapshot_write (
289
+ extension = self .extension ,
290
+ test_location = self .test_location ,
291
+ data = serialized_data ,
292
+ index = self .index ,
293
+ )
294
+ assertion_success = True
295
+ elif tainted :
296
+ raise TaintedSnapshotError
276
297
return assertion_success
277
298
except Exception as e :
278
299
assertion_exception = e
@@ -301,12 +322,19 @@ def _post_assert(self) -> None:
301
322
while self ._post_assert_actions :
302
323
self ._post_assert_actions .pop ()()
303
324
304
- def _recall_data (self , index : "SnapshotIndex" ) -> Optional ["SerializableData" ]:
325
+ def _recall_data (
326
+ self , index : "SnapshotIndex"
327
+ ) -> Tuple [Optional ["SerializableData" ], bool ]:
305
328
try :
306
- return self .extension .read_snapshot (
307
- test_location = self .test_location ,
308
- index = index ,
309
- session_id = str (id (self .session )),
329
+ return (
330
+ self .extension .read_snapshot (
331
+ test_location = self .test_location ,
332
+ index = index ,
333
+ session_id = str (id (self .session )),
334
+ ),
335
+ False ,
310
336
)
311
337
except SnapshotDoesNotExist :
312
- return None
338
+ return None , False
339
+ except TaintedSnapshotError as e :
340
+ return e .snapshot_data , True
0 commit comments