Skip to content

Commit 161693e

Browse files
author
Aleksander Boruch-Gruszecki
committed
Substitute TypeParams on RHS with TypeVars
Turns out this is quite important for correctness. Tests now pass with -Y-no-deep-subtypes, so GadtTests was simplified accordingly. GADTMap.setBound accepts a single type bound instead of TypeBounds to simplify the implementation. NatsVects tests was rewritten, as class parameters are known to not work with GADTs.
1 parent d65148b commit 161693e

File tree

11 files changed

+265
-152
lines changed

11 files changed

+265
-152
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ trait ConstraintHandling {
2525
protected def isSubType(tp1: Type, tp2: Type): Boolean
2626
protected def isSameType(tp1: Type, tp2: Type): Boolean
2727

28-
// val state: TyperState
2928
protected def constraint: Constraint
3029
protected def constraint_=(c: Constraint): Unit
3130

compiler/src/dotty/tools/dotc/core/Contexts.scala

+89-26
Original file line numberDiff line numberDiff line change
@@ -709,20 +709,19 @@ object Contexts {
709709
}
710710

711711
sealed abstract class GADTMap {
712-
def setBounds(sym: Symbol, b: TypeBounds)(implicit ctx: Context): Unit
712+
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
713+
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
713714
def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds
714715
def contains(sym: Symbol)(implicit ctx: Context): Boolean
715716
def derived: GADTMap
716717
}
717718

718-
class SmartGADTMap(
719+
final class SmartGADTMap(
719720
private[this] var myConstraint: Constraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty),
720-
private[this] var mapping: SimpleIdentityMap[Symbol, TypeVar] = SimpleIdentityMap.Empty
721+
private[this] var mapping: SimpleIdentityMap[Symbol, TypeVar] = SimpleIdentityMap.Empty,
722+
private[this] var reverseMapping: SimpleIdentityMap[TypeVar, Symbol] = SimpleIdentityMap.Empty
721723
) extends GADTMap with ConstraintHandling {
722-
def log(str: String): Unit = {
723-
import dotty.tools.dotc.config.Printers.gadts
724-
gadts.println(s"GADTMap: $str")
725-
}
724+
import dotty.tools.dotc.config.Printers.gadts
726725

727726
// TODO: dirty kludge - should this class be an inner class of TyperState instead?
728727
private[this] var myCtx: Context = null
@@ -739,56 +738,120 @@ object Contexts {
739738
override def isSubType(tp1: Type, tp2: Type): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
740739
override def isSameType(tp1: Type, tp2: Type): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
741740

742-
private[this] def tvar(sym: Symbol)(implicit ctx: Context) = inCtx(ctx) {
741+
private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = inCtx(ctx) {
743742
val res = mapping(sym) match {
744743
case tv: TypeVar => tv
745744
case null =>
746-
log(i"creating tvar for: $sym")
747745
val res = {
748746
import NameKinds.DepParamName
749-
// do not use newTypeVar:
750-
// it registers the TypeVar with TyperState, we don't want that since it instantiates them (TODO: when?)
751-
// (see pos/i3500.scala)
752-
// it registers the TypeVar with TyperState Constraint, which we don't care for but it's needless
753-
val poly = PolyType(DepParamName.fresh().toTypeName :: Nil)(
747+
// avoid registering the TypeVar with TyperState / TyperState#constraint
748+
// TyperState TypeVars get instantiated when we don't want them to (see pos/i3500.scala)
749+
// TyperState#constraint TypeVars can be narrowed in subtype checks - don't want that either
750+
val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)(
754751
pt => TypeBounds.empty :: Nil,
755752
pt => defn.AnyType)
756-
// null out creatorState, we don't need it anyway (and TypeVar can null it too)
753+
// null out creatorState as a precaution
757754
new TypeVar(poly.paramRefs.head, creatorState = null)
758755
}
756+
gadts.println(i"GADTMap: created tvar $sym -> $res")
759757
constraint = constraint.add(res.origin.binder, res :: Nil)
760758
mapping = mapping.updated(sym, res)
759+
reverseMapping = reverseMapping.updated(res, sym)
761760
res
762761
}
763-
log(i"tvar: $sym -> $res")
764762
res
765763
}
766764

767-
override def setBounds(sym: Symbol, b: TypeBounds)(implicit ctx: Context): Unit = inCtx(ctx) {
768-
val tv = tvar(sym)
769-
log(i"setBounds `$sym` `$tv`: `$b`")
770-
addUpperBound(tv.origin, b.hi)
771-
addLowerBound(tv.origin, b.lo)
765+
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)
766+
767+
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = inCtx(ctx) {
768+
def isEmptyBounds(tp: Type) = tp match {
769+
case TypeBounds(lo, hi) => (lo eq defn.NothingType) && (hi eq defn.AnyType)
770+
case _ => false
771+
}
772+
773+
val symTvar = tvar(sym)
774+
775+
def doAddOrdering(bound: TypeParamRef) =
776+
if (isUpper) addLess(symTvar.origin, bound) else addLess(bound, symTvar.origin)
777+
778+
def doAddBound(bound: Type) =
779+
if (isUpper) addUpperBound(symTvar.origin, bound) else addLowerBound(symTvar.origin, bound)
780+
781+
val tvarBound = (new TypeVarInsertingMap)(bound)
782+
val res = tvarBound match {
783+
case boundTvar: TypeVar =>
784+
if (boundTvar eq symTvar) true else doAddOrdering(boundTvar.origin)
785+
// hack to normalize T and T[_]
786+
case AppliedType(boundTvar: TypeVar, args) if args forall isEmptyBounds =>
787+
doAddOrdering(boundTvar.origin)
788+
case tp => doAddBound(tp)
789+
}
790+
791+
gadts.println {
792+
val descr = if (isUpper) "upper" else "lower"
793+
val op = if (isUpper) "<:" else ">:"
794+
i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $tvarBound )"
795+
}
796+
res
772797
}
773798

774799
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = inCtx(ctx) {
775800
mapping(sym) match {
776801
case null => null
777-
case tv => constraint.fullBounds(tv.origin)
802+
case tv =>
803+
val tb = constraint.fullBounds(tv.origin)
804+
val res = {
805+
val tm = new TypeVarRemovingMap
806+
tb.derivedTypeBounds(tm(tb.lo), tm(tb.hi))
807+
}
808+
gadts.println(i"gadt bounds $sym: $res\t( $tv: $tb )")
809+
res
778810
}
779811
}
780812

781-
override def contains(sym: Symbol)(implicit ctx: Context) = mapping(sym) ne null
813+
override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null
782814

783815
override def derived: GADTMap = new SmartGADTMap(
784816
this.myConstraint,
785-
this.mapping
817+
this.mapping,
818+
this.reverseMapping
786819
)
820+
821+
private final class TypeVarInsertingMap extends TypeMap {
822+
override def apply(tp: Type): Type = tp match {
823+
case tp: TypeRef =>
824+
val sym = tp.typeSymbol
825+
if (contains(sym)) tvar(sym) else tp
826+
case _ =>
827+
mapOver(tp)
828+
}
829+
}
830+
831+
private final class TypeVarRemovingMap extends TypeMap {
832+
override def apply(tp: Type): Type = tp match {
833+
case tpr: TypeParamRef =>
834+
constraint.typeVarOfParam(tpr) match {
835+
case tv: TypeVar =>
836+
reverseMapping(tv).typeRef
837+
case unexpected =>
838+
// if we didn't get a TypeVar, it's likely to cause problems
839+
gadts.println(i"GADTMap: unexpected typeVarOfParam($tpr) = `$unexpected` ${unexpected.getClass}")
840+
tpr
841+
}
842+
case tv: TypeVar =>
843+
if (reverseMapping.contains(tv)) reverseMapping(tv).typeRef
844+
else tv
845+
case _ =>
846+
mapOver(tp)
847+
}
848+
}
787849
}
788850

789851
@sharable object EmptyGADTMap extends GADTMap {
790-
override def setBounds(sym: Symbol, b: TypeBounds)(implicit ctx: Context) = unsupported("EmptyGADTMap.setBounds")
791-
override def bounds(sym: Symbol)(implicit ctx: Context) = null
852+
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds")
853+
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound")
854+
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
792855
override def contains(sym: Symbol)(implicit ctx: Context) = false
793856
override def derived = new SmartGADTMap
794857
}

compiler/src/dotty/tools/dotc/core/Symbols.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,11 @@ trait Symbols { this: Context =>
223223
*/
224224
def newPatternBoundSymbol(name: Name, info: Type, pos: Position): Symbol = {
225225
val sym = newSymbol(owner, name, Case, info, coord = pos)
226-
if (name.isTypeName) gadt.setBounds(sym, info.bounds)
226+
if (name.isTypeName) {
227+
val bounds = info.bounds
228+
gadt.addBound(sym, bounds.lo, isUpper = false)
229+
gadt.addBound(sym, bounds.hi, isUpper = true)
230+
}
227231
sym
228232
}
229233

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+13-5
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
106106
true
107107
}
108108

109-
protected def gadtBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = ctx.gadt.bounds(sym)
110-
protected def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = ctx.gadt.setBounds(sym, b)
109+
protected def gadtBounds(sym: Symbol)(implicit ctx: Context) = ctx.gadt.bounds(sym)
110+
protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false)
111+
protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true)
111112

