|
| 1 | +// Copyright 2018 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +// This file implements type parameter inference given |
| 6 | +// a list of concrete arguments and a parameter list. |
| 7 | + |
| 8 | +package types |
| 9 | + |
| 10 | +import ( |
| 11 | + "go/token" |
| 12 | + "strings" |
| 13 | +) |
| 14 | + |
| 15 | +// infer returns the list of actual type arguments for the given list of type parameters tparams |
| 16 | +// by inferring them from the actual arguments args for the parameters params. If type inference |
| 17 | +// is impossible because unification fails, an error is reported and the resulting types list is |
| 18 | +// nil, and index is 0. Otherwise, types is the list of inferred type arguments, and index is |
| 19 | +// the index of the first type argument in that list that couldn't be inferred (and thus is nil). |
| 20 | +// If all type arguments were inferred successfully, index is < 0. |
| 21 | +func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand) (types []Type, index int) { |
| 22 | + assert(params.Len() == len(args)) |
| 23 | + |
| 24 | + u := newUnifier(check, false) |
| 25 | + u.x.init(tparams) |
| 26 | + |
| 27 | + errorf := func(kind string, tpar, targ Type, arg *operand) { |
| 28 | + // provide a better error message if we can |
| 29 | + targs, failed := u.x.types() |
| 30 | + if failed == 0 { |
| 31 | + // The first type parameter couldn't be inferred. |
| 32 | + // If none of them could be inferred, don't try |
| 33 | + // to provide the inferred type in the error msg. |
| 34 | + allFailed := true |
| 35 | + for _, targ := range targs { |
| 36 | + if targ != nil { |
| 37 | + allFailed = false |
| 38 | + break |
| 39 | + } |
| 40 | + } |
| 41 | + if allFailed { |
| 42 | + check.errorf(arg, 0, "%s %s of %s does not match %s (cannot infer %s)", kind, targ, arg.expr, tpar, typeNamesString(tparams)) |
| 43 | + return |
| 44 | + } |
| 45 | + } |
| 46 | + smap := makeSubstMap(tparams, targs) |
| 47 | + // TODO(rFindley): pass a positioner here, rather than arg.Pos(). |
| 48 | + inferred := check.subst(arg.Pos(), tpar, smap) |
| 49 | + if inferred != tpar { |
| 50 | + check.errorf(arg, 0, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar) |
| 51 | + } else { |
| 52 | + check.errorf(arg, 0, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar) |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + // Terminology: generic parameter = function parameter with a type-parameterized type |
| 57 | + |
| 58 | + // 1st pass: Unify parameter and argument types for generic parameters with typed arguments |
| 59 | + // and collect the indices of generic parameters with untyped arguments. |
| 60 | + var indices []int |
| 61 | + for i, arg := range args { |
| 62 | + par := params.At(i) |
| 63 | + // If we permit bidirectional unification, this conditional code needs to be |
| 64 | + // executed even if par.typ is not parameterized since the argument may be a |
| 65 | + // generic function (for which we want to infer // its type arguments). |
| 66 | + if isParameterized(tparams, par.typ) { |
| 67 | + if arg.mode == invalid { |
| 68 | + // An error was reported earlier. Ignore this targ |
| 69 | + // and continue, we may still be able to infer all |
| 70 | + // targs resulting in fewer follon-on errors. |
| 71 | + continue |
| 72 | + } |
| 73 | + if targ := arg.typ; isTyped(targ) { |
| 74 | + // If we permit bidirectional unification, and targ is |
| 75 | + // a generic function, we need to initialize u.y with |
| 76 | + // the respective type parameters of targ. |
| 77 | + if !u.unify(par.typ, targ) { |
| 78 | + errorf("type", par.typ, targ, arg) |
| 79 | + return nil, 0 |
| 80 | + } |
| 81 | + } else { |
| 82 | + indices = append(indices, i) |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + // Some generic parameters with untyped arguments may have been given a type |
| 88 | + // indirectly through another generic parameter with a typed argument; we can |
| 89 | + // ignore those now. (This only means that we know the types for those generic |
| 90 | + // parameters; it doesn't mean untyped arguments can be passed safely. We still |
| 91 | + // need to verify that assignment of those arguments is valid when we check |
| 92 | + // function parameter passing external to infer.) |
| 93 | + j := 0 |
| 94 | + for _, i := range indices { |
| 95 | + par := params.At(i) |
| 96 | + // Since untyped types are all basic (i.e., non-composite) types, an |
| 97 | + // untyped argument will never match a composite parameter type; the |
| 98 | + // only parameter type it can possibly match against is a *TypeParam. |
| 99 | + // Thus, only keep the indices of generic parameters that are not of |
| 100 | + // composite types and which don't have a type inferred yet. |
| 101 | + if tpar, _ := par.typ.(*TypeParam); tpar != nil && u.x.at(tpar.index) == nil { |
| 102 | + indices[j] = i |
| 103 | + j++ |
| 104 | + } |
| 105 | + } |
| 106 | + indices = indices[:j] |
| 107 | + |
| 108 | + // 2nd pass: Unify parameter and default argument types for remaining generic parameters. |
| 109 | + for _, i := range indices { |
| 110 | + par := params.At(i) |
| 111 | + arg := args[i] |
| 112 | + targ := Default(arg.typ) |
| 113 | + // The default type for an untyped nil is untyped nil. We must not |
| 114 | + // infer an untyped nil type as type parameter type. Ignore untyped |
| 115 | + // nil by making sure all default argument types are typed. |
| 116 | + if isTyped(targ) && !u.unify(par.typ, targ) { |
| 117 | + errorf("default type", par.typ, targ, arg) |
| 118 | + return nil, 0 |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + return u.x.types() |
| 123 | +} |
| 124 | + |
| 125 | +// typeNamesString produces a string containing all the |
| 126 | +// type names in list suitable for human consumption. |
| 127 | +func typeNamesString(list []*TypeName) string { |
| 128 | + // common cases |
| 129 | + n := len(list) |
| 130 | + switch n { |
| 131 | + case 0: |
| 132 | + return "" |
| 133 | + case 1: |
| 134 | + return list[0].name |
| 135 | + case 2: |
| 136 | + return list[0].name + " and " + list[1].name |
| 137 | + } |
| 138 | + |
| 139 | + // general case (n > 2) |
| 140 | + var b strings.Builder |
| 141 | + for i, tname := range list[:n-1] { |
| 142 | + if i > 0 { |
| 143 | + b.WriteString(", ") |
| 144 | + } |
| 145 | + b.WriteString(tname.name) |
| 146 | + } |
| 147 | + b.WriteString(", and ") |
| 148 | + b.WriteString(list[n-1].name) |
| 149 | + return b.String() |
| 150 | +} |
| 151 | + |
| 152 | +// IsParameterized reports whether typ contains any of the type parameters of tparams. |
| 153 | +func isParameterized(tparams []*TypeName, typ Type) bool { |
| 154 | + w := tpWalker{ |
| 155 | + seen: make(map[Type]bool), |
| 156 | + tparams: tparams, |
| 157 | + } |
| 158 | + return w.isParameterized(typ) |
| 159 | +} |
| 160 | + |
| 161 | +type tpWalker struct { |
| 162 | + seen map[Type]bool |
| 163 | + tparams []*TypeName |
| 164 | +} |
| 165 | + |
| 166 | +func (w *tpWalker) isParameterized(typ Type) (res bool) { |
| 167 | + // detect cycles |
| 168 | + if x, ok := w.seen[typ]; ok { |
| 169 | + return x |
| 170 | + } |
| 171 | + w.seen[typ] = false |
| 172 | + defer func() { |
| 173 | + w.seen[typ] = res |
| 174 | + }() |
| 175 | + |
| 176 | + switch t := typ.(type) { |
| 177 | + case nil, *Basic: // TODO(gri) should nil be handled here? |
| 178 | + break |
| 179 | + |
| 180 | + case *Array: |
| 181 | + return w.isParameterized(t.elem) |
| 182 | + |
| 183 | + case *Slice: |
| 184 | + return w.isParameterized(t.elem) |
| 185 | + |
| 186 | + case *Struct: |
| 187 | + for _, fld := range t.fields { |
| 188 | + if w.isParameterized(fld.typ) { |
| 189 | + return true |
| 190 | + } |
| 191 | + } |
| 192 | + |
| 193 | + case *Pointer: |
| 194 | + return w.isParameterized(t.base) |
| 195 | + |
| 196 | + case *Tuple: |
| 197 | + n := t.Len() |
| 198 | + for i := 0; i < n; i++ { |
| 199 | + if w.isParameterized(t.At(i).typ) { |
| 200 | + return true |
| 201 | + } |
| 202 | + } |
| 203 | + |
| 204 | + case *Sum: |
| 205 | + return w.isParameterizedList(t.types) |
| 206 | + |
| 207 | + case *Signature: |
| 208 | + // t.tparams may not be nil if we are looking at a signature |
| 209 | + // of a generic function type (or an interface method) that is |
| 210 | + // part of the type we're testing. We don't care about these type |
| 211 | + // parameters. |
| 212 | + // Similarly, the receiver of a method may declare (rather then |
| 213 | + // use) type parameters, we don't care about those either. |
| 214 | + // Thus, we only need to look at the input and result parameters. |
| 215 | + return w.isParameterized(t.params) || w.isParameterized(t.results) |
| 216 | + |
| 217 | + case *Interface: |
| 218 | + if t.allMethods != nil { |
| 219 | + // TODO(rFindley) at some point we should enforce completeness here |
| 220 | + for _, m := range t.allMethods { |
| 221 | + if w.isParameterized(m.typ) { |
| 222 | + return true |
| 223 | + } |
| 224 | + } |
| 225 | + return w.isParameterizedList(unpackType(t.allTypes)) |
| 226 | + } |
| 227 | + |
| 228 | + return t.iterate(func(t *Interface) bool { |
| 229 | + for _, m := range t.methods { |
| 230 | + if w.isParameterized(m.typ) { |
| 231 | + return true |
| 232 | + } |
| 233 | + } |
| 234 | + return w.isParameterizedList(unpackType(t.types)) |
| 235 | + }, nil) |
| 236 | + |
| 237 | + case *Map: |
| 238 | + return w.isParameterized(t.key) || w.isParameterized(t.elem) |
| 239 | + |
| 240 | + case *Chan: |
| 241 | + return w.isParameterized(t.elem) |
| 242 | + |
| 243 | + case *Named: |
| 244 | + return w.isParameterizedList(t.targs) |
| 245 | + |
| 246 | + case *TypeParam: |
| 247 | + // t must be one of w.tparams |
| 248 | + return t.index < len(w.tparams) && w.tparams[t.index].typ == t |
| 249 | + |
| 250 | + case *instance: |
| 251 | + return w.isParameterizedList(t.targs) |
| 252 | + |
| 253 | + default: |
| 254 | + unreachable() |
| 255 | + } |
| 256 | + |
| 257 | + return false |
| 258 | +} |
| 259 | + |
| 260 | +func (w *tpWalker) isParameterizedList(list []Type) bool { |
| 261 | + for _, t := range list { |
| 262 | + if w.isParameterized(t) { |
| 263 | + return true |
| 264 | + } |
| 265 | + } |
| 266 | + return false |
| 267 | +} |
| 268 | + |
| 269 | +// inferB returns the list of actual type arguments inferred from the type parameters' |
| 270 | +// bounds and an initial set of type arguments. If type inference is impossible because |
| 271 | +// unification fails, an error is reported, the resulting types list is nil, and index is 0. |
| 272 | +// Otherwise, types is the list of inferred type arguments, and index is the index of the |
| 273 | +// first type argument in that list that couldn't be inferred (and thus is nil). If all |
| 274 | +// type arguments where inferred successfully, index is < 0. The number of type arguments |
| 275 | +// provided may be less than the number of type parameters, but there must be at least one. |
| 276 | +func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, index int) { |
| 277 | + assert(len(tparams) >= len(targs) && len(targs) > 0) |
| 278 | + |
| 279 | + // Setup bidirectional unification between those structural bounds |
| 280 | + // and the corresponding type arguments (which may be nil!). |
| 281 | + u := newUnifier(check, false) |
| 282 | + u.x.init(tparams) |
| 283 | + u.y = u.x // type parameters between LHS and RHS of unification are identical |
| 284 | + |
| 285 | + // Set the type arguments which we know already. |
| 286 | + for i, targ := range targs { |
| 287 | + if targ != nil { |
| 288 | + u.x.set(i, targ) |
| 289 | + } |
| 290 | + } |
| 291 | + |
| 292 | + // Unify type parameters with their structural constraints, if any. |
| 293 | + for _, tpar := range tparams { |
| 294 | + typ := tpar.typ.(*TypeParam) |
| 295 | + sbound := check.structuralType(typ.bound) |
| 296 | + if sbound != nil { |
| 297 | + if !u.unify(typ, sbound) { |
| 298 | + check.errorf(tpar, 0, "%s does not match %s", tpar, sbound) |
| 299 | + return nil, 0 |
| 300 | + } |
| 301 | + } |
| 302 | + } |
| 303 | + |
| 304 | + // u.x.types() now contains the incoming type arguments plus any additional type |
| 305 | + // arguments for which there were structural constraints. The newly inferred non- |
| 306 | + // nil entries may still contain references to other type parameters. For instance, |
| 307 | + // for [A any, B interface{type []C}, C interface{type *A}], if A == int |
| 308 | + // was given, unification produced the type list [int, []C, *A]. We eliminate the |
| 309 | + // remaining type parameters by substituting the type parameters in this type list |
| 310 | + // until nothing changes anymore. |
| 311 | + types, index = u.x.types() |
| 312 | + if debug { |
| 313 | + for i, targ := range targs { |
| 314 | + assert(targ == nil || types[i] == targ) |
| 315 | + } |
| 316 | + } |
| 317 | + |
| 318 | + // dirty tracks the indices of all types that may still contain type parameters. |
| 319 | + // We know that nil type entries and entries corresponding to provided (non-nil) |
| 320 | + // type arguments are clean, so exclude them from the start. |
| 321 | + var dirty []int |
| 322 | + for i, typ := range types { |
| 323 | + if typ != nil && (i >= len(targs) || targs[i] == nil) { |
| 324 | + dirty = append(dirty, i) |
| 325 | + } |
| 326 | + } |
| 327 | + |
| 328 | + for len(dirty) > 0 { |
| 329 | + // TODO(gri) Instead of creating a new substMap for each iteration, |
| 330 | + // provide an update operation for substMaps and only change when |
| 331 | + // needed. Optimization. |
| 332 | + smap := makeSubstMap(tparams, types) |
| 333 | + n := 0 |
| 334 | + for _, index := range dirty { |
| 335 | + t0 := types[index] |
| 336 | + if t1 := check.subst(token.NoPos, t0, smap); t1 != t0 { |
| 337 | + types[index] = t1 |
| 338 | + dirty[n] = index |
| 339 | + n++ |
| 340 | + } |
| 341 | + } |
| 342 | + dirty = dirty[:n] |
| 343 | + } |
| 344 | + |
| 345 | + return |
| 346 | +} |
| 347 | + |
| 348 | +// structuralType returns the structural type of a constraint, if any. |
| 349 | +func (check *Checker) structuralType(constraint Type) Type { |
| 350 | + if iface, _ := under(constraint).(*Interface); iface != nil { |
| 351 | + check.completeInterface(token.NoPos, iface) |
| 352 | + types := unpackType(iface.allTypes) |
| 353 | + if len(types) == 1 { |
| 354 | + return types[0] |
| 355 | + } |
| 356 | + return nil |
| 357 | + } |
| 358 | + return constraint |
| 359 | +} |
0 commit comments