Skip to content

Commit bdb3fc8

Browse files
committed
Merge pull request #1 from lpereir4/avl
AvlTree performance improvements
2 parents 5e9dd4a + 06945b6 commit bdb3fc8

File tree

8 files changed

+612
-135
lines changed

8 files changed

+612
-135
lines changed

src/library/scala/collection/mutable/AVLTree.scala

Lines changed: 162 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
package scala.collection
1010
package mutable
1111

12-
import annotation.tailrec
1312

1413
/**
1514
* An immutable AVL Tree implementation used by mutable.TreeSet
@@ -22,185 +21,221 @@ private[mutable] sealed trait AVLTree[+A] extends Serializable {
2221

2322
def depth: Int
2423

25-
}
24+
def iterator[B >: A]: Iterator[B] = Iterator.empty
2625

27-
private case class Node[A](val data: A, val left: AVLTree[A], val right: AVLTree[A]) extends AVLTree[A] {
28-
override val balance: Int = right.depth - left.depth
26+
def contains[B >: A](value: B, ordering: Ordering[B]): Boolean = false
2927

30-
override val depth: Int = math.max(left.depth, right.depth) + 1
28+
/**
29+
* Returns a new tree containing the given element.
30+
* Thows an IllegalArgumentException if element is already present.
31+
*
32+
*/
33+
def insert[B >: A](value: B, ordering: Ordering[B]): AVLTree[B] = Node(value, Leaf, Leaf)
3134

35+
/**
36+
* Return a new tree which not contains given element.
37+
*
38+
*/
39+
def remove[B >: A](value: B, ordering: Ordering[B]): AVLTree[A] =
40+
throw new NoSuchElementException(String.valueOf(value))
41+
42+
/**
43+
* Return a tuple containing the smallest element of the provided tree
44+
* and a new tree from which this element has been extracted.
45+
*
46+
*/
47+
def removeMin[B >: A]: (B, AVLTree[B]) = sys.error("Should not happen.")
48+
49+
/**
50+
* Return a tuple containing the biggest element of the provided tree
51+
* and a new tree from which this element has been extracted.
52+
*
53+
*/
54+
def removeMax[B >: A]: (B, AVLTree[B]) = sys.error("Should not happen.")
55+
56+
def rebalance[B >: A]: AVLTree[B] = this
57+
58+
def leftRotation[B >: A]: Node[B] = sys.error("Should not happen.")
59+
60+
def rightRotation[B >: A]: Node[B] = sys.error("Should not happen.")
61+
62+
def doubleLeftRotation[B >: A]: Node[B] = sys.error("Should not happen.")
63+
64+
def doubleRightRotation[B >: A]: Node[B] = sys.error("Should not happen.")
3265
}
3366

3467
private case object Leaf extends AVLTree[Nothing] {
3568
override val balance: Int = 0
3669

3770
override val depth: Int = -1
38-
3971
}
4072

