@@ -21,6 +21,8 @@ class Cityscapes(VisionDataset):
21
21
and returns a transformed version. E.g, ``transforms.RandomCrop``
22
22
target_transform (callable, optional): A function/transform that takes in the
23
23
target and transforms it.
24
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
25
+ and returns a transformed version.
24
26
25
27
Examples:
26
28
@@ -95,8 +97,8 @@ class Cityscapes(VisionDataset):
95
97
]
96
98
97
99
def __init__ (self , root , split = 'train' , mode = 'fine' , target_type = 'instance' ,
98
- transform = None , target_transform = None ):
99
- super (Cityscapes , self ).__init__ (root )
100
+ transform = None , target_transform = None , transforms = None ):
101
+ super (Cityscapes , self ).__init__ (root , transforms , transform , target_transform )
100
102
self .transform = transform
101
103
self .target_transform = target_transform
102
104
self .mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
@@ -163,11 +165,8 @@ def __getitem__(self, index):
163
165
164
166
target = tuple (targets ) if len (targets ) > 1 else targets [0 ]
165
167
166
- if self .transform :
167
- image = self .transform (image )
168
-
169
- if self .target_transform :
170
- target = self .target_transform (target )
168
+ if self .transforms is not None :
169
+ image , target = self .transforms (image , target )
171
170
172
171
return image , target
173
172
0 commit comments