112113
protected def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = tvar.underlying
113114

@@ -1217,8 +1218,10 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
12171218
val newBounds =
12181219
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
12191220
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
1221+
// gadtMap can check its own satisfiability, but the subtype check is still necessary
1222+
// see tests/patmat/gadt-nontrivial2.scala
12201223
isSubType(newBounds.lo, newBounds.hi) &&
1221-
{ gadtSetBounds(tparam, newBounds); true }
1224+
(if (isUpper) gadtAddUpperBound(tparam, bound) else gadtAddLowerBound(tparam, bound))
12221225
}
12231226
}
12241227
}
@@ -1826,9 +1829,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
18261829
super.gadtBounds(sym)
18271830
}
18281831

1829-
override def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = {
1832+
override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = {
18301833
footprint += sym.typeRef
1831-
super.gadtSetBounds(sym, b)
1834+
super.gadtAddLowerBound(sym, b)
1835+
}
1836+
1837+
override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = {
1838+
footprint += sym.typeRef
1839+
super.gadtAddUpperBound(sym, b)
18321840
}
18331841

18341842
override def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = {

compiler/src/dotty/tools/dotc/typer/Typer.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,7 @@ class Typer extends Namer
10071007
def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = {
10081008
val gadtCtx = ctx.fresh.setFreshGADTBounds
10091009
for (sym <- gadtSyms)
1010-
if (!gadtCtx.gadt.contains(sym))
1011-
gadtCtx.gadt.setBounds(sym, TypeBounds.empty)
1010+
if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym)
10121011
gadtCtx
10131012
}
10141013

@@ -1477,8 +1476,11 @@ class Typer extends Namer
14771476
// that their type parameters are aliases of the class type parameters.
14781477
// See pos/i941.scala
14791478
rhsCtx = ctx.fresh.setFreshGADTBounds
1480-
(tparams1, sym.owner.typeParams).zipped.foreach ((tdef, tparam) =>
1481-
rhsCtx.gadt.setBounds(tdef.symbol, TypeAlias(tparam.typeRef)))
1479+
(tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) =>
1480+
val tr = tparam.typeRef
1481+
rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false)
1482+
rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true)
1483+
}
14821484
}
14831485
if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
14841486
val rhs1 = typedExpr(ddef.rhs, tpt1.tpe)(rhsCtx)

