33# pyre-strict
44
55import sys
6+ import typing
67import warnings
78from time import time
8- from typing import Any , cast , Iterable , Literal , Optional , Sized , TextIO
9+ from types import TracebackType
10+ from typing import (
11+ Any ,
12+ Callable ,
13+ cast ,
14+ Generic ,
15+ Iterable ,
16+ Iterator ,
17+ Literal ,
18+ Optional ,
19+ Sized ,
20+ TextIO ,
21+ Type ,
22+ TypeVar ,
23+ Union ,
24+ )
925
1026try :
1127 from tqdm .auto import tqdm
1228except ImportError :
1329 tqdm = None
1430
31+ T = TypeVar ("T" )
32+ IterableType = TypeVar ("IterableType" )
33+
1534
1635class DisableErrorIOWrapper (object ):
1736 def __init__ (self , wrapped : TextIO ) -> None :
@@ -21,15 +40,13 @@ def __init__(self, wrapped: TextIO) -> None:
2140 """
2241 self ._wrapped = wrapped
2342
24- # pyre-fixme[3]: Return type must be annotated.
25- # pyre-fixme[2]: Parameter must be annotated.
26- def __getattr__ (self , name ):
43+ def __getattr__ (self , name : str ) -> object :
2744 return getattr (self ._wrapped , name )
2845
2946 @staticmethod
30- # pyre-fixme[3]: Return type must be annotated.
31- # pyre-fixme[2]: Parameter must be annotated.
32- def _wrapped_run ( func , * args , ** kwargs ) :
47+ def _wrapped_run (
48+ func : Callable [..., T ], * args : object , ** kwargs : object
49+ ) -> Union [ T , None ] :
3350 try :
3451 return func (* args , ** kwargs )
3552 except OSError as e :
@@ -38,19 +55,16 @@ def _wrapped_run(func, *args, **kwargs):
3855 except ValueError as e :
3956 if "closed" not in str (e ):
4057 raise
58+ return None
4159
42- # pyre-fixme[3]: Return type must be annotated.
43- # pyre-fixme[2]: Parameter must be annotated.
44- def write (self , * args , ** kwargs ):
60+ def write (self , * args : object , ** kwargs : object ) -> Optional [int ]:
4561 return self ._wrapped_run (self ._wrapped .write , * args , ** kwargs )
4662
47- # pyre-fixme[3]: Return type must be annotated.
48- # pyre-fixme[2]: Parameter must be annotated.
49- def flush (self , * args , ** kwargs ):
63+ def flush (self , * args : object , ** kwargs : object ) -> None :
5064 return self ._wrapped_run (self ._wrapped .flush , * args , ** kwargs )
5165
5266
53- class NullProgress :
67+ class NullProgress ( Iterable [ IterableType ]) :
5468 """Passthrough class that implements the progress API.
5569
5670 This class implements the tqdm and SimpleProgressBar api but
@@ -61,27 +75,28 @@ class NullProgress:
6175
6276 def __init__ (
6377 self ,
64- # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
65- iterable : Optional [Iterable ] = None ,
78+ iterable : Optional [Iterable [IterableType ]] = None ,
6679 * args : Any ,
6780 ** kwargs : Any ,
6881 ) -> None :
6982 del args , kwargs
7083 self .iterable = iterable
7184
72- def __enter__ (self ) -> "NullProgress" :
85+ def __enter__ (self ) -> "NullProgress[IterableType] " :
7386 return self
7487
75- # pyre-fixme[2]: Parameter must be annotated.
76- def __exit__ (self , exc_type , exc_value , exc_traceback ) -> Literal [False ]:
88+ def __exit__ (
89+ self ,
90+ exc_type : Union [Type [BaseException ], None ],
91+ exc_value : Union [BaseException , None ],
92+ exc_traceback : Union [TracebackType , None ],
93+ ) -> Literal [False ]:
7794 return False
7895
79- # pyre-fixme[3]: Return type must be annotated.
80- def __iter__ (self ):
96+ def __iter__ (self ) -> Iterator [IterableType ]:
8197 if not self .iterable :
8298 return
83- # pyre-fixme[16]: `Optional` has no attribute `__iter__`.
84- for it in self .iterable :
99+ for it in cast (Iterable [IterableType ], self .iterable ):
85100 yield it
86101
87102 def update (self , amount : int = 1 ) -> None :
@@ -91,11 +106,10 @@ def close(self) -> None:
91106 pass
92107
93108
94- class SimpleProgress :
109+ class SimpleProgress ( Iterable [ IterableType ]) :
95110 def __init__ (
96111 self ,
97- # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
98- iterable : Optional [Iterable ] = None ,
112+ iterable : Optional [Iterable [IterableType ]] = None ,
99113 desc : Optional [str ] = None ,
100114 total : Optional [int ] = None ,
101115 file : Optional [TextIO ] = None ,
@@ -117,34 +131,33 @@ def __init__(
117131
118132 self .desc = desc
119133
120- # pyre-fixme[9]: file has type `Optional[TextIO]`; used as
121- # `DisableErrorIOWrapper`.
122- file = DisableErrorIOWrapper (file if file else sys .stderr )
123- cast (TextIO , file )
124- self .file = file
134+ file_wrapper = DisableErrorIOWrapper (file if file else sys .stderr )
135+ self .file : DisableErrorIOWrapper = file_wrapper
125136
126137 self .mininterval = mininterval
127138 self .last_print_t = 0.0
128139 self .closed = False
129140 self ._is_parent = False
130141
131- def __enter__ (self ) -> "SimpleProgress" :
142+ def __enter__ (self ) -> "SimpleProgress[IterableType] " :
132143 self ._is_parent = True
133144 self ._refresh ()
134145 return self
135146
136- # pyre-fixme[2]: Parameter must be annotated.
137- def __exit__ (self , exc_type , exc_value , exc_traceback ) -> Literal [False ]:
147+ def __exit__ (
148+ self ,
149+ exc_type : Union [Type [BaseException ], None ],
150+ exc_value : Union [BaseException , None ],
151+ exc_traceback : Union [TracebackType , None ],
152+ ) -> Literal [False ]:
138153 self .close ()
139154 return False
140155
141- # pyre-fixme[3]: Return type must be annotated.
142- def __iter__ (self ):
156+ def __iter__ (self ) -> Iterator [IterableType ]:
143157 if self .closed or not self .iterable :
144158 return
145159 self ._refresh ()
146- # pyre-fixme[16]: `Optional` has no attribute `__iter__`.
147- for it in self .iterable :
160+ for it in cast (Iterable [IterableType ], self .iterable ):
148161 yield it
149162 self .update ()
150163 self .close ()
@@ -153,9 +166,7 @@ def _refresh(self) -> None:
153166 progress_str = self .desc + ": " if self .desc else ""
154167 if self .total :
155168 # e.g., progress: 60% 3/5
156- # pyre-fixme[58]: `//` is not supported for operand types `int` and
157- # `Optional[int]`.
158- progress_str += f"{ 100 * self .cur // self .total } % { self .cur } /{ self .total } "
169+ progress_str += f"{ 100 * self .cur // cast (int , self .total )} % { self .cur } /{ cast (int , self .total )} "
159170 else :
160171 # e.g., progress: .....
161172 progress_str += "." * self .cur
@@ -179,18 +190,39 @@ def close(self) -> None:
179190 self .closed = True
180191
181192
182- # pyre-fixme[3]: Return type must be annotated.
193+ @typing .overload
194+ def progress (
195+ iterable : None = None ,
196+ desc : Optional [str ] = None ,
197+ total : Optional [int ] = None ,
198+ use_tqdm : bool = True ,
199+ file : Optional [TextIO ] = None ,
200+ mininterval : float = 0.5 ,
201+ ** kwargs : object ,
202+ ) -> Union [SimpleProgress [None ], tqdm ]: ...
203+
204+
205+ @typing .overload
206+ def progress (
207+ iterable : Iterable [IterableType ],
208+ desc : Optional [str ] = None ,
209+ total : Optional [int ] = None ,
210+ use_tqdm : bool = True ,
211+ file : Optional [TextIO ] = None ,
212+ mininterval : float = 0.5 ,
213+ ** kwargs : object ,
214+ ) -> Union [SimpleProgress [IterableType ], tqdm ]: ...
215+
216+
183217def progress (
184- # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
185- iterable : Optional [Iterable ] = None ,
218+ iterable : Optional [Iterable [IterableType ]] = None ,
186219 desc : Optional [str ] = None ,
187220 total : Optional [int ] = None ,
188221 use_tqdm : bool = True ,
189222 file : Optional [TextIO ] = None ,
190223 mininterval : float = 0.5 ,
191- # pyre-fixme[2]: Parameter must be annotated.
192- ** kwargs ,
193- ):
224+ ** kwargs : object ,
225+ ) -> Union [SimpleProgress [IterableType ], tqdm ]:
194226 # Try to use tqdm is possible. Fall back to simple progress print
195227 if tqdm and use_tqdm :
196228 return tqdm (
0 commit comments