Skip to content

Commit a77ae43

Browse files
authored
Merge pull request #25 from kbsriram/add-types-3
Update type annotations for itertools extras.
2 parents 750de7a + 3bd2dd9 commit a77ae43

File tree

3 files changed

+356
-25
lines changed

3 files changed

+356
-25
lines changed

adafruit_itertools/adafruit_itertools_extras.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,54 @@
4141

4242
import adafruit_itertools as it
4343

44+
try:
45+
from typing import (
46+
Any,
47+
Callable,
48+
Iterable,
49+
Iterator,
50+
List,
51+
Optional,
52+
Tuple,
53+
Type,
54+
TypeVar,
55+
Union,
56+
)
57+
from typing_extensions import TypeAlias
58+
59+
_T = TypeVar("_T")
60+
_N: TypeAlias = Union[int, float, complex]
61+
_Predicate: TypeAlias = Callable[[_T], bool]
62+
except ImportError:
63+
pass
64+
65+
4466
__version__ = "0.0.0+auto.0"
4567
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Itertools.git"
4668

4769

48-
def all_equal(iterable):
70+
def all_equal(iterable: Iterable[Any]) -> bool:
4971
"""Returns True if all the elements are equal to each other.
5072
5173
:param iterable: source of values
5274
5375
"""
5476
g = it.groupby(iterable)
55-
next(g) # should succeed, value isn't relevant
5677
try:
57-
next(g) # should fail: only 1 group
78+
next(g) # value isn't relevant
79+
except StopIteration:
80+
# Empty iterable, return True to match cpython behavior.
81+
return True
82+
try:
83+
next(g)
84+
# more than one group, so we have different elements.
5885
return False
5986
except StopIteration:
87+
# Only one group - all elements must be equal.
6088
return True
6189

6290

63-
def dotproduct(vec1, vec2):
91+
def dotproduct(vec1: Iterable[_N], vec2: Iterable[_N]) -> _N:
6492
"""Compute the dot product of two vectors.
6593
6694
:param vec1: the first vector
@@ -71,7 +99,11 @@ def dotproduct(vec1, vec2):
7199
return sum(map(lambda x, y: x * y, vec1, vec2))
72100

73101

74-
def first_true(iterable, default=False, pred=None):
102+
def first_true(
103+
iterable: Iterable[_T],
104+
default: Union[bool, _T] = False,
105+
pred: Optional[_Predicate[_T]] = None,
106+
) -> Union[bool, _T]:
75107
"""Returns the first true value in the iterable.
76108
77109
If no true value is found, returns *default*
@@ -94,7 +126,7 @@ def first_true(iterable, default=False, pred=None):
94126
return default
95127

96128

97-
def flatten(iterable_of_iterables):
129+
def flatten(iterable_of_iterables: Iterable[Iterable[_T]]) -> Iterator[_T]:
98130
"""Flatten one level of nesting.
99131
100132
:param iterable_of_iterables: a sequence of iterables to flatten
@@ -104,7 +136,9 @@ def flatten(iterable_of_iterables):
104136
return it.chain_from_iterable(iterable_of_iterables)
105137

106138

107-
def grouper(iterable, n, fillvalue=None):
139+
def grouper(
140+
iterable: Iterable[_T], n: int, fillvalue: Optional[_T] = None
141+
) -> Iterator[Tuple[_T, ...]]:
108142
"""Collect data into fixed-length chunks or blocks.
109143
110144
:param iterable: source of values
@@ -118,7 +152,7 @@ def grouper(iterable, n, fillvalue=None):
118152
return it.zip_longest(*args, fillvalue=fillvalue)
119153

120154

121-
def iter_except(func, exception):
155+
def iter_except(func: Callable[[], _T], exception: Type[BaseException]) -> Iterator[_T]:
122156
"""Call a function repeatedly, yielding the results, until exception is raised.
123157
124158
Converts a call-until-exception interface to an iterator interface.
@@ -143,7 +177,7 @@ def iter_except(func, exception):
143177
pass
144178

145179

146-
def ncycles(iterable, n):
180+
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]:
147181
"""Returns the sequence elements a number of times.
148182
149183
:param iterable: the source of values
@@ -153,7 +187,7 @@ def ncycles(iterable, n):
153187
return it.chain_from_iterable(it.repeat(tuple(iterable), n))
154188

155189

