Skip to content

Commit 152e295

Browse files
committed
completions/inference: perform type inference in completions with generic funcs
When using generic functions, the LSP now tries to infer an instantiation based on the surroundings of the call expression. If successful, it improves completions for parameters of generic functions. It shouldn't collide with any pre-existing code paths. Fixes #69754
1 parent 813e3c7 commit 152e295

File tree

5 files changed

+627
-393
lines changed

5 files changed

+627
-393
lines changed

gopls/internal/golang/completion/completion.go

Lines changed: 215 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,30 @@ func (i *CompletionItem) Snippet() string {
134134
return i.InsertText
135135
}
136136

137+
func (i *CompletionItem) addPrefixSuffix(c *completer, prefix string, suffix string) error {
138+
if prefix != "" {
139+
// If we are in a selector, add an edit to place prefix before selector.
140+
if sel := enclosingSelector(c.path, c.pos); sel != nil {
141+
edits, err := c.editText(sel.Pos(), sel.Pos(), prefix)
142+
if err != nil {
143+
return err
144+
}
145+
i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...)
146+
} else {
147+
// If there is no selector, just stick the prefix at the start.
148+
i.InsertText = prefix + i.InsertText
149+
i.snippet.PrependText(prefix)
150+
}
151+
}
152+
153+
if suffix != "" {
154+
i.InsertText += suffix
155+
i.snippet.WriteText(suffix)
156+
}
157+
158+
return nil
159+
}
160+
137161
// Scoring constants are used for weighting the relevance of different candidates.
138162
const (
139163
// stdScore is the base score for all completion items.
@@ -659,6 +683,7 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p
659683
c.addStatementCandidates()
660684

661685
c.sortItems()
686+
662687
return c.items, c.getSurrounding(), nil
663688
}
664689

@@ -2301,17 +2326,21 @@ Nodes:
23012326

23022327
sig, _ := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature)
23032328

2304-
if sig != nil && sig.TypeParams().Len() > 0 {
2305-
// If we are completing a generic func call, re-check the call expression.
2306-
// This allows type param inference to work in cases like:
2307-
//
2308-
// func foo[T any](T) {}
2309-
// foo[int](<>) // <- get "int" completions instead of "T"
2310-
//
2311-
// TODO: remove this after https://go.dev/issue/52503
2312-
info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)}
2313-
types.CheckExpr(c.pkg.FileSet(), c.pkg.Types(), node.Fun.Pos(), node.Fun, info)
2314-
sig, _ = info.Types[node.Fun].Type.(*types.Signature)
2329+
if sig != nil && sig.TypeParams().Len() > 0 && len(c.path) > i+1 {
2330+
// infer an instantiation for the CallExpr from it's context
2331+
switch c.path[i+1].(type) {
2332+
case *ast.AssignStmt, *ast.ReturnStmt, *ast.SendStmt, *ast.ValueSpec:
2333+
// Defer call to reverseInferExpectedCallParamType so we can provide it the
2334+
// inferences about its parent node.
2335+
defer func(sig *types.Signature) {
2336+
inf = c.reverseInferExpectedCallParam(inf, node, sig)
2337+
}(sig)
2338+
continue Nodes
2339+
case *ast.KeyValueExpr:
2340+
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i:], node.Pos(), c.pkg.TypesInfo())
2341+
inf = c.reverseInferExpectedCallParam(inf, node, sig)
2342+
break Nodes
2343+
}
23152344
}
23162345

23172346
if sig != nil {
@@ -2395,6 +2424,28 @@ Nodes:
23952424
}
23962425

