Skip to content

Commit 4711299

Browse files
randall77gopherbot
authored andcommitted
cmd/compile: use jump tables for large type switches
For large interface -> concrete type switches, we can use a jump table on some bits of the type hash instead of a binary search on the type hash. name old time/op new time/op delta SwitchTypePredictable-24 1.99ns ± 2% 1.78ns ± 5% -10.87% (p=0.000 n=10+10) SwitchTypeUnpredictable-24 11.0ns ± 1% 9.1ns ± 2% -17.55% (p=0.000 n=7+9) Change-Id: Ida4768e5d62c3ce1c2701288b72664aaa9e64259 Reviewed-on: https://go-review.googlesource.com/c/go/+/521497 Reviewed-by: Keith Randall <[email protected]> Auto-Submit: Keith Randall <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Cherry Mui <[email protected]> Run-TryBot: Keith Randall <[email protected]>
1 parent 556e9c5 commit 4711299

File tree

3 files changed

+142
-1
lines changed

3 files changed

+142
-1
lines changed

src/cmd/compile/internal/test/switch_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,48 @@ func benchmarkSwitchString(b *testing.B, predictable bool) {
120120
sink = n
121121
}
122122

123+
func BenchmarkSwitchTypePredictable(b *testing.B) {
124+
benchmarkSwitchType(b, true)
125+
}
126+
func BenchmarkSwitchTypeUnpredictable(b *testing.B) {
127+
benchmarkSwitchType(b, false)
128+
}
129+
func benchmarkSwitchType(b *testing.B, predictable bool) {
130+
a := []any{
131+
int8(1),
132+
int16(2),
133+
int32(3),
134+
int64(4),
135+
uint8(5),
136+
uint16(6),
137+
uint32(7),
138+
uint64(8),
139+
}
140+
n := 0
141+
rng := newRNG()
142+
for i := 0; i < b.N; i++ {
143+
rng = rng.next(predictable)
144+
switch a[rng.value()&7].(type) {
145+
case int8:
146+
n += 1
147+
case int16:
148+
n += 2
149+
case int32:
150+
n += 3
151+
case int64:
152+
n += 4
153+
case uint8:
154+
n += 5
155+
case uint16:
156+
n += 6
157+
case uint32:
158+
n += 7
159+
case uint64:
160+
n += 8
161+
}
162+
}
163+
}
164+
123165
// A simple random number generator used to make switches conditionally predictable.
124166
type rng uint64
125167

src/cmd/compile/internal/walk/switch.go

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package walk
77
import (
88
"go/constant"
99
"go/token"
10+
"math/bits"
1011
"sort"
1112

1213
"cmd/compile/internal/base"
@@ -617,7 +618,9 @@ func (s *typeSwitch) flush() {
617618
}
618619
cc = merged
619620

620-
// TODO: figure out if we could use a jump table using some low bits of the type hashes.
621+
if s.tryJumpTable(cc, &s.done) {
622+
return
623+
}
621624
binarySearch(len(cc), &s.done,
622625
func(i int) ir.Node {
623626
return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
@@ -632,6 +635,83 @@ func (s *typeSwitch) flush() {
632635
)
633636
}
634637

638+
// Try to implement the clauses with a jump table. Returns true if successful.
639+
func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
640+
const minCases = 5 // have at least minCases cases in the switch
641+
if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
642+
return false
643+
}
644+
if len(cc) < minCases {
645+
return false // not enough cases for it to be worth it
646+
}
647+
hashes := make([]uint32, len(cc))
648+
// b = # of bits to use. Start with the minimum number of
649+
// bits possible, but try a few larger sizes if needed.
650+
b0 := bits.Len(uint(len(cc) - 1))
651+
for b := b0; b < b0+3; b++ {
652+
pickI:
653+
for i := 0; i <= 32-b; i++ { // starting bit position
654+
// Compute the hash we'd get from all the cases,
655+
// selecting b bits starting at bit i.
656+
hashes = hashes[:0]
657+
for _, c := range cc {
658+
h := c.hash >> i & (1<<b - 1)
659+
hashes = append(hashes, h)
660+
}
661+
// Order by increasing hash.
662+
sort.Slice(hashes, func(j, k int) bool {
663+
return hashes[j] < hashes[k]
664+
})
665+
for j := 1; j < len(hashes); j++ {
666+
if hashes[j] == hashes[j-1] {
667+
// There is a duplicate hash; try a different b/i pair.
668+
continue pickI
669+
}
670+
}
671+
672+
// All hashes are distinct. Use these values of b and i.
673+
h := s.hashname
674+
if i != 0 {
675+
h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
676+
}
677+
h = ir.NewBinaryExpr(base.Pos, ir.OAND, h, ir.NewInt(base.Pos, int64(1<<b-1)))
678+
h = typecheck.Expr(h)
679+
680+
// Build jump table.
681+
jt := ir.NewJumpTableStmt(base.Pos, h)
682+
jt.Cases = make([]constant.Value, 1<<b)
683+
jt.Targets = make([]*types.Sym, 1<<b)
684+
out.Append(jt)
685+
686+
// Start with all hashes going to the didn't-match target.
687+
noMatch := typecheck.AutoLabel(".s")
688+
for j := 0; j < 1<<b; j++ {
689+
jt.Cases[j] = constant.MakeInt64(int64(j))
690+
jt.Targets[j] = noMatch
691+
}
692+
// This statement is not reachable, but it will make it obvious that we don't
693+
// fall through to the first case.
694+
out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
695+
696+
// Emit each of the actual cases.
697+
for _, c := range cc {
698+
h := c.hash >> i & (1<<b - 1)
699+
label := typecheck.AutoLabel(".s")
700+
jt.Targets[h] = label
701+
out.Append(ir.NewLabelStmt(base.Pos, label))
702+
out.Append(c.body...)
703+
// We reach here if the hash matches but the type equality test fails.
704+
out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
705+
}
706+
// Emit point to go to if type doesn't match any case.
707+
out.Append(ir.NewLabelStmt(base.Pos, noMatch))
708+
return true
709+
}
710+
}
711+
// Couldn't find a perfect hash. Fall back to binary search.
712+
return false
713+
}
714+
635715
// binarySearch constructs a binary search tree for handling n cases,
636716
// and appends it to out. It's used for efficiently implementing
637717
// switch statements.

test/codegen/switch.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,22 @@ func mimetype(ext string) string {
9999
return ""
100100
}
101101
}
102+
103+
// use jump tables for type switches to concrete types.
104+
func typeSwitch(x any) int {
105+
// amd64:`JMP\s\(.*\)\(.*\)$`
106+
// arm64:`MOVD\s\(R.*\)\(R.*<<3\)`,`JMP\s\(R.*\)$`
107+
switch x.(type) {
108+
case int:
109+
return 0
110+
case int8:
111+
return 1
112+
case int16:
113+
return 2
114+
case int32:
115+
return 3
116+
case int64:
117+
return 4
118+
}
119+
return 7
120+
}

0 commit comments

Comments
 (0)