Skip to content

Commit 1811aea

Browse files
committed
cmd/compile: deal with helper generic types that add methods to T
Deal with cases like: 'type P[T any] T' (used to add methods to an arbitrary type T), In this case, P[T] has kind types.TTYPEPARAM (as does T itself), but requires more code to substitute than a simple TTYPEPARAM T. See the comment near the beginning of subster.typ() in stencil.go. Add new test absdiff.go. This test has a case for complex types (which I've commented out) that will only work when we deal better with Go builtins in generic functions (like real and imag). Remove change in fmt.go for TTYPEPARAMS that is no longer needed (since all TTYPEPARAMS have a sym) and was sometimes causing an extra prefix when formatting method names. Separate out the setting of a TTYPEPARAM bound, since it can reference the TTYPEPARAM being defined, so must be done separately. Also, we don't currently (and may not ever) need bounds after types2 typechecking. Change-Id: Id173057e0c4563b309b95e665e9c1151ead4ba77 Reviewed-on: https://go-review.googlesource.com/c/go/+/300049 Run-TryBot: Dan Scales <[email protected]> TryBot-Result: Go Bot <[email protected]> Trust: Dan Scales <[email protected]> Trust: Robert Griesemer <[email protected]> Reviewed-by: Robert Griesemer <[email protected]>
1 parent 5edab39 commit 1811aea

File tree

5 files changed

+138
-8
lines changed

5 files changed

+138
-8
lines changed

src/cmd/compile/internal/noder/stencil.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,25 @@ func (subst *subster) typ(t *types.Type) *types.Type {
553553
return subst.targs[i].Type()
554554
}
555555
}
556-
return t
556+
// If t is a simple typeparam T, then t has the name/symbol 'T'
557+
// and t.Underlying() == t.
558+
//
559+
// However, consider the type definition: 'type P[T any] T'. We
560+
// might use this definition so we can have a variant of type T
561+
// that we can add new methods to. Suppose t is a reference to
562+
// P[T]. t has the name 'P[T]', but its kind is TTYPEPARAM,
563+
// because P[T] is defined as T. If we look at t.Underlying(), it
564+
// is different, because the name of t.Underlying() is 'T' rather
565+
// than 'P[T]'. But the kind of t.Underlying() is also TTYPEPARAM.
566+
// In this case, we do the needed recursive substitution in the
567+
// case statement below.
568+
if t.Underlying() == t {
569+
// t is a simple typeparam that didn't match anything in tparam
570+
return t
571+
}
572+
// t is a more complex typeparam (e.g. P[T], as above, whose
573+
// definition is just T).
574+
assert(t.Sym() != nil)
557575
}
558576

