Skip to content

Commit 691997a

Browse files
h9jianggopherbot
authored andcommitted
gopls/internal/golang: consolidate imports from both file in qualifier
This commit improves qualifier by consolidating imports from both the main file (x.go) and its corresponding test file (x_test.go). An imports map is used to track all import paths and their local renames. Imports from x_test.go are prioritized over x.go as gopls is generating test in x_test.go. This ensures that the generated qualifier correctly reflects any necessary renames, improving accuracy and consistency. For golang/vscode-go#1594 Change-Id: I457d5f22f7de4fe86006b57487f243494c8e7f6f Reviewed-on: https://go-review.googlesource.com/c/tools/+/622320 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]> Auto-Submit: Hongxiang Jiang <[email protected]>
1 parent 0c792f1 commit 691997a

File tree

2 files changed

+134
-11
lines changed

2 files changed

+134
-11
lines changed

gopls/internal/golang/addtest.go

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"html/template"
1818
"os"
1919
"path/filepath"
20+
"strconv"
2021
"strings"
2122

2223
"golang.org/x/tools/go/ast/astutil"
@@ -115,6 +116,33 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
115116
return nil, fmt.Errorf("package has type errors: %v", errors[0])
116117
}
117118

119+
// imports is a map from package path to local package name.
120+
var imports = make(map[string]string)
121+
122+
var collectImports = func(file *ast.File) error {
123+
for _, spec := range file.Imports {
124+
// TODO(hxjiang): support dot imports.
125+
if spec.Name != nil && spec.Name.Name == "." {
126+
return fmt.Errorf("\"add a test for FUNC\" does not support files containing dot imports")
127+
}
128+
path, err := strconv.Unquote(spec.Path.Value)
129+
if err != nil {
130+
return err
131+
}
132+
if spec.Name != nil && spec.Name.Name != "_" {
133+
imports[path] = spec.Name.Name
134+
} else {
135+
imports[path] = filepath.Base(path)
136+
}
137+
}
138+
return nil
139+
}
140+
141+
// Collect all the imports from the x.go, keep track of the local package name.
142+
if err := collectImports(pgf.File); err != nil {
143+
return nil, err
144+
}
145+
118146
testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go"
119147
goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.Dir().Path(), testBase))
120148

@@ -192,6 +220,26 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
192220
if err != nil {
193221
return nil, err
194222
}
223+
224+
// Collect all the imports from the x_test.go, overwrite the local pakcage
225+
// name collected from x.go.
226+
if err := collectImports(testPGF.File); err != nil {
227+
return nil, err
228+
}
229+
}
230+
231+
// qf qualifier returns the local package name need to use in x_test.go by
232+
// consulting the consolidated imports map.
233+
qf := func(p *types.Package) string {
234+
// When generating test in x packages, any type/function defined in the same
235+
// x package can emit package name.
236+
if !xtest && p == pkg.Types() {
237+
return ""
238+
}
239+
if local, ok := imports[p.Path()]; ok {
240+
return local
241+
}
242+
return p.Name()
195243
}
196244

197245
// TODO(hxjiang): modify existing imports or add new imports.
@@ -231,16 +279,6 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
231279
// the option to drop the return value if the type is unexported.
232280
}
233281

234-
// TODO(hxjiang): qualifier should consolidate existing imports from x
235-
// package and existing x_test package. The existing x_test package imports
236-
// should overwrite x package imports.
237-
var qf types.Qualifier
238-
if xtest {
239-
qf = (*types.Package).Name
240-
} else {
241-
qf = typesinternal.NameRelativeTo(pkg.Types())
242-
}
243-
244282
testName, err := testName(fn)
245283
if err != nil {
246284
return nil, err
@@ -251,7 +289,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
251289
}
252290

253291
if sig.Recv() == nil && xtest {
254-
data.PackageName = pkg.Types().Name()
292+
data.PackageName = qf(pkg.Types())
255293
}
256294

257295
for i := range sig.Params().Len() {

gopls/internal/test/marker/testdata/codeaction/addtest.txt

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,88 @@ func Foo(in, in1, in2, in3 string) (out, out1, out2 string) {return in, in, in}
388388
+ }
389389
+ }
390390
+}
391+
-- xpackagerename/xpackagerename.go --
392+
package main
393+
394+
import (
395+
mytime "time"
396+
myast "go/ast"
397+
)
398+
399+
func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeactionedit("Foo", "source.addTest", xpackage_rename)
400+
401+
-- @xpackage_rename/xpackagerename/xpackagerename_test.go --
402+
@@ -0,0 +1,26 @@
403+
+package main_test
404+
+
405+
+func TestFoo(t *testing.T) {
406+
+ type args struct {
407+
+ in mytime.Time
408+
+ in2 *myast.Node
409+
+ }
410+
+ tests := []struct {
411+
+ name string // description of this test case
412+
+ args args
413+
+ want mytime.Time
414+
+ want2 *myast.Node
415+
+ }{
416+
+ // TODO: Add test cases.
417+
+ }
418+
+ for _, tt := range tests {
419+
+ got, got2 := main.Foo(tt.args.in, tt.args.in2)
420+
+ // TODO: update the condition below to compare got with tt.want.
421+
+ if true {
422+
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got, tt.want)
423+
+ }
424+
+ if true {
425+
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got2, tt.want2)
426+
+ }
427+
+ }
428+
+}
429+
-- xtestpackagerename/xtestpackagerename.go --
430+
package main
431+
432+
import (
433+
mytime "time"
434+
myast "go/ast"
435+
)
436+
437+
func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeactionedit("Foo", "source.addTest", xtest_package_rename)
438+
439+
-- xtestpackagerename/xtestpackagerename_test.go --
440+
package main_test
441+
442+
import (
443+
yourtime "time"
444+
yourast "go/ast"
445+
)
446+
447+
var fooTime = yourtime.Time{}
448+
var fooNode = yourast.Node{}
449+
450+
-- @xtest_package_rename/xtestpackagerename/xtestpackagerename_test.go --
451+
@@ -11 +11,24 @@
452+
+func TestFoo(t *testing.T) {
453+
+ type args struct {
454+
+ in yourtime.Time
455+
+ in2 *yourast.Node
456+
+ }
457+
+ tests := []struct {
458+
+ name string // description of this test case
459+
+ args args
460+
+ want yourtime.Time
461+
+ want2 *yourast.Node
462+
+ }{
463+
+ // TODO: Add test cases.
464+
+ }
465+
+ for _, tt := range tests {
466+
+ got, got2 := main.Foo(tt.args.in, tt.args.in2)
467+
+ // TODO: update the condition below to compare got with tt.want.
468+
+ if true {
469+
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got, tt.want)
470+
+ }
471+
+ if true {
472+
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got2, tt.want2)
473+
+ }
474+
+ }
475+
+}

0 commit comments

Comments
 (0)