Skip to content

Commit 5147d8b

Browse files
committed
Add segmentation
1 parent c00a181 commit 5147d8b

File tree

6 files changed

+119
-45
lines changed

6 files changed

+119
-45
lines changed

references/detection/coco_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def get_coco(root, image_set, transforms, mode="instances"):
207207
img_folder = os.path.join(root, img_folder)
208208
ann_file = os.path.join(root, ann_file)
209209

210-
211210
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
212211
dataset = wrap_dataset_for_transforms_v2(dataset)
213212

references/detection/presets.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections import defaultdict
22

33
import torch
4-
import transforms as reference_transforms
54
import torchvision
5+
import transforms as reference_transforms
6+
67
torchvision.disable_beta_transforms_warning()
7-
from torchvision import datapoints
88
import torchvision.transforms.v2 as T
9+
from torchvision import datapoints
910

1011

1112
# TODO: Should we provide a transforms that filters-out keys?
@@ -64,7 +65,9 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104
6465
transforms += [
6566
T.ConvertImageDtype(torch.float),
6667
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
67-
T.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]) # TODO: sad it's not the default!
68+
T.SanitizeBoundingBoxes(
69+
labels_getter=lambda sample: sample[1]["labels"]
70+
), # TODO: sad it's not the default!
6871
]
6972

7073
super().__init__(transforms)
@@ -78,9 +81,8 @@ def __init__(self, backend="pil"):
7881
backend = backend.lower()
7982
if backend == "tensor":
8083
transforms.append(T.PILToTensor())
81-
else: # for datapoint **and** PIL
84+
else: # for datapoint **and** PIL
8285
transforms.append(T.ToImageTensor())
8386

84-
8587
transforms.append(T.ConvertImageDtype(torch.float))
8688
super().__init__(transforms)

references/segmentation/presets.py

Lines changed: 77 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,82 @@
1+
from collections import defaultdict
2+
13
import torch
2-
import transforms as T
4+
import torchvision
5+
6+
torchvision.disable_beta_transforms_warning()
7+
import torchvision.transforms.v2 as T
8+
from torchvision import datapoints
9+
from transforms import PadIfSmaller, WrapIntoFeatures
10+
11+
12+
class SegmentationPresetTrain(T.Compose):
13+
def __init__(
14+
self,
15+
*,
16+
base_size,
17+
crop_size,
18+
hflip_prob=0.5,
19+
mean=(0.485, 0.456, 0.406),
20+
std=(0.229, 0.224, 0.225),
21+
backend="pil",
22+
):
23+
24+
transforms = []
325

26+
transforms.append(WrapIntoFeatures())
427

5-
class SegmentationPresetTrain:
6-
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
7-
min_size = int(0.5 * base_size)
8-
max_size = int(2.0 * base_size)
28+
backend = backend.lower()
29+
if backend == "datapoint":
30+
transforms.append(T.ToImageTensor())
31+
elif backend == "tensor":
32+
transforms.append(T.PILToTensor())
33+
else:
34+
assert backend == "pil"
35+
36+
transforms.append(T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True))
937

