Skip to content

Commit 7786076

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Added typing annotations to io/__init__ (#4224)
Summary: * style: Added typing annotations * Specified types for iter and seek. Reviewed By: fmassa Differential Revision: D30793319 fbshipit-source-id: b5d3a220639d239f64cee6712aa07e19fdaaf875 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent bcb51e7 commit 7786076

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torchvision/io/__init__.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from typing import Any, Dict, Iterator
23

34
from ._video_opt import (
45
Timebase,
@@ -33,13 +34,13 @@
3334

3435
if _HAS_VIDEO_OPT:
3536

36-
def _has_video_opt():
37+
def _has_video_opt() -> bool:
3738
return True
3839

3940

4041
else:
4142

42-
def _has_video_opt():
43+
def _has_video_opt() -> bool:
4344
return False
4445

4546

@@ -99,7 +100,7 @@ class VideoReader:
99100
Currently available options include ``['video', 'audio']``
100101
"""
101102

102-
def __init__(self, path, stream="video"):
103+
def __init__(self, path: str, stream: str = "video") -> None:
103104
if not _has_video_opt():
104105
raise RuntimeError(
105106
"Not compiled with video_reader support, "
@@ -109,7 +110,7 @@ def __init__(self, path, stream="video"):
109110
)
110111
self._c = torch.classes.torchvision.Video(path, stream)
111112

112-
def __next__(self):
113+
def __next__(self) -> Dict[str, Any]:
113114
"""Decodes and returns the next frame of the current stream.
114115
Frames are encoded as a dict with mandatory
115116
data and pts fields, where data is a tensor, and pts is a
@@ -126,10 +127,10 @@ def __next__(self):
126127
raise StopIteration
127128
return {"data": frame, "pts": pts}
128129

129-
def __iter__(self):
130+
def __iter__(self) -> Iterator['VideoReader']:
130131
return self
131132

132-
def seek(self, time_s: float):
133+
def seek(self, time_s: float) -> 'VideoReader':
133134
"""Seek within current stream.
134135
135136
Args:
@@ -144,15 +145,15 @@ def seek(self, time_s: float):
144145
self._c.seek(time_s)
145146
return self
146147

147-
def get_metadata(self):
148+
def get_metadata(self) -> Dict[str, Any]:
148149
"""Returns video metadata
149150
150151
Returns:
151152
(dict): dictionary containing duration and frame rate for every stream
152153
"""
153154
return self._c.get_metadata()
154155

155-
def set_current_stream(self, stream: str):
156+
def set_current_stream(self, stream: str) -> bool:
156157
"""Set current stream.
157158
Explicitly define the stream we are operating on.
158159

0 commit comments

Comments
 (0)