@@ -24,13 +24,13 @@ import scala.annotation.meta.getter
24
24
*
25
25
* @since 2.10
26
26
*/
27
- private [immutable ]
27
+ private [collection ]
28
28
object RedBlackTree {
29
29
30
30
def isEmpty (tree : Tree [_, _]): Boolean = tree eq null
31
31
32
- def contains [A ](tree : Tree [A , _], x : A )( implicit ordering : Ordering [ A ] ): Boolean = lookup(tree, x) ne null
33
- def get [A , B ](tree : Tree [A , B ], x : A )( implicit ordering : Ordering [ A ] ): Option [B ] = lookup(tree, x) match {
32
+ def contains [A : Ordering ](tree : Tree [A , _], x : A ): Boolean = lookup(tree, x) ne null
33
+ def get [A : Ordering , B ](tree : Tree [A , B ], x : A ): Option [B ] = lookup(tree, x) match {
34
34
case null => None
35
35
case tree => Some (tree.value)
36
36
}
@@ -44,8 +44,27 @@ object RedBlackTree {
44
44
}
45
45
46
46
def count (tree : Tree [_, _]) = if (tree eq null ) 0 else tree.count
47
- def update [A , B , B1 >: B ](tree : Tree [A , B ], k : A , v : B1 , overwrite : Boolean )(implicit ordering : Ordering [A ]): Tree [A , B1 ] = blacken(upd(tree, k, v, overwrite))
48
- def delete [A , B ](tree : Tree [A , B ], k : A )(implicit ordering : Ordering [A ]): Tree [A , B ] = blacken(del(tree, k))
47
+ /**
48
+ * Count all the nodes with keys greater than or equal to the lower bound and less than the upper bound.
49
+ * The two bounds are optional.
50
+ */
51
+ def countInRange [A ](tree : Tree [A , _], from : Option [A ], to: Option [A ])(implicit ordering : Ordering [A ]) : Int =
52
+ if (tree eq null ) 0 else
53
+ (from, to) match {
54
+ // with no bounds use this node's count
55
+ case (None , None ) => tree.count
56
+ // if node is less than the lower bound, try the tree on the right, it might be in range
57
+ case (Some (lb), _) if ordering.lt(tree.key, lb) => countInRange(tree.right, from, to)
58
+ // if node is greater than or equal to the upper bound, try the tree on the left, it might be in range
59
+ case (_, Some (ub)) if ordering.gteq(tree.key, ub) => countInRange(tree.left, from, to)
60
+ // node is in range so the tree on the left will all be less than the upper bound and the tree on the
61
+ // right will all be greater than or equal to the lower bound. So 1 for this node plus
62
+ // count the subtrees by stripping off the bounds that we don't need any more
63
+ case _ => 1 + countInRange(tree.left, from, None ) + countInRange(tree.right, None , to)
64
+
65
+ }
66
+ def update [A : Ordering , B , B1 >: B ](tree : Tree [A , B ], k : A , v : B1 , overwrite : Boolean ): Tree [A , B1 ] = blacken(upd(tree, k, v, overwrite))
67
+ def delete [A : Ordering , B ](tree : Tree [A , B ], k : A ): Tree [A , B ] = blacken(del(tree, k))
49
68
def rangeImpl [A : Ordering , B ](tree : Tree [A , B ], from : Option [A ], until : Option [A ]): Tree [A , B ] = (from, until) match {
50
69
case (Some (from), Some (until)) => this .range(tree, from, until)
51
70
case (Some (from), None ) => this .from(tree, from)
@@ -91,9 +110,9 @@ object RedBlackTree {
91
110
if (tree.right ne null ) _foreachKey(tree.right, f)
92
111
}
93
112
94
- def iterator [A , B ](tree : Tree [A , B ]): Iterator [(A , B )] = new EntriesIterator (tree)
95
- def keysIterator [A , _ ](tree : Tree [A , _]): Iterator [A ] = new KeysIterator (tree)
96
- def valuesIterator [_ , B ](tree : Tree [_ , B ]): Iterator [B ] = new ValuesIterator (tree)
113
+ def iterator [A : Ordering , B ](tree : Tree [A , B ], start : Option [ A ] = None ): Iterator [(A , B )] = new EntriesIterator (tree, start )
114
+ def keysIterator [A : Ordering ](tree : Tree [A , _], start : Option [ A ] = None ): Iterator [A ] = new KeysIterator (tree, start )
115
+ def valuesIterator [A : Ordering , B ](tree : Tree [A , B ], start : Option [ A ] = None ): Iterator [B ] = new ValuesIterator (tree, start )
97
116
98
117
@ tailrec
99
118
def nth [A , B ](tree : Tree [A , B ], n : Int ): Tree [A , B ] = {
@@ -425,32 +444,28 @@ object RedBlackTree {
425
444
def unapply [A , B ](t : BlackTree [A , B ]) = Some ((t.key, t.value, t.left, t.right))
426
445
}
427
446
428
- private [this ] abstract class TreeIterator [A , B , R ](tree : Tree [A , B ]) extends Iterator [R ] {
447
+ private [this ] abstract class TreeIterator [A , B , R ](root : Tree [A , B ], start : Option [ A ])( implicit ordering : Ordering [ A ]) extends Iterator [R ] {
429
448
protected [this ] def nextResult (tree : Tree [A , B ]): R
430
449
431
- override def hasNext : Boolean = next ne null
450
+ override def hasNext : Boolean = lookahead ne null
432
451
433
- override def next : R = next match {
452
+ override def next : R = lookahead match {
434
453
case null =>
435
454
throw new NoSuchElementException (" next on empty iterator" )
436
455
case tree =>
437
- next = findNext( tree.right )
456
+ lookahead = findLeftMostOrPopOnEmpty(goRight( tree) )
438
457
nextResult(tree)
439
458
}
440
459
441
460
@ tailrec
442
- private [this ] def findNext (tree : Tree [A , B ]): Tree [A , B ] = {
443
- if (tree eq null ) popPath ()
461
+ private [this ] def findLeftMostOrPopOnEmpty (tree : Tree [A , B ]): Tree [A , B ] =
462
+ if (tree eq null ) popNext ()
444
463
else if (tree.left eq null ) tree
445
- else {
446
- pushPath(tree)
447
- findNext(tree.left)
448
- }
449
- }
464
+ else findLeftMostOrPopOnEmpty(goLeft(tree))
450
465
451
- private [this ] def pushPath (tree : Tree [A , B ]) {
466
+ private [this ] def pushNext (tree : Tree [A , B ]) {
452
467
try {
453
- path (index) = tree
468
+ stackOfNexts (index) = tree
454
469
index += 1
455
470
} catch {
456
471
case _ : ArrayIndexOutOfBoundsException =>
@@ -462,17 +477,17 @@ object RedBlackTree {
462
477
* An exception handler is used instead of an if-condition to optimize the normal path.
463
478
* This makes a large difference in iteration speed!
464
479
*/
465
- assert(index >= path .length)
466
- path :+= null
467
- pushPath (tree)
480
+ assert(index >= stackOfNexts .length)
481
+ stackOfNexts :+= null
482
+ pushNext (tree)
468
483
}
469
484
}
470
- private [this ] def popPath (): Tree [A , B ] = if (index == 0 ) null else {
485
+ private [this ] def popNext (): Tree [A , B ] = if (index == 0 ) null else {
471
486
index -= 1
472
- path (index)
487
+ stackOfNexts (index)
473
488
}
474
489
475
- private [this ] var path = if (tree eq null ) null else {
490
+ private [this ] var stackOfNexts = if (root eq null ) null else {
476
491
/*
477
492
* According to "Ralf Hinze. Constructing red-black trees" [http://www.cs.ox.ac.uk/ralf.hinze/publications/#P5]
478
493
* the maximum height of a red-black tree is 2*log_2(n + 2) - 2.
@@ -481,22 +496,45 @@ object RedBlackTree {
481
496
*
482
497
* We also don't store the deepest nodes in the path so the maximum path length is further reduced by one.
483
498
*/
484
- val maximumHeight = 2 * (32 - Integer .numberOfLeadingZeros(tree .count + 2 - 1 )) - 2 - 1
499
+ val maximumHeight = 2 * (32 - Integer .numberOfLeadingZeros(root .count + 2 - 1 )) - 2 - 1
485
500
new Array [Tree [A , B ]](maximumHeight)
486
501
}
487
502
private [this ] var index = 0
488
- private [this ] var next : Tree [A , B ] = findNext(tree)
503
+ private [this ] var lookahead : Tree [A , B ] = start map startFrom getOrElse findLeftMostOrPopOnEmpty(root)
504
+
505
+ /**
506
+ * Find the leftmost subtree whose key is equal to the given key, or if no such thing,
507
+ * the leftmost subtree with the key that would be "next" after it according
508
+ * to the ordering. Along the way build up the iterator's path stack so that "next"
509
+ * functionality works.
510
+ */
511
+ private [this ] def startFrom (key : A ) : Tree [A ,B ] = if (root eq null ) null else {
512
+ @ tailrec def find (tree : Tree [A , B ]): Tree [A , B ] =
513
+ if (tree eq null ) popNext
514
+ else find(
515
+ if (ordering.lteq(key, tree.key)) goLeft(tree)
516
+ else goRight(tree)
517
+ )
518
+ find(root)
519
+ }
520
+
521
+ private [this ] def goLeft (tree : Tree [A , B ]) = {
522
+ pushNext(tree)
523
+ tree.left
524
+ }
525
+
526
+ private [this ] def goRight (tree : Tree [A , B ]) = tree.right
489
527
}
490
528
491
- private [this ] class EntriesIterator [A , B ](tree : Tree [A , B ]) extends TreeIterator [A , B , (A , B )](tree) {
529
+ private [this ] class EntriesIterator [A : Ordering , B ](tree : Tree [A , B ], focus : Option [ A ] ) extends TreeIterator [A , B , (A , B )](tree, focus ) {
492
530
override def nextResult (tree : Tree [A , B ]) = (tree.key, tree.value)
493
531
}
494
532
495
- private [this ] class KeysIterator [A , B ](tree : Tree [A , B ]) extends TreeIterator [A , B , A ](tree) {
533
+ private [this ] class KeysIterator [A : Ordering , B ](tree : Tree [A , B ], focus : Option [ A ] ) extends TreeIterator [A , B , A ](tree, focus ) {
496
534
override def nextResult (tree : Tree [A , B ]) = tree.key
497
535
}
498
536
499
- private [this ] class ValuesIterator [A , B ](tree : Tree [A , B ]) extends TreeIterator [A , B , B ](tree) {
537
+ private [this ] class ValuesIterator [A : Ordering , B ](tree : Tree [A , B ], focus : Option [ A ] ) extends TreeIterator [A , B , B ](tree, focus ) {
500
538
override def nextResult (tree : Tree [A , B ]) = tree.value
501
539
}
502
540
}
0 commit comments