@@ -51,11 +51,18 @@ def rotate_left(self) -> RedBlackTree:
51
51
"""
52
52
parent = self .parent
53
53
right = self .right
54
+
55
+ if right is None :
56
+ return self
57
+
54
58
self .right = right .left
59
+
55
60
if self .right :
56
61
self .right .parent = self
57
62
self .parent = right
63
+
58
64
right .left = self
65
+
59
66
if parent is not None :
60
67
if parent .left == self :
61
68
parent .left = right
@@ -69,13 +76,20 @@ def rotate_right(self) -> RedBlackTree:
69
76
returns the new root to this subtree.
70
77
Performing one rotation can be done in O(1).
71
78
"""
79
+ if self .left is None :
80
+ return self
72
81
parent = self .parent
73
82
left = self .left
83
+
74
84
self .left = left .right
85
+
75
86
if self .left :
76
87
self .left .parent = self
88
+
77
89
self .parent = left
90
+
78
91
left .right = self
92
+
79
93
if parent is not None :
80
94
if parent .right is self :
81
95
parent .right = left
@@ -123,23 +137,34 @@ def _insert_repair(self) -> None:
123
137
if color (uncle ) == 0 :
124
138
if self .is_left () and self .parent .is_right ():
125
139
self .parent .rotate_right ()
126
- self .right ._insert_repair ()
140
+ if self .right :
141
+ self .right ._insert_repair ()
127
142
elif self .is_right () and self .parent .is_left ():
128
143
self .parent .rotate_left ()
129
- self .left ._insert_repair ()
144
+ if self .left :
145
+ self .left ._insert_repair ()
146
+
130
147
elif self .is_left ():
131
- self .grandparent .rotate_right ()
132
- self .parent .color = 0
133
- self .parent .right .color = 1
148
+ if self .grandparent :
149
+ self .grandparent .rotate_right ()
150
+ self .parent .color = 0
151
+
152
+ if self .parent .right :
153
+ self .parent .right .color = 1
134
154
else :
135
- self .grandparent .rotate_left ()
136
- self .parent .color = 0
137
- self .parent .left .color = 1
155
+ if self .grandparent :
156
+ self .grandparent .rotate_left ()
157
+ self .parent .color = 0
158
+
159
+ if self .parent .left :
160
+ self .parent .left .color = 1
138
161
else :
139
162
self .parent .color = 0
140
- uncle .color = 0
141
- self .grandparent .color = 1
142
- self .grandparent ._insert_repair ()
163
+
164
+ if uncle and self .grandparent :
165
+ uncle .color = 0
166
+ self .grandparent .color = 1
167
+ self .grandparent ._insert_repair ()
143
168
144
169
def remove (self , label : int ) -> RedBlackTree :
145
170
"""Remove label from this tree."""
@@ -150,7 +175,7 @@ def remove(self, label: int) -> RedBlackTree:
150
175
# it and remove that.
151
176
value = self .left .get_max ()
152
177
self .label = value
153
- self .left .remove (value )
178
+ self .left .remove (value ) if value else None
154
179
else :
155
180
# This node has at most one non-None child, so we don't
156
181
# need to replace
@@ -160,10 +185,12 @@ def remove(self, label: int) -> RedBlackTree:
160
185
# The only way this happens to a node with one child
161
186
# is if both children are None leaves.
162
187
# We can just remove this node and call it a day.
163
- if self .is_left ():
164
- self .parent .left = None
165
- else :
166
- self .parent .right = None
188
+ if self .parent :
189
+
190
+ if self .is_left ():
191
+ self .parent .left = None
192
+ else :
193
+ self .parent .right = None
167
194
else :
168
195
# The node is black
169
196
if child is None :
@@ -188,7 +215,7 @@ def remove(self, label: int) -> RedBlackTree:
188
215
self .left .parent = self
189
216
if self .right :
190
217
self .right .parent = self
191
- elif self .label > label :
218
+ elif self .label is not None and self . label > label :
192
219
if self .left :
193
220
self .left .remove (label )
194
221
else :
@@ -198,6 +225,14 @@ def remove(self, label: int) -> RedBlackTree:
198
225
199
226
def _remove_repair (self ) -> None :
200
227
"""Repair the coloring of the tree that may have been messed up."""
228
+ if (
229
+ self .parent is None
230
+ or self .sibling is None
231
+ or self .parent .sibling is None
232
+ or self .grandparent is None
233
+ ):
234
+ return
235
+
201
236
if color (self .sibling ) == 1 :
202
237
self .sibling .color = 0
203
238
self .parent .color = 1
@@ -231,7 +266,9 @@ def _remove_repair(self) -> None:
231
266
):
232
267
self .sibling .rotate_right ()
233
268
self .sibling .color = 0
234
- self .sibling .right .color = 1
269
+
270
+ if self .sibling .right :
271
+ self .sibling .right .color = 1
235
272
if (
236
273
self .is_right ()
237
274
and color (self .sibling ) == 0
@@ -240,7 +277,9 @@ def _remove_repair(self) -> None:
240
277
):
241
278
self .sibling .rotate_left ()
242
279
self .sibling .color = 0
243
- self .sibling .left .color = 1
280
+
281
+ if self .sibling .left :
282
+ self .sibling .left .color = 1
244
283
if (
245
284
self .is_left ()
246
285
and color (self .sibling ) == 0
@@ -297,7 +336,7 @@ def check_color_properties(self) -> bool:
297
336
# All properties were met
298
337
return True
299
338
300
- def check_coloring (self ) -> None :
339
+ def check_coloring (self ) -> bool :
301
340
"""A helper function to recursively check Property 4 of a
302
341
Red-Black Tree. See check_color_properties for more info.
303
342
"""
@@ -310,12 +349,12 @@ def check_coloring(self) -> None:
310
349
return False
311
350
return True
312
351
313
- def black_height (self ) -> int :
352
+ def black_height (self ) -> int | None :
314
353
"""Returns the number of black nodes from this node to the
315
354
leaves of the tree, or None if there isn't one such value (the
316
355
tree is color incorrectly).
317
356
"""
318
- if self is None :
357
+ if self is None or self . left is None or self . right is None :
319
358
# If we're already at a leaf, there is no path
320
359
return 1
321
360
left = RedBlackTree .black_height (self .left )
@@ -332,21 +371,22 @@ def black_height(self) -> int:
332
371
333
372
# Here are functions which are general to all binary search trees
334
373
335
- def __contains__ (self , label ) -> bool :
374
+ def __contains__ (self , label : int ) -> bool :
336
375
"""Search through the tree for label, returning True iff it is
337
376
found somewhere in the tree.
338
377
Guaranteed to run in O(log(n)) time.
339
378
"""
340
379
return self .search (label ) is not None
341
380
342
- def search (self , label : int ) -> RedBlackTree :
381
+ def search (self , label : int ) -> RedBlackTree | None :
343
382
"""Search through the tree for label, returning its node if
344
383
it's found, and None otherwise.
345
384
This method is guaranteed to run in O(log(n)) time.
346
385
"""
347
386
if self .label == label :
348
387
return self
349
- elif label > self .label :
388
+
389
+ elif self .label is not None and label > self .label :
350
390
if self .right is None :
351
391
return None
352
392
else :
@@ -357,12 +397,12 @@ def search(self, label: int) -> RedBlackTree:
357
397
else :
358
398
return self .left .search (label )
359
399
360
- def floor (self , label : int ) -> int :
400
+ def floor (self , label : int ) -> int | None :
361
401
"""Returns the largest element in this tree which is at most label.
362
402
This method is guaranteed to run in O(log(n)) time."""
363
403
if self .label == label :
364
404
return self .label
365
- elif self .label > label :
405
+ elif self .label is not None and self . label > label :
366
406
if self .left :
367
407
return self .left .floor (label )
368
408
else :
@@ -374,13 +414,13 @@ def floor(self, label: int) -> int:
374
414
return attempt
375
415
return self .label
376
416
377
- def ceil (self , label : int ) -> int :
417
+ def ceil (self , label : int ) -> int | None :
378
418
"""Returns the smallest element in this tree which is at least label.
379
419
This method is guaranteed to run in O(log(n)) time.
380
420
"""
381
421
if self .label == label :
382
422
return self .label
383
- elif self .label < label :
423
+ elif self .label is not None and self . label < label :
384
424
if self .right :
385
425
return self .right .ceil (label )
386
426
else :
@@ -392,7 +432,7 @@ def ceil(self, label: int) -> int:
392
432
return attempt
393
433
return self .label
394
434
395
- def get_max (self ) -> int :
435
+ def get_max (self ) -> int | None :
396
436
"""Returns the largest element in this tree.
397
437
This method is guaranteed to run in O(log(n)) time.
398
438
"""
@@ -402,7 +442,7 @@ def get_max(self) -> int:
402
442
else :
403
443
return self .label
404
444
405
- def get_min (self ) -> int :
445
+ def get_min (self ) -> int | None :
406
446
"""Returns the smallest element in this tree.
407
447
This method is guaranteed to run in O(log(n)) time.
408
448
"""
@@ -413,15 +453,15 @@ def get_min(self) -> int:
413
453
return self .label
414
454
415
455
@property
416
- def grandparent (self ) -> RedBlackTree :
456
+ def grandparent (self ) -> RedBlackTree | None :
417
457
"""Get the current node's grandparent, or None if it doesn't exist."""
418
458
if self .parent is None :
419
459
return None
420
460
else :
421
461
return self .parent .parent
422
462
423
463
@property
424
- def sibling (self ) -> RedBlackTree :
464
+ def sibling (self ) -> RedBlackTree | None :
425
465
"""Get the current node's sibling, or None if it doesn't exist."""
426
466
if self .parent is None :
427
467
return None
@@ -432,11 +472,16 @@ def sibling(self) -> RedBlackTree:
432
472
433
473
def is_left (self ) -> bool :
434
474
"""Returns true iff this node is the left child of its parent."""
435
- return self .parent and self .parent .left is self
475
+ if self .parent is None :
476
+ return False
477
+
478
+ return self .parent .left is self .parent .left is self
436
479
437
480
def is_right (self ) -> bool :
438
481
"""Returns true iff this node is the right child of its parent."""
439
- return self .parent and self .parent .right is self
482
+ if self .parent is None :
483
+ return False
484
+ return self .parent .right is self
440
485
441
486
def __bool__ (self ) -> bool :
442
487
return True
@@ -446,27 +491,28 @@ def __len__(self) -> int:
446
491
Return the number of nodes in this tree.
447
492
"""
448
493
ln = 1
494
+
449
495
if self .left :
450
496
ln += len (self .left )
451
497
if self .right :
452
498
ln += len (self .right )
453
499
return ln
454
500
455
- def preorder_traverse (self ) -> Iterator [int ]:
501
+ def preorder_traverse (self ) -> Iterator [int | None ]:
456
502
yield self .label
457
503
if self .left :
458
504
yield from self .left .preorder_traverse ()
459
505
if self .right :
460
506
yield from self .right .preorder_traverse ()
461
507
462
- def inorder_traverse (self ) -> Iterator [int ]:
508
+ def inorder_traverse (self ) -> Iterator [int | None ]:
463
509
if self .left :
464
510
yield from self .left .inorder_traverse ()
465
511
yield self .label
466
512
if self .right :
467
513
yield from self .right .inorder_traverse ()
468
514
469
- def postorder_traverse (self ) -> Iterator [int ]:
515
+ def postorder_traverse (self ) -> Iterator [int | None ]:
470
516
if self .left :
471
517
yield from self .left .postorder_traverse ()
472
518
if self .right :
@@ -488,15 +534,18 @@ def __repr__(self) -> str:
488
534
indent = 1 ,
489
535
)
490
536
491
- def __eq__ (self , other ) -> bool :
537
+ def __eq__ (self , other : object ) -> bool :
492
538
"""Test if two trees are equal."""
539
+ if not isinstance (other , RedBlackTree ):
540
+ return NotImplemented
541
+
493
542
if self .label == other .label :
494
543
return self .left == other .left and self .right == other .right
495
544
else :
496
545
return False
497
546
498
547
499
- def color (node ) -> int :
548
+ def color (node : RedBlackTree | None ) -> int :
500
549
"""Returns the color of a node, allowing for None leaves."""
501
550
if node is None :
502
551
return 0
@@ -514,7 +563,9 @@ def test_rotations() -> bool:
514
563
"""Test that the rotate_left and rotate_right functions work."""
515
564
# Make a tree to test on
516
565
tree = RedBlackTree (0 )
566
+
517
567
tree .left = RedBlackTree (- 10 , parent = tree )
568
+
518
569
tree .right = RedBlackTree (10 , parent = tree )
519
570
tree .left .left = RedBlackTree (- 20 , parent = tree .left )
520
571
tree .left .right = RedBlackTree (- 5 , parent = tree .left )
@@ -529,8 +580,10 @@ def test_rotations() -> bool:
529
580
left_rot .left .left .right = RedBlackTree (- 5 , parent = left_rot .left .left )
530
581
left_rot .right = RedBlackTree (20 , parent = left_rot )
531
582
tree = tree .rotate_left ()
583
+
532
584
if tree != left_rot :
533
585
return False
586
+
534
587
tree = tree .rotate_right ()
535
588
tree = tree .rotate_right ()
536
589
# Make the left rotation
0 commit comments