@@ -134,6 +134,30 @@ func (i *CompletionItem) Snippet() string {
134
134
return i .InsertText
135
135
}
136
136
137
+ func (i * CompletionItem ) addPrefixSuffix (c * completer , prefix string , suffix string ) error {
138
+ if prefix != "" {
139
+ // If we are in a selector, add an edit to place prefix before selector.
140
+ if sel := enclosingSelector (c .path , c .pos ); sel != nil {
141
+ edits , err := c .editText (sel .Pos (), sel .Pos (), prefix )
142
+ if err != nil {
143
+ return err
144
+ }
145
+ i .AdditionalTextEdits = append (i .AdditionalTextEdits , edits ... )
146
+ } else {
147
+ // If there is no selector, just stick the prefix at the start.
148
+ i .InsertText = prefix + i .InsertText
149
+ i .snippet .PrependText (prefix )
150
+ }
151
+ }
152
+
153
+ if suffix != "" {
154
+ i .InsertText += suffix
155
+ i .snippet .WriteText (suffix )
156
+ }
157
+
158
+ return nil
159
+ }
160
+
137
161
// Scoring constants are used for weighting the relevance of different candidates.
138
162
const (
139
163
// stdScore is the base score for all completion items.
@@ -659,6 +683,7 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p
659
683
c .addStatementCandidates ()
660
684
661
685
c .sortItems ()
686
+
662
687
return c .items , c .getSurrounding (), nil
663
688
}
664
689
@@ -2301,17 +2326,21 @@ Nodes:
2301
2326
2302
2327
sig , _ := c .pkg .TypesInfo ().Types [node .Fun ].Type .(* types.Signature )
2303
2328
2304
- if sig != nil && sig .TypeParams ().Len () > 0 {
2305
- // If we are completing a generic func call, re-check the call expression.
2306
- // This allows type param inference to work in cases like:
2307
- //
2308
- // func foo[T any](T) {}
2309
- // foo[int](<>) // <- get "int" completions instead of "T"
2310
- //
2311
- // TODO: remove this after https://go.dev/issue/52503
2312
- info := & types.Info {Types : make (map [ast.Expr ]types.TypeAndValue )}
2313
- types .CheckExpr (c .pkg .FileSet (), c .pkg .Types (), node .Fun .Pos (), node .Fun , info )
2314
- sig , _ = info .Types [node .Fun ].Type .(* types.Signature )
2329
+ if sig != nil && sig .TypeParams ().Len () > 0 && len (c .path ) > i + 1 {
2330
+ // infer an instantiation for the CallExpr from it's context
2331
+ switch c .path [i + 1 ].(type ) {
2332
+ case * ast.AssignStmt , * ast.ReturnStmt , * ast.SendStmt , * ast.ValueSpec :
2333
+ // Defer call to reverseInferExpectedCallParamType so we can provide it the
2334
+ // inferences about its parent node.
2335
+ defer func (sig * types.Signature ) {
2336
+ inf = c .reverseInferExpectedCallParam (inf , node , sig )
2337
+ }(sig )
2338
+ continue Nodes
2339
+ case * ast.KeyValueExpr :
2340
+ c .enclosingCompositeLiteral = enclosingCompositeLiteral (c .path [i :], node .Pos (), c .pkg .TypesInfo ())
2341
+ inf = c .reverseInferExpectedCallParam (inf , node , sig )
2342
+ break Nodes
2343
+ }
2315
2344
}
2316
2345
2317
2346
if sig != nil {
@@ -2395,6 +2424,28 @@ Nodes:
2395
2424
}
2396
2425
2397
2426
if ct := expectedConstraint (tv .Type , 0 ); ct != nil {
2427
+ // Infer the type parameters in a function call based on it's context
2428
+ if len (c .path ) > i + 2 {
2429
+ if node , ok := c .path [i + 1 ].(* ast.CallExpr ); ok {
2430
+ if sig , ok := c .pkg .TypesInfo ().Types [node .Fun ].Type .(* types.Signature ); ok && sig .TypeParams ().Len () != 0 {
2431
+ // skip again to get the parent of the call expression
2432
+ i ++
2433
+ switch c .path [i + 1 ].(type ) {
2434
+ case * ast.AssignStmt , * ast.ValueSpec , * ast.ReturnStmt , * ast.SendStmt :
2435
+ // Defer call to reverseInferExpectedCallParamType so we can provide it the
2436
+ // inferences about its parent node.
2437
+ defer func () {
2438
+ inf = c .reverseInferExpectedTypeParam (inf , ct , 0 , sig )
2439
+ }()
2440
+ continue Nodes
2441
+ case * ast.KeyValueExpr :
2442
+ c .enclosingCompositeLiteral = enclosingCompositeLiteral (c .path [i + 2 :], node .Pos (), c .pkg .TypesInfo ())
2443
+ inf = c .reverseInferExpectedTypeParam (inf , ct , 0 , sig )
2444
+ break Nodes
2445
+ }
2446
+ }
2447
+ }
2448
+ }
2398
2449
inf .objType = ct
2399
2450
inf .typeName .wantTypeName = true
2400
2451
inf .typeName .isTypeParam = true
@@ -2405,7 +2456,30 @@ Nodes:
2405
2456
case * ast.IndexListExpr :
2406
2457
if node .Lbrack < c .pos && c .pos <= node .Rbrack {
2407
2458
if tv , ok := c .pkg .TypesInfo ().Types [node .X ]; ok {
2408
- if ct := expectedConstraint (tv .Type , exprAtPos (c .pos , node .Indices )); ct != nil {
2459
+ typeParamIdx := exprAtPos (c .pos , node .Indices )
2460
+ if ct := expectedConstraint (tv .Type , typeParamIdx ); ct != nil {
2461
+ // Infer the type parameters in a function call based on it's context
2462
+ if len (c .path ) > i + 2 {
2463
+ if callnode , ok := c .path [i + 1 ].(* ast.CallExpr ); ok {
2464
+ if sig , ok := c .pkg .TypesInfo ().Types [callnode .Fun ].Type .(* types.Signature ); ok && sig .TypeParams ().Len () != 0 {
2465
+ // skip again to get the parent of the call expression
2466
+ i ++
2467
+ switch c .path [i + 1 ].(type ) {
2468
+ case * ast.AssignStmt , * ast.ValueSpec , * ast.ReturnStmt , * ast.SendStmt :
2469
+ // Defer call to reverseInferExpectedCallParamType so we can provide it the
2470
+ // inferences about its parent node.
2471
+ defer func () {
2472
+ inf = c .reverseInferExpectedTypeParam (inf , ct , typeParamIdx , sig )
2473
+ }()
2474
+ continue Nodes
2475
+ case * ast.KeyValueExpr :
2476
+ c .enclosingCompositeLiteral = enclosingCompositeLiteral (c .path [i + 2 :], callnode .Pos (), c .pkg .TypesInfo ())
2477
+ inf = c .reverseInferExpectedTypeParam (inf , ct , typeParamIdx , sig )
2478
+ break Nodes
2479
+ }
2480
+ }
2481
+ }
2482
+ }
2409
2483
inf .objType = ct
2410
2484
inf .typeName .wantTypeName = true
2411
2485
inf .typeName .isTypeParam = true
@@ -2457,6 +2531,118 @@ Nodes:
2457
2531
return inf
2458
2532
}
2459
2533
2534
+ func reverseInferSignature (sig * types.Signature , targetType []types.Type ) []types.Type {
2535
+ if sig .Results ().Len () != len (targetType ) {
2536
+ return nil
2537
+ }
2538
+
2539
+ tparams := []* types.TypeParam {}
2540
+ targs := []types.Type {}
2541
+ for i := range sig .TypeParams ().Len () {
2542
+ tparams = append (tparams , sig .TypeParams ().At (i ))
2543
+ targs = append (targs , nil )
2544
+ }
2545
+
2546
+ u := newUnifier (tparams , targs )
2547
+ for i , assignee := range targetType {
2548
+ // reverseInferSignature instantiates the call site of a generic function
2549
+ // based on the expected return types. Returns nil if inference fails or is invalid.
2550
+ //
2551
+ // targetType is the expected return types of the function after instantiation.
2552
+ if ! u .unify (sig .Results ().At (i ).Type (), assignee , unifyMode (unifyModeExact )) {
2553
+ return nil
2554
+ }
2555
+ }
2556
+
2557
+ substs := []types.Type {}
2558
+ for i := 0 ; i < sig .TypeParams ().Len (); i ++ {
2559
+ if v := u .handles [sig .TypeParams ().At (i )]; v != nil && * v != nil {
2560
+ substs = append (substs , * v )
2561
+ } else {
2562
+ substs = append (substs , nil )
2563
+ }
2564
+ }
2565
+
2566
+ return substs
2567
+ }
2568
+
2569
+ func (c * completer ) reverseInferredSubstitions (inf candidateInference , sig * types.Signature ) []types.Type {
2570
+ targetType := []types.Type {}
2571
+ if inf .assignees != nil {
2572
+ targetType = inf .assignees
2573
+ inf .assignees = nil
2574
+ } else if c .enclosingCompositeLiteral != nil && ! c .wantStructFieldCompletions () {
2575
+ targetType = append (targetType , c .expectedCompositeLiteralType ())
2576
+ } else if t := inf .objType ; t != nil {
2577
+ inf .objType = nil
2578
+ targetType = append (targetType , t )
2579
+ } else {
2580
+ return nil
2581
+ }
2582
+ return reverseInferSignature (sig , targetType )
2583
+ }
2584
+
2585
+ // reverseInferExpectedTypeParam uses inferences and completion parameters from the parent scope
2586
+ // to instantiate the generalized signature of the call node.
2587
+ //
2588
+ // inf is expected to contain inferences based on the parent of the CallExpr node.
2589
+ func (c * completer ) reverseInferExpectedTypeParam (inf candidateInference , expectedConstraint types.Type , typeParamIdx int , sig * types.Signature ) candidateInference {
2590
+ if typeParamIdx >= sig .TypeParams ().Len () {
2591
+ inf .objType = nil
2592
+ inf .assignees = nil
2593
+ return inf
2594
+ }
2595
+
2596
+ substs := c .reverseInferredSubstitions (inf , sig )
2597
+ if substs != nil && len (substs ) > 0 {
2598
+ if substs [typeParamIdx ] != nil {
2599
+ inf .objType = substs [typeParamIdx ]
2600
+ } else {
2601
+ // default to the constraint if no viable substition
2602
+ inf .objType = expectedConstraint
2603
+ }
2604
+ inf .typeName .wantTypeName = true
2605
+ inf .typeName .isTypeParam = true
2606
+ }
2607
+ return inf
2608
+ }
2609
+
2610
+ // reverseInferExpectedCallParam uses inferences and completion parameters from the parent scope
2611
+ // to instantiate the generalized signature of the call node.
2612
+ //
2613
+ // inf is expected to contain inferences based on the parent of the CallExpr node.
2614
+ func (c * completer ) reverseInferExpectedCallParam (inf candidateInference , node * ast.CallExpr , sig * types.Signature ) candidateInference {
2615
+ substs := c .reverseInferredSubstitions (inf , sig )
2616
+ if substs == nil {
2617
+ return inf
2618
+ }
2619
+
2620
+ for i := range substs {
2621
+ if substs [i ] == nil {
2622
+ substs [i ] = sig .TypeParams ().At (i )
2623
+ }
2624
+ }
2625
+
2626
+ if inst , err := types .Instantiate (nil , sig , substs , true ); err == nil {
2627
+ if inst , ok := inst .(* types.Signature ); ok {
2628
+ inf = c .expectedCallParamType (inf , node , inst )
2629
+
2630
+ // Interface type variants shouldn't be candidates as arguments if the caller isn't
2631
+ // explicitly instantiated
2632
+ //
2633
+ // func generic[T any](x T) T { return x }
2634
+ // var x someInterface = generic(someImplementor{})
2635
+ // ^^ wanted generic[someInterface] but got generic[someImplementor]
2636
+ // When offering completions, add a conversion if necessary.
2637
+ // generic(someInterface(someImplementor{}))
2638
+ if types .IsInterface (inf .objType ) {
2639
+ inf .convertibleTo = inf .objType
2640
+ }
2641
+ }
2642
+ }
2643
+ return inf
2644
+ }
2645
+
2460
2646
func (c * completer ) expectedCallParamType (inf candidateInference , node * ast.CallExpr , sig * types.Signature ) candidateInference {
2461
2647
numParams := sig .Params ().Len ()
2462
2648
if numParams == 0 {
@@ -2938,6 +3124,12 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool {
2938
3124
}
2939
3125
2940
3126
if ci .convertibleTo != nil && convertibleTo (candType , ci .convertibleTo ) {
3127
+ // Candidate implements an interface, but needs explicit conversion to the interface
3128
+ // type. This happens when passing arguments to a generic function.
3129
+ if ci .objType != nil && types .IsInterface (ci .objType ) && ! types .Identical (candType , ci .convertibleTo ) {
3130
+ cand .score *= 0.95 // should rank barely lower if it needs a conversion, even though it's perfectly valid
3131
+ cand .convertTo = ci .objType
3132
+ }
2941
3133
return true
2942
3134
}
2943
3135
@@ -3161,6 +3353,10 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
3161
3353
return false
3162
3354
}
3163
3355
3356
+ wantInterfaceTypeParam := c .inference .typeName .isTypeParam &&
3357
+ c .inference .typeName .wantTypeName && c .inference .objType != nil &&
3358
+ types .IsInterface (c .inference .objType )
3359
+
3164
3360
typeMatches := func (candType types.Type ) bool {
3165
3361
// Take into account any type name modifier prefixes.
3166
3362
candType = c .inference .applyTypeNameModifiers (candType )
@@ -3179,6 +3375,13 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
3179
3375
}
3180
3376
}
3181
3377
3378
+ // When performing reverse type inference
3379
+ // x = Foo[<>]()
3380
+ // Where x is an interface, only suggest the interface rather than its implementors
3381
+ if wantInterfaceTypeParam && types .Identical (candType , c .inference .objType ) {
3382
+ return true
3383
+ }
3384
+
3182
3385
if c .inference .typeName .wantComparable && ! types .Comparable (candType ) {
3183
3386
return false
3184
3387
}
0 commit comments