compiler/test/dotty/tools/dotc/GadtTests.scala

+1-25
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,9 @@ class GadtTests extends ParallelTesting {
2020
def isInteractive = SummaryReport.isInteractive
2121
def testFilter = Properties.testsFilter
2222

23-
24-
// @Test def posTestFromTasty: Unit = {
25-
// // Can be reproduced with
26-
// // > sbt
27-
// // > dotc -Ythrough-tasty -Ycheck:all <source>
28-
29-
// implicit val testGroup: TestGroup = TestGroup("posTestFromTasty")
30-
// compileTastyInDir("tests/pos", defaultOptions,
31-
// fromTastyFilter = FileFilter.exclude(TestSources.posFromTastyBlacklisted),
32-
// decompilationFilter = FileFilter.exclude(TestSources.posDecompilationBlacklisted),
33-
// recompilationFilter = FileFilter.include(TestSources.posRecompilationWhitelist)
34-
// ).checkCompile()
35-
// }
36-
3723
@Test def compileGadtTests: Unit = {
3824
implicit val testGroup: TestGroup = TestGroup("compileGadtTests")
39-
compileFilesInDir("tests/gadt+noCheckOptions", TestFlags(basicClasspath, noCheckOptions)).checkCompile()
40-
}
41-
42-
@Test def compileGadtCheckOptionsTests: Unit = {
43-
implicit val testGroup: TestGroup = TestGroup("compileGadtCheckOptionsTests")
44-
compileFilesInDir("tests/gadt+checkOptions", TestFlags(basicClasspath, noCheckOptions ++ checkOptions)).checkCompile()
45-
}
46-
47-
@Test def compileGadtDefaultOptionsTests: Unit = {
48-
implicit val testGroup: TestGroup = TestGroup("compileGadtDefaultOptionsTests")
49-
compileFilesInDir("tests/gadt+defaultOptions", defaultOptions).checkCompile()
25+
compileFilesInDir("tests/gadt", defaultOptions).checkCompile()
5026
}
5127
}
5228

tests/gadt/NatsVects.ignore

-90
This file was deleted.

0 commit comments

Comments
 (0)