Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@


class FlowDataset(ABC, VisionDataset):
# Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what `valid` should be.
# Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False

def __init__(self, root, transforms=None):
Expand All @@ -38,11 +38,14 @@ def __init__(self, root, transforms=None):
self._image_list = []

def _read_img(self, file_name):
return Image.open(file_name)
img = Image.open(file_name)
if img.mode != "RGB":
img = img.convert("RGB")
return img

@abstractmethod
def _read_flow(self, file_name):
# Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
pass

def __getitem__(self, index):
Expand All @@ -53,23 +56,27 @@ def __getitem__(self, index):
if self._flow_list: # it will be empty for some dataset when split="test"
flow = self._read_flow(self._flow_list[index])
if self._has_builtin_flow_mask:
flow, valid = flow
flow, valid_flow_mask = flow
else:
valid = None
valid_flow_mask = None
else:
flow = valid = None
flow = valid_flow_mask = None

if self.transforms is not None:
img1, img2, flow, valid = self.transforms(img1, img2, flow, valid)
img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)

if self._has_builtin_flow_mask:
return img1, img2, flow, valid
if self._has_builtin_flow_mask or valid_flow_mask is not None:
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
return img1, img2, flow, valid_flow_mask
else:
return img1, img2, flow

def __len__(self):
return len(self._image_list)

def __rmul__(self, v):
return torch.utils.data.ConcatDataset([self] * v)


class Sintel(FlowDataset):
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
Expand Down Expand Up @@ -107,8 +114,8 @@ class Sintel(FlowDataset):
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

Expand Down Expand Up @@ -140,9 +147,11 @@ def __getitem__(self, index):
index(int): The index of the example to retrieve

Returns:
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
3-tuple with ``(img1, img2, None)`` is returned.
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return super().__getitem__(index)

Expand All @@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
root (string): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True
Expand Down Expand Up @@ -199,11 +208,11 @@ def __getitem__(self, index):
index(int): The index of the example to retrieve

Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
``split="test"``.
"""
return super().__getitem__(index)

Expand Down Expand Up @@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

Expand Down Expand Up @@ -269,6 +278,9 @@ def __getitem__(self, index):
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="val"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return super().__getitem__(index)

Expand Down Expand Up @@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

Expand Down Expand Up @@ -357,6 +369,9 @@ def __getitem__(self, index):
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return super().__getitem__(index)

Expand All @@ -382,7 +397,7 @@ class HD1K(FlowDataset):
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True
Expand Down Expand Up @@ -422,11 +437,11 @@ def __getitem__(self, index):
index(int): The index of the example to retrieve

Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
``split="test"``.
"""
return super().__getitem__(index)

Expand All @@ -451,11 +466,12 @@ def _read_flo(file_name):
def _read_16bits_png_with_flow_and_valid_mask(file_name):

flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive
valid_flow_mask = valid_flow_mask.bool()

# For consistency with other datasets, we convert to numpy
return flow.numpy(), valid.numpy()
return flow.numpy(), valid_flow_mask.numpy()


def _read_pfm(file_name):
Expand Down