Skip to content

Commit 2845895

Browse files
committed
handle structural types/type members
1 parent b6b409c commit 2845895

7 files changed

+480
-40
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+92-39
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
176176
* code would have two extra parameters for each of the many calls that go from
177177
* one sub-part of isSubType to another.
178178
*/
179-
protected def recur(tp1: Type, tp2: Type): Boolean = trace(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping) {
179+
protected def recur(tp1: Type, tp2: Type): Boolean =
180+
// trace.force(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping)
181+
{
180182

181183
def monitoredIsSubType = {
182184
if (pendingSubTypes == null) {
@@ -2143,40 +2145,31 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
21432145
import config.Printers.debug
21442146
import typer.Inferencing._
21452147

2146-
def incompatibleClasses: Boolean = {
2148+
def compatibleClasses: Boolean = {
21472149
import Flags._
21482150
val tpClassSym = tp.widenSingleton.classSymbol
21492151
val ptClassSym = pt.widenSingleton.classSymbol
21502152
debug.println(i"tpClassSym=$tpClassSym, fin=${tpClassSym.is(Final)}")
21512153
debug.println(i"pt=$pt {${pt.getClass}}, ptClassSym=$ptClassSym, fin=${ptClassSym.is(Final)}")
2152-
tpClassSym.exists && ptClassSym.exists && {
2153-
if (tpClassSym.is(Final)) !tpClassSym.derivesFrom(ptClassSym)
2154-
else if (ptClassSym.is(Final)) !ptClassSym.derivesFrom(tpClassSym)
2154+
!tpClassSym.exists || !ptClassSym.exists || {
2155+
if (tpClassSym.is(Final)) tpClassSym.derivesFrom(ptClassSym)
2156+
else if (ptClassSym.is(Final)) ptClassSym.derivesFrom(tpClassSym)
21552157
else if (!tpClassSym.is(Flags.Trait) && !ptClassSym.is(Flags.Trait))
2156-
!(tpClassSym.derivesFrom(ptClassSym) || ptClassSym.derivesFrom(tpClassSym))
2157-
else false
2158+
tpClassSym.derivesFrom(ptClassSym) || ptClassSym.derivesFrom(tpClassSym)
2159+
else true
21582160
}
21592161
}
21602162

21612163
def loop(tp: Type): Boolean =
21622164
// trace.force(i"loop($tp) // ${tp.toString}")
21632165
{
2164-
if (constrainPatternType(pt, tp)) true
2165-
else if (incompatibleClasses) {
2166-
// println("incompatible classes")
2167-
false
2168-
}
2169-
else tp match {
2170-
case _: ConstantType =>
2171-
// constants cannot possibly intersect with types that aren't their supertypes
2172-
false
2173-
case tp: SingletonType => loop(tp.underlying)
2174-
case tp: TypeRef if tp.symbol.isClass => loop(tp.firstParent)
2166+
val res: Type = tp match {
2167+
case tp: TypeRef if tp.symbol.isClass => tp.firstParent
21752168
case tp @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
21762169
val ptClassSym = pt.classSymbol
21772170
def firstParentSharedWithPt(tp: Type, tpClassSym: ClassSymbol): Symbol =
2178-
// trace.force(i"f($tp)")
2179-
{
2171+
// trace.force(i"f($tp)")
2172+
{
21802173
var parents = tpClassSym.info.parents
21812174
// println(i"parents of $tpClassSym = $parents%, %")
21822175
parents match {
@@ -2195,29 +2188,89 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
21952188
}
21962189
val sym = firstParentSharedWithPt(tycon, tycon.symbol.asClass)
21972190
// println(i"sym=$sym ; tyconsym=${tycon.symbol}")
2198-
if (!sym.exists) true
2199-
else !(sym == tycon.symbol) && loop(tp.baseType(sym))
2191+
if (!sym.exists) return true
2192+
// else !(sym == tycon.symbol) &&
2193+
tp.baseType(sym)
22002194
case tp: TypeProxy =>
2201-
loop(tp.superType)
2202-
case _ => false
2195+
tp.superType
2196+
case _ => return true
22032197
}
2198+
constrainPatternType(pt, res) || loop(res)
22042199
}
22052200

2206-
pt match {
2207-
case AndType(pt1, pt2) =>
2208-
notIntersection(tp, pt1) && notIntersection(tp, pt2)
2209-
case OrType(pt1, pt2) =>
2210-
either(notIntersection(tp, pt1), notIntersection(tp, pt2))
2211-
case _ =>
2212-
tp match {
2213-
case OrType(tp1, tp2) =>
2214-
either(notIntersection(tp1, pt), notIntersection(tp2, pt))
2215-
case AndType(tp1, tp2) =>
2216-
notIntersection(tp1, pt) && notIntersection(tp2, pt)
2217-
case _ =>
2218-
loop(tp)
2219-
}
2220-
}
2201+
tp.dealias match {
2202+
case OrType(tp1, tp2) =>
2203+
either(notIntersection(tp1, pt), notIntersection(tp2, pt))
2204+
case AndType(tp1, tp2) =>
2205+
notIntersection(tp1, pt) && notIntersection(tp2, pt)
2206+
case tp: RefinedOrRecType =>
2207+
def keepInvariantRefinements(tp: Type): Type = tp match {
2208+
case tp: RefinedType =>
2209+
if (tp.refinedName.isTermName) keepInvariantRefinements(tp.parent)
2210+
else {
2211+
// def resolve(tp: Type): Type = tp match {
2212+
// case TypeAlias(tp) => resolve(tp.dealias)
2213+
// case tp => tp
2214+
// }
2215+
// val tpInfo = tp.refinedInfo
2216+
// val tpInfoDealiased = resolve(tpInfo)
2217+
// val ptInfo = pt.member(tp.refinedName).info
2218+
// val ptInfoDealiased = resolve(ptInfo)
2219+
// println(
2220+
// i"""tpInfo = ${tpInfo}
2221+
// |tpInfoDealiased = ${tpInfoDealiased}
2222+
// |ptInfo = ${ptInfo}
2223+
// |ptInfoDealiased = ${ptInfoDealiased}""".stripMargin
2224+
// )
2225+
// println(i"visiting refinement: ${tp.refinedName} : ${tp.refinedInfo}")
2226+
tp.refinedInfo match {
2227+
case TypeAlias(tp1) =>
2228+
val pt1 = pt.member(tp.refinedName).info
2229+
if (pt1.exists && pt1.bounds.contains(tp1) || !pt1.exists)
2230+
keepInvariantRefinements(tp.parent)
2231+
else
2232+
NoType
2233+
case tpb: TypeBounds =>
2234+
pt.member(tp.refinedName).info match {
2235+
case TypeAlias(pt1) =>
2236+
if (tpb.contains(pt1))
2237+
keepInvariantRefinements(tp.parent)
2238+
else
2239+
NoType
2240+
case _ =>
2241+
keepInvariantRefinements(tp.parent)
2242+
}
2243+
}
2244+
}
2245+
case tp: RecType =>
2246+
keepInvariantRefinements(tp.parent)
2247+
case _ =>
2248+
tp
2249+
}
2250+
val tp1 = keepInvariantRefinements(tp)
2251+
if (!tp1.exists) {
2252+
// println(i"noType for $tp")
2253+
false
2254+
} else
2255+
notIntersection(tp1, pt)
2256+
case tp =>
2257+
pt.dealias match {
2258+
case AndType(pt1, pt2) =>
2259+
notIntersection(tp, pt1) && notIntersection(tp, pt2)
2260+
case OrType(pt1, pt2) =>
2261+
either(notIntersection(tp, pt1), notIntersection(tp, pt2))
2262+
case pt: RefinedOrRecType =>
2263+
// note: at this point, we have already extracted the information we wanted from the refinement
2264+
// and it would only interfere in the following subtype check in constrainPatternType
2265+
def stripRefinement(tp: Type): Type = tp match {
2266+
case tp: RefinedOrRecType => stripRefinement(tp.parent)
2267+
case tp => tp
2268+
}
2269+
notIntersection(tp, stripRefinement(pt))
2270+
case pt =>
2271+
constrainPatternType(pt, tp) || compatibleClasses && loop(tp)
2272+
}
2273+
}
22212274
}
22222275
}
22232276

compiler/src/dotty/tools/dotc/typer/Applications.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,11 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
11081108
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
11091109
selType
11101110
} else {
1111-
isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))
1111+
val res = isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))
1112+
if (!res) ctx.warning(
1113+
ex"Pattern type $unapplyArgType does not intersect selector type $selType",
1114+
tree.sourcePos
1115+
)
11121116
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
11131117
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
11141118
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")

tests/neg/structural-gadt.scala

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
object Test {
2+
trait Expr { type T }
3+
trait IntLit extends Expr { type T <: Int }
4+
trait IntExpr extends Expr { type T = Int }
5+
6+
def foo[A](e: Expr { type T = A }) = e match {
7+
case _: IntLit =>
8+
val a: A = 0 // error
9+
val i: Int = ??? : A
10+
11+
case _: Expr { type T <: Int } =>
12+
val a: A = 0 // error
13+
val i: Int = ??? : A
14+
15+
case _: IntExpr =>
16+
val a: A = 0
17+
val i: Int = ??? : A
18+
19+
case _: Expr { type T = Int } =>
20+
val a: A = 0
21+
val i: Int = ??? : A
22+
}
23+
24+
def bar[A](e: Expr { type T <: A }) = e match {
25+
case _: IntLit =>
26+
val a: A = 0 // error
27+
val i: Int = ??? : A // error
28+
29+
case _: Expr { type T <: Int } =>
30+
val a: A = 0 // error
31+
val i: Int = ??? : A // error
32+
33+
case _: IntExpr =>
34+
val a: A = 0
35+
val i: Int = ??? : A // error
36+
37+
case _: Expr { type T = Int } =>
38+
val a: A = 0
39+
val i: Int = ??? : A // error
40+
}
41+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
object Test {
2+
// Some error comments in this file are preceded by // ?
3+
// This indicates that we should actually accept that line,
4+
// but we don't due to limitation of the implementation
5+
//
6+
//
7+
8+
trait Expr { type T }
9+
trait IntLit extends Expr { type T <: Int }
10+
trait IntExpr extends Expr { type T = Int }
11+
12+
type ExprSub[+A] = Expr { type T <: A }
13+
type ExprExact[A] = Expr { type T = A }
14+
15+
trait IndirectIntLit extends Expr { type S <: Int; type T = S }
16+
trait IndirectIntExpr extends Expr { type S = Int; type T = S }
17+
18+
type IndirectExprSub[+A] = Expr { type S <: A; type T = S }
19+
type IndirectExprSub2[A] = Expr { type S = A; type T <: S }
20+
type IndirectExprExact[A] = Expr { type S = A; type T = S }
21+
22+
trait AltIndirectIntLit extends Expr { type U <: Int; type T = U }
23+
trait AltIndirectIntExpr extends Expr { type U = Int; type T = U }
24+
25+
type AltIndirectExprSub[+A] = Expr { type U <: A; type T = U }
26+
type AltIndirectExprSub2[A] = Expr { type U = A; type T <: U }
27+
type AltIndirectExprExact[A] = Expr { type U = A; type T = U }
28+
29+
def foo[A](e: IndirectExprExact[A]) = e match {
30+
case _: AltIndirectIntLit =>
31+
val a: A = 0 // error
32+
val i: Int = ??? : A
33+
34+
case _: AltIndirectExprSub[Int] =>
35+
val a: A = 0 // error
36+
val i: Int = ??? : A
37+
38+
case _: AltIndirectExprSub2[Int] =>
39+
val a: A = 0 // error
40+
val i: Int = ??? : A
41+
42+
case _: AltIndirectIntExpr =>
43+
val a: A = 0
44+
val i: Int = ??? : A
45+
46+
case _: AltIndirectExprExact[Int] =>
47+
val a: A = 0
48+
val i: Int = ??? : A
49+
}
50+
51+
def bar[A](e: IndirectExprSub[A]) = e match {
52+
case _: AltIndirectIntLit =>
53+
val a: A = 0 // error
54+
val i: Int = ??? : A // error
55+
56+
case _: AltIndirectExprSub[Int] =>
57+
val a: A = 0 // error
58+
val i: Int = ??? : A // error
59+
60+
case _: AltIndirectExprSub2[Int] =>
61+
val a: A = 0 // error
62+
val i: Int = ??? : A // error
63+
64+
case _: AltIndirectIntExpr =>
65+
val a: A = 0 // ? // error
66+
val i: Int = ??? : A // error
67+
68+
case _: AltIndirectExprExact[Int] =>
69+
val a: A = 0 // ? // error
70+
val i: Int = ??? : A // error
71+
}
72+
73+
def baz[A](e: IndirectExprSub2[A]) = e match {
74+
case _: AltIndirectIntLit =>
75+
val a: A = 0 // error
76+
val i: Int = ??? : A // error
77+
78+
case _: AltIndirectExprSub[Int] =>
79+
val a: A = 0 // error
80+
val i: Int = ??? : A // error
81+
82+
case _: AltIndirectExprSub2[Int] =>
83+
val a: A = 0 // error
84+
val i: Int = ??? : A // error
85+
86+
case _: AltIndirectIntExpr =>
87+
val a: A = 0
88+
val i: Int = ??? : A // error
89+
90+
case _: AltIndirectExprExact[Int] =>
91+
val a: A = 0
92+
val i: Int = ??? : A // error
93+
}
94+
}

0 commit comments

Comments
 (0)