|
9 | 9 | "fmt"
|
10 | 10 | "io"
|
11 | 11 | "path"
|
12 |
| - "strings" |
13 | 12 |
|
14 | 13 | "golang.org/x/tools/internal/event"
|
15 | 14 | "golang.org/x/tools/internal/lsp/debug/tag"
|
@@ -97,17 +96,16 @@ func (s *Server) executeCommand(ctx context.Context, params *protocol.ExecuteCom
|
97 | 96 | switch command {
|
98 | 97 | case source.CommandTest:
|
99 | 98 | var uri protocol.DocumentURI
|
100 |
| - var flag string |
101 |
| - var funcName string |
102 |
| - if err := source.UnmarshalArgs(params.Arguments, &uri, &flag, &funcName); err != nil { |
| 99 | + var tests, benchmarks []string |
| 100 | + if err := source.UnmarshalArgs(params.Arguments, &uri, &tests, &benchmarks); err != nil { |
103 | 101 | return nil, err
|
104 | 102 | }
|
105 | 103 | snapshot, _, ok, release, err := s.beginFileRequest(ctx, uri, source.UnknownKind)
|
106 | 104 | defer release()
|
107 | 105 | if !ok {
|
108 | 106 | return nil, err
|
109 | 107 | }
|
110 |
| - go s.runTest(ctx, snapshot, []string{flag, funcName}, params.WorkDoneToken) |
| 108 | + go s.runTests(ctx, snapshot, uri, params.WorkDoneToken, tests, benchmarks) |
111 | 109 | case source.CommandGenerate:
|
112 | 110 | var uri protocol.DocumentURI
|
113 | 111 | var recursive bool
|
@@ -193,26 +191,74 @@ func (s *Server) directGoModCommand(ctx context.Context, uri protocol.DocumentUR
|
193 | 191 | return snapshot.RunGoCommandDirect(ctx, verb, args)
|
194 | 192 | }
|
195 | 193 |
|
196 |
| -func (s *Server) runTest(ctx context.Context, snapshot source.Snapshot, args []string, token protocol.ProgressToken) error { |
| 194 | +func (s *Server) runTests(ctx context.Context, snapshot source.Snapshot, uri protocol.DocumentURI, token protocol.ProgressToken, tests, benchmarks []string) error { |
197 | 195 | ctx, cancel := context.WithCancel(ctx)
|
198 | 196 | defer cancel()
|
199 | 197 |
|
| 198 | + pkgs, err := snapshot.PackagesForFile(ctx, uri.SpanURI()) |
| 199 | + if err != nil { |
| 200 | + return err |
| 201 | + } |
| 202 | + if len(pkgs) == 0 { |
| 203 | + return fmt.Errorf("package could not be found for file: %s", uri.SpanURI().Filename()) |
| 204 | + } |
| 205 | + pkgPath := pkgs[0].PkgPath() |
| 206 | + |
| 207 | + // create output |
200 | 208 | ew := &eventWriter{ctx: ctx, operation: "test"}
|
201 |
| - msg := fmt.Sprintf("running `go test %s`", strings.Join(args, " ")) |
202 |
| - wc := s.progress.newWriter(ctx, "test", msg, msg, token, cancel) |
| 209 | + var title string |
| 210 | + if len(tests) > 0 && len(benchmarks) > 0 { |
| 211 | + title = "tests and benchmarks" |
| 212 | + } else if len(tests) > 0 { |
| 213 | + title = "tests" |
| 214 | + } else if len(benchmarks) > 0 { |
| 215 | + title = "benchmarks" |
| 216 | + } else { |
| 217 | + return errors.New("No functions were provided") |
| 218 | + } |
| 219 | + msg := fmt.Sprintf("Running %s...", title) |
| 220 | + wc := s.progress.newWriter(ctx, title, msg, msg, token, cancel) |
203 | 221 | defer wc.Close()
|
204 | 222 |
|
205 |
| - messageType := protocol.Info |
206 |
| - message := "test passed" |
207 | 223 | stderr := io.MultiWriter(ew, wc)
|
208 | 224 |
|
209 |
| - if err := snapshot.RunGoCommandPiped(ctx, "test", args, ew, stderr); err != nil { |
210 |
| - if errors.Is(err, context.Canceled) { |
211 |
| - return err |
| 225 | + // run `go test -run Func` on each test |
| 226 | + var failedTests int |
| 227 | + for _, funcName := range tests { |
| 228 | + args := []string{pkgPath, "-run", fmt.Sprintf("^%s$", funcName)} |
| 229 | + if err := snapshot.RunGoCommandPiped(ctx, "test", args, ew, stderr); err != nil { |
| 230 | + if errors.Is(err, context.Canceled) { |
| 231 | + return err |
| 232 | + } |
| 233 | + failedTests++ |
212 | 234 | }
|
| 235 | + } |
| 236 | + |
| 237 | + // run `go test -run=^$ -bench Func` on each test |
| 238 | + var failedBenchmarks int |
| 239 | + for _, funcName := range tests { |
| 240 | + args := []string{pkgPath, "-run=^$", "-bench", fmt.Sprintf("^%s$", funcName)} |
| 241 | + if err := snapshot.RunGoCommandPiped(ctx, "test", args, ew, stderr); err != nil { |
| 242 | + if errors.Is(err, context.Canceled) { |
| 243 | + return err |
| 244 | + } |
| 245 | + failedBenchmarks++ |
| 246 | + } |
| 247 | + } |
| 248 | + |
| 249 | + messageType := protocol.Info |
| 250 | + message := fmt.Sprintf("all %s passed", title) |
| 251 | + if failedTests > 0 || failedBenchmarks > 0 { |
213 | 252 | messageType = protocol.Error
|
214 |
| - message = "test failed" |
215 | 253 | }
|
| 254 | + if failedTests > 0 && failedBenchmarks > 0 { |
| 255 | + message = fmt.Sprintf("%d / %d tests failed and %d / %d benchmarks failed", failedTests, len(tests), failedBenchmarks, len(benchmarks)) |
| 256 | + } else if failedTests > 0 { |
| 257 | + message = fmt.Sprintf("%d / %d tests failed", failedTests, len(tests)) |
| 258 | + } else if failedBenchmarks > 0 { |
| 259 | + message = fmt.Sprintf("%d / %d benchmarks failed", failedBenchmarks, len(benchmarks)) |
| 260 | + } |
| 261 | + |
216 | 262 | return s.client.ShowMessage(ctx, &protocol.ShowMessageParams{
|
217 | 263 | Type: messageType,
|
218 | 264 | Message: message,
|
|
0 commit comments