Skip to content

Commit ec6b94d

Browse files
committed
[mypy] Fix type annotations in red_black_tree.py
1 parent 0590d73 commit ec6b94d

File tree

1 file changed

+93
-40
lines changed

1 file changed

+93
-40
lines changed

data_structures/binary_tree/red_black_tree.py

Lines changed: 93 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,18 @@ def rotate_left(self) -> RedBlackTree:
5151
"""
5252
parent = self.parent
5353
right = self.right
54+
55+
if right is None:
56+
return self
57+
5458
self.right = right.left
59+
5560
if self.right:
5661
self.right.parent = self
5762
self.parent = right
63+
5864
right.left = self
65+
5966
if parent is not None:
6067
if parent.left == self:
6168
parent.left = right
@@ -69,13 +76,20 @@ def rotate_right(self) -> RedBlackTree:
6976
returns the new root to this subtree.
7077
Performing one rotation can be done in O(1).
7178
"""
79+
if self.left is None:
80+
return self
7281
parent = self.parent
7382
left = self.left
83+
7484
self.left = left.right
85+
7586
if self.left:
7687
self.left.parent = self
88+
7789
self.parent = left
90+
7891
left.right = self
92+
7993
if parent is not None:
8094
if parent.right is self:
8195
parent.right = left
@@ -123,23 +137,34 @@ def _insert_repair(self) -> None:
123137
if color(uncle) == 0:
124138
if self.is_left() and self.parent.is_right():
125139
self.parent.rotate_right()
126-
self.right._insert_repair()
140+
if self.right:
141+
self.right._insert_repair()
127142
elif self.is_right() and self.parent.is_left():
128143
self.parent.rotate_left()
129-
self.left._insert_repair()
144+
if self.left:
145+
self.left._insert_repair()
146+
130147
elif self.is_left():
131-
self.grandparent.rotate_right()
132-
self.parent.color = 0
133-
self.parent.right.color = 1
148+
if self.grandparent:
149+
self.grandparent.rotate_right()
150+
self.parent.color = 0
151+
152+
if self.parent.right:
153+
self.parent.right.color = 1
134154
else:
135-
self.grandparent.rotate_left()
136-
self.parent.color = 0
137-
self.parent.left.color = 1
155+
if self.grandparent:
156+
self.grandparent.rotate_left()
157+
self.parent.color = 0
158+
159+
if self.parent.left:
160+
self.parent.left.color = 1
138161
else:
139162
self.parent.color = 0
140-
uncle.color = 0
141-
self.grandparent.color = 1
142-
self.grandparent._insert_repair()
163+
164+
if uncle and self.grandparent:
165+
uncle.color = 0
166+
self.grandparent.color = 1
167+
self.grandparent._insert_repair()
143168

144169
def remove(self, label: int) -> RedBlackTree:
145170
"""Remove label from this tree."""
@@ -150,7 +175,7 @@ def remove(self, label: int) -> RedBlackTree:
150175
# it and remove that.
151176
value = self.left.get_max()
152177
self.label = value
153-
self.left.remove(value)
178+
self.left.remove(value) if value else None
154179
else:
155180
# This node has at most one non-None child, so we don't
156181
# need to replace
@@ -160,10 +185,12 @@ def remove(self, label: int) -> RedBlackTree:
160185
# The only way this happens to a node with one child
161186
# is if both children are None leaves.
162187
# We can just remove this node and call it a day.
163-
if self.is_left():
164-
self.parent.left = None
165-
else:
166-
self.parent.right = None
188+
if self.parent:
189+
190+
if self.is_left():
191+
self.parent.left = None
192+
else:
193+
self.parent.right = None
167194
else:
168195
# The node is black
169196
if child is None:
@@ -188,7 +215,7 @@ def remove(self, label: int) -> RedBlackTree:
188215
self.left.parent = self
189216
if self.right:
190217
self.right.parent = self
191-
elif self.label > label:
218+
elif self.label is not None and self.label > label:
192219
if self.left:
193220
self.left.remove(label)
194221
else:
@@ -198,6 +225,14 @@ def remove(self, label: int) -> RedBlackTree:
198225

