Skip to content

Commit bdb480f

Browse files
committed
cmd/compile: fix mishandling of unsafe-uintptr arguments in go/defer
Currently, the statement: go g(uintptr(f())) gets rewritten into: tmp := f() newproc(8, g, uintptr(tmp)) runtime.KeepAlive(tmp) which doesn't guarantee that tmp is still alive by time the g call is scheduled to run. This CL fixes the issue, by wrapping g call in a closure: go func(p unsafe.Pointer) { g(uintptr(p)) }(f()) then this will be rewritten into: tmp := f() go func(p unsafe.Pointer) { g(uintptr(p)) runtime.KeepAlive(p) }(tmp) runtime.KeepAlive(tmp) // superfluous, but harmless So the unsafe.Pointer p will be kept alive at the time g call runs. Updates #24491 Change-Id: Ic10821251cbb1b0073daec92b82a866c6ebaf567 Reviewed-on: https://go-review.googlesource.com/c/go/+/253457 Run-TryBot: Cuong Manh Le <[email protected]> Reviewed-by: Matthew Dempsky <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent 1e6ad65 commit bdb480f

File tree

4 files changed

+117
-24
lines changed

4 files changed

+117
-24
lines changed

src/cmd/compile/internal/gc/order.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ func (o *Order) call(n *Node) {
502502
x := o.copyExpr(arg.Left, arg.Left.Type, false)
503503
x.Name.SetKeepalive(true)
504504
arg.Left = x
505+
n.SetNeedsWrapper(true)
505506
}
506507
}
507508

src/cmd/compile/internal/gc/syntax.go

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,20 @@ const (
141141
nodeInitorder, _ // tracks state during init1; two bits
142142
_, _ // second nodeInitorder bit
143143
_, nodeHasBreak
144-
_, nodeNoInline // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
145-
_, nodeImplicit // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
146-
_, nodeIsDDD // is the argument variadic
147-
_, nodeDiag // already printed error about this
148-
_, nodeColas // OAS resulting from :=
149-
_, nodeNonNil // guaranteed to be non-nil
150-
_, nodeTransient // storage can be reused immediately after this statement
151-
_, nodeBounded // bounds check unnecessary
152-
_, nodeHasCall // expression contains a function call
153-
_, nodeLikely // if statement condition likely
154-
_, nodeHasVal // node.E contains a Val
155-
_, nodeHasOpt // node.E contains an Opt
156-
_, nodeEmbedded // ODCLFIELD embedded type
144+
_, nodeNoInline // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
145+
_, nodeImplicit // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
146+
_, nodeIsDDD // is the argument variadic
147+
_, nodeDiag // already printed error about this
148+
_, nodeColas // OAS resulting from :=
149+
_, nodeNonNil // guaranteed to be non-nil
150+
_, nodeTransient // storage can be reused immediately after this statement
151+
_, nodeBounded // bounds check unnecessary
152+
_, nodeHasCall // expression contains a function call
153+
_, nodeLikely // if statement condition likely
154+
_, nodeHasVal // node.E contains a Val
155+
_, nodeHasOpt // node.E contains an Opt
156+
_, nodeEmbedded // ODCLFIELD embedded type
157+
_, nodeNeedsWrapper // OCALLxxx node that needs to be wrapped
157158
)
158159

159160
func (n *Node) Class() Class { return Class(n.flags.get3(nodeClass)) }
@@ -286,6 +287,20 @@ func (n *Node) SetIota(x int64) {
286287
n.Xoffset = x
287288
}
288289

