Skip to content

Commit 8bb5da3

Browse files
h9jianggopherbot
authored andcommitted
gopls/internal/golang: special handling for input context.Context
- Read and determine whether the first input parameter of function or constructor is a context.Context through package name and name comparison. - Call the function or constructor with context.Background() but honor is there is any renaming in foo.go or foo_test.go. - Fix the issue where the constructor param is added to function or method call. For golang/vscode-go#1594 Change-Id: Ic1d145e65bc4b7cb34f637bab8ebdeccd36a33f9 Reviewed-on: https://go-review.googlesource.com/c/tools/+/627355 Reviewed-by: Alan Donovan <[email protected]> Auto-Submit: Hongxiang Jiang <[email protected]> Reviewed-by: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent b4332e0 commit 8bb5da3

File tree

2 files changed

+111
-3
lines changed

2 files changed

+111
-3
lines changed

gopls/internal/golang/addtest.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,21 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
457457

458458
errorType := types.Universe.Lookup("error").Type()
459459

460-
// TODO(hxjiang): handle special case for ctx.Context input.
460+
var isContextType = func(t types.Type) bool {
461+
named, ok := t.(*types.Named)
462+
if !ok {
463+
return false
464+
}
465+
return named.Obj().Pkg().Path() == "context" && named.Obj().Name() == "Context"
466+
}
467+
461468
for i := range sig.Params().Len() {
462469
param := sig.Params().At(i)
463470
name, typ := param.Name(), param.Type()
464471
f := field{Type: types.TypeString(typ, qf)}
465-
if name == "" || name == "_" {
472+
if i == 0 && isContextType(typ) {
473+
f.Value = qf(types.NewPackage("context", "context")) + ".Background()"
474+
} else if name == "" || name == "_" {
466475
f.Value = typesinternal.ZeroString(typ, qf)
467476
} else {
468477
f.Name = name
@@ -594,7 +603,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
594603
param := constructor.Signature().Params().At(i)
595604
name, typ := param.Name(), param.Type()
596605
f := field{Type: types.TypeString(typ, qf)}
597-
if name == "" || name == "_" {
606+
if i == 0 && isContextType(typ) {
607+
f.Value = qf(types.NewPackage("context", "context")) + ".Background()"
608+
} else if name == "" || name == "_" {
598609
f.Value = typesinternal.ZeroString(typ, qf)
599610
} else {
600611
f.Name = name

gopls/internal/test/marker/testdata/codeaction/addtest.txt

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,3 +1319,100 @@ func (r *BarInputFunction) Method(one string, _ func(time.Time) *time.Time) {} /
13191319
+ })
13201320
+ }
13211321
+}
1322+
-- contextinput/contextinput.go --
1323+
package main
1324+
1325+
import "context"
1326+
1327+
func Function(ctx context.Context, _, _ string) (out, out1, out2 string) {return "", "", ""} //@codeactionedit("Function", "source.addTest", function_context)
1328+
1329+
type Foo struct {}
1330+
1331+
func NewFoo(ctx context.Context) (*Foo, error) {return nil, nil}
1332+
1333+
func (*Foo) Method(ctx context.Context, _, _ string) (out, out1, out2 string) {return "", "", ""} //@codeactionedit("Method", "source.addTest", method_context)
1334+
-- contextinput/contextinput_test.go --
1335+
package main_test
1336+
1337+
import renamedctx "context"
1338+
1339+
var local renamedctx.Context
1340+
1341+
-- @function_context/contextinput/contextinput_test.go --
1342+
@@ -3 +3,3 @@
1343+
-import renamedctx "context"
1344+
+import (
1345+
+ renamedctx "context"
1346+
+ "testing"
1347+
@@ -5 +7,3 @@
1348+
+ "golang.org/lsptests/addtest/contextinput"
1349+
+)
1350+
+
1351+
@@ -7 +12,26 @@
1352+
+
1353+
+func TestFunction(t *testing.T) {
1354+
+ tests := []struct {
1355+
+ name string // description of this test case
1356+
+ want string
1357+
+ want2 string
1358+
+ want3 string
1359+
+ }{
1360+
+ // TODO: Add test cases.
1361+
+ }
1362+
+ for _, tt := range tests {
1363+
+ t.Run(tt.name, func(t *testing.T) {
1364+
+ got, got2, got3 := main.Function(renamedctx.Background(), "", "")
1365+
+ // TODO: update the condition below to compare got with tt.want.
1366+
+ if true {
1367+
+ t.Errorf("Function() = %v, want %v", got, tt.want)
1368+
+ }
1369+
+ if true {
1370+
+ t.Errorf("Function() = %v, want %v", got2, tt.want2)
1371+
+ }
1372+
+ if true {
1373+
+ t.Errorf("Function() = %v, want %v", got3, tt.want3)
1374+
+ }
1375+
+ })
1376+
+ }
1377+
+}
1378+
-- @method_context/contextinput/contextinput_test.go --
1379+
@@ -3 +3,3 @@
1380+
-import renamedctx "context"
1381+
+import (
1382+
+ renamedctx "context"
1383+
+ "testing"
1384+
@@ -5 +7,3 @@
1385+
+ "golang.org/lsptests/addtest/contextinput"
1386+
+)
1387+
+
1388+
@@ -7 +12,30 @@
1389+
+
1390+
+func TestFoo_Method(t *testing.T) {
1391+
+ tests := []struct {
1392+
+ name string // description of this test case
1393+
+ want string
1394+
+ want2 string
1395+
+ want3 string
1396+
+ }{
1397+
+ // TODO: Add test cases.
1398+
+ }
1399+
+ for _, tt := range tests {
1400+
+ t.Run(tt.name, func(t *testing.T) {
1401+
+ f, err := main.NewFoo(renamedctx.Background())
1402+
+ if err != nil {
1403+
+ t.Fatalf("could not contruct receiver type: %v", err)
1404+
+ }
1405+
+ got, got2, got3 := f.Method(renamedctx.Background(), "", "")
1406+
+ // TODO: update the condition below to compare got with tt.want.
1407+
+ if true {
1408+
+ t.Errorf("Method() = %v, want %v", got, tt.want)
1409+
+ }
1410+
+ if true {
1411+
+ t.Errorf("Method() = %v, want %v", got2, tt.want2)
1412+
+ }
1413+
+ if true {
1414+
+ t.Errorf("Method() = %v, want %v", got3, tt.want3)
1415+
+ }
1416+
+ })
1417+
+ }
1418+
+}

0 commit comments

Comments
 (0)