Skip to content

Commit 30088ac

Browse files
committed
cmd/compile: make CSE faster
To refine a set of possibly equivalent values, the old CSE algorithm picked one value, compared it against all the others, and made two sets out of the results (the values that match the picked value and the values that didn't). Unfortunately, this leads to O(n^2) behavior. The picked value ends up being equal to no other values, we make size 1 and size n-1 sets, and then recurse on the size n-1 set. Instead, sort the set by the equivalence classes of its arguments. Then we just look for spots in the sorted list where the equivalence classes of the arguments change. This lets us do a multi-way split for O(n lg n) time. This change makes cmpDepth unnecessary. The refinement portion used to call the type comparator. That is unnecessary as the type was already part of the initial partition. Lowers time of 16361 from 8 sec to 3 sec. Lowers time of 15112 from 282 sec to 20 sec. That's kind of unfair, as CL 30257 changed it from 21 sec to 282 sec. But that CL fixed other bad compile times (issue #17127) by large factors, so net still a big win. Fixes #15112 Fixes #16361 Change-Id: I351ce111bae446608968c6d48710eeb6a3d8e527 Reviewed-on: https://go-review.googlesource.com/30354 Reviewed-by: Todd Neal <[email protected]>
1 parent bd06d48 commit 30088ac

File tree

1 file changed

+73
-52
lines changed
  • src/cmd/compile/internal/ssa

1 file changed

+73
-52
lines changed

src/cmd/compile/internal/ssa/cse.go

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ import (
99
"sort"
1010
)
1111

12-
const (
13-
cmpDepth = 1
14-
)
15-
1612
// cse does common-subexpression elimination on the Function.
1713
// Values are just relinked, nothing is deleted. A subsequent deadcode
1814
// pass is required to actually remove duplicate expressions.
@@ -60,7 +56,8 @@ func cse(f *Func) {
6056
valueEqClass[v.ID] = -v.ID
6157
}
6258
}
63-
for i, e := range partition {
59+
var pNum ID = 1
60+
for _, e := range partition {
6461
if f.pass.debug > 1 && len(e) > 500 {
6562
fmt.Printf("CSE.large partition (%d): ", len(e))
6663
for j := 0; j < 3; j++ {
@@ -70,60 +67,74 @@ func cse(f *Func) {
7067
}
7168

7269
for _, v := range e {
73-
valueEqClass[v.ID] = ID(i)
70+
valueEqClass[v.ID] = pNum
7471
}
7572
if f.pass.debug > 2 && len(e) > 1 {
76-
fmt.Printf("CSE.partition #%d:", i)
73+
fmt.Printf("CSE.partition #%d:", pNum)
7774
for _, v := range e {
7875
fmt.Printf(" %s", v.String())
7976
}
8077
fmt.Printf("\n")
8178
}
79+
pNum++
8280
}
8381

84-
// Find an equivalence class where some members of the class have
85-
// non-equivalent arguments. Split the equivalence class appropriately.
86-
// Repeat until we can't find any more splits.
82+
// Split equivalence classes at points where they have
83+
// non-equivalent arguments. Repeat until we can't find any
84+
// more splits.
85+
var splitPoints []int
8786
for {
8887
changed := false
8988

9089
// partition can grow in the loop. By not using a range loop here,
9190
// we process new additions as they arrive, avoiding O(n^2) behavior.
9291
for i := 0; i < len(partition); i++ {
9392
e := partition[i]
94-
v := e[0]
95-
// all values in this equiv class that are not equivalent to v get moved
96-
// into another equiv class.
97-
// To avoid allocating while building that equivalence class,
98-
// move the values equivalent to v to the beginning of e
99-
// and other values to the end of e.
100-
allvals := e
101-
eqloop:
102-
for j := 1; j < len(e); {
103-
w := e[j]
104-
equivalent := true
105-
for i := 0; i < len(v.Args); i++ {
106-
if valueEqClass[v.Args[i].ID] != valueEqClass[w.Args[i].ID] {
107-
equivalent = false
93+
94+
// Sort by eq class of arguments.
95+
sort.Sort(partitionByArgClass{e, valueEqClass})
96+
97+
// Find split points.
98+
splitPoints = append(splitPoints[:0], 0)
99+
for j := 1; j < len(e); j++ {
100+
v, w := e[j-1], e[j]
101+
eqArgs := true
102+
for k, a := range v.Args {
103+
b := w.Args[k]
104+
if valueEqClass[a.ID] != valueEqClass[b.ID] {
105+
eqArgs = false
108106
break
109107
}
110108
}
111-
if !equivalent || v.Type.Compare(w.Type) != CMPeq {
112-
// w is not equivalent to v.
113-
// move it to the end and shrink e.
114-
e[j], e[len(e)-1] = e[len(e)-1], e[j]
115-
e = e[:len(e)-1]
116-
valueEqClass[w.ID] = ID(len(partition))
117-
changed = true
118-
continue eqloop
109+
if !eqArgs {
110+
splitPoints = append(splitPoints, j)
119111
}
120-
// v and w are equivalent. Keep w in e.
121-
j++
122112
}
123-
partition[i] = e
124-
if len(e) < len(allvals) {
125-
partition = append(partition, allvals[len(e):])
113+
if len(splitPoints) == 1 {
114+
continue // no splits, leave equivalence class alone.
126115
}
116+
117+
// Move another equivalence class down in place of e.
118+
partition[i] = partition[len(partition)-1]
119+
partition = partition[:len(partition)-1]
120+
i--
121+
122+
// Add new equivalence classes for the parts of e we found.
123+
splitPoints = append(splitPoints, len(e))
124+
for j := 0; j < len(splitPoints)-1; j++ {
125+
f := e[splitPoints[j]:splitPoints[j+1]]
126+
if len(f) == 1 {
127+
// Don't add singletons.
128+
valueEqClass[f[0].ID] = -f[0].ID
129+
continue
130+
}
131+
for _, v := range f {
132+
valueEqClass[v.ID] = pNum
133+
}
134+
pNum++
135+
partition = append(partition, f)
136+
}
137+
changed = true
127138
}
128139

129140
if !changed {
@@ -253,7 +264,7 @@ func partitionValues(a []*Value, auxIDs auxmap) []eqclass {
253264
j := 1
254265
for ; j < len(a); j++ {
255266
w := a[j]
256-
if cmpVal(v, w, auxIDs, cmpDepth) != CMPeq {
267+
if cmpVal(v, w, auxIDs) != CMPeq {
257268
break
258269
}
259270
}
@@ -274,7 +285,7 @@ func lt2Cmp(isLt bool) Cmp {
274285

275286
type auxmap map[interface{}]int32
276287

277-
func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp {
288+
func cmpVal(v, w *Value, auxIDs auxmap) Cmp {
278289
// Try to order these comparison by cost (cheaper first)
279290
if v.Op != w.Op {
280291
return lt2Cmp(v.Op < w.Op)
@@ -308,18 +319,6 @@ func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp {
308319
return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux])
309320
}
310321

311-
if depth > 0 {
312-
for i := range v.Args {
313-
if v.Args[i] == w.Args[i] {
314-
// skip comparing equal args
315-
continue
316-
}
317-
if ac := cmpVal(v.Args[i], w.Args[i], auxIDs, depth-1); ac != CMPeq {
318-
return ac
319-
}
320-
}
321-
}
322-
323322
return CMPeq
324323
}
325324

@@ -334,7 +333,7 @@ func (sv sortvalues) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
334333
func (sv sortvalues) Less(i, j int) bool {
335334
v := sv.a[i]
336335
w := sv.a[j]
337-
if cmp := cmpVal(v, w, sv.auxIDs, cmpDepth); cmp != CMPeq {
336+
if cmp := cmpVal(v, w, sv.auxIDs); cmp != CMPeq {
338337
return cmp == CMPlt
339338
}
340339

@@ -354,3 +353,25 @@ func (sv partitionByDom) Less(i, j int) bool {
354353
w := sv.a[j]
355354
return sv.sdom.domorder(v.Block) < sv.sdom.domorder(w.Block)
356355
}
356+
357+
type partitionByArgClass struct {
358+
a []*Value // array of values
359+
eqClass []ID // equivalence class IDs of values
360+
}
361+
362+
func (sv partitionByArgClass) Len() int { return len(sv.a) }
363+
func (sv partitionByArgClass) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
364+
func (sv partitionByArgClass) Less(i, j int) bool {
365+
v := sv.a[i]
366+
w := sv.a[j]
367+
for i, a := range v.Args {
368+
b := w.Args[i]
369+
if sv.eqClass[a.ID] < sv.eqClass[b.ID] {
370+
return true
371+
}
372+
if sv.eqClass[a.ID] > sv.eqClass[b.ID] {
373+
return false
374+
}
375+
}
376+
return false
377+
}

0 commit comments

Comments
 (0)