290+
func (n *Node) NeedsWrapper() bool {
291+
return n.flags&nodeNeedsWrapper != 0
292+
}
293+
294+
// SetNeedsWrapper indicates that OCALLxxx node needs to be wrapped by a closure.
295+
func (n *Node) SetNeedsWrapper(b bool) {
296+
switch n.Op {
297+
case OCALLFUNC, OCALLMETH, OCALLINTER:
298+
default:
299+
Fatalf("Node.SetNeedsWrapper %v", n.Op)
300+
}
301+
n.flags.set(nodeNeedsWrapper, b)
302+
}
303+
289304
// mayBeShared reports whether n may occur in multiple places in the AST.
290305
// Extra care must be taken when mutating such a node.
291306
func (n *Node) mayBeShared() bool {

src/cmd/compile/internal/gc/walk.go

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,11 @@ func walkstmt(n *Node) *Node {
232232
n.Left = copyany(n.Left, &n.Ninit, true)
233233

234234
default:
235-
n.Left = walkexpr(n.Left, &n.Ninit)
235+
if n.Left.NeedsWrapper() {
236+
n.Left = wrapCall(n.Left, &n.Ninit)
237+
} else {
238+
n.Left = walkexpr(n.Left, &n.Ninit)
239+
}
236240
}
237241

238242
case OFOR, OFORUNTIL:
@@ -3857,6 +3861,14 @@ func candiscard(n *Node) bool {
38573861
// builtin(a1, a2, a3)
38583862
// }(x, y, z)
38593863
// for print, println, and delete.
3864+
//
3865+
// Rewrite
3866+
// go f(x, y, uintptr(unsafe.Pointer(z)))
3867+
// into
3868+
// go func(a1, a2, a3) {
3869+
// builtin(a1, a2, uintptr(a3))
3870+
// }(x, y, unsafe.Pointer(z))
3871+
// for function contains unsafe-uintptr arguments.
38603872

38613873
var wrapCall_prgen int
38623874

@@ -3868,33 +3880,53 @@ func wrapCall(n *Node, init *Nodes) *Node {
38683880
init.AppendNodes(&n.Ninit)
38693881
}
38703882

3883+
isBuiltinCall := n.Op != OCALLFUNC && n.Op != OCALLMETH && n.Op != OCALLINTER
3884+
// origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion.
3885+
origArgs := make([]*Node, n.List.Len())
38713886
t := nod(OTFUNC, nil, nil)
38723887
for i, arg := range n.List.Slice() {
38733888
s := lookupN("a", i)
3889+
if !isBuiltinCall && arg.Op == OCONVNOP && arg.Type.Etype == TUINTPTR && arg.Left.Type.Etype == TUNSAFEPTR {
3890+
origArgs[i] = arg
3891+
arg = arg.Left
3892+
n.List.SetIndex(i, arg)
3893+
}
38743894
t.List.Append(symfield(s, arg.Type))
38753895
}
38763896

38773897
wrapCall_prgen++
38783898
sym := lookupN("wrap·", wrapCall_prgen)
38793899
fn := dclfunc(sym, t)
38803900

3881-
a := nod(n.Op, nil, nil)
3882-
a.List.Set(paramNnames(t.Type))
3883-
a = typecheck(a, ctxStmt)
3884-
fn.Nbody.Set1(a)
3901+
args := paramNnames(t.Type)
3902+
for i, origArg := range origArgs {
3903+
if origArg == nil {
3904+
continue
3905+
}
3906+
arg := nod(origArg.Op, args[i], nil)
3907+
arg.Type = origArg.Type
3908+
args[i] = arg
3909+
}
3910+
call := nod(n.Op, nil, nil)
3911+
if !isBuiltinCall {
3912+
call.Op = OCALL
3913+
call.Left = n.Left
3914+
}
3915+
call.List.Set(args)
3916+
fn.Nbody.Set1(call)
38853917

38863918
funcbody()
38873919

38883920
fn = typecheck(fn, ctxStmt)
38893921
typecheckslice(fn.Nbody.Slice(), ctxStmt)
38903922
xtop = append(xtop, fn)
38913923

3892-
a = nod(OCALL, nil, nil)
3893-
a.Left = fn.Func.Nname
3894-
a.List.Set(n.List.Slice())
3895-
a = typecheck(a, ctxStmt)
3896-
a = walkexpr(a, init)
3897-
return a
3924+
call = nod(OCALL, nil, nil)
3925+
call.Left = fn.Func.Nname
3926+
call.List.Set(n.List.Slice())
3927+
call = typecheck(call, ctxStmt)
3928+
call = walkexpr(call, init)
3929+
return call
38983930
}
38993931

39003932
// substArgTypes substitutes the given list of types for

test/fixedbugs/issue24491.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// run
2+
3+
// Copyright 2020 The Go Authors. All rights reserved.
4+
// Use of this source code is governed by a BSD-style
5+
// license that can be found in the LICENSE file.
6+
7+
// This test makes sure unsafe-uintptr arguments are handled correctly.
8+
9+
package main
10+
11+
import (
12+
"runtime"
13+
"unsafe"
14+
)
15+
16+
var done = make(chan bool, 1)
17+
18+
func setup() unsafe.Pointer {
19+
s := "ok"
20+
runtime.SetFinalizer(&s, func(p *string) { *p = "FAIL" })
21+
return unsafe.Pointer(&s)
22+
}
23+
24+
//go:noinline
25+
//go:uintptrescapes
26+
func test(s string, p uintptr) {
27+
runtime.GC()
28+
if *(*string)(unsafe.Pointer(p)) != "ok" {
29+
panic(s + " return unexpected result")
30+
}
31+
done <- true
32+
}
33+
34+
func main() {
35+
test("normal", uintptr(setup()))
36+
<-done
37+
38+
go test("go", uintptr(setup()))
39+
<-done
40+
41+
func() {
42+
defer test("defer", uintptr(setup()))
43+
}()
44+
<-done
45+
}

0 commit comments

Comments
 (0)