41-
private[mutable] object AVLTree {
73+
private case class Node[A](val data: A, val left: AVLTree[A], val right: AVLTree[A]) extends AVLTree[A] {
74+
override val balance: Int = right.depth - left.depth
75+
76+
override val depth: Int = math.max(left.depth, right.depth) + 1
77+
78+
override def iterator[B >: A]: Iterator[B] = new AVLIterator(this)
79+
80+
override def contains[B >: A](value: B, ordering: Ordering[B]) = {
81+
val ord = ordering.compare(value, data)
82+
if (0 == ord)
83+
true
84+
else if (ord < 0)
85+
left.contains(value, ordering)
86+
else
87+
right.contains(value, ordering)
88+
}
4289

4390
/**
4491
* Returns a new tree containing the given element.
4592
* Thows an IllegalArgumentException if element is already present.
4693
*
4794
*/
48-
def insert[A](value: A, tree: AVLTree[A], ordering: Ordering[A]): AVLTree[A] = {
49-
@tailrec
50-
def insertTC(value: A, tree: AVLTree[A], reassemble: AVLTree[A] => AVLTree[A]): AVLTree[A] = tree match {
51-
case Leaf => reassemble(Node(value, Leaf, Leaf))
52-
53-
case Node(a, left, right) => if (0 == ordering.compare(value, a)) {
54-
throw new IllegalArgumentException()
55-
} else if (-1 == ordering.compare(value, a)) {
56-
insertTC(value, left, x => reassemble(rebalance(Node(a, x, right))))
57-
} else {
58-
insertTC(value, right, x => reassemble(rebalance(Node(a, left, x))))
59-
}
60-
}
61-
62-
insertTC(value, tree, x => rebalance(x))
63-
}
64-
65-
def contains[A](value: A, tree: AVLTree[A], ordering: Ordering[A]): Boolean = tree match {
66-
case Leaf => false
67-
68-
case Node(a, left, right) => if (0 == ordering.compare(value, a)) {
69-
true
70-
} else if (-1 == ordering.compare(value, a)) {
71-
contains(value, left, ordering)
72-
} else {
73-
contains(value, right, ordering)
74-
}
95+
override def insert[B >: A](value: B, ordering: Ordering[B]) = {
96+
val ord = ordering.compare(value, data)
97+
if (0 == ord)
98+
throw new IllegalArgumentException()
99+
else if (ord < 0)
100+
Node(data, left.insert(value, ordering), right).rebalance
101+
else
102+
Node(data, left, right.insert(value, ordering)).rebalance
75103
}
76104

77105
/**
78106
* Return a new tree which not contains given element.
79107
*
80108
*/
81-
def remove[A](value: A, tree: AVLTree[A], ordering: Ordering[A]): AVLTree[A] = tree match {
82-
case Leaf => throw new NoSuchElementException()
83-
84-
case Node(a, Leaf, Leaf) => if (0 == ordering.compare(value, a)) {
85-
Leaf
86-
} else {
87-
throw new NoSuchElementException()
88-
}
89-
90-
case Node(a, left, right@Node(_, _, _)) => if (0 == ordering.compare(value, a)) {
91-
val (min, newRight) = removeMin(right)
92-
rebalance(Node(min, left, newRight))
93-
} else if (-1 == ordering.compare(value, a)) {
94-
rebalance(Node(a, remove(value, left, ordering), right))
95-
} else {
96-
rebalance(Node(a, left, remove(value, right, ordering)))
97-
}
98-
99-
case Node(a, left@Node(_, _, _), right) => if (0 == ordering.compare(value, a)) {
100-
val (max, newLeft) = removeMax(left)
101-
rebalance(Node(max, newLeft, right))
102-
} else if (-1 == ordering.compare(value, a)) {
103-
rebalance(Node(a, remove(value, left, ordering), right))
109+
override def remove[B >: A](value: B, ordering: Ordering[B]): AVLTree[A] = {
110+
val ord = ordering.compare(value, data)
111+
if(ord == 0) {
112+
if (Leaf == left) {
113+
if (Leaf == right) {
114+
Leaf
115+
} else {
116+
val (min, newRight) = right.removeMin
117+
Node(min, left, newRight).rebalance
118+
}
119+
} else {
120+
val (max, newLeft) = left.removeMax
121+
Node(max, newLeft, right).rebalance
122+
}
123+
} else if (ord < 0) {
124+
Node(data, left.remove(value, ordering), right).rebalance
104125
} else {
105-
rebalance(Node(a, left, remove(value, right, ordering)))
126+
Node(data, left, right.remove(value, ordering)).rebalance
106127
}
107128
}
108129

109130
/**
110-
* Return a tuple containing the biggest element of the provided tree
131+
* Return a tuple containing the smallest element of the provided tree
111132
* and a new tree from which this element has been extracted.
112133
*
113134
*/
114-
def removeMax[A](tree: Node[A]): (A, AVLTree[A]) = {
115-
@tailrec
116-
def removeMaxTC(tree: AVLTree[A], assemble: (A, AVLTree[A]) => (A, AVLTree[A])): (A, AVLTree[A]) = tree match {
117-
case Node(a, Leaf, Leaf) => assemble(a, Leaf)
118-
case Node(a, left, Leaf) => assemble(a, left)
119-
case Node(a, left, right) => removeMaxTC(right,
120-
(max: A, avl: AVLTree[A]) => assemble(max, rebalance(Node(a, left, avl))))
121-
case Leaf => sys.error("Should not happen.")
135+
override def removeMin[B >: A]: (B, AVLTree[B]) = {
136+
if (Leaf == left)
137+
(data, right)
138+
else {
139+
val (min, newLeft) = left.removeMin
140+
(min, Node(data, newLeft, right).rebalance)
122141
}
123-
124-
removeMaxTC(tree, (a, b) => (a, b))
125142
}
126143

127144
/**
128-
* Return a tuple containing the smallest element of the provided tree
145+
* Return a tuple containing the biggest element of the provided tree
129146
* and a new tree from which this element has been extracted.
130147
*
131148
*/
132-
def removeMin[A](tree: Node[A]): (A, AVLTree[A]) = {
133-
@tailrec
134-
def removeMinTC(tree: AVLTree[A], assemble: (A, AVLTree[A]) => (A, AVLTree[A])): (A, AVLTree[A]) = tree match {
135-
case Node(a, Leaf, Leaf) => assemble(a, Leaf)
136-
case Node(a, Leaf, right) => assemble(a, right)
137-
case Node(a, left, right) => removeMinTC(left,
138-
(min: A, avl: AVLTree[A]) => assemble(min, rebalance(Node(a, avl, right))))
139-
case Leaf => sys.error("Should not happen.")
149+
override def removeMax[B >: A]: (B, AVLTree[B]) = {
150+
if (Leaf == right)
151+
(data, left)
152+
else {
153+
val (max, newRight) = right.removeMax
154+
(max, Node(data, left, newRight).rebalance)
140155
}
141-
142-
removeMinTC(tree, (a, b) => (a, b))
143156
}
144-
145-
/**
146-
* Returns a bounded stream of elements in the tree.
147-
*
148-
*/
149-
def toStream[A](tree: AVLTree[A], isLeftAcceptable: A => Boolean, isRightAcceptable: A => Boolean): Stream[A] = tree match {
150-
case Leaf => Stream.empty
151-
152-
case Node(a, left, right) => if (isLeftAcceptable(a)) {
153-
if (isRightAcceptable(a)) {
154-
toStream(left, isLeftAcceptable, isRightAcceptable) ++ Stream(a) ++ toStream(right, isLeftAcceptable, isRightAcceptable)
155-
} else {
156-
toStream(left, isLeftAcceptable, isRightAcceptable)
157-
}
158-
} else if (isRightAcceptable(a)) {
159-
toStream(right, isLeftAcceptable, isRightAcceptable)
157+
158+
override def rebalance[B >: A] = {
159+
if (-2 == balance) {
160+
if (1 == left.balance)
161+
doubleRightRotation
162+
else
163+
rightRotation
164+
} else if (2 == balance) {
165+
if (-1 == right.balance)
166+
doubleLeftRotation
167+
else
168+
leftRotation
160169
} else {
161-
Stream.empty
170+
this
162171
}
163172
}
164173

165-
/**
166-
* Returns a bounded iterator of elements in the tree.
167-
*
168-
*/
169-
def iterator[A](tree: AVLTree[A], isLeftAcceptable: A => Boolean, isRightAcceptable: A => Boolean): Iterator[A] =
170-
toStream(tree, isLeftAcceptable, isRightAcceptable).iterator
171-
172-
def rebalance[A](tree: AVLTree[A]): AVLTree[A] = (tree, tree.balance) match {
173-
case (node@Node(_, left, _), -2) => left.balance match {
174-
case 1 => doubleRightRotation(node)
175-
case _ => rightRotation(node)
176-
}
177-
178-
case (node@Node(_, _, right), 2) => right.balance match {
179-
case -1 => doubleLeftRotation(node)
180-
case _ => leftRotation(node)
181-
}
174+
override def leftRotation[B >: A] = {
175+
if (Leaf != right) {
176+
val r: Node[A] = right.asInstanceOf[Node[A]]
177+
Node(r.data, Node(data, left, r.left), r.right)
178+
} else sys.error("Should not happen.")
179+
}
182180

183-
case _ => tree
181+
override def rightRotation[B >: A] = {
182+
if (Leaf != left) {
183+
val l: Node[A] = left.asInstanceOf[Node[A]]
184+
Node(l.data, l.left, Node(data, l.right, right))
185+
} else sys.error("Should not happen.")
184186
}
185187

186-
def leftRotation[A](tree: Node[A]): AVLTree[A] = tree.right match {
187-
case Node(b, left, right) => Node(b, Node(tree.data, tree.left, left), right)
188-
case _ => sys.error("Should not happen.")
188+
override def doubleLeftRotation[B >: A] = {
189+
if (Leaf != right) {
190+
val r: Node[A] = right.asInstanceOf[Node[A]]
191+
// Let's save an instanceOf by 'inlining' the left rotation
192+
val rightRotated = r.rightRotation
193+
Node(rightRotated.data, Node(data, left, rightRotated.left), rightRotated.right)
194+
} else sys.error("Should not happen.")
189195
}
190196

191-
def rightRotation[A](tree: Node[A]): AVLTree[A] = tree.left match {
192-
case Node(b, left, right) => Node(b, left, Node(tree.data, right, tree.right))
193-
case _ => sys.error("Should not happen.")
197+
override def doubleRightRotation[B >: A] = {
198+
if (Leaf != left) {
199+
val l: Node[A] = left.asInstanceOf[Node[A]]
200+
// Let's save an instanceOf by 'inlining' the right rotation
201+
val leftRotated = l.leftRotation
202+
Node(leftRotated.data, leftRotated.left, Node(data, leftRotated.right, right))
203+
} else sys.error("Should not happen.")
194204
}
205+
}
206+
207+
private class AVLIterator[A](root: Node[A]) extends Iterator[A] {
208+
val stack = mutable.ArrayStack[Node[A]](root)
209+
diveLeft()
195210

196-
def doubleLeftRotation[A](tree: Node[A]): AVLTree[A] = tree.right match {
197-
case right@Node(b, l, r) => leftRotation(Node(tree.data, tree.left, rightRotation(right)))
198-
case _ => sys.error("Should not happen.")
211+
private def diveLeft(): Unit = {
212+
if (Leaf != stack.head.left) {
213+
val left: Node[A] = stack.head.left.asInstanceOf[Node[A]]
214+
stack.push(left)
215+
diveLeft()
216+
}
199217
}
200218

201-
def doubleRightRotation[A](tree: Node[A]): AVLTree[A] = tree.left match {
202-
case left@Node(b, l, r) => rightRotation(Node(tree.data, leftRotation(left), tree.right))
203-
case _ => sys.error("Should not happen.")
219+
private def engageRight(): Unit = {
220+
if (Leaf != stack.head.right) {
221+
val right: Node[A] = stack.head.right.asInstanceOf[Node[A]]
222+
stack.pop
223+
stack.push(right)
224+
diveLeft()
225+
} else
226+
stack.pop
204227
}
205228

229+
override def hasNext: Boolean = !stack.isEmpty
230+
231+
override def next(): A = {
232+
if (stack.isEmpty)
233+
throw new NoSuchElementException()
234+
else {
235+
val result = stack.head.data
236+
// Let's maintain stack for the next invocation
237+
engageRight()
238+
result
239+
}
240+
}
206241
}

src/library/scala/collection/mutable/TreeSet.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class TreeSet[A](implicit val ordering: Ordering[A]) extends SortedSet[A] with S
7979

8080
override def -=(elem: A): this.type = {
8181
try {
82-
resolve.avl = AVLTree.remove(elem, resolve.avl, ordering)
82+
resolve.avl = resolve.avl.remove(elem, ordering)
8383
resolve.cardinality = resolve.cardinality - 1
8484
} catch {
8585
case e: NoSuchElementException => ()
@@ -89,7 +89,7 @@ class TreeSet[A](implicit val ordering: Ordering[A]) extends SortedSet[A] with S
8989

9090
override def +=(elem: A): this.type = {
9191
try {
92-
resolve.avl = AVLTree.insert(elem, resolve.avl, ordering)
92+
resolve.avl = resolve.avl.insert(elem, ordering)
9393
resolve.cardinality = resolve.cardinality + 1
9494
} catch {
9595
case e: IllegalArgumentException => ()
@@ -98,7 +98,7 @@ class TreeSet[A](implicit val ordering: Ordering[A]) extends SortedSet[A] with S
9898
}
9999

100100
/**
101-
* Thanks to the nature immutable of the
101+
* Thanks to the immutable nature of the
102102
* underlying AVL Tree, we can share it with
103103
* the clone. So clone complexity in time is O(1).
104104
*
@@ -113,11 +113,11 @@ class TreeSet[A](implicit val ordering: Ordering[A]) extends SortedSet[A] with S
113113
override def contains(elem: A): Boolean = {
114114
isLeftAcceptable(from, ordering)(elem) &&
115115
isRightAcceptable(until, ordering)(elem) &&
116-
AVLTree.contains(elem, resolve.avl, ordering)
116+
resolve.avl.contains(elem, ordering)
117117
}
118118

119-
override def iterator: Iterator[A] =
120-
AVLTree.iterator(resolve.avl,
121-
isLeftAcceptable(from, ordering),
122-
isRightAcceptable(until, ordering))
119+
override def iterator: Iterator[A] = resolve.avl.iterator
120+
.dropWhile(e => !isLeftAcceptable(from, ordering)(e))
121+
.takeWhile(e => isRightAcceptable(until, ordering)(e))
122+
123123
}

0 commit comments

Comments
 (0)