diff --git a/README.md b/README.md index 12ff045..5bd3194 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,8 @@ This is an [MCP](https://modelcontextprotocol.io/introduction) server that runs - `hover`: Display documentation, type hints, or other hover information for a given location. - `rename_symbol`: Rename a symbol across a project. - `edit_file`: Allows making multiple text edits to a file based on line numbers. Provides a more reliable and context-economical way to edit files compared to search and replace based edit tools. +- `callers`: Shows all locations that call a given symbol +- `callees`: Shows all functions that a given symbol calls ## About diff --git a/integrationtests/snapshots/go/call_hierarchy/incoming-other-file.snap b/integrationtests/snapshots/go/call_hierarchy/incoming-other-file.snap new file mode 100644 index 0000000..c282e40 --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/incoming-other-file.snap @@ -0,0 +1,14 @@ + +--- +Name: HelperFunction +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • helper.go +/TEST_OUTPUT/workspace/helper.go +Range: L4:C6 - L4:C20 +- Called By: AnotherConsumer + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • another_consumer.go +/TEST_OUTPUT/workspace/another_consumer.go + Range: L6:C6 - L6:C21 +- Called By: ConsumerFunction + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • consumer.go +/TEST_OUTPUT/workspace/consumer.go + Range: L6:C6 - L6:C22 diff --git a/integrationtests/snapshots/go/call_hierarchy/incoming-same-file.snap b/integrationtests/snapshots/go/call_hierarchy/incoming-same-file.snap new file mode 100644 index 0000000..d49c425 --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/incoming-same-file.snap @@ -0,0 +1,10 @@ + +--- +Name: FooBar +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go +/TEST_OUTPUT/workspace/main.go +Range: L6:C6 - L6:C12 +- Called By: main + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go +/TEST_OUTPUT/workspace/main.go + Range: L12:C6 - L12:C10 diff --git a/integrationtests/snapshots/go/call_hierarchy/outgoing-other-file.snap b/integrationtests/snapshots/go/call_hierarchy/outgoing-other-file.snap new file mode 100644 index 0000000..65b93b0 --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/outgoing-other-file.snap @@ -0,0 +1,26 @@ + +--- +Name: ConsumerFunction +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • consumer.go +/TEST_OUTPUT/workspace/consumer.go +Range: L6:C6 - L6:C22 +- Calls: GetName + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go +/TEST_OUTPUT/workspace/types.go + Range: L21:C2 - L21:C9 +- Calls: HelperFunction + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • helper.go +/TEST_OUTPUT/workspace/helper.go + Range: L4:C6 - L4:C20 +- Calls: Method + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go +/TEST_OUTPUT/workspace/types.go + Range: L14:C24 - L14:C30 +- Calls: Println + Detail: fmt • print.go +/GOROOT/src/fmt/print.go + Range: L313:C6 - L313:C13 +- Calls: Process + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go +/TEST_OUTPUT/workspace/types.go + Range: L31:C24 - L31:C31 diff --git a/integrationtests/snapshots/go/call_hierarchy/outgoing-same-file.snap b/integrationtests/snapshots/go/call_hierarchy/outgoing-same-file.snap new file mode 100644 index 0000000..54be4ba --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/outgoing-same-file.snap @@ -0,0 +1,14 @@ + +--- +Name: main +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go +/TEST_OUTPUT/workspace/main.go +Range: L12:C6 - L12:C10 +- Calls: FooBar + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go +/TEST_OUTPUT/workspace/main.go + Range: L6:C6 - L6:C12 +- Calls: Println + Detail: fmt • print.go +/GOROOT/src/fmt/print.go + Range: L313:C6 - L313:C13 diff --git a/integrationtests/tests/common/helpers.go b/integrationtests/tests/common/helpers.go index 3d7243c..84ade98 100644 --- a/integrationtests/tests/common/helpers.go +++ b/integrationtests/tests/common/helpers.go @@ -1,9 +1,11 @@ package common import ( + "bytes" "fmt" "io" "os" + "os/exec" "path/filepath" "strings" "testing" @@ -91,10 +93,25 @@ func CleanupTestSuites(suites ...*TestSuite) { } } +// use instead of runtime.GOROOT which is deprecated +func getGoRoot() string { + cmd := exec.Command("go", "env", "GOROOT") + var out bytes.Buffer + cmd.Stdout = &out + err := cmd.Run() + if err != nil { + panic(err) + } + return strings.TrimSpace(out.String()) +} + // normalizePaths replaces absolute paths in the result with placeholder paths for consistent snapshots func normalizePaths(_ *testing.T, input string) string { // No need to get the repo root - we're just looking for patterns + // But this is useful + goroot := getGoRoot() + // Simple approach: just replace any path segments that contain workspace/ lines := strings.Split(input, "\n") for i, line := range lines { @@ -116,6 +133,13 @@ func normalizePaths(_ *testing.T, input string) string { lines[i] = "/TEST_OUTPUT/workspace/" + parts[1] } } + if strings.Contains(line, goroot) { + parts := strings.Split(line, goroot) + if len(parts) > 1 { + // Replace with a simple placeholder path + lines[i] = "/GOROOT" + parts[1] + } + } } return strings.Join(lines, "\n") diff --git a/integrationtests/tests/go/call_hierarchy/incoming_test.go b/integrationtests/tests/go/call_hierarchy/incoming_test.go new file mode 100644 index 0000000..b0dc0f5 --- /dev/null +++ b/integrationtests/tests/go/call_hierarchy/incoming_test.go @@ -0,0 +1,58 @@ +package callhierarchy_test + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/isaacphi/mcp-language-server/integrationtests/tests/common" + "github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal" + "github.com/isaacphi/mcp-language-server/internal/tools" +) + +func TestIncomingCalls(t *testing.T) { + suite := internal.GetTestSuite(t) + + ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second) + defer cancel() + + tests := []struct { + name string + symbolName string + expectedText string + snapshotName string + }{ + { + name: "Function with calls in same file", + symbolName: "FooBar", + expectedText: ": main", + snapshotName: "incoming-same-file", + }, + { + name: "Function with calls in other file", + symbolName: "HelperFunction", + expectedText: "ConsumerFunction", + snapshotName: "incoming-other-file", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Call the GetIncomingCalls tool + result, err := tools.GetCallers(ctx, suite.Client, tc.symbolName, 1) + if err != nil { + t.Fatalf("Failed to find incoming calls: %v", err) + } + + // Check that the result contains relevant information + if !strings.Contains(result, tc.expectedText) { + t.Errorf("Incoming calls do not contain expected text: %s", tc.expectedText) + } + + // Use snapshot testing to verify exact output + common.SnapshotTest(t, "go", "call_hierarchy", tc.snapshotName, result) + }) + } + +} diff --git a/integrationtests/tests/go/call_hierarchy/outgoing_test.go b/integrationtests/tests/go/call_hierarchy/outgoing_test.go new file mode 100644 index 0000000..d0cb745 --- /dev/null +++ b/integrationtests/tests/go/call_hierarchy/outgoing_test.go @@ -0,0 +1,58 @@ +package callhierarchy_test + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/isaacphi/mcp-language-server/integrationtests/tests/common" + "github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal" + "github.com/isaacphi/mcp-language-server/internal/tools" +) + +func TestOutgoingCalls(t *testing.T) { + suite := internal.GetTestSuite(t) + + ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second) + defer cancel() + + tests := []struct { + name string + symbolName string + expectedText string + snapshotName string + }{ + { + name: "Function with calls in other file", + symbolName: "ConsumerFunction", + expectedText: "HelperFunction", + snapshotName: "outgoing-other-file", + }, + { + name: "Function with calls in same file", + symbolName: "main", + expectedText: "FooBar", + snapshotName: "outgoing-same-file", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Call the GetOutgoingCalls tool + result, err := tools.GetCallees(ctx, suite.Client, tc.symbolName, 1) + if err != nil { + t.Fatalf("Failed to find outgoing calls: %v", err) + } + + // Check that the result contains relevant information + if !strings.Contains(result, tc.expectedText) { + t.Errorf("Outgoing calls do not contain expected text: %s", tc.expectedText) + } + + // Use snapshot testing to verify exact output + common.SnapshotTest(t, "go", "call_hierarchy", tc.snapshotName, result) + }) + } + +} diff --git a/internal/tools/call_hierarchy.go b/internal/tools/call_hierarchy.go new file mode 100644 index 0000000..96d876d --- /dev/null +++ b/internal/tools/call_hierarchy.go @@ -0,0 +1,200 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" +) + +func GetCallers(ctx context.Context, client *lsp.Client, symbolName string, maxDepth int) (string, error) { + return getCallHierarchy(ctx, client, symbolName, maxDepth, recurseIncomingCalls) +} + +func GetCallees(ctx context.Context, client *lsp.Client, symbolName string, maxDepth int) (string, error) { + return getCallHierarchy(ctx, client, symbolName, maxDepth, recurseOutgoingCalls) +} + +func getCallHierarchy( + ctx context.Context, client *lsp.Client, symbolName string, maxDepth int, + recurse func(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int), +) (string, error) { + // First get the symbol location like ReadDefinition does + symbolName, results, err := QuerySymbol(ctx, client, symbolName) + if err != nil { + return "", err + } + + // After this point we just return errors instead of erroring out + var result strings.Builder + + for _, symbol := range results { + var separator string + if strings.Contains(symbolName, ".") { + separator = "." + } else if strings.Contains(symbolName, "::") { + separator = "::" + } + + // Handle different matching strategies based on the search term + if separator != "" { + // For qualified names like "Type.Method", check for various matches + parts := strings.Split(symbolName, separator) + methodName := parts[len(parts)-1] + + // Try matching the unqualified method name for languages that don't use qualified names in symbols + if symbol.GetName() != symbolName && symbol.GetName() != methodName { + continue + } + } else if symbol.GetName() != symbolName { + // For unqualified names, exact match only + continue + } + + result.WriteString("\n---\n") + + // Get the location of the symbol + loc := symbol.GetLocation() + + chParams := protocol.CallHierarchyPrepareParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: loc.URI, + }, + Position: loc.Range.Start, + }, + } + items, err := client.PrepareCallHierarchy(ctx, chParams) + if err != nil { + result.WriteString(fmt.Sprintf("%s: Error: %v\n", symbol.GetName(), err)) + continue + } + + for _, item := range items { + recurse(ctx, client, item, &result, 0, maxDepth) + } + } + + return result.String(), nil +} + +func recurseIncomingCalls(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int) { + + var prefix string + if depth != 0 { + prefix = strings.Repeat(" ", (depth-1)*2+2) + + result.WriteString(strings.Repeat(" ", (depth-1)*2)) + result.WriteRune('-') + result.WriteString(" Called By: ") + } else { + result.WriteString("Name: ") + } + + result.WriteString(item.Name) + result.WriteRune('\n') + + result.WriteString(prefix) + result.WriteString("Detail: ") + result.WriteString(item.Detail) + result.WriteRune('\n') + + result.WriteString(prefix) + result.WriteString("File: ") + result.WriteString(strings.TrimPrefix(string(item.URI), "file://")) + result.WriteRune('\n') + + result.WriteString(prefix) + fmt.Fprintf(result, "Range: L%d:C%d - L%d:C%d\n", + item.Range.Start.Line+1, + item.Range.Start.Character+1, + item.Range.End.Line+1, + item.Range.End.Character+1) + + if depth >= maxDepth { + return + } + + calls, err := client.IncomingCalls(ctx, protocol.CallHierarchyIncomingCallsParams{ + Item: item, + }) + + if err != nil { + result.WriteString(prefix) + result.WriteString("Error: ") + result.WriteString(err.Error()) + result.WriteRune('\n') + return + } + + // ensure output is deterministic for tests + sort.Slice(calls, func(i, j int) bool { + return calls[i].From.Name < calls[j].From.Name + }) + + for _, call := range calls { + recurseIncomingCalls(ctx, client, call.From, result, depth+1, maxDepth) + } +} + +func recurseOutgoingCalls(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int) { + + var prefix string + if depth != 0 { + prefix = strings.Repeat(" ", (depth-1)*2+2) + + result.WriteString(strings.Repeat(" ", (depth-1)*2)) + result.WriteRune('-') + result.WriteString(" Calls: ") + } else { + result.WriteString("Name: ") + } + + result.WriteString(item.Name) + result.WriteRune('\n') + + result.WriteString(prefix) + result.WriteString("Detail: ") + result.WriteString(item.Detail) + result.WriteRune('\n') + + result.WriteString(prefix) + result.WriteString("File: ") + result.WriteString(strings.TrimPrefix(string(item.URI), "file://")) + result.WriteRune('\n') + + result.WriteString(prefix) + fmt.Fprintf(result, "Range: L%d:C%d - L%d:C%d\n", + item.Range.Start.Line+1, + item.Range.Start.Character+1, + item.Range.End.Line+1, + item.Range.End.Character+1) + + if depth >= maxDepth { + return + } + + calls, err := client.OutgoingCalls(ctx, protocol.CallHierarchyOutgoingCallsParams{ + Item: item, + }) + + if err != nil { + result.WriteString(prefix) + result.WriteString("Error: ") + result.WriteString(err.Error()) + result.WriteRune('\n') + return + } + + // ensure output is deterministic for tests + sort.Slice(calls, func(i, j int) bool { + return calls[i].To.Name < calls[j].To.Name + }) + + for _, call := range calls { + recurseOutgoingCalls(ctx, client, call.To, result, depth+1, maxDepth) + } +} diff --git a/tools.go b/tools.go index 4aeca7b..f7de723 100644 --- a/tools.go +++ b/tools.go @@ -336,6 +336,52 @@ func (s *mcpServer) registerTools() error { return mcp.NewToolResultText(text), nil }) + callersTool := mcp.NewTool("callers", + mcp.WithDescription("Determine which functions call the given symbol. Returns a list of the calling functions and the locations of the call sites."), + mcp.WithString("symbolName", + mcp.Required(), + mcp.Description("The name of the symbol whose callers you want to find (e.g. 'mypackage.MyFunction', 'MyType.MyMethod')"), + ), + ) + s.mcpServer.AddTool(callersTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments + symbolName, err := request.RequireString("symbolName") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + coreLogger.Debug("Executing callers for symbol: %s", symbolName) + text, err := tools.GetCallers(s.ctx, s.lspClient, symbolName, 1) + if err != nil { + coreLogger.Error("Failed to find callers: %v", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to find callers: %v", err)), nil + } + return mcp.NewToolResultText(text), nil + }) + + calleesTool := mcp.NewTool("callees", + mcp.WithDescription("Resolve which functions a given symbol calls. Returns a list of the called functions and their locations."), + mcp.WithString("symbolName", + mcp.Required(), + mcp.Description("The name of the symbol whose callees you want to find (e.g. 'mypackage.MyFunction', 'MyType.MyMethod')"), + ), + ) + s.mcpServer.AddTool(calleesTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments + symbolName, err := request.RequireString("symbolName") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + coreLogger.Debug("Executing callees for symbol: %s", symbolName) + text, err := tools.GetCallees(s.ctx, s.lspClient, symbolName, 1) + if err != nil { + coreLogger.Error("Failed to find callees: %v", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to find callees: %v", err)), nil + } + return mcp.NewToolResultText(text), nil + }) + coreLogger.Info("Successfully registered all MCP tools") return nil }