Skip to content

Commit aafaa2a

Browse files
Erotemicalykhantejani
authored andcommitted
param to cause FakeData to generate different images (#358)
* param to cause FakeData to generate different images
1 parent 4927309 commit aafaa2a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchvision/datasets/fakedata.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ class FakeData(data.Dataset):
1414
and returns a transformed version. E.g, ``transforms.RandomCrop``
1515
target_transform (callable, optional): A function/transform that takes in the
1616
target and transforms it.
17+
random_offset (int): Offsets the index-based random seed used to
18+
generate each image. Default: 0
1719
1820
"""
1921

20-
def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None):
22+
def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
23+
transform=None, target_transform=None, random_offset=0):
2124
self.size = size
2225
self.num_classes = num_classes
2326
self.image_size = image_size
2427
self.transform = transform
2528
self.target_transform = target_transform
29+
self.random_offset = random_offset
2630

2731
def __getitem__(self, index):
2832
"""
@@ -34,7 +38,7 @@ def __getitem__(self, index):
3438
"""
3539
# create random image that is consistent with the index id
3640
rng_state = torch.get_rng_state()
37-
torch.manual_seed(index)
41+
torch.manual_seed(index + self.random_offset)
3842
img = torch.randn(*self.image_size)
3943
target = torch.Tensor(1).random_(0, self.num_classes)[0]
4044
torch.set_rng_state(rng_state)

0 commit comments

Comments
 (0)