Skip to content

Commit 0c792f1

Browse files
committed
gopls/internal/golang: support generating test for functions
Add test for method is not fully implemented yet. For golang/vscode-go#1594 Change-Id: I4e18183baf96242c209e31e02a1b5cd642844c1d Reviewed-on: https://go-review.googlesource.com/c/tools/+/621057 Reviewed-by: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 27e1a3a commit 0c792f1

File tree

2 files changed

+393
-64
lines changed

2 files changed

+393
-64
lines changed

gopls/internal/golang/addtest.go

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"go/ast"
1515
"go/token"
1616
"go/types"
17+
"html/template"
1718
"os"
1819
"path/filepath"
1920
"strings"
@@ -26,6 +27,79 @@ import (
2627
"golang.org/x/tools/internal/typesinternal"
2728
)
2829

30+
const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
31+
{{- /* Functions/methods input parameters struct declaration. */}}
32+
{{- if gt (len .Args) 1}}
33+
type args struct {
34+
{{- range .Args}}
35+
{{.Name}} {{.Type}}
36+
{{- end}}
37+
}
38+
{{- end}}
39+
{{- /* Test cases struct declaration and empty initialization. */}}
40+
tests := []struct {
41+
name string // description of this test case
42+
{{- if gt (len .Args) 1}}
43+
args args
44+
{{- end}}
45+
{{- if eq (len .Args) 1}}
46+
arg {{(index .Args 0).Type}}
47+
{{- end}}
48+
{{- range $index, $res := .Results}}
49+
{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}} {{$res.Type}}
50+
{{- /* TODO(hxjiang): check whether the last return type is error and handle it using field "wantErr". */}}
51+
{{- end}}
52+
}{
53+
// TODO: Add test cases.
54+
}
55+
{{- /* Loop over all the test cases. */}}
56+
for _, tt := range tests {
57+
{{/* Got variables. */}}
58+
{{- if .Results}}{{fieldNames .Results ""}} := {{end}}
59+
60+
{{- /* Call expression. In xtest package test, call function by PACKAGE.FUNC. */}}
61+
{{- /* TODO(hxjiang): consider any renaming in existing xtest package imports. E.g. import renamedfoo "foo". */}}
62+
{{- /* TODO(hxjiang): support add test for methods by calling the right constructor. */}}
63+
{{- if .PackageName}}{{.PackageName}}.{{end}}{{.FuncName}}
64+
65+
{{- /* Input parameters. */ -}}
66+
({{if eq (len .Args) 1}}tt.arg{{end}}{{if gt (len .Args) 1}}{{fieldNames .Args "tt.args."}}{{end}})
67+
68+
{{- if .Results}}
69+
// TODO: update the condition below to compare got with tt.want.
70+
{{- range $index, $res := .Results}}
71+
if true {
72+
t.Errorf("%s: {{$.FuncName}}() = %v, want %v", tt.name, {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}})
73+
}
74+
{{- end}}
75+
{{- end}}
76+
}
77+
}
78+
`
79+
80+
type field struct {
81+
Name, Type string
82+
}
83+
84+
type testInfo struct {
85+
PackageName string
86+
FuncName string
87+
TestFuncName string
88+
Args []field
89+
Results []field
90+
}
91+
92+
var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{
93+
"add": func(a, b int) int { return a + b },
94+
"fieldNames": func(fields []field, qualifier string) (res string) {
95+
var names []string
96+
for _, f := range fields {
97+
names = append(names, qualifier+f.Name)
98+
}
99+
return strings.Join(names, ", ")
100+
},
101+
}).Parse(testTmplString))
102+
29103
// AddTestForFunc adds a test for the function enclosing the given input range.
30104
// It creates a _test.go file if one does not already exist.
31105
func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.Location) (changes []protocol.DocumentChange, _ error) {
@@ -138,6 +212,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
138212
}
139213

140214
fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func)
215+
sig := fn.Signature()
141216

142217
if xtest {
143218
// Reject if function/method is unexported.
@@ -146,30 +221,77 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
146221
}
147222

148223
// Reject if receiver is unexported.
149-
if fn.Signature().Recv() != nil {
224+
if sig.Recv() != nil {
150225
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); !ident.IsExported() {
151226
return nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name)
152227
}
153228
}
154-
155229
// TODO(hxjiang): reject if the any input parameter type is unexported.
156230
// TODO(hxjiang): reject if any return value type is unexported. Explore
157231
// the option to drop the return value if the type is unexported.
158232
}
159233

234+
// TODO(hxjiang): qualifier should consolidate existing imports from x
235+
// package and existing x_test package. The existing x_test package imports
236+
// should overwrite x package imports.
237+
var qf types.Qualifier
238+
if xtest {
239+
qf = (*types.Package).Name
240+
} else {
241+
qf = typesinternal.NameRelativeTo(pkg.Types())
242+
}
243+
160244
testName, err := testName(fn)
161245
if err != nil {
162246
return nil, err
163247
}
164-
// TODO(hxjiang): replace test function with table-driven test.
248+
data := testInfo{
249+
FuncName: fn.Name(),
250+
TestFuncName: testName,
251+
}
252+
253+
if sig.Recv() == nil && xtest {
254+
data.PackageName = pkg.Types().Name()
255+
}
256+
257+
for i := range sig.Params().Len() {
258+
if i == 0 {
259+
data.Args = append(data.Args, field{
260+
Name: "in",
261+
Type: types.TypeString(sig.Params().At(i).Type(), qf),
262+
})
263+
} else {
264+
data.Args = append(data.Args, field{
265+
Name: fmt.Sprintf("in%d", i+1),
266+
Type: types.TypeString(sig.Params().At(i).Type(), qf),
267+
})
268+
}
269+
}
270+
271+
for i := range sig.Results().Len() {
272+
if i == 0 {
273+
data.Results = append(data.Results, field{
274+
Name: "got",
275+
Type: types.TypeString(sig.Results().At(i).Type(), qf),
276+
})
277+
} else {
278+
data.Results = append(data.Results, field{
279+
Name: fmt.Sprintf("got%d", i+1),
280+
Type: types.TypeString(sig.Results().At(i).Type(), qf),
281+
})
282+
}
283+
}
284+
285+
var test bytes.Buffer
286+
if err := testTmpl.Execute(&test, data); err != nil {
287+
return nil, err
288+
}
289+
165290
edits = append(edits, protocol.TextEdit{
166-
Range: eofRange,
167-
NewText: fmt.Sprintf(`
168-
func %s(*testing.T) {
169-
// TODO: implement test
170-
}
171-
`, testName),
291+
Range: eofRange,
292+
NewText: test.String(),
172293
})
294+
173295
return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil
174296
}
175297

0 commit comments

Comments
 (0)