|
| 1 | +from typing import Any, Dict, Iterator |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from ..utils import _log_api_usage_once |
| 6 | + |
| 7 | +try: |
| 8 | + from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER |
| 9 | +except ModuleNotFoundError: |
| 10 | + _HAS_GPU_VIDEO_DECODER = False |
| 11 | +from ._video_opt import ( |
| 12 | + _HAS_VIDEO_OPT, |
| 13 | +) |
| 14 | + |
| 15 | +if _HAS_VIDEO_OPT: |
| 16 | + |
| 17 | + def _has_video_opt() -> bool: |
| 18 | + return True |
| 19 | + |
| 20 | + |
| 21 | +else: |
| 22 | + |
| 23 | + def _has_video_opt() -> bool: |
| 24 | + return False |
| 25 | + |
| 26 | + |
| 27 | +class VideoReader: |
| 28 | + """ |
| 29 | + Fine-grained video-reading API. |
| 30 | + Supports frame-by-frame reading of various streams from a single video |
| 31 | + container. |
| 32 | +
|
| 33 | + Example: |
| 34 | + The following examples creates a :mod:`VideoReader` object, seeks into 2s |
| 35 | + point, and returns a single frame:: |
| 36 | +
|
| 37 | + import torchvision |
| 38 | + video_path = "path_to_a_test_video" |
| 39 | + reader = torchvision.io.VideoReader(video_path, "video") |
| 40 | + reader.seek(2.0) |
| 41 | + frame = next(reader) |
| 42 | +
|
| 43 | + :mod:`VideoReader` implements the iterable API, which makes it suitable to |
| 44 | + using it in conjunction with :mod:`itertools` for more advanced reading. |
| 45 | + As such, we can use a :mod:`VideoReader` instance inside for loops:: |
| 46 | +
|
| 47 | + reader.seek(2) |
| 48 | + for frame in reader: |
| 49 | + frames.append(frame['data']) |
| 50 | + # additionally, `seek` implements a fluent API, so we can do |
| 51 | + for frame in reader.seek(2): |
| 52 | + frames.append(frame['data']) |
| 53 | +
|
| 54 | + With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the |
| 55 | + following code:: |
| 56 | +
|
| 57 | + for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)): |
| 58 | + frames.append(frame['data']) |
| 59 | +
|
| 60 | + and similarly, reading 10 frames after the 2s timestamp can be achieved |
| 61 | + as follows:: |
| 62 | +
|
| 63 | + for frame in itertools.islice(reader.seek(2), 10): |
| 64 | + frames.append(frame['data']) |
| 65 | +
|
| 66 | + .. note:: |
| 67 | +
|
| 68 | + Each stream descriptor consists of two parts: stream type (e.g. 'video') and |
| 69 | + a unique stream id (which are determined by the video encoding). |
| 70 | + In this way, if the video contaner contains multiple |
| 71 | + streams of the same type, users can acces the one they want. |
| 72 | + If only stream type is passed, the decoder auto-detects first stream of that type. |
| 73 | +
|
| 74 | + Args: |
| 75 | +
|
| 76 | + path (string): Path to the video file in supported format |
| 77 | +
|
| 78 | + stream (string, optional): descriptor of the required stream, followed by the stream id, |
| 79 | + in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``. |
| 80 | + Currently available options include ``['video', 'audio']`` |
| 81 | +
|
| 82 | + num_threads (int, optional): number of threads used by the codec to decode video. |
| 83 | + Default value (0) enables multithreading with codec-dependent heuristic. The performance |
| 84 | + will depend on the version of FFMPEG codecs supported. |
| 85 | +
|
| 86 | + device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. |
| 87 | +
|
| 88 | + """ |
| 89 | + |
| 90 | + def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None: |
| 91 | + _log_api_usage_once(self) |
| 92 | + self.is_cuda = False |
| 93 | + device = torch.device(device) |
| 94 | + if device.type == "cuda": |
| 95 | + if not _HAS_GPU_VIDEO_DECODER: |
| 96 | + raise RuntimeError("Not compiled with GPU decoder support.") |
| 97 | + self.is_cuda = True |
| 98 | + if device.index is None: |
| 99 | + raise RuntimeError("Invalid cuda device!") |
| 100 | + self._c = torch.classes.torchvision.GPUDecoder(path, device.index) |
| 101 | + return |
| 102 | + if not _has_video_opt(): |
| 103 | + raise RuntimeError( |
| 104 | + "Not compiled with video_reader support, " |
| 105 | + + "to enable video_reader support, please install " |
| 106 | + + "ffmpeg (version 4.2 is currently supported) and " |
| 107 | + + "build torchvision from source." |
| 108 | + ) |
| 109 | + |
| 110 | + self._c = torch.classes.torchvision.Video(path, stream, num_threads) |
| 111 | + |
| 112 | + def __next__(self) -> Dict[str, Any]: |
| 113 | + """Decodes and returns the next frame of the current stream. |
| 114 | + Frames are encoded as a dict with mandatory |
| 115 | + data and pts fields, where data is a tensor, and pts is a |
| 116 | + presentation timestamp of the frame expressed in seconds |
| 117 | + as a float. |
| 118 | +
|
| 119 | + Returns: |
| 120 | + (dict): a dictionary and containing decoded frame (``data``) |
| 121 | + and corresponding timestamp (``pts``) in seconds |
| 122 | +
|
| 123 | + """ |
| 124 | + if self.is_cuda: |
| 125 | + frame = self._c.next() |
| 126 | + if frame.numel() == 0: |
| 127 | + raise StopIteration |
| 128 | + return {"data": frame} |
| 129 | + frame, pts = self._c.next() |
| 130 | + if frame.numel() == 0: |
| 131 | + raise StopIteration |
| 132 | + return {"data": frame, "pts": pts} |
| 133 | + |
| 134 | + def __iter__(self) -> Iterator[Dict[str, Any]]: |
| 135 | + return self |
| 136 | + |
| 137 | + def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader": |
| 138 | + """Seek within current stream. |
| 139 | +
|
| 140 | + Args: |
| 141 | + time_s (float): seek time in seconds |
| 142 | + keyframes_only (bool): allow to seek only to keyframes |
| 143 | +
|
| 144 | + .. note:: |
| 145 | + Current implementation is the so-called precise seek. This |
| 146 | + means following seek, call to :mod:`next()` will return the |
| 147 | + frame with the exact timestamp if it exists or |
| 148 | + the first frame with timestamp larger than ``time_s``. |
| 149 | + """ |
| 150 | + self._c.seek(time_s, keyframes_only) |
| 151 | + return self |
| 152 | + |
| 153 | + def get_metadata(self) -> Dict[str, Any]: |
| 154 | + """Returns video metadata |
| 155 | +
|
| 156 | + Returns: |
| 157 | + (dict): dictionary containing duration and frame rate for every stream |
| 158 | + """ |
| 159 | + return self._c.get_metadata() |
| 160 | + |
| 161 | + def set_current_stream(self, stream: str) -> bool: |
| 162 | + """Set current stream. |
| 163 | + Explicitly define the stream we are operating on. |
| 164 | +
|
| 165 | + Args: |
| 166 | + stream (string): descriptor of the required stream. Defaults to ``"video:0"`` |
| 167 | + Currently available stream types include ``['video', 'audio']``. |
| 168 | + Each descriptor consists of two parts: stream type (e.g. 'video') and |
| 169 | + a unique stream id (which are determined by video encoding). |
| 170 | + In this way, if the video contaner contains multiple |
| 171 | + streams of the same type, users can acces the one they want. |
| 172 | + If only stream type is passed, the decoder auto-detects first stream |
| 173 | + of that type and returns it. |
| 174 | +
|
| 175 | + Returns: |
| 176 | + (bool): True on succes, False otherwise |
| 177 | + """ |
| 178 | + if self.is_cuda: |
| 179 | + print("GPU decoding only works with video stream.") |
| 180 | + return self._c.set_current_stream(stream) |
0 commit comments