Skip to content

Commit ec20315

Browse files
TheCodezfmassa
authored andcommitted
Use joint transform in Cityscapes (#1024)
* Use joint transform in Cityscapes * Add transforms doc
1 parent ae2cb6e commit ec20315

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

torchvision/datasets/cityscapes.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class Cityscapes(VisionDataset):
2121
and returns a transformed version. E.g, ``transforms.RandomCrop``
2222
target_transform (callable, optional): A function/transform that takes in the
2323
target and transforms it.
24+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
25+
and returns a transformed version.
2426
2527
Examples:
2628
@@ -95,8 +97,8 @@ class Cityscapes(VisionDataset):
9597
]
9698

9799
def __init__(self, root, split='train', mode='fine', target_type='instance',
98-
transform=None, target_transform=None):
99-
super(Cityscapes, self).__init__(root)
100+
transform=None, target_transform=None, transforms=None):
101+
super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
100102
self.transform = transform
101103
self.target_transform = target_transform
102104
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
@@ -163,11 +165,8 @@ def __getitem__(self, index):
163165

164166
target = tuple(targets) if len(targets) > 1 else targets[0]
165167

166-
if self.transform:
167-
image = self.transform(image)
168-
169-
if self.target_transform:
170-
target = self.target_transform(target)
168+
if self.transforms is not None:
169+
image, target = self.transforms(image, target)
171170

172171
return image, target
173172

0 commit comments

Comments
 (0)