156-
def nth(iterable, n, default=None):
190+
def nth(iterable: Iterable[_T], n: int, default: Optional[_T] = None) -> Optional[_T]:
157191
"""Returns the nth item or a default value.
158192
159193
:param iterable: the source of values
@@ -166,7 +200,7 @@ def nth(iterable, n, default=None):
166200
return default
167201

168202

169-
def padnone(iterable):
203+
def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]:
170204
"""Returns the sequence elements and then returns None indefinitely.
171205
172206
Useful for emulating the behavior of the built-in map() function.
@@ -177,13 +211,17 @@ def padnone(iterable):
177211
return it.chain(iterable, it.repeat(None))
178212

179213

180-
def pairwise(iterable):
181-
"""Pair up valuesin the iterable.
214+
def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]:
215+
"""Return successive overlapping pairs from the iterable.
216+
217+
The number of tuples from the output will be one fewer than the
218+
number of values in the input. It will be empty if the input has
219+
fewer than two values.
182220
183221
:param iterable: source of values
184222
185223
"""
186-
# pairwise(range(11)) -> (1, 2), (3, 4), (5, 6), (7, 8), (9, 10)
224+
# pairwise(range(5)) -> (0, 1), (1, 2), (2, 3), (3, 4)
187225
a, b = it.tee(iterable)
188226
try:
189227
next(b)
@@ -192,7 +230,9 @@ def pairwise(iterable):
192230
return zip(a, b)
193231

194232

195-
def partition(pred, iterable):
233+
def partition(
234+
pred: _Predicate[_T], iterable: Iterable[_T]
235+
) -> Tuple[Iterator[_T], Iterator[_T]]:
196236
"""Use a predicate to partition entries into false entries and true entries.
197237
198238
:param pred: the predicate that divides the values
@@ -204,7 +244,7 @@ def partition(pred, iterable):
204244
return it.filterfalse(pred, t1), filter(pred, t2)
205245

206246

207-
def prepend(value, iterator):
247+
def prepend(value: _T, iterator: Iterable[_T]) -> Iterator[_T]:
208248
"""Prepend a single value in front of an iterator
209249
210250
:param value: the value to prepend
@@ -215,7 +255,7 @@ def prepend(value, iterator):
215255
return it.chain([value], iterator)
216256

217257

218-
def quantify(iterable, pred=bool):
258+
def quantify(iterable: Iterable[_T], pred: _Predicate[_T] = bool) -> int:
219259
"""Count how many times the predicate is true.
220260
221261
:param iterable: source of values
@@ -227,7 +267,9 @@ def quantify(iterable, pred=bool):
227267
return sum(map(pred, iterable))
228268

229269

230-
def repeatfunc(func, times=None, *args):
270+
def repeatfunc(
271+
func: Callable[..., _T], times: Optional[int] = None, *args: Any
272+
) -> Iterator[_T]:
231273
"""Repeat calls to func with specified arguments.
232274
233275
Example: repeatfunc(random.random)
@@ -242,7 +284,7 @@ def repeatfunc(func, times=None, *args):
242284
return it.starmap(func, it.repeat(args, times))
243285

244286

245-
def roundrobin(*iterables):
287+
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]:
246288
"""Return an iterable created by repeatedly picking value from each
247289
argument in order.
248290
@@ -263,18 +305,19 @@ def roundrobin(*iterables):
263305
nexts = it.cycle(it.islice(nexts, num_active))
264306

265307

266-
def tabulate(function, start=0):
267-
"""Apply a function to a sequence of consecutive integers.
308+
def tabulate(function: Callable[[int], int], start: int = 0) -> Iterator[int]:
309+
"""Apply a function to a sequence of consecutive numbers.
268310
269-
:param function: the function of one integer argument
311+
:param function: the function of one numeric argument.
270312
:param start: optional value to start at (default is 0)
271313
272314
"""
273315
# take(5, tabulate(lambda x: x * x))) -> 0 1 4 9 16
274-
return map(function, it.count(start))
316+
counter: Iterator[int] = it.count(start) # type: ignore[assignment]
317+
return map(function, counter)
275318

276319

277-
def tail(n, iterable):
320+
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]:
278321
"""Return an iterator over the last n items
279322
280323
:param n: how many values to return
@@ -294,7 +337,7 @@ def tail(n, iterable):
294337
return iter(buf)
295338

296339

297-
def take(n, iterable):
340+
def take(n: int, iterable: Iterable[_T]) -> List[_T]:
298341
"""Return first n items of the iterable as a list
299342
300343
:param n: how many values to take

optional_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
# SPDX-FileCopyrightText: 2022 Alec Delaney, for Adafruit Industries
22
#
33
# SPDX-License-Identifier: Unlicense
4+
5+
# For comparison when running tests
6+
more-itertools

0 commit comments

Comments
 (0)