199226
def _remove_repair(self) -> None:
200227
"""Repair the coloring of the tree that may have been messed up."""
228+
if (
229+
self.parent is None
230+
or self.sibling is None
231+
or self.parent.sibling is None
232+
or self.grandparent is None
233+
):
234+
return
235+
201236
if color(self.sibling) == 1:
202237
self.sibling.color = 0
203238
self.parent.color = 1
@@ -231,7 +266,9 @@ def _remove_repair(self) -> None:
231266
):
232267
self.sibling.rotate_right()
233268
self.sibling.color = 0
234-
self.sibling.right.color = 1
269+
270+
if self.sibling.right:
271+
self.sibling.right.color = 1
235272
if (
236273
self.is_right()
237274
and color(self.sibling) == 0
@@ -240,7 +277,9 @@ def _remove_repair(self) -> None:
240277
):
241278
self.sibling.rotate_left()
242279
self.sibling.color = 0
243-
self.sibling.left.color = 1
280+
281+
if self.sibling.left:
282+
self.sibling.left.color = 1
244283
if (
245284
self.is_left()
246285
and color(self.sibling) == 0
@@ -297,7 +336,7 @@ def check_color_properties(self) -> bool:
297336
# All properties were met
298337
return True
299338

300-
def check_coloring(self) -> None:
339+
def check_coloring(self) -> bool:
301340
"""A helper function to recursively check Property 4 of a
302341
Red-Black Tree. See check_color_properties for more info.
303342
"""
@@ -310,12 +349,12 @@ def check_coloring(self) -> None:
310349
return False
311350
return True
312351

313-
def black_height(self) -> int:
352+
def black_height(self) -> int | None:
314353
"""Returns the number of black nodes from this node to the
315354
leaves of the tree, or None if there isn't one such value (the
316355
tree is color incorrectly).
317356
"""
318-
if self is None:
357+
if self is None or self.left is None or self.right is None:
319358
# If we're already at a leaf, there is no path
320359
return 1
321360
left = RedBlackTree.black_height(self.left)
@@ -332,21 +371,22 @@ def black_height(self) -> int:
332371

333372
# Here are functions which are general to all binary search trees
334373

335-
def __contains__(self, label) -> bool:
374+
def __contains__(self, label: int) -> bool:
336375
"""Search through the tree for label, returning True iff it is
337376
found somewhere in the tree.
338377
Guaranteed to run in O(log(n)) time.
339378
"""
340379
return self.search(label) is not None
341380

342-
def search(self, label: int) -> RedBlackTree:
381+
def search(self, label: int) -> RedBlackTree | None:
343382
"""Search through the tree for label, returning its node if
344383
it's found, and None otherwise.
345384
This method is guaranteed to run in O(log(n)) time.
346385
"""
347386
if self.label == label:
348387
return self
349-
elif label > self.label:
388+
389+
elif self.label is not None and label > self.label:
350390
if self.right is None:
351391
return None
352392
else:
@@ -357,12 +397,12 @@ def search(self, label: int) -> RedBlackTree:
357397
else:
358398
return self.left.search(label)
359399

360-
def floor(self, label: int) -> int:
400+
def floor(self, label: int) -> int | None:
361401
"""Returns the largest element in this tree which is at most label.
362402
This method is guaranteed to run in O(log(n)) time."""
363403
if self.label == label:
364404
return self.label
365-
elif self.label > label:
405+
elif self.label is not None and self.label > label:
366406
if self.left:
367407
return self.left.floor(label)
368408
else:
@@ -374,13 +414,13 @@ def floor(self, label: int) -> int:
374414
return attempt
375415
return self.label
376416

377-
def ceil(self, label: int) -> int:
417+
def ceil(self, label: int) -> int | None:
378418
"""Returns the smallest element in this tree which is at least label.
379419
This method is guaranteed to run in O(log(n)) time.
380420
"""
381421
if self.label == label:
382422
return self.label
383-
elif self.label < label:
423+
elif self.label is not None and self.label < label:
384424
if self.right:
385425
return self.right.ceil(label)
386426
else:
@@ -392,7 +432,7 @@ def ceil(self, label: int) -> int:
392432
return attempt
393433
return self.label
394434

395-
def get_max(self) -> int:
435+
def get_max(self) -> int | None:
396436
"""Returns the largest element in this tree.
397437
This method is guaranteed to run in O(log(n)) time.
398438
"""
@@ -402,7 +442,7 @@ def get_max(self) -> int:
402442
else:
403443
return self.label
404444

405-
def get_min(self) -> int:
445+
def get_min(self) -> int | None:
406446
"""Returns the smallest element in this tree.
407447
This method is guaranteed to run in O(log(n)) time.
408448
"""
@@ -413,15 +453,15 @@ def get_min(self) -> int:
413453
return self.label
414454

415455
@property
416-
def grandparent(self) -> RedBlackTree:
456+
def grandparent(self) -> RedBlackTree | None:
417457
"""Get the current node's grandparent, or None if it doesn't exist."""
418458
if self.parent is None:
419459
return None
420460
else:
421461
return self.parent.parent
422462

423463
@property
424-
def sibling(self) -> RedBlackTree:
464+
def sibling(self) -> RedBlackTree | None:
425465
"""Get the current node's sibling, or None if it doesn't exist."""
426466
if self.parent is None:
427467
return None
@@ -432,11 +472,16 @@ def sibling(self) -> RedBlackTree:
432472

433473
def is_left(self) -> bool:
434474
"""Returns true iff this node is the left child of its parent."""
435-
return self.parent and self.parent.left is self
475+
if self.parent is None:
476+
return False
477+
478+
return self.parent.left is self.parent.left is self
436479

437480
def is_right(self) -> bool:
438481
"""Returns true iff this node is the right child of its parent."""
439-
return self.parent and self.parent.right is self
482+
if self.parent is None:
483+
return False
484+
return self.parent.right is self
440485

441486
def __bool__(self) -> bool:
442487
return True
@@ -446,27 +491,28 @@ def __len__(self) -> int:
446491
Return the number of nodes in this tree.
447492
"""
448493
ln = 1
494+
449495
if self.left:
450496
ln += len(self.left)
451497
if self.right:
452498
ln += len(self.right)
453499
return ln
454500

455-
def preorder_traverse(self) -> Iterator[int]:
501+
def preorder_traverse(self) -> Iterator[int | None]:
456502
yield self.label
457503
if self.left:
458504
yield from self.left.preorder_traverse()
459505
if self.right:
460506
yield from self.right.preorder_traverse()
461507

462-
def inorder_traverse(self) -> Iterator[int]:
508+
def inorder_traverse(self) -> Iterator[int | None]:
463509
if self.left:
464510
yield from self.left.inorder_traverse()
465511
yield self.label
466512
if self.right:
467513
yield from self.right.inorder_traverse()
468514

469-
def postorder_traverse(self) -> Iterator[int]:
515+
def postorder_traverse(self) -> Iterator[int | None]:
470516
if self.left:
471517
yield from self.left.postorder_traverse()
472518
if self.right:
@@ -488,15 +534,18 @@ def __repr__(self) -> str:
488534
indent=1,
489535
)
490536

491-
def __eq__(self, other) -> bool:
537+
def __eq__(self, other: object) -> bool:
492538
"""Test if two trees are equal."""
539+
if not isinstance(other, RedBlackTree):
540+
return NotImplemented
541+
493542
if self.label == other.label:
494543
return self.left == other.left and self.right == other.right
495544
else:
496545
return False
497546

498547

499-
def color(node) -> int:
548+
def color(node: RedBlackTree | None) -> int:
500549
"""Returns the color of a node, allowing for None leaves."""
501550
if node is None:
502551
return 0
@@ -514,7 +563,9 @@ def test_rotations() -> bool:
514563
"""Test that the rotate_left and rotate_right functions work."""
515564
# Make a tree to test on
516565
tree = RedBlackTree(0)
566+
517567
tree.left = RedBlackTree(-10, parent=tree)
568+
518569
tree.right = RedBlackTree(10, parent=tree)
519570
tree.left.left = RedBlackTree(-20, parent=tree.left)
520571
tree.left.right = RedBlackTree(-5, parent=tree.left)
@@ -529,8 +580,10 @@ def test_rotations() -> bool:
529580
left_rot.left.left.right = RedBlackTree(-5, parent=left_rot.left.left)
530581
left_rot.right = RedBlackTree(20, parent=left_rot)
531582
tree = tree.rotate_left()
583+
532584
if tree != left_rot:
533585
return False
586+
534587
tree = tree.rotate_right()
535588
tree = tree.rotate_right()
536589
# Make the left rotation

0 commit comments

Comments
 (0)