23972426
if ct := expectedConstraint(tv.Type, 0); ct != nil {
2427+
// Infer the type parameters in a function call based on it's context
2428+
if len(c.path) > i+2 {
2429+
if node, ok := c.path[i+1].(*ast.CallExpr); ok {
2430+
if sig, ok := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature); ok && sig.TypeParams().Len() != 0 {
2431+
// skip again to get the parent of the call expression
2432+
i++
2433+
switch c.path[i+1].(type) {
2434+
case *ast.AssignStmt, *ast.ValueSpec, *ast.ReturnStmt, *ast.SendStmt:
2435+
// Defer call to reverseInferExpectedCallParamType so we can provide it the
2436+
// inferences about its parent node.
2437+
defer func() {
2438+
inf = c.reverseInferExpectedTypeParam(inf, ct, 0, sig)
2439+
}()
2440+
continue Nodes
2441+
case *ast.KeyValueExpr:
2442+
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i+2:], node.Pos(), c.pkg.TypesInfo())
2443+
inf = c.reverseInferExpectedTypeParam(inf, ct, 0, sig)
2444+
break Nodes
2445+
}
2446+
}
2447+
}
2448+
}
23982449
inf.objType = ct
23992450
inf.typeName.wantTypeName = true
24002451
inf.typeName.isTypeParam = true
@@ -2405,7 +2456,30 @@ Nodes:
24052456
case *ast.IndexListExpr:
24062457
if node.Lbrack < c.pos && c.pos <= node.Rbrack {
24072458
if tv, ok := c.pkg.TypesInfo().Types[node.X]; ok {
2408-
if ct := expectedConstraint(tv.Type, exprAtPos(c.pos, node.Indices)); ct != nil {
2459+
typeParamIdx := exprAtPos(c.pos, node.Indices)
2460+
if ct := expectedConstraint(tv.Type, typeParamIdx); ct != nil {
2461+
// Infer the type parameters in a function call based on it's context
2462+
if len(c.path) > i+2 {
2463+
if callnode, ok := c.path[i+1].(*ast.CallExpr); ok {
2464+
if sig, ok := c.pkg.TypesInfo().Types[callnode.Fun].Type.(*types.Signature); ok && sig.TypeParams().Len() != 0 {
2465+
// skip again to get the parent of the call expression
2466+
i++
2467+
switch c.path[i+1].(type) {
2468+
case *ast.AssignStmt, *ast.ValueSpec, *ast.ReturnStmt, *ast.SendStmt:
2469+
// Defer call to reverseInferExpectedCallParamType so we can provide it the
2470+
// inferences about its parent node.
2471+
defer func() {
2472+
inf = c.reverseInferExpectedTypeParam(inf, ct, typeParamIdx, sig)
2473+
}()
2474+
continue Nodes
2475+
case *ast.KeyValueExpr:
2476+
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i+2:], callnode.Pos(), c.pkg.TypesInfo())
2477+
inf = c.reverseInferExpectedTypeParam(inf, ct, typeParamIdx, sig)
2478+
break Nodes
2479+
}
2480+
}
2481+
}
2482+
}
24092483
inf.objType = ct
24102484
inf.typeName.wantTypeName = true
24112485
inf.typeName.isTypeParam = true
@@ -2457,6 +2531,118 @@ Nodes:
24572531
return inf
24582532
}
24592533

2534+
func reverseInferSignature(sig *types.Signature, targetType []types.Type) []types.Type {
2535+
if sig.Results().Len() != len(targetType) {
2536+
return nil
2537+
}
2538+
2539+
tparams := []*types.TypeParam{}
2540+
targs := []types.Type{}
2541+
for i := range sig.TypeParams().Len() {
2542+
tparams = append(tparams, sig.TypeParams().At(i))
2543+
targs = append(targs, nil)
2544+
}
2545+
2546+
u := newUnifier(tparams, targs)
2547+
for i, assignee := range targetType {
2548+
// reverseInferSignature instantiates the call site of a generic function
2549+
// based on the expected return types. Returns nil if inference fails or is invalid.
2550+
//
2551+
// targetType is the expected return types of the function after instantiation.
2552+
if !u.unify(sig.Results().At(i).Type(), assignee, unifyMode(unifyModeExact)) {
2553+
return nil
2554+
}
2555+
}
2556+
2557+
substs := []types.Type{}
2558+
for i := 0; i < sig.TypeParams().Len(); i++ {
2559+
if v := u.handles[sig.TypeParams().At(i)]; v != nil && *v != nil {
2560+
substs = append(substs, *v)
2561+
} else {
2562+
substs = append(substs, nil)
2563+
}
2564+
}
2565+
2566+
return substs
2567+
}
2568+
2569+
func (c *completer) reverseInferredSubstitions(inf candidateInference, sig *types.Signature) []types.Type {
2570+
targetType := []types.Type{}
2571+
if inf.assignees != nil {
2572+
targetType = inf.assignees
2573+
inf.assignees = nil
2574+
} else if c.enclosingCompositeLiteral != nil && !c.wantStructFieldCompletions() {
2575+
targetType = append(targetType, c.expectedCompositeLiteralType())
2576+
} else if t := inf.objType; t != nil {
2577+
inf.objType = nil
2578+
targetType = append(targetType, t)
2579+
} else {
2580+
return nil
2581+
}
2582+
return reverseInferSignature(sig, targetType)
2583+
}
2584+
2585+
// reverseInferExpectedTypeParam uses inferences and completion parameters from the parent scope
2586+
// to instantiate the generalized signature of the call node.
2587+
//
2588+
// inf is expected to contain inferences based on the parent of the CallExpr node.
2589+
func (c *completer) reverseInferExpectedTypeParam(inf candidateInference, expectedConstraint types.Type, typeParamIdx int, sig *types.Signature) candidateInference {
2590+
if typeParamIdx >= sig.TypeParams().Len() {
2591+
inf.objType = nil
2592+
inf.assignees = nil
2593+
return inf
2594+
}
2595+
2596+
substs := c.reverseInferredSubstitions(inf, sig)
2597+
if substs != nil && len(substs) > 0 {
2598+
if substs[typeParamIdx] != nil {
2599+
inf.objType = substs[typeParamIdx]
2600+
} else {
2601+
// default to the constraint if no viable substition
2602+
inf.objType = expectedConstraint
2603+
}
2604+
inf.typeName.wantTypeName = true
2605+
inf.typeName.isTypeParam = true
2606+
}
2607+
return inf
2608+
}
2609+
2610+
// reverseInferExpectedCallParam uses inferences and completion parameters from the parent scope
2611+
// to instantiate the generalized signature of the call node.
2612+
//
2613+
// inf is expected to contain inferences based on the parent of the CallExpr node.
2614+
func (c *completer) reverseInferExpectedCallParam(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference {
2615+
substs := c.reverseInferredSubstitions(inf, sig)
2616+
if substs == nil {
2617+
return inf
2618+
}
2619+
2620+
for i := range substs {
2621+
if substs[i] == nil {
2622+
substs[i] = sig.TypeParams().At(i)
2623+
}
2624+
}
2625+
2626+
if inst, err := types.Instantiate(nil, sig, substs, true); err == nil {
2627+
if inst, ok := inst.(*types.Signature); ok {
2628+
inf = c.expectedCallParamType(inf, node, inst)
2629+
2630+
// Interface type variants shouldn't be candidates as arguments if the caller isn't
2631+
// explicitly instantiated
2632+
//
2633+
// func generic[T any](x T) T { return x }
2634+
// var x someInterface = generic(someImplementor{})
2635+
// ^^ wanted generic[someInterface] but got generic[someImplementor]
2636+
// When offering completions, add a conversion if necessary.
2637+
// generic(someInterface(someImplementor{}))
2638+
if types.IsInterface(inf.objType) {
2639+
inf.convertibleTo = inf.objType
2640+
}
2641+
}
2642+
}
2643+
return inf
2644+
}
2645+
24602646
func (c *completer) expectedCallParamType(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference {
24612647
numParams := sig.Params().Len()
24622648
if numParams == 0 {
@@ -2938,6 +3124,12 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool {
29383124
}
29393125

29403126
if ci.convertibleTo != nil && convertibleTo(candType, ci.convertibleTo) {
3127+
// Candidate implements an interface, but needs explicit conversion to the interface
3128+
// type. This happens when passing arguments to a generic function.
3129+
if ci.objType != nil && types.IsInterface(ci.objType) && !types.Identical(candType, ci.convertibleTo) {
3130+
cand.score *= 0.95 // should rank barely lower if it needs a conversion, even though it's perfectly valid
3131+
cand.convertTo = ci.objType
3132+
}
29413133
return true
29423134
}
29433135

@@ -3161,6 +3353,10 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
31613353
return false
31623354
}
31633355

3356+
wantInterfaceTypeParam := c.inference.typeName.isTypeParam &&
3357+
c.inference.typeName.wantTypeName && c.inference.objType != nil &&
3358+
types.IsInterface(c.inference.objType)
3359+
31643360
typeMatches := func(candType types.Type) bool {
31653361
// Take into account any type name modifier prefixes.
31663362
candType = c.inference.applyTypeNameModifiers(candType)
@@ -3179,6 +3375,13 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
31793375
}
31803376
}
31813377

3378+
// When performing reverse type inference
3379+
// x = Foo[<>]()
3380+
// Where x is an interface, only suggest the interface rather than its implementors
3381+
if wantInterfaceTypeParam && types.Identical(candType, c.inference.objType) {
3382+
return true
3383+
}
3384+
31823385
if c.inference.typeName.wantComparable && !types.Comparable(candType) {
31833386
return false
31843387
}

gopls/internal/golang/completion/format.go

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -196,24 +196,9 @@ Suffixes:
196196
}
197197

198198
if cand.convertTo != nil {
199-
typeName := types.TypeString(cand.convertTo, c.qf)
200-
201-
switch t := cand.convertTo.(type) {
202-
// We need extra parens when casting to these types. For example,
203-
// we need "(*int)(foo)", not "*int(foo)".
204-
case *types.Pointer, *types.Signature:
205-
typeName = "(" + typeName + ")"
206-
case *types.Basic:
207-
// If the types are incompatible (as determined by typeMatches), then we
208-
// must need a conversion here. However, if the target type is untyped,
209-
// don't suggest converting to e.g. "untyped float" (golang/go#62141).
210-
if t.Info()&types.IsUntyped != 0 {
211-
typeName = types.TypeString(types.Default(cand.convertTo), c.qf)
212-
}
213-
}
214-
215-
prefix = typeName + "(" + prefix
216-
suffix = ")"
199+
p, s := c.formatConvertTo(cand.convertTo)
200+
prefix = p + prefix
201+
suffix = s
217202
}
218203

219204
if prefix != "" {
@@ -288,6 +273,24 @@ Suffixes:
288273
return item, nil
289274
}
290275

276+
func (c *completer) formatConvertTo(convertTo types.Type) (prefix string, suffix string) {
277+
typeName := types.TypeString(convertTo, c.qf)
278+
switch t := convertTo.(type) {
279+
// We need extra parens when casting to these types. For example,
280+
// we need "(*int)(foo)", not "*int(foo)".
281+
case *types.Pointer, *types.Signature:
282+
typeName = "(" + typeName + ")"
283+
case *types.Basic:
284+
// If the types are incompatible (as determined by typeMatches), then we
285+
// must need a conversion here. However, if the target type is untyped,
286+
// don't suggest converting to e.g. "untyped float" (golang/go#62141).
287+
if t.Info()&types.IsUntyped != 0 {
288+
typeName = types.TypeString(types.Default(convertTo), c.qf)
289+
}
290+
}
291+
return typeName + "(", ")"
292+
}
293+
291294
// importEdits produces the text edits necessary to add the given import to the current file.
292295
func (c *completer) importEdits(imp *importInfo) ([]protocol.TextEdit, error) {
293296
if imp == nil {

0 commit comments

Comments
 (0)