559577
var newsym *types.Sym
@@ -591,6 +609,15 @@ func (subst *subster) typ(t *types.Type) *types.Type {
591609
var newt *types.Type
592610

593611
switch t.Kind() {
612+
case types.TTYPEPARAM:
613+
if t.Sym() == newsym {
614+
// The substitution did not change the type.
615+
return t
616+
}
617+
// Substitute the underlying typeparam (e.g. T in P[T], see
618+
// the example describing type P[T] above).
619+
newt = subst.typ(t.Underlying())
620+
assert(newt != t)
594621

595622
case types.TARRAY:
596623
elem := t.Elem()

src/cmd/compile/internal/noder/types.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,18 @@ func (g *irgen) typ0(typ types2.Type) *types.Type {
180180
return types.NewInterface(g.tpkg(typ), append(embeddeds, methods...))
181181

182182
case *types2.TypeParam:
183-
tp := types.NewTypeParam(g.tpkg(typ), g.typ(typ.Bound()))
183+
tp := types.NewTypeParam(g.tpkg(typ))
184184
// Save the name of the type parameter in the sym of the type.
185185
// Include the types2 subscript in the sym name
186186
sym := g.pkg(typ.Obj().Pkg()).Lookup(types2.TypeString(typ, func(*types2.Package) string { return "" }))
187187
tp.SetSym(sym)
188+
// Set g.typs[typ] in case the bound methods reference typ.
189+
g.typs[typ] = tp
190+
191+
// TODO(danscales): we don't currently need to use the bounds
192+
// anywhere, so eventually we can probably remove.
193+
bound := g.typ(typ.Bound())
194+
*tp.Methods() = *bound.Methods()
188195
return tp
189196

190197
case *types2.Tuple:

src/cmd/compile/internal/types/fmt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func tconv2(b *bytes.Buffer, t *Type, verb rune, mode fmtMode, visited map[*Type
318318
}
319319

320320
// Unless the 'L' flag was specified, if the type has a name, just print that name.
321-
if verb != 'L' && t.Sym() != nil && t != Types[t.Kind()] && t.Kind() != TTYPEPARAM {
321+
if verb != 'L' && t.Sym() != nil && t != Types[t.Kind()] {
322322
switch mode {
323323
case fmtTypeID, fmtTypeIDName:
324324
if verb == 'S' {

src/cmd/compile/internal/types/type.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,12 +1742,9 @@ func NewInterface(pkg *Pkg, methods []*Field) *Type {
17421742
return t
17431743
}
17441744

1745-
// NewTypeParam returns a new type param with the given constraint (which may
1746-
// not really be needed except for the type checker).
1747-
func NewTypeParam(pkg *Pkg, constraint *Type) *Type {
1745+
// NewTypeParam returns a new type param.
1746+
func NewTypeParam(pkg *Pkg) *Type {
17481747
t := New(TTYPEPARAM)
1749-
constraint.wantEtype(TINTER)
1750-
t.methods = constraint.methods
17511748
t.Extra.(*Interface).pkg = pkg
17521749
t.SetHasTParam(true)
17531750
return t

test/typeparam/absdiff.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// run -gcflags=-G=3
2+
3+
// Copyright 2020 The Go Authors. All rights reserved.
4+
// Use of this source code is governed by a BSD-style
5+
// license that can be found in the LICENSE file.
6+
7+
package main
8+
9+
import (
10+
"fmt"
11+
//"math"
12+
)
13+
14+
type Numeric interface {
15+
type int, int8, int16, int32, int64,
16+
uint, uint8, uint16, uint32, uint64, uintptr,
17+
float32, float64,
18+
complex64, complex128
19+
}
20+
21+
// numericAbs matches numeric types with an Abs method.
22+
type numericAbs[T any] interface {
23+
Numeric
24+
Abs() T
25+
}
26+
27+
// AbsDifference computes the absolute value of the difference of
28+
// a and b, where the absolute value is determined by the Abs method.
29+
func absDifference[T numericAbs[T]](a, b T) T {
30+
d := a - b
31+
return d.Abs()
32+
}
33+
34+
// orderedNumeric matches numeric types that support the < operator.
35+
type orderedNumeric interface {
36+
type int, int8, int16, int32, int64,
37+
uint, uint8, uint16, uint32, uint64, uintptr,
38+
float32, float64
39+
}
40+
41+
// Complex matches the two complex types, which do not have a < operator.
42+
type Complex interface {
43+
type complex64, complex128
44+
}
45+
46+
// orderedAbs is a helper type that defines an Abs method for
47+
// ordered numeric types.
48+
type orderedAbs[T orderedNumeric] T
49+
50+
func (a orderedAbs[T]) Abs() orderedAbs[T] {
51+
// TODO(danscales): orderedAbs[T] conversion shouldn't be needed
52+
if a < orderedAbs[T](0) {
53+
return -a
54+
}
55+
return a
56+
}
57+
58+
// complexAbs is a helper type that defines an Abs method for
59+
// complex types.
60+
// type complexAbs[T Complex] T
61+
62+
// func (a complexAbs[T]) Abs() complexAbs[T] {
63+
// r := float64(real(a))
64+
// i := float64(imag(a))
65+
// d := math.Sqrt(r * r + i * i)
66+
// return complexAbs[T](complex(d, 0))
67+
// }
68+
69+
// OrderedAbsDifference returns the absolute value of the difference
70+
// between a and b, where a and b are of an ordered type.
71+
func orderedAbsDifference[T orderedNumeric](a, b T) T {
72+
return T(absDifference(orderedAbs[T](a), orderedAbs[T](b)))
73+
}
74+
75+
// ComplexAbsDifference returns the absolute value of the difference
76+
// between a and b, where a and b are of a complex type.
77+
// func complexAbsDifference[T Complex](a, b T) T {
78+
// return T(absDifference(complexAbs[T](a), complexAbs[T](b)))
79+
// }
80+
81+
func main() {
82+
if got, want := orderedAbsDifference(1.0, -2.0), 3.0; got != want {
83+
panic(fmt.Sprintf("got = %v, want = %v", got, want))
84+
}
85+
if got, want := orderedAbsDifference(-1.0, 2.0), 3.0; got != want {
86+
panic(fmt.Sprintf("got = %v, want = %v", got, want))
87+
}
88+
if got, want := orderedAbsDifference(-20, 15), 35; got != want {
89+
panic(fmt.Sprintf("got = %v, want = %v", got, want))
90+
}
91+
92+
// Still have to handle built-ins real/abs to make this work
93+
// if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5; got != want {
94+
// panic(fmt.Sprintf("got = %v, want = %v", got, want)
95+
// }
96+
// if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5; got != want {
97+
// panic(fmt.Sprintf("got = %v, want = %v", got, want)
98+
// }
99+
}

0 commit comments

Comments
 (0)