10-
trans = [T.RandomResize(min_size, max_size)]
1138
if hflip_prob > 0:
12-
trans.append(T.RandomHorizontalFlip(hflip_prob))
13-
trans.extend(
14-
[
15-
T.RandomCrop(crop_size),
16-
T.PILToTensor(),
17-
T.ConvertImageDtype(torch.float),
18-
T.Normalize(mean=mean, std=std),
19-
]
20-
)
21-
self.transforms = T.Compose(trans)
22-
23-
def __call__(self, img, target):
24-
return self.transforms(img, target)
25-
26-
27-
class SegmentationPresetEval:
28-
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
29-
self.transforms = T.Compose(
30-
[
31-
T.RandomResize(base_size, base_size),
32-
T.PILToTensor(),
33-
T.ConvertImageDtype(torch.float),
34-
T.Normalize(mean=mean, std=std),
35-
]
36-
)
37-
38-
def __call__(self, img, target):
39-
return self.transforms(img, target)
39+
transforms.append(T.RandomHorizontalFlip(hflip_prob))
40+
41+
transforms += [
42+
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
43+
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
44+
PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
45+
T.RandomCrop(crop_size),
46+
]
47+
48+
if backend == "pil":
49+
transforms.append(T.ToImageTensor())
50+
51+
transforms += [
52+
T.ConvertImageDtype(torch.float),
53+
T.Normalize(mean=mean, std=std),
54+
]
55+
56+
super().__init__(transforms)
57+
58+
59+
class SegmentationPresetEval(T.Compose):
60+
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil"):
61+
transforms = []
62+
63+
transforms.append(WrapIntoFeatures())
64+
65+
backend = backend.lower()
66+
if backend == "datapoint":
67+
transforms.append(T.ToImageTensor())
68+
elif backend == "tensor":
69+
transforms.append(T.PILToTensor())
70+
else:
71+
assert backend == "pil"
72+
73+
transforms.append(T.Resize(base_size, antialias=True))
74+
75+
if backend == "pil":
76+
transforms.append(T.ToImageTensor())
77+
78+
transforms += [
79+
T.ConvertImageDtype(torch.float),
80+
T.Normalize(mean=mean, std=std),
81+
]
82+
super().__init__(transforms)

references/segmentation/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def sbd(*args, **kwargs):
3131

3232
def get_transform(train, args):
3333
if train:
34-
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
34+
return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend)
3535
elif args.weights and args.test_only:
3636
weights = torchvision.models.get_weight(args.weights)
3737
trans = weights.transforms()
@@ -44,7 +44,7 @@ def preprocessing(img, target):
4444

4545
return preprocessing
4646
else:
47-
return presets.SegmentationPresetEval(base_size=520)
47+
return presets.SegmentationPresetEval(base_size=520, backend=args.backend)
4848

4949

5050
def criterion(inputs, target):
@@ -306,6 +306,7 @@ def get_args_parser(add_help=True):
306306

307307
# Mixed precision training parameters
308308
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
309+
parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive")
309310

310311
return parser
311312

references/segmentation/transforms.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,41 @@
22

33
import numpy as np
44
import torch
5-
from torchvision import transforms as T
5+
import torchvision.transforms.v2 as PT
6+
import torchvision.transforms.v2.functional as PF
7+
from torchvision import datapoints, transforms as T
68
from torchvision.transforms import functional as F
79

810

11+
class WrapIntoFeatures(PT.Transform):
12+
def forward(self, sample):
13+
image, mask = sample
14+
# return PF.to_image_tensor(image), datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64)
15+
return image, datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64)
16+
17+
18+
class PadIfSmaller(PT.Transform):
19+
def __init__(self, size, fill=0):
20+
super().__init__()
21+
self.size = size
22+
self.fill = PT._geometry._setup_fill_arg(fill)
23+
24+
def _get_params(self, sample):
25+
_, height, width = PT.utils.query_chw(sample)
26+
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
27+
needs_padding = any(padding)
28+
return dict(padding=padding, needs_padding=needs_padding)
29+
30+
def _transform(self, inpt, params):
31+
if not params["needs_padding"]:
32+
return inpt
33+
34+
fill = self.fill[type(inpt)]
35+
fill = PT._utils._convert_fill_arg(fill)
36+
37+
return PF.pad(inpt, padding=params["padding"], fill=fill)
38+
39+
940
def pad_if_smaller(img, size, fill=0):
1041
min_size = min(img.size)
1142
if min_size < size:

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ def wrapper(idx, sample):
253253
len(batched_target["keypoints"]), -1, 3
254254
)
255255

256-
257-
258256
return image, batched_target
259257

260258
return wrapper

0 commit comments

Comments
 (0)