@@ -14,15 +14,19 @@ class FakeData(data.Dataset):
14
14
and returns a transformed version. E.g, ``transforms.RandomCrop``
15
15
target_transform (callable, optional): A function/transform that takes in the
16
16
target and transforms it.
17
+ random_offset (int): Offsets the index-based random seed used to
18
+ generate each image. Default: 0
17
19
18
20
"""
19
21
20
- def __init__ (self , size = 1000 , image_size = (3 , 224 , 224 ), num_classes = 10 , transform = None , target_transform = None ):
22
+ def __init__ (self , size = 1000 , image_size = (3 , 224 , 224 ), num_classes = 10 ,
23
+ transform = None , target_transform = None , random_offset = 0 ):
21
24
self .size = size
22
25
self .num_classes = num_classes
23
26
self .image_size = image_size
24
27
self .transform = transform
25
28
self .target_transform = target_transform
29
+ self .random_offset = random_offset
26
30
27
31
def __getitem__ (self , index ):
28
32
"""
@@ -34,7 +38,7 @@ def __getitem__(self, index):
34
38
"""
35
39
# create random image that is consistent with the index id
36
40
rng_state = torch .get_rng_state ()
37
- torch .manual_seed (index )
41
+ torch .manual_seed (index + self . random_offset )
38
42
img = torch .randn (* self .image_size )
39
43
target = torch .Tensor (1 ).random_ (0 , self .num_classes )[0 ]
40
44
torch .set_rng_state (rng_state )
0 commit comments