@@ -200,10 +200,14 @@ def test_fuzz(self):
200
200
except Exception :
201
201
pass
202
202
203
- def test_loads_recursion (self ):
203
+ def test_loads_2x_code (self ):
204
204
s = b'c' + (b'X' * 4 * 4 ) + b'{' * 2 ** 20
205
205
self .assertRaises (ValueError , marshal .loads , s )
206
206
207
+ def test_loads_recursion (self ):
208
+ s = b'c' + (b'X' * 4 * 5 ) + b'{' * 2 ** 20
209
+ self .assertRaises (ValueError , marshal .loads , s )
210
+
207
211
def test_recursion_limit (self ):
208
212
# Create a deeply nested structure.
209
213
head = last = []
@@ -323,6 +327,122 @@ def test_frozenset(self, size):
323
327
def test_bytearray (self , size ):
324
328
self .check_unmarshallable (bytearray (size ))
325
329
330
+ def CollectObjectIDs (ids , obj ):
331
+ """Collect object ids seen in a structure"""
332
+ if id (obj ) in ids :
333
+ return
334
+ ids .add (id (obj ))
335
+ if isinstance (obj , (list , tuple , set , frozenset )):
336
+ for e in obj :
337
+ CollectObjectIDs (ids , e )
338
+ elif isinstance (obj , dict ):
339
+ for k , v in obj .items ():
340
+ CollectObjectIDs (ids , k )
341
+ CollectObjectIDs (ids , v )
342
+ return len (ids )
343
+
344
+ class InstancingTestCase (unittest .TestCase , HelperMixin ):
345
+ intobj = 123321
346
+ floatobj = 1.2345
347
+ strobj = "abcde" * 3
348
+ dictobj = {"hello" :floatobj , "goodbye" :floatobj , floatobj :"hello" }
349
+
350
+ def helper3 (self , rsample , recursive = False , simple = False ):
351
+ #we have two instances
352
+ sample = (rsample , rsample )
353
+
354
+ n0 = CollectObjectIDs (set (), sample )
355
+
356
+ s3 = marshal .dumps (sample , 3 )
357
+ n3 = CollectObjectIDs (set (), marshal .loads (s3 ))
358
+
359
+ #same number of instances generated
360
+ self .assertEqual (n3 , n0 )
361
+
362
+ if not recursive :
363
+ #can compare with version 2
364
+ s2 = marshal .dumps (sample , 2 )
365
+ n2 = CollectObjectIDs (set (), marshal .loads (s2 ))
366
+ #old format generated more instances
367
+ self .assertGreater (n2 , n0 )
368
+
369
+ #if complex objects are in there, old format is larger
370
+ if not simple :
371
+ self .assertGreater (len (s2 ), len (s3 ))
372
+ else :
373
+ self .assertGreaterEqual (len (s2 ), len (s3 ))
374
+
375
+ def testInt (self ):
376
+ self .helper (self .intobj )
377
+ self .helper3 (self .intobj , simple = True )
378
+
379
+ def testFloat (self ):
380
+ self .helper (self .floatobj )
381
+ self .helper3 (self .floatobj )
382
+
383
+ def testStr (self ):
384
+ self .helper (self .strobj )
385
+ self .helper3 (self .strobj )
386
+
387
+ def testDict (self ):
388
+ self .helper (self .dictobj )
389
+ self .helper3 (self .dictobj )
390
+
391
+ def testModule (self ):
392
+ with open (__file__ , "rb" ) as f :
393
+ code = f .read ()
394
+ if __file__ .endswith (".py" ):
395
+ code = compile (code , __file__ , "exec" )
396
+ self .helper (code )
397
+ self .helper3 (code )
398
+
399
+ def testRecursion (self ):
400
+ d = dict (self .dictobj )
401
+ d ["self" ] = d
402
+ self .helper3 (d , recursive = True )
403
+ l = [self .dictobj ]
404
+ l .append (l )
405
+ self .helper3 (l , recursive = True )
406
+
407
+ class CompatibilityTestCase (unittest .TestCase ):
408
+ def _test (self , version ):
409
+ with open (__file__ , "rb" ) as f :
410
+ code = f .read ()
411
+ if __file__ .endswith (".py" ):
412
+ code = compile (code , __file__ , "exec" )
413
+ data = marshal .dumps (code , version )
414
+ marshal .loads (data )
415
+
416
+ def test0To3 (self ):
417
+ self ._test (0 )
418
+
419
+ def test1To3 (self ):
420
+ self ._test (1 )
421
+
422
+ def test2To3 (self ):
423
+ self ._test (2 )
424
+
425
+ def test3To3 (self ):
426
+ self ._test (3 )
427
+
428
+ class InterningTestCase (unittest .TestCase , HelperMixin ):
429
+ strobj = "this is an interned string"
430
+ strobj = sys .intern (strobj )
431
+
432
+ def testIntern (self ):
433
+ s = marshal .loads (marshal .dumps (self .strobj ))
434
+ self .assertEqual (s , self .strobj )
435
+ self .assertEqual (id (s ), id (self .strobj ))
436
+ s2 = sys .intern (s )
437
+ self .assertEqual (id (s2 ), id (s ))
438
+
439
+ def testNoIntern (self ):
440
+ s = marshal .loads (marshal .dumps (self .strobj , 2 ))
441
+ self .assertEqual (s , self .strobj )
442
+ self .assertNotEqual (id (s ), id (self .strobj ))
443
+ s2 = sys .intern (s )
444
+ self .assertNotEqual (id (s2 ), id (s ))
445
+
326
446
327
447
def test_main ():
328
448
support .run_unittest (IntTestCase ,
0 commit comments