|
| 1 | +// Invoke with //go:generate helper/helper -t Server -d protocol/tsserver.go -u lsp -o server_gen.go |
| 2 | +// invoke in internal/lsp |
| 3 | +package main |
| 4 | + |
| 5 | +import ( |
| 6 | + "bytes" |
| 7 | + "flag" |
| 8 | + "fmt" |
| 9 | + "go/ast" |
| 10 | + "go/parser" |
| 11 | + "go/token" |
| 12 | + "log" |
| 13 | + "os" |
| 14 | + "sort" |
| 15 | + "strings" |
| 16 | + "text/template" |
| 17 | +) |
| 18 | + |
| 19 | +var ( |
| 20 | + typ = flag.String("t", "Server", "generate code for this type") |
| 21 | + def = flag.String("d", "", "the file the type is defined in") // this relies on punning |
| 22 | + use = flag.String("u", "", "look for uses in this package") |
| 23 | + out = flag.String("o", "", "where to write the generated file") |
| 24 | +) |
| 25 | + |
| 26 | +func main() { |
| 27 | + log.SetFlags(log.Lshortfile) |
| 28 | + flag.Parse() |
| 29 | + if *typ == "" || *def == "" || *use == "" || *out == "" { |
| 30 | + flag.PrintDefaults() |
| 31 | + return |
| 32 | + } |
| 33 | + // read the type definition and see what methods we're looking for |
| 34 | + doTypes() |
| 35 | + |
| 36 | + // parse the package and see which methods are defined |
| 37 | + doUses() |
| 38 | + |
| 39 | + output() |
| 40 | +} |
| 41 | + |
| 42 | +// replace "\\\n" with nothing before using |
| 43 | +var tmpl = ` |
| 44 | +package lsp |
| 45 | +
|
| 46 | +// code generated by helper. DO NOT EDIT. |
| 47 | +
|
| 48 | +import ( |
| 49 | + "context" |
| 50 | +
|
| 51 | + "golang.org/x/tools/internal/lsp/protocol" |
| 52 | +) |
| 53 | +
|
| 54 | +{{range $key, $v := .Stuff}} |
| 55 | +func (s *{{$.Type}}) {{$v.Name}}({{.Param}}) {{.Result}} { |
| 56 | + {{if ne .Found ""}} return s.{{.Internal}}({{.Invoke}})\ |
| 57 | + {{else}}return {{if lt 1 (len .Results)}}nil, {{end}}notImplemented("{{.Name}}"){{end}} |
| 58 | +} |
| 59 | +{{end}} |
| 60 | +` |
| 61 | + |
| 62 | +func output() { |
| 63 | + // put in empty param names as needed |
| 64 | + for _, t := range types { |
| 65 | + if t.paramnames == nil { |
| 66 | + t.paramnames = make([]string, len(t.paramtypes)) |
| 67 | + } |
| 68 | + for i, p := range t.paramtypes { |
| 69 | + cm := "" |
| 70 | + if i > 0 { |
| 71 | + cm = ", " |
| 72 | + } |
| 73 | + t.Param += fmt.Sprintf("%s%s %s", cm, t.paramnames[i], p) |
| 74 | + t.Invoke += fmt.Sprintf("%s%s", cm, t.paramnames[i]) |
| 75 | + } |
| 76 | + if len(t.Results) > 1 { |
| 77 | + t.Result = "(" |
| 78 | + } |
| 79 | + for i, r := range t.Results { |
| 80 | + cm := "" |
| 81 | + if i > 0 { |
| 82 | + cm = ", " |
| 83 | + } |
| 84 | + t.Result += fmt.Sprintf("%s%s", cm, r) |
| 85 | + } |
| 86 | + if len(t.Results) > 1 { |
| 87 | + t.Result += ")" |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + fd, err := os.Create(*out) |
| 92 | + if err != nil { |
| 93 | + log.Fatal(err) |
| 94 | + } |
| 95 | + t, err := template.New("foo").Parse(tmpl) |
| 96 | + if err != nil { |
| 97 | + log.Fatal(err) |
| 98 | + } |
| 99 | + type par struct { |
| 100 | + Type string |
| 101 | + Stuff []*Function |
| 102 | + } |
| 103 | + p := par{*typ, types} |
| 104 | + if false { // debugging the template |
| 105 | + t.Execute(os.Stderr, &p) |
| 106 | + } |
| 107 | + buf := bytes.NewBuffer(nil) |
| 108 | + err = t.Execute(buf, &p) |
| 109 | + if err != nil { |
| 110 | + log.Fatal(err) |
| 111 | + } |
| 112 | + ans := bytes.Replace(buf.Bytes(), []byte("\\\n"), []byte{}, -1) |
| 113 | + fd.Write(ans) |
| 114 | +} |
| 115 | + |
| 116 | +func doUses() { |
| 117 | + fset := token.NewFileSet() |
| 118 | + pkgs, err := parser.ParseDir(fset, *use, nil, 0) |
| 119 | + if err != nil { |
| 120 | + log.Fatalf("%q:%v", *use, err) |
| 121 | + } |
| 122 | + pkg := pkgs["lsp"] // CHECK |
| 123 | + files := pkg.Files |
| 124 | + for fname, f := range files { |
| 125 | + for _, d := range f.Decls { |
| 126 | + fd, ok := d.(*ast.FuncDecl) |
| 127 | + if !ok { |
| 128 | + continue |
| 129 | + } |
| 130 | + nm := fd.Name.String() |
| 131 | + if isExported(nm) { |
| 132 | + // we're looking for things like didChange |
| 133 | + continue |
| 134 | + } |
| 135 | + if fx, ok := byname[nm]; ok { |
| 136 | + if fx.Found != "" { |
| 137 | + log.Fatalf("found %s in %s and %s", fx.Internal, fx.Found, fname) |
| 138 | + } |
| 139 | + fx.Found = fname |
| 140 | + // and the Paramnames |
| 141 | + ft := fd.Type |
| 142 | + for _, f := range ft.Params.List { |
| 143 | + nm := "" |
| 144 | + if len(f.Names) > 0 { |
| 145 | + nm = f.Names[0].String() |
| 146 | + } |
| 147 | + fx.paramnames = append(fx.paramnames, nm) |
| 148 | + } |
| 149 | + } |
| 150 | + } |
| 151 | + } |
| 152 | + if false { |
| 153 | + for i, f := range types { |
| 154 | + log.Printf("%d %s %s", i, f.Internal, f.Found) |
| 155 | + } |
| 156 | + } |
| 157 | +} |
| 158 | + |
| 159 | +type Function struct { |
| 160 | + Name string |
| 161 | + Internal string // first letter lower case |
| 162 | + paramtypes []string |
| 163 | + paramnames []string |
| 164 | + Results []string |
| 165 | + Param string |
| 166 | + Result string // do it in code, easier than in a template |
| 167 | + Invoke string |
| 168 | + Found string // file it was found in |
| 169 | +} |
| 170 | + |
| 171 | +var types []*Function |
| 172 | +var byname = map[string]*Function{} // internal names |
| 173 | + |
| 174 | +func doTypes() { |
| 175 | + fset := token.NewFileSet() |
| 176 | + f, err := parser.ParseFile(fset, *def, nil, 0) |
| 177 | + if err != nil { |
| 178 | + log.Fatal(err) |
| 179 | + } |
| 180 | + fd, err := os.Create("/tmp/ast") |
| 181 | + if err != nil { |
| 182 | + log.Fatal(err) |
| 183 | + } |
| 184 | + ast.Fprint(fd, fset, f, ast.NotNilFilter) |
| 185 | + ast.Inspect(f, inter) |
| 186 | + sort.Slice(types, func(i, j int) bool { return types[i].Name < types[j].Name }) |
| 187 | + if false { |
| 188 | + for i, f := range types { |
| 189 | + log.Printf("%d %s(%v) %v", i, f.Name, f.paramtypes, f.Results) |
| 190 | + } |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +func inter(n ast.Node) bool { |
| 195 | + x, ok := n.(*ast.TypeSpec) |
| 196 | + if !ok || x.Name.Name != *typ { |
| 197 | + return true |
| 198 | + } |
| 199 | + m := x.Type.(*ast.InterfaceType).Methods.List |
| 200 | + for _, fld := range m { |
| 201 | + fn := fld.Type.(*ast.FuncType) |
| 202 | + p := fn.Params.List |
| 203 | + r := fn.Results.List |
| 204 | + fx := &Function{ |
| 205 | + Name: fld.Names[0].String(), |
| 206 | + } |
| 207 | + fx.Internal = strings.ToLower(fx.Name[:1]) + fx.Name[1:] |
| 208 | + for _, f := range p { |
| 209 | + fx.paramtypes = append(fx.paramtypes, whatis(f.Type)) |
| 210 | + } |
| 211 | + for _, f := range r { |
| 212 | + fx.Results = append(fx.Results, whatis(f.Type)) |
| 213 | + } |
| 214 | + types = append(types, fx) |
| 215 | + byname[fx.Internal] = fx |
| 216 | + } |
| 217 | + return false |
| 218 | +} |
| 219 | + |
| 220 | +func whatis(x ast.Expr) string { |
| 221 | + switch n := x.(type) { |
| 222 | + case *ast.SelectorExpr: |
| 223 | + return whatis(n.X) + "." + n.Sel.String() |
| 224 | + case *ast.StarExpr: |
| 225 | + return "*" + whatis(n.X) |
| 226 | + case *ast.Ident: |
| 227 | + if isExported(n.Name) { |
| 228 | + // these are from package protocol |
| 229 | + return "protocol." + n.Name |
| 230 | + } |
| 231 | + return n.Name |
| 232 | + case *ast.ArrayType: |
| 233 | + return "[]" + whatis(n.Elt) |
| 234 | + case *ast.InterfaceType: |
| 235 | + return "interface{}" |
| 236 | + default: |
| 237 | + log.Fatalf("Fatal %T", x) |
| 238 | + return fmt.Sprintf("%T", x) |
| 239 | + } |
| 240 | +} |
| 241 | + |
| 242 | +func isExported(n string) bool { |
| 243 | + return n[0] >= 'A' && n[0] <= 'Z' |
| 244 | +} |
0 commit comments