Skip to content

Commit 2d52406

Browse files
authored
Improve factor() recipe and fix its tests (GH-100576)
1 parent 984894a commit 2d52406

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

Doc/library/itertools.rst

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -899,18 +899,16 @@ which incur interpreter overhead.
899899

900900
def factor(n):
901901
"Prime factors of n."
902-
# factor(97) --> 97
903-
# factor(98) --> 2 7 7
904902
# factor(99) --> 3 3 11
905-
for prime in sieve(n+1):
906-
while True:
903+
for prime in sieve(math.isqrt(n) + 1):
904+
while n >= prime:
907905
quotient, remainder = divmod(n, prime)
908906
if remainder:
909907
break
910908
yield prime
911909
n = quotient
912-
if n == 1:
913-
return
910+
if n >= 2:
911+
yield n
914912

915913
def flatten(list_of_lists):
916914
"Flatten one level of nesting"
@@ -1266,33 +1264,35 @@ which incur interpreter overhead.
12661264
>>> set(sieve(10_000)).isdisjoint(carmichael)
12671265
True
12681266

1269-
list(factor(0))
1267+
>>> list(factor(0))
12701268
[]
1271-
list(factor(1))
1269+
>>> list(factor(1))
12721270
[]
1273-
list(factor(2))
1271+
>>> list(factor(2))
12741272
[2]
1275-
list(factor(3))
1273+
>>> list(factor(3))
12761274
[3]
1277-
list(factor(4))
1275+
>>> list(factor(4))
12781276
[2, 2]
1279-
list(factor(5))
1277+
>>> list(factor(5))
12801278
[5]
1281-
list(factor(6))
1279+
>>> list(factor(6))
12821280
[2, 3]
1283-
list(factor(7))
1281+
>>> list(factor(7))
12841282
[7]
1285-
list(factor(8))
1283+
>>> list(factor(8))
12861284
[2, 2, 2]
1287-
list(factor(9))
1285+
>>> list(factor(9))
12881286
[3, 3]
1289-
list(factor(10))
1287+
>>> list(factor(10))
12901288
[2, 5]
1291-
all(math.prod(factor(n)) == n for n in range(1, 1000))
1289+
>>> list(factor(999953*999983))
1290+
[999953, 999983]
1291+
>>> all(math.prod(factor(n)) == n for n in range(1, 1000))
12921292
True
1293-
all(set(factor(n)) <= set(sieve(n+1)) for n in range(1, 1000))
1293+
>>> all(set(factor(n)) <= set(sieve(n+1)) for n in range(1, 1000))
12941294
True
1295-
all(list(factor(n)) == sorted(factor(n)) for n in range(1, 1000))
1295+
>>> all(list(factor(n)) == sorted(factor(n)) for n in range(1, 1000))
12961296
True
12971297

12981298
>>> list(flatten([('a', 'b'), (), ('c', 'd', 'e'), ('f',), ('g', 'h', 'i')]))

0 commit comments

Comments
 (0)