diff --git a/integrationtests/snapshots/go/call_hierarchy/incoming-other-file.snap b/integrationtests/snapshots/go/call_hierarchy/incoming-other-file.snap new file mode 100644 index 0000000..4dae2c5 --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/incoming-other-file.snap @@ -0,0 +1,8 @@ + +--- +Name: HelperFunction +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • helper.go +- Called By: AnotherConsumer + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • another_consumer.go +- Called By: ConsumerFunction + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • consumer.go diff --git a/integrationtests/snapshots/go/call_hierarchy/incoming-same-file.snap b/integrationtests/snapshots/go/call_hierarchy/incoming-same-file.snap new file mode 100644 index 0000000..7c3af8b --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/incoming-same-file.snap @@ -0,0 +1,6 @@ + +--- +Name: FooBar +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go +- Called By: main + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go diff --git a/integrationtests/snapshots/go/call_hierarchy/outgoing-other-file.snap b/integrationtests/snapshots/go/call_hierarchy/outgoing-other-file.snap new file mode 100644 index 0000000..af71293 --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/outgoing-other-file.snap @@ -0,0 +1,18 @@ + +--- +Name: ConsumerFunction +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • consumer.go +- Calls: GetName + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go +- Calls: HelperFunction + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • helper.go +- Calls: Method + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go +- Calls: Println + Detail: fmt • print.go + - Calls: Fprintln + Detail: fmt • print.go +- Calls: Process + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • types.go + - Calls: Printf + Detail: fmt • print.go diff --git a/integrationtests/snapshots/go/call_hierarchy/outgoing-same-file.snap b/integrationtests/snapshots/go/call_hierarchy/outgoing-same-file.snap new file mode 100644 index 0000000..b7442f6 --- /dev/null +++ b/integrationtests/snapshots/go/call_hierarchy/outgoing-same-file.snap @@ -0,0 +1,12 @@ + +--- +Name: main +Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go +- Calls: FooBar + Detail: github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace • main.go + - Calls: Println + Detail: fmt • print.go +- Calls: Println + Detail: fmt • print.go + - Calls: Fprintln + Detail: fmt • print.go diff --git a/integrationtests/tests/go/call_hierarchy/incoming_test.go b/integrationtests/tests/go/call_hierarchy/incoming_test.go new file mode 100644 index 0000000..68fdb36 --- /dev/null +++ b/integrationtests/tests/go/call_hierarchy/incoming_test.go @@ -0,0 +1,61 @@ +package callhierarchy_test + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/isaacphi/mcp-language-server/integrationtests/tests/common" + "github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal" + "github.com/isaacphi/mcp-language-server/internal/tools" +) + +func TestIncomingCalls(t *testing.T) { + suite := internal.GetTestSuite(t) + + ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second) + defer cancel() + + tests := []struct { + name string + symbolName string + depth int + expectedText string + snapshotName string + }{ + { + name: "Function with calls in same file", + symbolName: "FooBar", + expectedText: ": main", + depth: 5, + snapshotName: "incoming-same-file", + }, + { + name: "Function with calls in other file", + symbolName: "HelperFunction", + depth: 5, + expectedText: "ConsumerFunction", + snapshotName: "incoming-other-file", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Call the GetIncomingCalls tool + result, err := tools.GetIncomingCalls(ctx, suite.Client, tc.symbolName, tc.depth) + if err != nil { + t.Fatalf("Failed to find incoming calls: %v", err) + } + + // Check that the result contains relevant information + if !strings.Contains(result, tc.expectedText) { + t.Errorf("Incoming calls do not contain expected text: %s", tc.expectedText) + } + + // Use snapshot testing to verify exact output + common.SnapshotTest(t, "go", "call_hierarchy", tc.snapshotName, result) + }) + } + +} diff --git a/integrationtests/tests/go/call_hierarchy/outgoing_test.go b/integrationtests/tests/go/call_hierarchy/outgoing_test.go new file mode 100644 index 0000000..790a683 --- /dev/null +++ b/integrationtests/tests/go/call_hierarchy/outgoing_test.go @@ -0,0 +1,61 @@ +package callhierarchy_test + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/isaacphi/mcp-language-server/integrationtests/tests/common" + "github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal" + "github.com/isaacphi/mcp-language-server/internal/tools" +) + +func TestOutgoingCalls(t *testing.T) { + suite := internal.GetTestSuite(t) + + ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second) + defer cancel() + + tests := []struct { + name string + symbolName string + depth int + expectedText string + snapshotName string + }{ + { + name: "Function with calls in other file", + symbolName: "ConsumerFunction", + depth: 2, + expectedText: "HelperFunction", + snapshotName: "outgoing-other-file", + }, + { + name: "Function with calls in same file", + symbolName: "main", + depth: 2, + expectedText: "FooBar", + snapshotName: "outgoing-same-file", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Call the GetOutgoingCalls tool + result, err := tools.GetOutgoingCalls(ctx, suite.Client, tc.symbolName, tc.depth) + if err != nil { + t.Fatalf("Failed to find outgoing calls: %v", err) + } + + // Check that the result contains relevant information + if !strings.Contains(result, tc.expectedText) { + t.Errorf("Outgoing calls do not contain expected text: %s", tc.expectedText) + } + + // Use snapshot testing to verify exact output + common.SnapshotTest(t, "go", "call_hierarchy", tc.snapshotName, result) + }) + } + +} diff --git a/internal/tools/call_hierarchy.go b/internal/tools/call_hierarchy.go new file mode 100644 index 0000000..10bcdb3 --- /dev/null +++ b/internal/tools/call_hierarchy.go @@ -0,0 +1,176 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" +) + +func GetIncomingCalls(ctx context.Context, client *lsp.Client, symbolName string, maxDepth int) (string, error) { + return getCallHierarchy(ctx, client, symbolName, maxDepth, recurseIncomingCalls) +} + +func GetOutgoingCalls(ctx context.Context, client *lsp.Client, symbolName string, maxDepth int) (string, error) { + return getCallHierarchy(ctx, client, symbolName, maxDepth, recurseOutgoingCalls) +} + +func getCallHierarchy( + ctx context.Context, client *lsp.Client, symbolName string, maxDepth int, + recurse func(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int), +) (string, error) { + // First get the symbol location like ReadDefinition does + symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{ + Query: symbolName, + }) + if err != nil { + return "", fmt.Errorf("failed to fetch symbol: %v", err) + } + + results, err := symbolResult.Results() + if err != nil { + return "", fmt.Errorf("failed to parse results: %v", err) + } + + // After this point we just return errors instead of erroring out + var result strings.Builder + + for _, symbol := range results { + // Handle different matching strategies based on the search term + if strings.Contains(symbolName, ".") { + // For qualified names like "Type.Method", check for various matches + parts := strings.Split(symbolName, ".") + methodName := parts[len(parts)-1] + + // Try matching the unqualified method name for languages that don't use qualified names in symbols + if symbol.GetName() != symbolName && symbol.GetName() != methodName { + continue + } + } else if symbol.GetName() != symbolName { + // For unqualified names, exact match only + continue + } + + result.WriteString("\n---\n") + + // Get the location of the symbol + loc := symbol.GetLocation() + + chParams := protocol.CallHierarchyPrepareParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: loc.URI, + }, + Position: loc.Range.Start, + }, + } + items, err := client.PrepareCallHierarchy(ctx, chParams) + if err != nil { + result.WriteString(fmt.Sprintf("%s: Error: %v\n", symbol.GetName(), err)) + continue + } + + for _, item := range items { + recurse(ctx, client, item, &result, 0, maxDepth) + } + } + + return result.String(), nil +} + +func recurseIncomingCalls(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int) { + + var prefix string + if depth != 0 { + prefix = strings.Repeat(" ", (depth-1)*2+2) + + result.WriteString(strings.Repeat(" ", (depth-1)*2)) + result.WriteRune('-') + result.WriteString(" Called By: ") + } else { + result.WriteString("Name: ") + } + + result.WriteString(item.Name) + result.WriteRune('\n') + + result.WriteString(prefix) + result.WriteString("Detail: ") + result.WriteString(item.Detail) + result.WriteRune('\n') + + if depth >= maxDepth { + return + } + + calls, err := client.IncomingCalls(ctx, protocol.CallHierarchyIncomingCallsParams{ + Item: item, + }) + + if err != nil { + result.WriteString(prefix) + result.WriteString("Error: ") + result.WriteString(err.Error()) + result.WriteRune('\n') + return + } + + // ensure output is deterministic for tests + sort.Slice(calls, func(i, j int) bool { + return calls[i].From.Name < calls[j].From.Name + }) + + for _, call := range calls { + recurseIncomingCalls(ctx, client, call.From, result, depth+1, maxDepth) + } +} + +func recurseOutgoingCalls(ctx context.Context, client *lsp.Client, item protocol.CallHierarchyItem, result *strings.Builder, depth int, maxDepth int) { + + var prefix string + if depth != 0 { + prefix = strings.Repeat(" ", (depth-1)*2+2) + + result.WriteString(strings.Repeat(" ", (depth-1)*2)) + result.WriteRune('-') + result.WriteString(" Calls: ") + } else { + result.WriteString("Name: ") + } + + result.WriteString(item.Name) + result.WriteRune('\n') + + result.WriteString(prefix) + result.WriteString("Detail: ") + result.WriteString(item.Detail) + result.WriteRune('\n') + + if depth >= maxDepth { + return + } + + calls, err := client.OutgoingCalls(ctx, protocol.CallHierarchyOutgoingCallsParams{ + Item: item, + }) + + if err != nil { + result.WriteString(prefix) + result.WriteString("Error: ") + result.WriteString(err.Error()) + result.WriteRune('\n') + return + } + + // ensure output is deterministic for tests + sort.Slice(calls, func(i, j int) bool { + return calls[i].To.Name < calls[j].To.Name + }) + + for _, call := range calls { + recurseOutgoingCalls(ctx, client, call.To, result, depth+1, maxDepth) + } +} diff --git a/tools.go b/tools.go index 55898ca..de7d232 100644 --- a/tools.go +++ b/tools.go @@ -363,6 +363,88 @@ func (s *mcpServer) registerTools() error { return mcp.NewToolResultText(text), nil }) + incomingCallsTool := mcp.NewTool("incoming_calls", + mcp.WithDescription("Resolve the incoming calls for the given symbol. Returns a tree of incoming calls and their location"), + mcp.WithString("symbolName", + mcp.Required(), + mcp.Description("The name of the symbol whose definition you want to find (e.g. 'mypackage.MyFunction', 'MyType.MyMethod')"), + ), + mcp.WithNumber("maxDepth", + mcp.Description("max depth of call tree"), + mcp.DefaultNumber(5), + ), + ) + s.mcpServer.AddTool(incomingCallsTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments + symbolName, ok := request.Params.Arguments["symbolName"].(string) + if !ok { + return mcp.NewToolResultError("symbolName must be a string"), nil + } + + // default value + maxDepth := 5 + + if arg, ok := request.Params.Arguments["maxDepth"]; ok { + switch v := arg.(type) { + case float64: + maxDepth = int(v) + case int: + maxDepth = v + default: + return mcp.NewToolResultError("maxDepth must be a number"), nil + } + } + + coreLogger.Debug("Executing incoming calls for symbol: %s", symbolName) + text, err := tools.GetIncomingCalls(s.ctx, s.lspClient, symbolName, maxDepth) + if err != nil { + coreLogger.Error("Failed to find incoming calls: %v", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to find incoming calls: %v", err)), nil + } + return mcp.NewToolResultText(text), nil + }) + + outgoingCallsTool := mcp.NewTool("outgoing_calls", + mcp.WithDescription("Resolve the outgoing calls for the given symbol. Returns a tree of outgoing calls and their location"), + mcp.WithString("symbolName", + mcp.Required(), + mcp.Description("The name of the symbol whose definition you want to find (e.g. 'mypackage.MyFunction', 'MyType.MyMethod')"), + ), + mcp.WithNumber("maxDepth", + mcp.Description("max depth of call tree"), + mcp.DefaultNumber(5), + ), + ) + s.mcpServer.AddTool(outgoingCallsTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments + symbolName, ok := request.Params.Arguments["symbolName"].(string) + if !ok { + return mcp.NewToolResultError("symbolName must be a string"), nil + } + + // default value + maxDepth := 5 + + if arg, ok := request.Params.Arguments["maxDepth"]; ok { + switch v := arg.(type) { + case float64: + maxDepth = int(v) + case int: + maxDepth = v + default: + return mcp.NewToolResultError("maxDepth must be a number"), nil + } + } + + coreLogger.Debug("Executing outgoing calls for symbol: %s", symbolName) + text, err := tools.GetOutgoingCalls(s.ctx, s.lspClient, symbolName, maxDepth) + if err != nil { + coreLogger.Error("Failed to find outgoing calls: %v", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to find outgoing calls: %v", err)), nil + } + return mcp.NewToolResultText(text), nil + }) + coreLogger.Info("Successfully registered all MCP tools") return nil }