|
| 1 | +""" |
| 2 | +===================================================== |
| 3 | +Optical Flow: Predicting movement with the RAFT model |
| 4 | +===================================================== |
| 5 | +
|
| 6 | +Optical flow is the task of predicting movement between two images, usually two |
| 7 | +consecutive frames of a video. Optical flow models take two images as input, and |
| 8 | +predict a flow: the flow indicates the displacement of every single pixel in the |
| 9 | +first image, and maps it to its corresponding pixel in the second image. Flows |
| 10 | +are (2, H, W)-dimensional tensors, where the first axis corresponds to the |
| 11 | +predicted horizontal and vertical displacements. |
| 12 | +
|
| 13 | +The following example illustrates how torchvision can be used to predict flows |
| 14 | +using our implementation of the RAFT model. We will also see how to convert the |
| 15 | +predicted flows to RGB images for visualization. |
| 16 | +""" |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | +import matplotlib.pyplot as plt |
| 21 | +import torchvision.transforms.functional as F |
| 22 | +import torchvision.transforms as T |
| 23 | + |
| 24 | + |
| 25 | +plt.rcParams["savefig.bbox"] = "tight" |
| 26 | +# sphinx_gallery_thumbnail_number = 2 |
| 27 | + |
| 28 | + |
| 29 | +def plot(imgs, **imshow_kwargs): |
| 30 | + if not isinstance(imgs[0], list): |
| 31 | + # Make a 2d grid even if there's just 1 row |
| 32 | + imgs = [imgs] |
| 33 | + |
| 34 | + num_rows = len(imgs) |
| 35 | + num_cols = len(imgs[0]) |
| 36 | + _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) |
| 37 | + for row_idx, row in enumerate(imgs): |
| 38 | + for col_idx, img in enumerate(row): |
| 39 | + ax = axs[row_idx, col_idx] |
| 40 | + img = F.to_pil_image(img.to("cpu")) |
| 41 | + ax.imshow(np.asarray(img), **imshow_kwargs) |
| 42 | + ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
| 43 | + |
| 44 | + plt.tight_layout() |
| 45 | + |
| 46 | +################################### |
| 47 | +# Reading Videos Using Torchvision |
| 48 | +# -------------------------------- |
| 49 | +# We will first read a video using :func:`~torchvision.io.read_video`. |
| 50 | +# Alternatively one can use the new :class:`~torchvision.io.VideoReader` API (if |
| 51 | +# torchvision is built from source). |
| 52 | +# The video we will use here is free of use from `pexels.com |
| 53 | +# <https://www.pexels.com/video/a-man-playing-a-game-of-basketball-5192157/>`_, |
| 54 | +# credits go to `Pavel Danilyuk <https://www.pexels.com/@pavel-danilyuk>`_. |
| 55 | + |
| 56 | + |
| 57 | +import tempfile |
| 58 | +from pathlib import Path |
| 59 | +from urllib.request import urlretrieve |
| 60 | + |
| 61 | + |
| 62 | +video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4" |
| 63 | +video_path = Path(tempfile.mkdtemp()) / "basketball.mp4" |
| 64 | +_ = urlretrieve(video_url, video_path) |
| 65 | + |
| 66 | +######################### |
| 67 | +# :func:`~torchvision.io.read_video` returns the video frames, audio frames and |
| 68 | +# the metadata associated with the video. In our case, we only need the video |
| 69 | +# frames. |
| 70 | +# |
| 71 | +# Here we will just make 2 predictions between 2 pre-selected pairs of frames, |
| 72 | +# namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a |
| 73 | +# single model input. |
| 74 | + |
| 75 | +from torchvision.io import read_video |
| 76 | +frames, _, _ = read_video(str(video_path)) |
| 77 | +frames = frames.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) |
| 78 | + |
| 79 | +img1_batch = torch.stack([frames[100], frames[150]]) |
| 80 | +img2_batch = torch.stack([frames[101], frames[151]]) |
| 81 | + |
| 82 | +plot(img1_batch) |
| 83 | + |
| 84 | +######################### |
| 85 | +# The RAFT model that we will use accepts RGB float images with pixel values in |
| 86 | +# [-1, 1]. The frames we got from :func:`~torchvision.io.read_video` are int |
| 87 | +# images with values in [0, 255], so we will have to pre-process them. We also |
| 88 | +# reduce the image sizes for the example to run faster. Image dimension must be |
| 89 | +# divisible by 8. |
| 90 | + |
| 91 | + |
| 92 | +def preprocess(batch): |
| 93 | + transforms = T.Compose( |
| 94 | + [ |
| 95 | + T.ConvertImageDtype(torch.float32), |
| 96 | + T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] |
| 97 | + T.Resize(size=(520, 960)), |
| 98 | + ] |
| 99 | + ) |
| 100 | + batch = transforms(batch) |
| 101 | + return batch |
| 102 | + |
| 103 | + |
| 104 | +# If you can, run this example on a GPU, it will be a lot faster. |
| 105 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 106 | + |
| 107 | +img1_batch = preprocess(img1_batch).to(device) |
| 108 | +img2_batch = preprocess(img2_batch).to(device) |
| 109 | + |
| 110 | +print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") |
| 111 | + |
| 112 | + |
| 113 | +#################################### |
| 114 | +# Estimating Optical flow using RAFT |
| 115 | +# ---------------------------------- |
| 116 | +# We will use our RAFT implementation from |
| 117 | +# :func:`~torchvision.models.optical_flow.raft_large`, which follows the same |
| 118 | +# architecture as the one described in the `original paper <https://arxiv.org/abs/2003.12039>`_. |
| 119 | +# We also provide the :func:`~torchvision.models.optical_flow.raft_small` model |
| 120 | +# builder, which is smaller and faster to run, sacrificing a bit of accuracy. |
| 121 | + |
| 122 | +from torchvision.models.optical_flow import raft_large |
| 123 | + |
| 124 | +model = raft_large(pretrained=True, progress=False).to(device) |
| 125 | +model = model.eval() |
| 126 | + |
| 127 | +list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) |
| 128 | +print(f"type = {type(list_of_flows)}") |
| 129 | +print(f"length = {len(list_of_flows)} = number of iterations of the model") |
| 130 | + |
| 131 | +#################################### |
| 132 | +# The RAFT model outputs lists of predicted flows where each entry is a |
| 133 | +# (N, 2, H, W) batch of predicted flows that corresponds to a given "iteration" |
| 134 | +# in the model. For more details on the iterative nature of the model, please |
| 135 | +# refer to the `original paper <https://arxiv.org/abs/2003.12039>`_. Here, we |
| 136 | +# are only interested in the final predicted flows (they are the most acccurate |
| 137 | +# ones), so we will just retrieve the last item in the list. |
| 138 | +# |
| 139 | +# As described above, a flow is a tensor with dimensions (2, H, W) (or (N, 2, H, |
| 140 | +# W) for batches of flows) where each entry corresponds to the horizontal and |
| 141 | +# vertical displacement of each pixel from the first image to the second image. |
| 142 | +# Note that the predicted flows are in "pixel" unit, they are not normalized |
| 143 | +# w.r.t. the dimensions of the images. |
| 144 | +predicted_flows = list_of_flows[-1] |
| 145 | +print(f"dtype = {predicted_flows.dtype}") |
| 146 | +print(f"shape = {predicted_flows.shape} = (N, 2, H, W)") |
| 147 | +print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}") |
| 148 | + |
| 149 | + |
| 150 | +#################################### |
| 151 | +# Visualizing predicted flows |
| 152 | +# --------------------------- |
| 153 | +# Torchvision provides the :func:`~torchvision.utils.flow_to_image` utlity to |
| 154 | +# convert a flow into an RGB image. It also supports batches of flows. |
| 155 | +# each "direction" in the flow will be mapped to a given RGB color. In the |
| 156 | +# images below, pixels with similar colors are assumed by the model to be moving |
| 157 | +# in similar directions. The model is properly able to predict the movement of |
| 158 | +# the ball and the player. Note in particular the different predicted direction |
| 159 | +# of the ball in the first image (going to the left) and in the second image |
| 160 | +# (going up). |
| 161 | + |
| 162 | +from torchvision.utils import flow_to_image |
| 163 | + |
| 164 | +flow_imgs = flow_to_image(predicted_flows) |
| 165 | + |
| 166 | +# The images have been mapped into [-1, 1] but for plotting we want them in [0, 1] |
| 167 | +img1_batch = [(img1 + 1) / 2 for img1 in img1_batch] |
| 168 | + |
| 169 | +grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)] |
| 170 | +plot(grid) |
| 171 | + |
| 172 | +#################################### |
| 173 | +# Bonus: Creating GIFs of predicted flows |
| 174 | +# --------------------------------------- |
| 175 | +# In the example above we have only shown the predicted flows of 2 pairs of |
| 176 | +# frames. A fun way to apply the Optical Flow models is to run the model on an |
| 177 | +# entire video, and create a new video from all the predicted flows. Below is a |
| 178 | +# snippet that can get you started with this. We comment out the code, because |
| 179 | +# this example is being rendered on a machine without a GPU, and it would take |
| 180 | +# too long to run it. |
| 181 | + |
| 182 | +# from torchvision.io import write_jpeg |
| 183 | +# for i, (img1, img2) in enumerate(zip(frames, frames[1:])): |
| 184 | +# # Note: it would be faster to predict batches of flows instead of individual flows |
| 185 | +# img1 = preprocess(img1[None]).to(device) |
| 186 | +# img2 = preprocess(img2[None]).to(device) |
| 187 | + |
| 188 | +# list_of_flows = model(img1_batch, img2_batch) |
| 189 | +# predicted_flow = list_of_flows[-1][0] |
| 190 | +# flow_img = flow_to_image(predicted_flow).to("cpu") |
| 191 | +# output_folder = "/tmp/" # Update this to the folder of your choice |
| 192 | +# write_jpeg(flow_img, output_folder + f"predicted_flow_{i}.jpg") |
| 193 | + |
| 194 | +#################################### |
| 195 | +# Once the .jpg flow images are saved, you can convert them into a video or a |
| 196 | +# GIF using ffmpeg with e.g.: |
| 197 | +# |
| 198 | +# ffmpeg -f image2 -framerate 30 -i predicted_flow_%d.jpg -loop -1 flow.gif |
0 commit comments