diff --git a/gopls/internal/lsp/source/implementation.go b/gopls/internal/lsp/source/implementation.go index ca62f4e664d..0d60841ac4a 100644 --- a/gopls/internal/lsp/source/implementation.go +++ b/gopls/internal/lsp/source/implementation.go @@ -23,7 +23,7 @@ func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp pro ctx, done := event.Start(ctx, "source.Implementation") defer done() - impls, err := implementations(ctx, snapshot, f, pp) + impls, err := implementations(ctx, snapshot, f, pp, true) if err != nil { return nil, err } @@ -58,120 +58,168 @@ func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp pro var ErrNotAType = errors.New("not a type name or method") // implementations returns the concrete implementations of the specified -// interface, or the interfaces implemented by the specified concrete type. -// It populates only the definition-related fields of qualifiedObject. -// (Arguably it should return a smaller data type.) -func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) { +// interface, or the interfaces implemented by the specified concrete type, +// or the concrete implementations of a function type. It populates only +// the definition-related fields of qualifiedObject. (Arguably it should +// return a smaller data type.) +func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position, includeFuncs bool) (impls []qualifiedObject, err error) { // Find all named types, even local types // (which can have methods due to promotion). - var ( - allNamed []*types.Named - pkgs = make(map[*types.Package]Package) - ) + qos, err := qualifiedObjsAtProtocolPos(ctx, s, f.URI(), pp) + if err != nil { + return nil, err + } knownPkgs, err := s.KnownPackages(ctx) if err != nil { return nil, err } - for _, pkg := range knownPkgs { - pkgs[pkg.GetTypes()] = pkg - for _, obj := range pkg.GetTypesInfo().Defs { - obj, ok := obj.(*types.TypeName) - // We ignore aliases 'type M = N' to avoid duplicate reporting - // of the Named type N. - if !ok || obj.IsAlias() { - continue + + var objs []types.Object + + for _, qo := range qos { + var queryType types.Type + var queryMethod types.Object + + sig, hasSig := qo.obj.Type().Underlying().(*types.Signature) + // If there's a signature, then qo must be a function. + // If there's no receiver, then search for implementations of the + // function type, or function types that qo's signature matches. + if hasSig && sig.Recv() != nil { + // If there's a receiver, then qo must be a method. + // Query for implementations of the interface's method / + // interfaces that the method fully or partially implements. + queryType = ensurePointer(sig.Recv().Type()) + queryMethod = qo.obj + } else if !hasSig { + // If there's no signature, then qo must be a type. + // Query for implementations of the interface / types that + // implement the interface. + queryType = ensurePointer(qo.obj.Type()) + } + + if queryType != nil { + for _, pkg := range knownPkgs { + objs = append(objs, findInterfaceImplementations(pkg, queryType, queryMethod)...) } - if named, ok := obj.Type().(*types.Named); ok { - allNamed = append(allNamed, named) + } else if hasSig && includeFuncs { + for _, pkg := range knownPkgs { + objs = append(objs, findFunctionImplementations(pkg, sig)...) } + } else { + return nil, ErrNotAType } } - qos, err := qualifiedObjsAtProtocolPos(ctx, s, f.URI(), pp) - if err != nil { - return nil, err + pkgs := make(map[*types.Package]Package, len(knownPkgs)) + for _, pkg := range knownPkgs { + pkgs[pkg.GetTypes()] = pkg } - var ( - impls []qualifiedObject - seen = make(map[token.Position]bool) - ) - for _, qo := range qos { - // Ascertain the query identifier (type or method). - var ( - queryType types.Type - queryMethod *types.Func - ) - switch obj := qo.obj.(type) { - case *types.Func: - queryMethod = obj - if recv := obj.Type().(*types.Signature).Recv(); recv != nil { - queryType = ensurePointer(recv.Type()) - } - case *types.TypeName: - queryType = ensurePointer(obj.Type()) - } + seen := make(map[token.Position]bool) + for _, obj := range objs { + pkg := pkgs[obj.Pkg()] // may be nil (e.g. error) - if queryType == nil { - return nil, ErrNotAType + // TODO(adonovan): the logic below assumes there is only one + // predeclared (pkg=nil) object of interest, the error type. + // That could change in a future version of Go. + + var posn token.Position + if pkg != nil { + posn = pkg.FileSet().Position(obj.Pos()) + } + if seen[posn] { + continue } + seen[posn] = true + impls = append(impls, qualifiedObject{ + obj: obj, + pkg: pkg, + }) + } - if types.NewMethodSet(queryType).Len() == 0 { - return nil, nil + return impls, nil +} + +func findInterfaceImplementations(pkg Package, queryType types.Type, queryMethod types.Object) (objs []types.Object) { + if types.NewMethodSet(queryType).Len() == 0 { + return nil + } + for _, obj := range pkg.GetTypesInfo().Defs { + obj, ok := obj.(*types.TypeName) + // We ignore aliases 'type M = N' to avoid duplicate reporting + // of the Named type N. + if !ok || obj.IsAlias() { + continue + } + named, ok := obj.Type().(*types.Named) + if !ok { + continue } // Find all the named types that match our query. - for _, named := range allNamed { - var ( - candObj types.Object = named.Obj() - candType = ensurePointer(named) - ) - - if !concreteImplementsIntf(candType, queryType) { - continue - } + var ( + candObj types.Object = named.Obj() + candType = ensurePointer(named) + ) - ms := types.NewMethodSet(candType) - if ms.Len() == 0 { - // Skip empty interfaces. - continue - } + if !concreteImplementsIntf(candType, queryType) { + continue + } - // If client queried a method, look up corresponding candType method. - if queryMethod != nil { - sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name()) - if sel == nil { - continue - } - candObj = sel.Obj() - } + ms := types.NewMethodSet(candType) + if ms.Len() == 0 { + // Skip empty interfaces. + continue + } - if candObj == queryMethod { + // If client queried a method, look up corresponding candType method. + if queryMethod != nil { + sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name()) + if sel == nil { continue } + candObj = sel.Obj() + } - pkg := pkgs[candObj.Pkg()] // may be nil (e.g. error) + if candObj == queryMethod { + continue + } - // TODO(adonovan): the logic below assumes there is only one - // predeclared (pkg=nil) object of interest, the error type. - // That could change in a future version of Go. + objs = append(objs, candObj) + } + return objs +} - var posn token.Position - if pkg != nil { - posn = pkg.FileSet().Position(candObj.Pos()) - } - if seen[posn] { - continue +func findFunctionImplementations(pkg Package, sig *types.Signature) (objs []types.Object) { + for _, name := range pkg.GetTypes().Scope().Names() { + o := pkg.GetTypes().Scope().Lookup(name) + + // Look up methods that match the signature. + if obj, isTypeName := o.(*types.TypeName); isTypeName && !obj.IsAlias() { + if named, isNamed := obj.Type().(*types.Named); isNamed { + ms := types.NewMethodSet(ensurePointer(named)) + for i := 0; i < ms.Len(); i++ { + o := ms.At(i).Obj() + if objectImplementsSignature(o, sig) { + objs = append(objs, o) + } + } } - seen[posn] = true + } - impls = append(impls, qualifiedObject{ - obj: candObj, - pkg: pkg, - }) + // Look up functions that match. + if _, isType := o.(*types.TypeName); isType { + continue + } + if objectImplementsSignature(o, sig) { + objs = append(objs, o) } } + return objs +} - return impls, nil +func objectImplementsSignature(o types.Object, sig *types.Signature) bool { + csig, isSig := o.Type().Underlying().(*types.Signature) + return isSig && types.AssignableTo(sig, csig) } // concreteImplementsIntf returns true if a is an interface type implemented by diff --git a/gopls/internal/lsp/source/references.go b/gopls/internal/lsp/source/references.go index d0310560c10..0a219070dfd 100644 --- a/gopls/internal/lsp/source/references.go +++ b/gopls/internal/lsp/source/references.go @@ -258,7 +258,7 @@ func equalOrigin(obj1, obj2 types.Object) bool { // interfaceReferences returns the references to the interfaces implemented by // the type or method at the given position. func interfaceReferences(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]*ReferenceInfo, error) { - implementations, err := implementations(ctx, s, f, pp) + implementations, err := implementations(ctx, s, f, pp, false) if err != nil { if errors.Is(err, ErrNotAType) { return nil, nil diff --git a/gopls/internal/lsp/testdata/implementation/implementation.go b/gopls/internal/lsp/testdata/implementation/implementation.go index b817319d5ef..ffdff3b322a 100644 --- a/gopls/internal/lsp/testdata/implementation/implementation.go +++ b/gopls/internal/lsp/testdata/implementation/implementation.go @@ -29,3 +29,35 @@ type cryer int //@implementations("cryer", Cryer) func (cryer) Cry(other.CryType) {} //@mark(CryImpl, "Cry"),implementations("Cry", Cry) type Empty interface{} //@implementations("Empty") + +type FunctionType func(s string, i int) //@FunctionType,implementations("FunctionType", ImplementationOfFunctionType1, ImplementationOfFunctionType2) + +func ImplementationOfFunctionType1(s string, i int) { //@mark(ImplementationOfFunctionType1, "ImplementationOfFunctionType1") +} + +func ImplementationOfFunctionType2(s string, i int) { //@mark(ImplementationOfFunctionType2, "ImplementationOfFunctionType2") + +func TestFunctionType(f FunctionType) { + f("s", 0) //implementations("f", ImplementationOfFunctionType1, ImplementationOfFunctionType2) +} + +type StructWithFunctionFields struct { + FT FunctionType //implementations("FT", ImplementationOfFunctionType1, ImplementationOfFunctionType2) +} + +func (s StructWithFunctionFields) Test() { + s.FT("s", 0) //implementations("FT", ImplementationOfFunctionType1, ImplementationOfFunctionType2) +} + +func implementationOfAnonymous1(data []byte) error { //@mark(implementationOfAnonymous1, "implementationOfAnonymous1") + return nil +} + +func implementationOfAnonymous2(data []byte) error { //@mark(implementationOfAnonymous2, "implementationOfAnonymous2") + return nil +} + +func TestAnonymousFunction(af func([]byte, cry func(other.CryType)) error) { + af([]byte{0, 1}) //implementations("af", implementationOfAnonymous1, implementationOfAnonymous2) + cry(other.CryType(12)) //implementations("Cry", Cry) +} diff --git a/gopls/internal/lsp/testdata/summary.txt.golden b/gopls/internal/lsp/testdata/summary.txt.golden index cfe8e4a267d..5dc731c28ff 100644 --- a/gopls/internal/lsp/testdata/summary.txt.golden +++ b/gopls/internal/lsp/testdata/summary.txt.golden @@ -27,5 +27,5 @@ SymbolsCount = 1 WorkspaceSymbolsCount = 20 SignaturesCount = 33 LinksCount = 7 -ImplementationsCount = 14 +ImplementationsCount = 15 diff --git a/gopls/internal/lsp/testdata/summary_go1.18.txt.golden b/gopls/internal/lsp/testdata/summary_go1.18.txt.golden index 2b7bf976b2f..b7ec9006fc6 100644 --- a/gopls/internal/lsp/testdata/summary_go1.18.txt.golden +++ b/gopls/internal/lsp/testdata/summary_go1.18.txt.golden @@ -27,5 +27,5 @@ SymbolsCount = 2 WorkspaceSymbolsCount = 20 SignaturesCount = 33 LinksCount = 7 -ImplementationsCount = 14 +ImplementationsCount = 15