Skip to content

Commit ea2e55f

Browse files
If f calls Do, panic
Change-Id: Ib2674eca814b56db5408009cefc1f8544f5a717c
1 parent 31d2cc3 commit ea2e55f

File tree

5 files changed

+94
-20
lines changed

5 files changed

+94
-20
lines changed

src/runtime/mainthread/mainthread.go

+5-8
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,22 @@
1313
// yield the main thread temporarily to other [Do] calls by calling [Yield].
1414
//
1515
// Each package's initialization functions always run on the main thread,
16-
// as if by successive calls to Do(init).
16+
// so call Do in init is unnecessary.
1717
//
1818
// For compatibility with earlier versions of Go, if an init function calls [runtime.LockOSThread],
19-
// then package main's func main also runs on the main thread, as if by Do(main).
19+
// then package main's func main also runs on the main thread,
20+
// until the thread is unlocked using [runtime.UnlockOSThread].
2021
// In this situation, main must explicitly yield the main thread
21-
// to allow other calls to Do are to proceed.
22+
// to allow other thread calls to Do are to proceed.
2223
// See the documentation for [Waiting] for examples.
23-
package mainthread // imported as "runtime/mainthread"
24+
package mainthread
2425

2526
import _ "unsafe"
2627

2728
// Do calls f on the main thread.
2829
// Nothing else runs on the main thread until f returns or calls [Yield].
2930
// If f calls Do, the nested call panics.
3031
//
31-
// Package initialization functions run as if by Do(init).
32-
// If an init function calls [runtime.LockOSThread], then package main's func main
33-
// runs as if by Do(main), until the thread is unlocked using [runtime.UnlockOSThread].
34-
//
3532
//go:linkname Do runtime.mainThreadDo
3633
func Do(f func())
3734

src/runtime/proc.go

+79-10
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,57 @@ var modinfo string
115115

116116
var (
117117
m0 m
118+
mtPtr atomic.Pointer[mainThreadOnce]
118119
g0 g
119120
mcache0 *mcache
120121
raceprocctx0 uintptr
121122
raceFiniLock mutex
122123
)
123124

125+
func init() {
126+
mtPtr.Store(newmainThreadInfo())
127+
}
128+
129+
type mainThreadOnce struct {
130+
//m0once make sure to only send m0wait once if mainthread.Yield is not called.
131+
m0once Once
132+
}
133+
134+
// copy from sync package.
135+
type Once struct {
136+
done atomic.Uint32
137+
m mutex
138+
}
139+
140+
func (o *Once) Do(f func()) {
141+
if o.done.Load() == 0 {
142+
o.doSlow(f)
143+
}
144+
}
145+
146+
func (o *Once) doSlow(f func()) {
147+
lock(&o.m)
148+
defer unlock(&o.m)
149+
if o.done.Load() == 0 {
150+
defer o.done.Store(1)
151+
f()
152+
}
153+
}
154+
124155
var (
156+
// m0func pass f from a call to mainthread.Do on non-main thread to the main thread.
125157
m0func = make(chan func())
126-
waitm0 = make(chan struct{})
127-
m0Exec = make(chan struct{}, 1)
158+
// m0wait send a signal that the non-main thread is waiting for mainthread.Yield.
159+
m0wait = make(chan struct{})
160+
// m0exec notifies mainthread.Do when the f passed from Do on the non-main thread
161+
// to Yield on the main thread has completed.
162+
m0exec = make(chan struct{}, 1)
128163
)
129164

165+
func newmainThreadInfo() *mainThreadOnce {
166+
return new(mainThreadOnce)
167+
}
168+
130169
// This slice records the initializing tasks that need to be
131170
// done to start up the runtime. It is built by the linker.
132171
var runtime_inittasks []*initTask
@@ -150,38 +189,68 @@ var runtimeInitTime int64
150189
var initSigmask sigset
151190

152191
func mainThreadDo(f func()) {
153-
gp := getg()
154-
if gp.m == &m0 {
192+
g := getg()
193+
if g.inMainThradDo {
194+
panic("runtime: nested call mainthread.Do")
195+
}
196+
g.inMainThradDo = true
197+
if g.m == &m0 {
155198
// lock os thread ensure that the main thread always
156199
// run only f during a call to f.
157200
lockOSThread()
158201
defer unlockOSThread()
159202
f()
160203
return
161204
}
162-
waitm0 <- struct{}{}
205+
mt := mtPtr.Load()
206+
mt.m0once.Do(func() {
207+
m0wait <- struct{}{}
208+
})
163209
m0func <- f
164-
_ = <-m0Exec
210+
_ = <-m0exec
165211
}
166212

167213
func mainThreadYield() {
168214
g := getg()
169215
if g.m != &m0 {
170216
panic("runtime: call mainthread.Yield must on main thread")
171217
}
218+
if g.inMainThradDo {
219+
panic("runtime: nested call mainthread.Do")
220+
}
221+
g.inMainThradDo = true
172222
// lock os thread ensure that the main thread always
173223
// run only f during a call to f.
174224
lockOSThread()
175225
defer func() {
176226
unlockOSThread()
177-
m0Exec <- struct{}{}
178227
}()
179-
f := <-m0func
180-
f()
228+
for {
229+
select {
230+
case f := <-m0func:
231+
f()
232+
m0exec <- struct{}{}
233+
default:
234+
// because there is only one main thread,
235+
// can use Store directly without considering concurrent call Yield.
236+
mtPtr.Store(newmainThreadInfo())
237+
for {
238+
select {
239+
// if there is a new send from m0func before the mtPtr is updated after
240+
// all the send from m0func has been received.
241+
case f := <-m0func:
242+
f()
243+
m0exec <- struct{}{}
244+
default:
245+
return
246+
}
247+
}
248+
}
249+
}
181250
}
182251

183252
func mainThreadWaiting() <-chan struct{} {
184-
return waitm0
253+
return m0wait
185254
}
186255

187256
// The main goroutine.

src/runtime/proc_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ func TestMainThread(t *testing.T) {
11711171
runtime.Releasem()
11721172
if m != runtime.M0 {
11731173
// don`t use t.Fatal, because it call
1174-
// Goexit, cause TestMain goroutine exit,
1174+
// Goexit, cause TestMain goroutine on main thread exit,
11751175
// and test timeout.
11761176
t.Fail()
11771177
t.Log("mainthread.Do.f must on main thread")

src/runtime/runtime2.go

+2
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ type g struct {
489489

490490
coroarg *coro // argument during coroutine transfers
491491

492+
inMainThradDo bool
493+
492494
// Per-G tracer state.
493495
trace gTraceState
494496

src/runtime/testdata/testprog/mainthread.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package main
66

77
import (
8+
"fmt"
89
"runtime"
910
"runtime/mainthread"
1011
"sync"
@@ -29,13 +30,18 @@ func init() {
2930
MainThread()
3031
})
3132
register("MainThread2", func() {
32-
println("expect: hello,world")
33+
println("expect: hello,runtime: nested call mainthread.Do")
3334
MainThread2()
3435
})
3536
}
3637

3738
func MainThread2() {
3839
var wg sync.WaitGroup
40+
defer func() {
41+
if err := recover(); err != nil {
42+
fmt.Print(err)
43+
}
44+
}()
3945
runtime.LockOSThread()
4046
wg.Add(1)
4147
go func() {

0 commit comments

Comments
 (0)