From b6141884f3a78815827779ab2c477c70c27041c4 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 7 Feb 2019 11:01:08 +0100 Subject: [PATCH 01/20] Move Constraint#fullBounds to ConstraintHandler --- .../dotty/tools/dotc/core/Constraint.scala | 15 +++----- .../tools/dotc/core/ConstraintHandling.scala | 33 ++++++++++++++++- .../src/dotty/tools/dotc/core/Contexts.scala | 34 ++++++++++++++--- .../tools/dotc/core/OrderingConstraint.scala | 9 ----- .../dotty/tools/dotc/core/TypeComparer.scala | 12 ++++++ .../src/dotty/tools/dotc/core/Types.scala | 4 +- .../tools/dotc/printing/Formatting.scala | 4 +- .../tools/dotc/printing/PlainPrinter.scala | 5 ++- .../tools/dotc/typer/ErrorReporting.scala | 4 +- .../dotty/tools/dotc/typer/Implicits.scala | 34 ++++++++++------- .../dotty/tools/dotc/typer/Inferencing.scala | 4 +- .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/pos/gadt-accumulatable.scala | 37 +++++++++++++++++++ 13 files changed, 148 insertions(+), 49 deletions(-) create mode 100644 tests/pos/gadt-accumulatable.scala diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index 91bedf35948b..89a22d0dd0a5 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -45,6 +45,12 @@ abstract class Constraint extends Showable { /** The parameters that are known to be greater wrt <: than `param` */ def upper(param: TypeParamRef): List[TypeParamRef] + /** `lower`, except that `minLower.forall(tpr => !minLower.exists(_ <:< tpr))` */ + def minLower(param: TypeParamRef): List[TypeParamRef] + + /** `upper`, except that `minUpper.forall(tpr => !minUpper.exists(tpr <:< _))` */ + def minUpper(param: TypeParamRef): List[TypeParamRef] + /** lower(param) \ lower(butNot) */ def exclusiveLower(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef] @@ -58,15 +64,6 @@ abstract class Constraint extends Showable { */ def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds - /** The lower bound of `param` including all known-to-be-smaller parameters */ - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type - - /** The upper bound of `param` including all known-to-be-greater parameters */ - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type - - /** The bounds of `param` including all known-to-be-smaller and -greater parameters */ - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds - /** A new constraint which is derived from this constraint by adding * entries for all type parameters of `poly`. * @param tvars A list of type variables associated with the params, diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 0560866a3e6e..fbb673a888c2 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -2,10 +2,13 @@ package dotty.tools package dotc package core -import Types._, Contexts._, Symbols._ +import Types._ +import Contexts._ +import Symbols._ import Decorators._ import config.Config import config.Printers.{constr, typr} +import dotty.tools.dotc.reporting.trace /** Methods for adding constraints and solving them. * @@ -31,6 +34,8 @@ trait ConstraintHandling[AbstractContext] { protected def constraint: Constraint protected def constraint_=(c: Constraint): Unit + protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type + private[this] var addConstraintInvocations = 0 /** If the constraint is frozen we cannot add new bounds to the constraint. */ @@ -66,6 +71,30 @@ trait ConstraintHandling[AbstractContext] { case tp => tp } + def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = + constraint.nonParamBounds(param) match { + case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr)) + case tb => tb + } + + def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).lo /: constraint.minLower(param)) { + (t, u) => t | externalize(u) + } + + def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param)) { + (t, u) => t & externalize(u) + } + + /** Full bounds of `param`, including other lower/upper params. + * + * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` + * of some param when comparing types might lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = + nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) + protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean = !constraint.contains(param) || { def occursIn(bound: Type): Boolean = { @@ -262,7 +291,7 @@ trait ConstraintHandling[AbstractContext] { } constraint.entry(param) match { case _: TypeBounds => - val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param) + val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param) val inst = avoidParam(bound) typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}") inst diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 6e91741ae444..81bc5904a483 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -778,7 +778,15 @@ object Contexts { sealed abstract class GADTMap { def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean + def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. + * + * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` + * of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds def contains(sym: Symbol)(implicit ctx: Context): Boolean def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type def debugBoundsDescription(implicit ctx: Context): String @@ -807,6 +815,12 @@ object Contexts { override protected def constraint = myConstraint override protected def constraint_=(c: Constraint) = myConstraint = c + override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = + reverseMapping(param) match { + case sym: Symbol => sym.typeRef + case null => param + } + override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) @@ -866,12 +880,21 @@ object Contexts { }, gadts) } finally boundAdditionInProgress = false + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = + constraint.isLess(tvar(sym1).origin, tvar(sym2).origin) + + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = + mapping(sym) match { + case null => null + case tv => removeTypeVars(fullBounds(tv.origin)).asInstanceOf[TypeBounds] + } + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { mapping(sym) match { case null => null case tv => def retrieveBounds: TypeBounds = { - val tb = constraint.fullBounds(tv.origin) + val tb = bounds(tv.origin) removeTypeVars(tb).asInstanceOf[TypeBounds] } ( @@ -883,10 +906,7 @@ object Contexts { boundCache = boundCache.updated(sym, bounds) bounds } - ).reporting({ res => - // i"gadt bounds $sym: $res" - "" - }, gadts) + )// .reporting({ res => i"gadt bounds $sym: $res" }, gadts) } } @@ -984,7 +1004,7 @@ object Contexts { sb ++= constraint.show sb += '\n' mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${bounds(sym)}\n" + sb ++= i"$sym: ${fullBounds(sym)}\n" } sb.result } @@ -993,7 +1013,9 @@ object Contexts { @sharable object EmptyGADTMap extends GADTMap { override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds") override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound") + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.isLess") override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null override def contains(sym: Symbol)(implicit ctx: Context) = false override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation") override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 2f568dfe7750..869c8330a5a3 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -196,15 +196,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = entry(param).bounds - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).lo /: minLower(param))(_ | _) - - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).hi /: minUpper(param))(_ & _) - - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = - nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) - def typeVarOfParam(param: TypeParamRef): Type = { val entries = boundsMap(param.binder) if (entries == null) NoType diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 8a69184f0846..f4a85532c1b1 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -33,6 +33,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { def constraint: Constraint = state.constraint def constraint_=(c: Constraint): Unit = state.constraint = c + override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = param + private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null private[this] var recCount = 0 private[this] var monitored = false @@ -442,6 +444,16 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { val gbounds2 = gadtBounds(tp2.symbol) (gbounds2 != null) && (isSubTypeWhenFrozen(tp1, gbounds2.lo) || + (tp1 match { + case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false + }) || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) && GADTusage(tp2.symbol) } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 49908bacdcd6..75fb6629bf39 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -3863,10 +3863,10 @@ object Types { def contextInfo(tp: Type): Type = tp match { case tp: TypeParamRef => val constraint = ctx.typerState.constraint - if (constraint.entry(tp).exists) constraint.fullBounds(tp) + if (constraint.entry(tp).exists) ctx.typeComparer.fullBounds(tp) else NoType case tp: TypeRef => - val bounds = ctx.gadt.bounds(tp.symbol) + val bounds = ctx.gadt.fullBounds(tp.symbol) if (bounds == null) NoType else bounds case tp: TypeVar => tp.underlying diff --git a/compiler/src/dotty/tools/dotc/printing/Formatting.scala b/compiler/src/dotty/tools/dotc/printing/Formatting.scala index 6b6a6845565a..408d369d84a4 100644 --- a/compiler/src/dotty/tools/dotc/printing/Formatting.scala +++ b/compiler/src/dotty/tools/dotc/printing/Formatting.scala @@ -170,7 +170,7 @@ object Formatting { case sym: Symbol => val info = if (ctx.gadt.contains(sym)) - sym.info & ctx.gadt.bounds(sym) + sym.info & ctx.gadt.fullBounds(sym) else sym.info s"is a ${ctx.printer.kindString(sym)}${sym.showExtendedLocation}${addendum("bounds", info)}" @@ -190,7 +190,7 @@ object Formatting { case param: TermParamRef => false case skolem: SkolemType => true case sym: Symbol => - ctx.gadt.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty + ctx.gadt.contains(sym) && ctx.gadt.fullBounds(sym) != TypeBounds.empty case _ => assert(false, "unreachable") false diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 9efd80c1424c..d3c170d50243 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -208,7 +208,10 @@ class PlainPrinter(_ctx: Context) extends Printer { else { val constr = ctx.typerState.constraint val bounds = - if (constr.contains(tp)) constr.fullBounds(tp.origin)(ctx.addMode(Mode.Printing)) + if (constr.contains(tp)) { + val ctx0 = ctx.addMode(Mode.Printing) + ctx0.typeComparer.fullBounds(tp.origin)(ctx0) + } else TypeBounds.empty if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value) else if (ctx.settings.YshowVarBounds.value) "(" ~ toText(tp.origin) ~ "?" ~ toText(bounds) ~ ")" diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 03b67a7a91fe..bd07958cbb91 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -128,8 +128,8 @@ object ErrorReporting { case tp: TypeParamRef => constraint.entry(tp) match { case bounds: TypeBounds => - if (variance < 0) apply(constraint.fullUpperBound(tp)) - else if (variance > 0) apply(constraint.fullLowerBound(tp)) + if (variance < 0) apply(ctx.typeComparer.fullUpperBound(tp)) + else if (variance > 0) apply(ctx.typeComparer.fullLowerBound(tp)) else tp case NoType => tp case instType => apply(instType) diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 1fa53e79582a..52f14d3feb9c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -397,21 +397,29 @@ object Implicits { * what was expected */ override def clarify(tp: Type)(implicit ctx: Context): Type = { - val map = new TypeMap { - def apply(t: Type): Type = t match { - case t: TypeParamRef => - constraint.entry(t) match { - case NoType => t - case bounds: TypeBounds => constraint.fullBounds(t) - case t1 => t1 - } - case t: TypeVar => - t.instanceOpt.orElse(apply(t.origin)) - case _ => - mapOver(t) + val ctx0 = ctx + locally { + implicit val ctx = ctx0.fresh.setTyperState { + val ts = ctx0.typerState.fresh() + ts.constraint_=(constraint)(ctx0) + ts + } + val map = new TypeMap { + def apply(t: Type): Type = t match { + case t: TypeParamRef => + constraint.entry(t) match { + case NoType => t + case bounds: TypeBounds => ctx.typeComparer.fullBounds(t) + case t1 => t1 + } + case t: TypeVar => + t.instanceOpt.orElse(apply(t.origin)) + case _ => + mapOver(t) + } } + map(tp) } - map(tp) } def explanation(implicit ctx: Context): String = diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index eb0674802cf6..8441690dd741 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -263,7 +263,7 @@ object Inferencing { * 0 if unconstrained, or constraint is from below and above. */ private def instDirection(param: TypeParamRef)(implicit ctx: Context): Int = { - val constrained = ctx.typerState.constraint.fullBounds(param) + val constrained = ctx.typeComparer.fullBounds(param) val original = param.binder.paramInfos(param.paramNum) val cmp = ctx.typeComparer val approxBelow = @@ -298,7 +298,7 @@ object Inferencing { if (v == 1) tvar.instantiate(fromBelow = false) else if (v == -1) tvar.instantiate(fromBelow = true) else { - val bounds = ctx.typerState.constraint.fullBounds(tvar.origin) + val bounds = ctx.typeComparer.fullBounds(tvar.origin) if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x) tvar.instantiate(fromBelow = false) else { diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f07563ec49f3..196fdb088323 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1096,7 +1096,7 @@ class Typer extends Namer if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(sym) else ctx.error(new DuplicateBind(b, cdef), b.sourcePos) if (!ctx.isAfterTyper) { - val bounds = ctx.gadt.bounds(sym) + val bounds = ctx.gadt.fullBounds(sym) if (bounds != null) sym.info = bounds } b diff --git a/tests/pos/gadt-accumulatable.scala b/tests/pos/gadt-accumulatable.scala new file mode 100644 index 000000000000..ce4cf347538d --- /dev/null +++ b/tests/pos/gadt-accumulatable.scala @@ -0,0 +1,37 @@ +object `gadt-accumulatable` { + sealed abstract class Or[+G,+B] extends Product with Serializable + final case class Good[+G](g: G) extends Or[G,Nothing] + final case class Bad[+B](b: B) extends Or[Nothing,B] + + sealed trait Validation[+E] extends Product with Serializable + case object Pass extends Validation[Nothing] + case class Fail[E](error: E) extends Validation[E] + + sealed abstract class Every[+T] protected (underlying: Vector[T]) extends /*PartialFunction[Int, T] with*/ Product with Serializable + final case class One[+T](loneElement: T) extends Every[T](Vector(loneElement)) + final case class Many[+T](firstElement: T, secondElement: T, otherElements: T*) extends Every[T](firstElement +: secondElement +: Vector(otherElements: _*)) + + class Accumulatable[G, ERR, EVERY[_]] { } + + def convertOrToAccumulatable[G, ERR, EVERY[b] <: Every[b]](accumulatable: G Or EVERY[ERR]): Accumulatable[G, ERR, EVERY] = { + new Accumulatable[G, ERR, EVERY] { + def when[OTHERERR >: ERR](validations: (G => Validation[OTHERERR])*): G Or Every[OTHERERR] = { + accumulatable match { + case Good(g) => + val results = validations flatMap (_(g) match { case Fail(x) => val z: OTHERERR = x; Seq(x); case Pass => Seq.empty}) + results.length match { + case 0 => Good(g) + case 1 => Bad(One(results.head)) + case _ => + val first = results.head + val tail = results.tail + val second = tail.head + val rest = tail.tail + Bad(Many(first, second, rest: _*)) + } + case Bad(myBad) => Bad(myBad) + } + } + } + } +} From 24731527db9ae5e57d0e62c50a14748d3d784439 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Fri, 8 Feb 2019 15:29:11 +0100 Subject: [PATCH 02/20] Clean up TypeVar insertion/removal in SmartGADTMap If we do not insert TypeVars into the bounds every time, then the only time we need to remove them is when taking the full bounds of some type. Since that logic now resides in ConstraintHandling and replaces all TypeParamRefs internal to SmartGADTMap, we have no need to perform expensive type traversals. This removes the only reason for caching bounds. The addition of HK parameter variance adaptation was necessary to make tests/pos/i6014-gadt.scala pass. --- .../src/dotty/tools/dotc/core/Contexts.scala | 113 +++++------------- 1 file changed, 32 insertions(+), 81 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 81bc5904a483..a5902bbfeb14 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -799,15 +799,13 @@ object Contexts { private var myConstraint: Constraint, private var mapping: SimpleIdentityMap[Symbol, TypeVar], private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private var boundCache: SimpleIdentityMap[Symbol, TypeBounds] ) extends GADTMap with ConstraintHandling[Context] { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} def this() = this( myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), mapping = SimpleIdentityMap.Empty, - reverseMapping = SimpleIdentityMap.Empty, - boundCache = SimpleIdentityMap.Empty + reverseMapping = SimpleIdentityMap.Empty ) implicit override def ctx(implicit ctx: Context): Context = ctx @@ -826,9 +824,7 @@ object Contexts { override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = try { - boundCache = SimpleIdentityMap.Empty - boundAdditionInProgress = true + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = { @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { case tv: TypeVar => val inst = instType(tv) @@ -836,49 +832,44 @@ object Contexts { case _ => tp } - def externalizedSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = { - val externalizedTp1 = removeTypeVars(tp1) - val externalizedTp2 = removeTypeVars(tp2) - - ( - if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2 - else externalizedTp2 frozen_<:< externalizedTp1 - ).reporting({ res => - val descr = i"$externalizedTp1 frozen_${if (isSubtype) "<:<" else ">:>"} $externalizedTp2" - i"$descr = $res" - }, gadts) - } - val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { case tv: TypeVar => tv case inst => - val externalizedInst = removeTypeVars(inst) - gadts.println(i"instantiated: $sym -> $externalizedInst") - return if (isUpper) isSubType(externalizedInst , bound) else isSubType(bound, externalizedInst) + gadts.println(i"instantiated: $sym -> $inst") + return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst) } - val internalizedBound = insertTypeVars(bound) + val internalizedBound = bound match { + case nt: NamedType if contains(nt.symbol) => + stripInternalTypeVar(tvar(nt.symbol)) + case _ => bound + } ( - stripInternalTypeVar(internalizedBound) match { + internalizedBound match { case boundTvar: TypeVar => if (boundTvar eq symTvar) true else if (isUpper) addLess(symTvar.origin, boundTvar.origin) else addLess(boundTvar.origin, symTvar.origin) case bound => - if (externalizedSubtype(symTvar, bound, isSubtype = !isUpper)) { - gadts.println(i"manually unifying $symTvar with $bound") - constraint = constraint.updateEntry(symTvar.origin, bound) - true - } - else if (isUpper) addUpperBound(symTvar.origin, bound) - else addLowerBound(symTvar.origin, bound) + val oldUpperBound = bounds(symTvar.origin) + // If we already have bounds `F >: [t] => List[t] <: [t] => Any` + // and we want to record that `F <: [+A] => List[A]`, we need to adapt + // type parameter variances of the bound. Consider that the following is valid: + // + // class Foo[F[t] >: List[t]] + // type T = Foo[List] + // + // precisely because `Foo[List]` is desugared to `Foo[[A] => List[A]]`. + val bound1 = bound.adaptHkVariances(oldUpperBound) + if (isUpper) addUpperBound(symTvar.origin, bound1) + else addLowerBound(symTvar.origin, bound1) } ).reporting({ res => val descr = if (isUpper) "upper" else "lower" val op = if (isUpper) "<:" else ">:" i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" }, gadts) - } finally boundAdditionInProgress = false + } override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = constraint.isLess(tvar(sym1).origin, tvar(sym2).origin) @@ -886,34 +877,27 @@ object Contexts { override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = mapping(sym) match { case null => null - case tv => removeTypeVars(fullBounds(tv.origin)).asInstanceOf[TypeBounds] + case tv => fullBounds(tv.origin) } override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { mapping(sym) match { case null => null case tv => - def retrieveBounds: TypeBounds = { - val tb = bounds(tv.origin) - removeTypeVars(tb).asInstanceOf[TypeBounds] - } - ( - if (boundAdditionInProgress || ctx.mode.is(Mode.GADTflexible)) retrieveBounds - else boundCache(sym) match { - case tb: TypeBounds => tb - case null => - val bounds = retrieveBounds - boundCache = boundCache.updated(sym, bounds) - bounds + def retrieveBounds: TypeBounds = + bounds(tv.origin) match { + case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => + TypeAlias(reverseMapping(tpr).typeRef) + case tb => tb } - )// .reporting({ res => i"gadt bounds $sym: $res" }, gadts) + retrieveBounds//.reporting({ res => i"gadt bounds $sym: $res" }, gadts) } } override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { - val res = removeTypeVars(approximation(tvar(sym).origin, fromBelow = fromBelow)) + val res = approximation(tvar(sym).origin, fromBelow = fromBelow) gadts.println(i"approximating $sym ~> $res") res } @@ -921,8 +905,7 @@ object Contexts { override def fresh: GADTMap = new SmartGADTMap( myConstraint, mapping, - reverseMapping, - boundCache + reverseMapping ) def restore(other: GADTMap): Unit = other match { @@ -930,7 +913,6 @@ object Contexts { this.myConstraint = other.myConstraint this.mapping = other.mapping this.reverseMapping = other.reverseMapping - this.boundCache = other.boundCache case _ => ; } @@ -964,37 +946,6 @@ object Contexts { } } - private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { - case tp: TypeRef => - val sym = tp.typeSymbol - if (contains(sym)) tvar(sym) else tp - case _ => - (if (map != null) map else new TypeVarInsertingMap()).mapOver(tp) - } - private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap { - override def apply(tp: Type): Type = insertTypeVars(tp, this) - } - - private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { - case tpr: TypeParamRef => - reverseMapping(tpr) match { - case null => tpr - case sym => sym.typeRef - } - case tv: TypeVar => - reverseMapping(tv.origin) match { - case null => tv - case sym => sym.typeRef - } - case _ => - (if (map != null) map else new TypeVarRemovingMap()).mapOver(tp) - } - private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap { - override def apply(tp: Type): Type = removeTypeVars(tp, this) - } - - private[this] var boundAdditionInProgress = false - // ---- Debug ------------------------------------------------------------ override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) From 2ab8a6d21925bda809a0e6653c71e941c05f0b4d Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 11 Feb 2019 16:15:00 +0100 Subject: [PATCH 03/20] Allow constraining type parameters of all enclosing functions gadtSyms/gadtContext became redundant, so they were removed. The logic in typedDefDef was adjusted to only create a fresh context when necessary. --- .../dotty/tools/dotc/core/TypeComparer.scala | 4 +- .../tools/dotc/transform/TreeChecker.scala | 4 +- .../src/dotty/tools/dotc/typer/Inliner.scala | 6 +- .../src/dotty/tools/dotc/typer/Namer.scala | 10 ++- .../src/dotty/tools/dotc/typer/Typer.scala | 76 ++++++++----------- tests/neg/classOf.check | 2 + tests/pos/gadt-all-params.scala | 9 +++ tests/pos/gadt-inference.scala | 44 +++++++++++ tests/run-macros/tasty-extractors-3.check | 2 + 9 files changed, 107 insertions(+), 50 deletions(-) create mode 100644 tests/pos/gadt-all-params.scala create mode 100644 tests/pos/gadt-inference.scala diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index f4a85532c1b1..e842a802708d 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -455,7 +455,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { case _ => false }) || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) && - GADTusage(tp2.symbol) + { tp1.isRef(NothingClass) || GADTusage(tp2.symbol) } } isSubApproxHi(tp1, info2.lo) || compareGADT || fourthTry @@ -714,7 +714,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { (gbounds1 != null) && (isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) && - GADTusage(tp1.symbol) + { tp2.isRef(AnyClass) || GADTusage(tp1.symbol) } } isSubType(hi1, tp2, approx.addLow) || compareGADT case _ => diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index e4d71a68488a..375d11bedc30 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -401,9 +401,9 @@ class TreeChecker extends Phase with SymTransformer { } } - override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = { + override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { withPatSyms(tpd.patVars(tree.pat.asInstanceOf[tpd.Tree])) { - super.typedCase(tree, selType, pt, gadtSyms) + super.typedCase(tree, selType, pt) } } diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 193a92f0000a..0c754eab2570 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -531,9 +531,12 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { /** A utility object offering methods for rewriting inlined code */ object reducer { + import dotty.tools.dotc.core.Contexts.GADTMap + /** An extractor for terms equivalent to `new C(args)`, returning the class `C`, * a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can * follow a reference to an inline value binding to its right hand side. + * * @return optionally, a triple consisting of * - the class `C` * - the arguments `args` @@ -729,7 +732,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = { val isImplicit = scrutinee.isEmpty - val gadtSyms = typer.gadtSyms(scrutType) /** Try to match pattern `pat` against scrutinee reference `scrut`. If successful add * bindings for variables bound in this pattern to `caseBindingMap`. @@ -920,7 +922,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } if (!isImplicit) caseBindingMap += ((NoSymbol, scrutineeBinding)) - val gadtCtx = typer.gadtContext(gadtSyms).addMode(Mode.GADTflexible) + val gadtCtx = ctx.fresh.setFreshGADTBounds.addMode(Mode.GADTflexible) if (reducePattern(caseBindingMap, scrutineeSym.termRef, cdef.pat)(gadtCtx)) { val (caseBindings, from, to) = substBindings(caseBindingMap.toList, mutable.ListBuffer(), Nil, Nil) val guardOK = cdef.guard.isEmpty || { diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index bbf1541fe222..65a9a0789d08 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1335,8 +1335,16 @@ class Namer { typer: Typer => // it would be erased to BoxedUnit. def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp - var rhsCtx = ctx.addMode(Mode.InferringReturnType) + var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) + if (typeParams.nonEmpty) { + rhsCtx.setFreshGADTBounds + typeParams.foreach { tdef => + val TypeBounds(lo, hi) = tdef.info.bounds + rhsCtx.gadt.addBound(tdef, lo, isUpper = false) + rhsCtx.gadt.addBound(tdef, hi, isUpper = true) + } + } def rhsType = typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(rhsCtx).tpe // Approximate a type `tp` with a type that does not contain skolem types. diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 196fdb088323..8cb5133640bd 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1047,37 +1047,8 @@ class Typer extends Namer assignType(cpy.Match(tree)(sel, cases1), sel, cases1) } - /** gadtSyms = "all type parameters of enclosing methods that appear - * non-variantly in the selector type" todo: should typevars - * which appear with variances +1 and -1 (in different - * places) be considered as well? - */ - def gadtSyms(selType: Type)(implicit ctx: Context): Set[Symbol] = trace(i"GADT syms of $selType", gadts) { - val accu = new TypeAccumulator[Set[Symbol]] { - def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = { - val tsyms1 = t match { - case tr: TypeRef if (tr.symbol is TypeParam) && tr.symbol.owner.isTerm && variance == 0 => - tsyms + tr.symbol - case _ => - tsyms - } - foldOver(tsyms1, t) - } - } - accu(Set.empty, selType) - } - - /** Context with fresh GADT bounds for all gadtSyms */ - def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = { - val gadtCtx = ctx.fresh.setFreshGADTBounds - for (sym <- gadtSyms) - if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym) - gadtCtx - } - def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] = { - val gadts = gadtSyms(selType) - cases.mapconserve(typedCase(_, selType, pt, gadts)) + cases.mapconserve(typedCase(_, selType, pt)) } /** - strip all instantiated TypeVars from pattern types. @@ -1105,9 +1076,9 @@ class Typer extends Namer } /** Type a case. */ - def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") { + def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = track("typedCase") { val originalCtx = ctx - val gadtCtx = gadtContext(gadtSyms) + val gadtCtx: Context = ctx.fresh.setFreshGADTBounds def caseRest(pat: Tree)(implicit ctx: Context) = { val pat1 = indexPattern(tree).transform(pat) @@ -1537,19 +1508,38 @@ class Typer extends Namer if (sym is ImplicitOrImplied) checkImplicitConversionDefOK(sym) val tpt1 = checkSimpleKinded(typedType(tpt)) - var rhsCtx = ctx - if (sym.isConstructor && !sym.isPrimaryConstructor && tparams1.nonEmpty) { - // for secondary constructors we need a context that "knows" - // that their type parameters are aliases of the class type parameters. - // See pos/i941.scala - rhsCtx = ctx.fresh.setFreshGADTBounds - (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => - val tr = tparam.typeRef - rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) - rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) + val rhsCtx: Context = { + var _result: FreshContext = null + def resultCtx(): FreshContext = { + if (_result == null) _result = ctx.fresh + _result + } + + if (tparams1.nonEmpty) { + resultCtx().setFreshGADTBounds + if (!sym.isConstructor) { + // if we're _not_ in a constructor, allow constraining type parameters + tparams1.foreach { tdef => + val tb @ TypeBounds(lo, hi) = tdef.symbol.info.bounds + resultCtx().gadt.addBound(tdef.symbol, lo, isUpper = false) + resultCtx().gadt.addBound(tdef.symbol, hi, isUpper = true) + } + } else if (!sym.isPrimaryConstructor) { + // otherwise, for secondary constructors we need a context that "knows" + // that their type parameters are aliases of the class type parameters. + // See pos/i941.scala + (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => + val tr = tparam.typeRef + resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = false) + resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = true) + } + } } + + if (sym.isInlineMethod) resultCtx().addMode(Mode.InlineableBody) + + if (_result ne null) _result else ctx } - if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) val rhs1 = typedExpr(ddef.rhs, tpt1.tpe.widenExpr)(rhsCtx) if (sym.isInlineMethod) { diff --git a/tests/neg/classOf.check b/tests/neg/classOf.check index 7c761b8af5be..b8416e4007e3 100644 --- a/tests/neg/classOf.check +++ b/tests/neg/classOf.check @@ -2,5 +2,7 @@ Test.C{I = String} is not a class type [116..117] in classOf.scala T is not a class type + +where: T is a type in method f2 with bounds <: String [72..73] in classOf.scala T is not a class type diff --git a/tests/pos/gadt-all-params.scala b/tests/pos/gadt-all-params.scala new file mode 100644 index 000000000000..b5d7baecc283 --- /dev/null +++ b/tests/pos/gadt-all-params.scala @@ -0,0 +1,9 @@ +object `gadt-all-params` { + enum Expr[T] { + case UnitLit extends Expr[Unit] + } + + def foo[T >: TT <: TT, TT](e: Expr[T]): T = e match { + case Expr.UnitLit => () + } +} diff --git a/tests/pos/gadt-inference.scala b/tests/pos/gadt-inference.scala new file mode 100644 index 000000000000..e625e4823dc0 --- /dev/null +++ b/tests/pos/gadt-inference.scala @@ -0,0 +1,44 @@ +object `gadt-inference` { + enum Expr[T] { + case StrLit(s: String) extends Expr[String] + case IntLit(i: Int) extends Expr[Int] + } + import Expr._ + + def eval[T](e: Expr[T]) = + e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + + def nested[T](o: Option[Expr[T]]) = + o match { + case Some(e) => e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + case None => ??? + } + + def local[T](e: Expr[T]) = { + def eval[T](e: Expr[T]) = + e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + + eval(e) : T + } +} diff --git a/tests/run-macros/tasty-extractors-3.check b/tests/run-macros/tasty-extractors-3.check index 2e3b9f23e983..35c88a7598f5 100644 --- a/tests/run-macros/tasty-extractors-3.check +++ b/tests/run-macros/tasty-extractors-3.check @@ -10,6 +10,8 @@ Type.SymRef(IsClassDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDe Type.SymRef(IsTypeDefSymbol(), NoPrefix()) +Type.SymRef(IsTypeDefSymbol(), NoPrefix()) + TypeBounds(Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix())))), Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix()))))) Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix())))) From eac0e533dcf225dca716daf85fb3de14508a2e52 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 11 Feb 2019 16:37:17 +0100 Subject: [PATCH 04/20] Do not constrain types under either approx --- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 2 +- tests/neg/gadt-no-approx.scala | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 tests/neg/gadt-no-approx.scala diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index e842a802708d..73ff054866bb 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1376,7 +1376,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { * Test that the resulting bounds are still satisfiable. */ private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { - val boundImprecise = if (isUpper) approx.high else approx.low + val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GADTflexible) && !frozenConstraint && !boundImprecise && { val tparam = tr.symbol gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") diff --git a/tests/neg/gadt-no-approx.scala b/tests/neg/gadt-no-approx.scala new file mode 100644 index 000000000000..eef0d82cba21 --- /dev/null +++ b/tests/neg/gadt-no-approx.scala @@ -0,0 +1,10 @@ +object `gadt-no-approx` { + def fo[U](u: U): U = + (0 : Int) match { + case _: u.type => + val i: Int = (??? : U) // error + // potentially could compile + // val i2: Int = u + u + } +} From f5a1ef2cc6ab9dfc68030416419991152f8ffb5c Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 17 Jan 2019 16:19:24 +0100 Subject: [PATCH 05/20] Use Skolems to infer GADT constraints The rationale for using a Skolem here is: we want to record that there is at least one value that is both of the pattern type and the scrutinee type. All symbols are now considered valid for adding GADT constraints - the rationale is that set of constrainable symbols should be either selected on a per-(sub)pattern basis, or be the same for all matches. Previously, symbols which were only appearing variantly in a scrutinee type could be considered constrainable anyway because of an outer pattern match. --- .../src/dotty/tools/dotc/core/Types.scala | 7 +- .../dotty/tools/dotc/typer/Applications.scala | 2 +- .../dotty/tools/dotc/typer/Inferencing.scala | 48 +++++--------- .../src/dotty/tools/dotc/typer/Typer.scala | 4 +- tests/neg/creative-gadt-constraints.scala | 66 +++++++++++++++++++ tests/neg/int-extractor.scala | 31 +++++++++ tests/neg/invariant-gadt.scala | 27 ++++++++ tests/neg/typeclass-derivation2.scala | 20 ++++-- tests/pos/precise-pattern-type.scala | 16 +++++ tests/run/typeclass-derivation2.scala | 49 ++++++++++---- 10 files changed, 216 insertions(+), 54 deletions(-) create mode 100644 tests/neg/creative-gadt-constraints.scala create mode 100644 tests/neg/int-extractor.scala create mode 100644 tests/neg/invariant-gadt.scala create mode 100644 tests/pos/precise-pattern-type.scala diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 75fb6629bf39..33efe736b3c8 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -3704,7 +3704,12 @@ object Types { // ----- Skolem types ----------------------------------------------- - /** A skolem type reference with underlying type `info` */ + /** A skolem type reference with underlying type `info`. + * + * For Dotty, a skolem type is a singleton type of some unknown value of type `info`. + * Note that care is needed when creating them, since not all types need to be inhabited. + * A skolem is equal to itself and no other type. + */ case class SkolemType(info: Type) extends UncachedProxyType with ValueType with SingletonType { override def underlying(implicit ctx: Context): Type = info def derivedSkolemType(info: Type)(implicit ctx: Context): SkolemType = diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index ca5d94dba23b..c56aef84aebc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1090,7 +1090,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic => * - If a type proxy P is not a reference to a class, P's supertype is in G */ def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean = - if (constrainPatternType(subtp, tp)) true + if (constrainPatternType(subtp, tp, termPattern = true)) true else tp match { case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent) case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType) diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 8441690dd741..f84a7d59e494 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -153,41 +153,22 @@ object Inferencing { def isSkolemFree(tp: Type)(implicit ctx: Context): Boolean = !tp.existsPart(_.isInstanceOf[SkolemType]) - /** Derive information about a pattern type by comparing it with some variant of the - * static scrutinee type. We have the following situation in case of a (dynamic) pattern match: + /** Infer constraints that should be in scope for a case body with given pattern and scrutinee types. * - * StaticScrutineeType PatternType - * \ / - * DynamicScrutineeType + * If `termPattern`, infer constraints from knowing that there exists a value which of both scrutinee + * and pattern types (which is the case for normal pattern matching). If not `termPattern`, instead + * infer constraints from knowing that `tp <: pt`. * - * If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained. - * Now let's say we can prove that `PatternType <: StaticScrutineeType`. + * If a pattern matches during normal pattern matching, we can be certain that there exists a value + * which is of both scrutinee and pattern types (the value we're matching on). If this value + * was in a variable, say `x`, then we could simply infer constraints from `x.type <: pt`. Since we might + * be matching on an expression as well, we take a skolem of the scrutinee, which is essentially an existential + * singleton type (see [[dotty.tools.dotc.core.Types.SkolemType]]). * - * StaticScrutineeType - * | \ - * | \ - * | \ - * | PatternType - * | / - * DynamicScrutineeType - * - * What can we say about the relationship of parameter types between `PatternType` and - * `DynamicScrutineeType`? - * - * - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType` - * in the same way as `PatternType` ("invariant refinement"), the subtype test - * `PatternType <:< StaticScrutineeType` tells us all we need to know. - * - Otherwise, if variant refinement is a possibility we can only make predictions - * about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test - * where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams` - * replaces all type argument of variant parameters with empty bounds. - * - * Invariant refinement can be assumed if `PatternType`'s class(es) are final or - * case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`). - * - * TODO: Update so that GADT symbols can be variant, and we special case final class types in patterns + * Note that we need to sometimes widen type parameters of the scrutinee type to avoid unsoundness - + * see i3989c.scala and related issue discussion on Github. */ - def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = { + def constrainPatternType(tp: Type, pt: Type, termPattern: Boolean)(implicit ctx: Context): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case) case tp: TypeProxy => refinementIsInvariant(tp.underlying) @@ -209,8 +190,9 @@ object Inferencing { } val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt) - trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { - tp <:< widePt + val narrowTp = if (termPattern) SkolemType(tp) else tp + trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { + narrowTp <:< widePt } } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 8cb5133640bd..ca08f8eb7d89 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -604,7 +604,7 @@ class Typer extends Namer def handlePattern: Tree = { val tpt1 = typedTpt if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef) - constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible)) + constrainPatternType(tpt1.tpe, pt, termPattern = true)(ctx.addMode(Mode.GADTflexible)) // special case for an abstract type that comes with a class tag tryWithClassTag(ascription(tpt1, isWildcard = true), pt) } @@ -1104,7 +1104,7 @@ class Typer extends Namer def caseRest(implicit ctx: Context) = { val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern))) if (!ctx.isAfterTyper) - constrainPatternType(pat1.tpe, selType)(ctx.addMode(Mode.GADTflexible)) + constrainPatternType(pat1.tpe, selType, termPattern = false)(ctx.addMode(Mode.GADTflexible)) val pat2 = indexPattern(cdef).transform(pat1) val body1 = typedType(cdef.body, pt) assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1) diff --git a/tests/neg/creative-gadt-constraints.scala b/tests/neg/creative-gadt-constraints.scala new file mode 100644 index 000000000000..a8869de5f75d --- /dev/null +++ b/tests/neg/creative-gadt-constraints.scala @@ -0,0 +1,66 @@ +object buffer { + object EssaInt { + def unapply(i: Int): Some[Int] = Some(i) + } + + case class Inv[T](t: T) + + enum EQ[A, B] { case Refl[T]() extends EQ[T, T] } + enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] } // A <: B + + def test_eq1[A, B](eq: EQ[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) | Sko(Int) + eq match { case EQ.Refl() => // a = b + val success: A = b + val fail: A = 0 // error + 0 // error + } + } + } + + def test_eq2[A, B](eq: EQ[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) + eq match { case EQ.Refl() => // a = b + val success: A = b + val fail: A = 0 // error + 0 // error + } + } + } + + def test_sub1[A, B](sub: SUB[A, B], a: A, b: B): B = + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int) + sub match { case SUB.Refl() => // b >: a + val success: B = a + val fail: A = 0 // error + 0 // error + } + } + } + + def test_sub2[A, B](sub: SUB[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int) + sub match { case SUB.Refl() => // b >: a + val success: B = a + val fail: A = 0 // error + 0 // error + } + } + } + + + def test_sub_eq[A, B, C](sub: SUB[A|B, C], eqA: EQ[A, 5], eqB: EQ[B, 6]): C = + sub match { case SUB.Refl() => // C >: A | B + eqA match { case EQ.Refl() => // A = 5 + eqB match { case EQ.Refl() => // B = 6 + val fail1: A = 0 // error + val fail2: B = 0 // error + 0 // error + } + } + } +} diff --git a/tests/neg/int-extractor.scala b/tests/neg/int-extractor.scala new file mode 100644 index 000000000000..8534c5a1bc00 --- /dev/null +++ b/tests/neg/int-extractor.scala @@ -0,0 +1,31 @@ +object Test { + object EssaInt { + def unapply(i: Int): Some[Int] = Some(i) + } + + def foo1[T](t: T): T = t match { + case EssaInt(_) => + 0 // error + } + + def foo2[T](t: T): T = t match { + case EssaInt(_) => t match { + case EssaInt(_) => + 0 // error + } + } + + case class Inv[T](t: T) + + def bar1[T](t: T): T = Inv(t) match { + case Inv(EssaInt(_)) => + 0 // error + } + + def bar2[T](t: T): T = t match { + case Inv(EssaInt(_)) => t match { + case Inv(EssaInt(_)) => + 0 // error + } + } +} diff --git a/tests/neg/invariant-gadt.scala b/tests/neg/invariant-gadt.scala new file mode 100644 index 000000000000..ac335f57743f --- /dev/null +++ b/tests/neg/invariant-gadt.scala @@ -0,0 +1,27 @@ +object `invariant-gadt` { + case class Invariant[T](value: T) + + def unsound0[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => + (0: Any) // error + } + + def unsound1[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => + 0 // error + } + + def unsound2[T](t: T): T = Invariant(t) match { + case Invariant(value) => value match { + case _: Int => + 0 // error + } + } + + def unsoundTwice[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => Invariant(t) match { + case Invariant(_: Int) => + 0 // error + } + } +} diff --git a/tests/neg/typeclass-derivation2.scala b/tests/neg/typeclass-derivation2.scala index 33c64494e9c5..ddb6517fb869 100644 --- a/tests/neg/typeclass-derivation2.scala +++ b/tests/neg/typeclass-derivation2.scala @@ -111,6 +111,13 @@ object TypeLevel { * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. */ abstract class Shaped[T, S <: Shape] extends Reflected[T] + + // substitute for erasedValue that allows precise matching + final abstract class Type[-A, +B] + type Subtype[t] = Type[_, t] + type Supertype[t] = Type[t, _] + type Exactly[t] = Type[t, t] + erased def typeOf[T]: Type[T, T] = ??? } // An algebraic datatype @@ -203,7 +210,7 @@ trait Show[T] { def show(x: T): String } object Show { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryShow[T](x: T): String = implicit match { @@ -229,9 +236,14 @@ object Show { inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => showCase[T, elems](r, x) - case _ => showCases[T, alts1](r, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _ => + error("invalid call to showCases: one of Alts is not a subtype of T") } case _: Unit => throw new MatchError(x) diff --git a/tests/pos/precise-pattern-type.scala b/tests/pos/precise-pattern-type.scala new file mode 100644 index 000000000000..856672fafbf2 --- /dev/null +++ b/tests/pos/precise-pattern-type.scala @@ -0,0 +1,16 @@ +object `precise-pattern-type` { + class Type { + def isType: Boolean = true + } + + class Tree[-T >: Null] { + def tpe: T @annotation.unchecked.uncheckedVariance = ??? + } + + case class Select[-T >: Null](qual: Tree[T]) extends Tree[T] + + def test[T <: Tree[Type]](tree: T) = tree match { + case Select(q) => + q.tpe.isType + } +} diff --git a/tests/run/typeclass-derivation2.scala b/tests/run/typeclass-derivation2.scala index 8ac7cec4487c..f8812b461d48 100644 --- a/tests/run/typeclass-derivation2.scala +++ b/tests/run/typeclass-derivation2.scala @@ -113,6 +113,13 @@ object TypeLevel { * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. */ abstract class Shaped[T, S <: Shape] extends Reflected[T] + + // substitute for erasedValue that allows precise matching + final abstract class Type[-A, +B] + type Subtype[t] = Type[_, t] + type Supertype[t] = Type[t, _] + type Exactly[t] = Type[t, t] + erased def typeOf[T]: Type[T, T] = ??? } // An algebraic datatype @@ -217,7 +224,7 @@ trait Eq[T] { } object Eq { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryEql[T](x: T, y: T) = implicit match { @@ -239,8 +246,13 @@ object Eq { inline def eqlCases[T, Alts <: Tuple](xm: Mirror, ym: Mirror, ordinal: Int, n: Int): Boolean = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - if (n == ordinal) eqlElems[elems](xm, ym, 0) - else eqlCases[T, alts1](xm, ym, ordinal, n + 1) + inline typeOf[alt] match { + case _: Subtype[T] => + if (n == ordinal) eqlElems[elems](xm, ym, 0) + else eqlCases[T, alts1](xm, ym, ordinal, n + 1) + case _ => + error("invalid call to eqlCases: one of Alts is not a subtype of T") + } case _: Unit => false } @@ -271,7 +283,7 @@ trait Pickler[T] { } object Pickler { - import scala.compiletime.{erasedValue, constValue} + import scala.compiletime.{erasedValue, constValue, error} import TypeLevel._ def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1) @@ -294,12 +306,17 @@ object Pickler { inline def pickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T, n: Int): Unit = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => - buf += n - pickleCase[T, elems](r, buf, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => + buf += n + pickleCase[T, elems](r, buf, x) + case _ => + pickleCases[T, alts1](r, buf, x, n + 1) + } case _ => - pickleCases[T, alts1](r, buf, x, n + 1) + error("invalid pickleCases call: one of Alts is not a subtype of T") } case _: Unit => } @@ -362,7 +379,7 @@ trait Show[T] { def show(x: T): String } object Show { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryShow[T](x: T): String = implicit match { @@ -388,9 +405,15 @@ object Show { inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => showCase[T, elems](r, x) - case _ => showCases[T, alts1](r, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => + showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _ => + error("invalid call to showCases: one of Alts is not a subtype of T") } case _: Unit => throw new MatchError(x) From 3df2c0f549bca9387680b50935006e56d5e47607 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 14 Feb 2019 15:22:25 +0100 Subject: [PATCH 06/20] Adjust behaviour of TypeComparer.either for GADTflexible --- .../src/dotty/tools/dotc/core/Contexts.scala | 13 ++++- compiler/src/dotty/tools/dotc/core/Mode.scala | 2 +- .../dotty/tools/dotc/core/TypeComparer.scala | 56 +++++++++++++++---- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index a5902bbfeb14..69522f041454 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -775,7 +775,7 @@ object Contexts { else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase") } - sealed abstract class GADTMap { + sealed abstract class GADTMap extends Showable { def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean @@ -802,6 +802,14 @@ object Contexts { ) extends GADTMap with ConstraintHandling[Context] { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} + def subsumes(left: GADTMap, right: GADTMap, pre: GADTMap)(implicit ctx: Context): Boolean = { + def extractConstraint(g: GADTMap) = g match { + case s: SmartGADTMap => s.constraint + case EmptyGADTMap => OrderingConstraint.empty + } + subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + } + def this() = this( myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), mapping = SimpleIdentityMap.Empty, @@ -950,6 +958,8 @@ object Contexts { override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) + override def toText(printer: Printer): Texts.Text = constraint.toText(printer) + override def debugBoundsDescription(implicit ctx: Context): String = { val sb = new mutable.StringBuilder sb ++= constraint.show @@ -969,6 +979,7 @@ object Contexts { override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null override def contains(sym: Symbol)(implicit ctx: Context) = false override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation") + override def toText(printer: Printer): Texts.Text = "EmptyGADTMap" override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" override def fresh = new SmartGADTMap override def restore(other: GADTMap): Unit = { diff --git a/compiler/src/dotty/tools/dotc/core/Mode.scala b/compiler/src/dotty/tools/dotc/core/Mode.scala index 430d0b062c84..81b9fc5ea5c4 100644 --- a/compiler/src/dotty/tools/dotc/core/Mode.scala +++ b/compiler/src/dotty/tools/dotc/core/Mode.scala @@ -49,7 +49,7 @@ object Mode { /** We are in a pattern alternative */ val InPatternAlternative: Mode = newMode(7, "InPatternAlternative") - /** Allow GADTFlexType labelled types to have their bounds adjusted */ + /** Infer GADT constraints during type comparisons `A <:< B` */ val GADTflexible: Mode = newMode(8, "GADTflexible") /** Assume -language:strictEquality */ diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 73ff054866bb..7ea8262741a8 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1252,16 +1252,52 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { */ private def either(op1: => Boolean, op2: => Boolean): Boolean = { val preConstraint = constraint - op1 && { - val leftConstraint = constraint - constraint = preConstraint - if (!(op2 && subsumes(leftConstraint, constraint, preConstraint))) { - if (constr != noPrinter && !subsumes(constraint, leftConstraint, preConstraint)) - constr.println(i"CUT - prefer $leftConstraint over $constraint") - constraint = leftConstraint - } - true - } || op2 + + if (ctx.mode.is(Mode.GADTflexible)) { + val preGadt = ctx.gadt.fresh + // if GADTflexible mode is on, we always have a SmartGADTMap + val pre = preGadt.asInstanceOf[SmartGADTMap] + if (op1) { + val leftConstraint = constraint + val leftGadt = ctx.gadt.fresh + constraint = preConstraint + ctx.gadt.restore(preGadt) + if (op2) { + if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) { + gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt") + constr.println(i"CUT - prefer $constraint over $leftConstraint") + true + } else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) { + gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}") + constr.println(i"CUT - prefer $leftConstraint over $constraint") + constraint = leftConstraint + ctx.gadt.restore(leftGadt) + true + } else { + gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt") + constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint") + constraint = preConstraint + ctx.gadt.restore(preGadt) + true + } + } else { + constraint = leftConstraint + ctx.gadt.restore(leftGadt) + true + } + } else op2 + } else { + op1 && { + val leftConstraint = constraint + constraint = preConstraint + if (!(op2 && subsumes(leftConstraint, constraint, preConstraint))) { + if (constr != noPrinter && !subsumes(constraint, leftConstraint, preConstraint)) + constr.println(i"CUT - prefer $leftConstraint over $constraint") + constraint = leftConstraint + } + true + } || op2 + } } /** Does type `tp1` have a member with name `name` whose normalized type is a subtype of From e1df3ae2847bcb589de5cf6b7406c7134a2321af Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Wed, 27 Mar 2019 10:06:20 +0100 Subject: [PATCH 07/20] Move GADT constraint code to a separate file Also rename the classes to better reflect their role, and document and reorder definitions to make more sense. --- .../src/dotty/tools/dotc/core/Contexts.scala | 223 +-------------- .../tools/dotc/core/GadtConstraint.scala | 255 ++++++++++++++++++ .../dotty/tools/dotc/core/TypeComparer.scala | 4 +- .../dotty/tools/dotc/typer/Implicits.scala | 2 +- .../src/dotty/tools/dotc/typer/Inliner.scala | 2 - 5 files changed, 263 insertions(+), 223 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/core/GadtConstraint.scala diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 69522f041454..cbc3a12f0f8b 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -139,9 +139,9 @@ object Contexts { final def importInfo: ImportInfo = _importInfo /** The current bounds in force for type parameters appearing in a GADT */ - private[this] var _gadt: GADTMap = _ - protected def gadt_=(gadt: GADTMap): Unit = _gadt = gadt - final def gadt: GADTMap = _gadt + private[this] var _gadt: GadtConstraint = _ + protected def gadt_=(gadt: GadtConstraint): Unit = _gadt = gadt + final def gadt: GadtConstraint = _gadt /** The history of implicit searches that are currently active */ private[this] var _searchHistory: SearchHistory = null @@ -534,7 +534,7 @@ object Contexts { def setTypeAssigner(typeAssigner: TypeAssigner): this.type = { this.typeAssigner = typeAssigner; this } def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) } def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this } - def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this } + def setGadt(gadt: GadtConstraint): this.type = { this.gadt = gadt; this } def setFreshGADTBounds: this.type = setGadt(gadt.fresh) def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this } def setSource(source: SourceFile): this.type = { this.source = source; this } @@ -617,7 +617,7 @@ object Contexts { store = initialStore.updated(settingsStateLoc, settingsGroup.defaultState) typeComparer = new TypeComparer(this) searchHistory = new SearchRoot - gadt = EmptyGADTMap + gadt = EmptyGadtConstraint } @sharable object NoContext extends Context(null) { @@ -774,217 +774,4 @@ object Contexts { if (thread == null) thread = Thread.currentThread() else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase") } - - sealed abstract class GADTMap extends Showable { - def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit - def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean - def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean - def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds - - /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. - * - * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` - * of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead. - */ - def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds - def contains(sym: Symbol)(implicit ctx: Context): Boolean - def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type - def debugBoundsDescription(implicit ctx: Context): String - def fresh: GADTMap - def restore(other: GADTMap): Unit - def isEmpty: Boolean - } - - final class SmartGADTMap private ( - private var myConstraint: Constraint, - private var mapping: SimpleIdentityMap[Symbol, TypeVar], - private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - ) extends GADTMap with ConstraintHandling[Context] { - import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} - - def subsumes(left: GADTMap, right: GADTMap, pre: GADTMap)(implicit ctx: Context): Boolean = { - def extractConstraint(g: GADTMap) = g match { - case s: SmartGADTMap => s.constraint - case EmptyGADTMap => OrderingConstraint.empty - } - subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) - } - - def this() = this( - myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), - mapping = SimpleIdentityMap.Empty, - reverseMapping = SimpleIdentityMap.Empty - ) - - implicit override def ctx(implicit ctx: Context): Context = ctx - - override protected def constraint = myConstraint - override protected def constraint_=(c: Constraint) = myConstraint = c - - override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = - reverseMapping(param) match { - case sym: Symbol => sym.typeRef - case null => param - } - - override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) - override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) - - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) - - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = { - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } - - val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { - case tv: TypeVar => tv - case inst => - gadts.println(i"instantiated: $sym -> $inst") - return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst) - } - - val internalizedBound = bound match { - case nt: NamedType if contains(nt.symbol) => - stripInternalTypeVar(tvar(nt.symbol)) - case _ => bound - } - ( - internalizedBound match { - case boundTvar: TypeVar => - if (boundTvar eq symTvar) true - else if (isUpper) addLess(symTvar.origin, boundTvar.origin) - else addLess(boundTvar.origin, symTvar.origin) - case bound => - val oldUpperBound = bounds(symTvar.origin) - // If we already have bounds `F >: [t] => List[t] <: [t] => Any` - // and we want to record that `F <: [+A] => List[A]`, we need to adapt - // type parameter variances of the bound. Consider that the following is valid: - // - // class Foo[F[t] >: List[t]] - // type T = Foo[List] - // - // precisely because `Foo[List]` is desugared to `Foo[[A] => List[A]]`. - val bound1 = bound.adaptHkVariances(oldUpperBound) - if (isUpper) addUpperBound(symTvar.origin, bound1) - else addLowerBound(symTvar.origin, bound1) - } - ).reporting({ res => - val descr = if (isUpper) "upper" else "lower" - val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" - }, gadts) - } - - override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = - constraint.isLess(tvar(sym1).origin, tvar(sym2).origin) - - override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = - mapping(sym) match { - case null => null - case tv => fullBounds(tv.origin) - } - - override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { - mapping(sym) match { - case null => null - case tv => - def retrieveBounds: TypeBounds = - bounds(tv.origin) match { - case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => - TypeAlias(reverseMapping(tpr).typeRef) - case tb => tb - } - retrieveBounds//.reporting({ res => i"gadt bounds $sym: $res" }, gadts) - } - } - - override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null - - override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { - val res = approximation(tvar(sym).origin, fromBelow = fromBelow) - gadts.println(i"approximating $sym ~> $res") - res - } - - override def fresh: GADTMap = new SmartGADTMap( - myConstraint, - mapping, - reverseMapping - ) - - def restore(other: GADTMap): Unit = other match { - case other: SmartGADTMap => - this.myConstraint = other.myConstraint - this.mapping = other.mapping - this.reverseMapping = other.reverseMapping - case _ => ; - } - - override def isEmpty: Boolean = mapping.size == 0 - - // ---- Private ---------------------------------------------------------- - - private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { - mapping(sym) match { - case tv: TypeVar => - tv - case null => - val res = { - import NameKinds.DepParamName - // avoid registering the TypeVar with TyperState / TyperState#constraint - // - we don't want TyperState instantiating these TypeVars - // - we don't want TypeComparer constraining these TypeVars - val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( - pt => (sym.info match { - case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb - case _ => TypeBounds.empty - }) :: Nil, - pt => defn.AnyType) - new TypeVar(poly.paramRefs.head, creatorState = null) - } - gadts.println(i"GADTMap: created tvar $sym -> $res") - constraint = constraint.add(res.origin.binder, res :: Nil) - mapping = mapping.updated(sym, res) - reverseMapping = reverseMapping.updated(res.origin, sym) - res - } - } - - // ---- Debug ------------------------------------------------------------ - - override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) - - override def toText(printer: Printer): Texts.Text = constraint.toText(printer) - - override def debugBoundsDescription(implicit ctx: Context): String = { - val sb = new mutable.StringBuilder - sb ++= constraint.show - sb += '\n' - mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${fullBounds(sym)}\n" - } - sb.result - } - } - - @sharable object EmptyGADTMap extends GADTMap { - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds") - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound") - override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.isLess") - override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null - override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null - override def contains(sym: Symbol)(implicit ctx: Context) = false - override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation") - override def toText(printer: Printer): Texts.Text = "EmptyGADTMap" - override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" - override def fresh = new SmartGADTMap - override def restore(other: GADTMap): Unit = { - if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") - } - override def isEmpty: Boolean = true - } } diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala new file mode 100644 index 000000000000..00802c897dd3 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -0,0 +1,255 @@ +package dotty.tools +package dotc +package core + +import Decorators._ +import Contexts._ +import Types._ +import Symbols._ +import util.SimpleIdentityMap +import collection.mutable +import printing._ + +import scala.annotation.internal.sharable + +/** Represents GADT constraints currently in scope */ +sealed abstract class GadtConstraint extends Showable { + /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ + def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. + * + * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` + * of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Is `sym1` ordered to be less than `sym2`? */ + def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean + + def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit + def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean + + /** Is the symbol registered in the constraint? + * + * Note that this is returns `true` even if `sym` is already instantiated to some type, + * unlike [[Constraint.contains]]. + */ + def contains(sym: Symbol)(implicit ctx: Context): Boolean + + def isEmpty: Boolean + + /** See [[ConstraintHandling.approximation]] */ + def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type + + def fresh: GadtConstraint + + /** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */ + def restore(other: GadtConstraint): Unit + + def debugBoundsDescription(implicit ctx: Context): String +} + +final class ProperGadtConstraint private( + private var myConstraint: Constraint, + private var mapping: SimpleIdentityMap[Symbol, TypeVar], + private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], +) extends GadtConstraint with ConstraintHandling[Context] { + import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} + + def this() = this( + myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), + mapping = SimpleIdentityMap.Empty, + reverseMapping = SimpleIdentityMap.Empty + ) + + /** Exposes ConstraintHandling.subsumes */ + def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(implicit ctx: Context): Boolean = { + def extractConstraint(g: GadtConstraint) = g match { + case s: ProperGadtConstraint => s.constraint + case EmptyGadtConstraint => OrderingConstraint.empty + } + subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + } + + override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) + + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = { + @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { + case tv: TypeVar => tv + case inst => + gadts.println(i"instantiated: $sym -> $inst") + return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst) + } + + val internalizedBound = bound match { + case nt: NamedType if contains(nt.symbol) => + stripInternalTypeVar(tvar(nt.symbol)) + case _ => bound + } + ( + internalizedBound match { + case boundTvar: TypeVar => + if (boundTvar eq symTvar) true + else if (isUpper) addLess(symTvar.origin, boundTvar.origin) + else addLess(boundTvar.origin, symTvar.origin) + case bound => + val oldUpperBound = bounds(symTvar.origin) + // If we already have bounds `F >: [t] => List[t] <: [t] => Any` + // and we want to record that `F <: [+A] => List[A]`, we need to adapt + // type parameter variances of the bound. Consider that the following is valid: + // + // class Foo[F[t] >: List[t]] + // type T = Foo[List] + // + // precisely because `Foo[List]` is desugared to `Foo[[A] => List[A]]`. + val bound1 = bound.adaptHkVariances(oldUpperBound) + if (isUpper) addUpperBound(symTvar.origin, bound1) + else addLowerBound(symTvar.origin, bound1) + } + ).reporting({ res => + val descr = if (isUpper) "upper" else "lower" + val op = if (isUpper) "<:" else ">:" + i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" + }, gadts) + } + + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = + constraint.isLess(tvar(sym1).origin, tvar(sym2).origin) + + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = + mapping(sym) match { + case null => null + case tv => fullBounds(tv.origin) + } + + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { + mapping(sym) match { + case null => null + case tv => + def retrieveBounds: TypeBounds = + bounds(tv.origin) match { + case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => + TypeAlias(reverseMapping(tpr).typeRef) + case tb => tb + } + retrieveBounds//.reporting({ res => i"gadt bounds $sym: $res" }, gadts) + } + } + + override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null + + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { + val res = approximation(tvar(sym).origin, fromBelow = fromBelow) + gadts.println(i"approximating $sym ~> $res") + res + } + + override def fresh: GadtConstraint = new ProperGadtConstraint( + myConstraint, + mapping, + reverseMapping + ) + + def restore(other: GadtConstraint): Unit = other match { + case other: ProperGadtConstraint => + this.myConstraint = other.myConstraint + this.mapping = other.mapping + this.reverseMapping = other.reverseMapping + case _ => ; + } + + override def isEmpty: Boolean = mapping.size == 0 + + // ---- Protected/internal ----------------------------------------------- + + implicit override def ctx(implicit ctx: Context): Context = ctx + + override protected def constraint = myConstraint + override protected def constraint_=(c: Constraint) = myConstraint = c + + override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = + reverseMapping(param) match { + case sym: Symbol => sym.typeRef + case null => param + } + + override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) + override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) + + // ---- Private ---------------------------------------------------------- + + private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { + mapping(sym) match { + case tv: TypeVar => + tv + case null => + val res = { + import NameKinds.DepParamName + // avoid registering the TypeVar with TyperState / TyperState#constraint + // - we don't want TyperState instantiating these TypeVars + // - we don't want TypeComparer constraining these TypeVars + val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( + pt => (sym.info match { + case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb + case _ => TypeBounds.empty + }) :: Nil, + pt => defn.AnyType) + new TypeVar(poly.paramRefs.head, creatorState = null) + } + gadts.println(i"GADTMap: created tvar $sym -> $res") + constraint = constraint.add(res.origin.binder, res :: Nil) + mapping = mapping.updated(sym, res) + reverseMapping = reverseMapping.updated(res.origin, sym) + res + } + } + + // ---- Debug ------------------------------------------------------------ + + override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) + + override def toText(printer: Printer): Texts.Text = constraint.toText(printer) + + override def debugBoundsDescription(implicit ctx: Context): String = { + val sb = new mutable.StringBuilder + sb ++= constraint.show + sb += '\n' + mapping.foreachBinding { case (sym, _) => + sb ++= i"$sym: ${fullBounds(sym)}\n" + } + sb.result + } +} + +@sharable object EmptyGadtConstraint extends GadtConstraint { + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + + override def isEmpty: Boolean = true + + override def contains(sym: Symbol)(implicit ctx: Context) = false + + override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGadtConstraint.addEmptyBounds") + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addBound") + + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGadtConstraint.approximation") + + override def fresh = new ProperGadtConstraint + override def restore(other: GadtConstraint): Unit = { + if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") + } + + override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGadtConstraint" + + override def toText(printer: Printer): Texts.Text = "EmptyGadtConstraint" +} diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 7ea8262741a8..92878da97d76 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1255,8 +1255,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { if (ctx.mode.is(Mode.GADTflexible)) { val preGadt = ctx.gadt.fresh - // if GADTflexible mode is on, we always have a SmartGADTMap - val pre = preGadt.asInstanceOf[SmartGADTMap] + // if GADTflexible mode is on, we always have a ProperGadtConstraint + val pre = preGadt.asInstanceOf[ProperGadtConstraint] if (op1) { val leftConstraint = constraint val leftGadt = ctx.gadt.fresh diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 52f14d3feb9c..cdc68f2214b7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -345,7 +345,7 @@ object Implicits { * @param level The level where the reference was found * @param tstate The typer state to be committed if this alternative is chosen */ - case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GADTMap) extends SearchResult with Showable + case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GadtConstraint) extends SearchResult with Showable /** A failed search */ case class SearchFailure(tree: Tree) extends SearchResult { diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 0c754eab2570..09ab7f6d528a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -531,8 +531,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { /** A utility object offering methods for rewriting inlined code */ object reducer { - import dotty.tools.dotc.core.Contexts.GADTMap - /** An extractor for terms equivalent to `new C(args)`, returning the class `C`, * a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can * follow a reference to an inline value binding to its right hand side. From 824108c1f5390601244aa2ef93e1c5778a9a5560 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 28 Mar 2019 14:50:54 +0100 Subject: [PATCH 08/20] Document suspicious code in ProperGadtConstraint#tvar --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 00802c897dd3..488fcd5540c5 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -193,14 +193,20 @@ final class ProperGadtConstraint private( case null => val res = { import NameKinds.DepParamName + // For symbols standing for HK types, we need to preserve the kind information + // (see also usage of adaptHKvariances above) + // Ideally we'd always preserve the bounds, + // but first we need an equivalent of ConstraintHandling#addConstraint + // TODO: implement the above + val initialBounds = sym.info match { + case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb + case _ => TypeBounds.empty + } // avoid registering the TypeVar with TyperState / TyperState#constraint // - we don't want TyperState instantiating these TypeVars // - we don't want TypeComparer constraining these TypeVars val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( - pt => (sym.info match { - case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb - case _ => TypeBounds.empty - }) :: Nil, + pt => initialBounds :: Nil, pt => defn.AnyType) new TypeVar(poly.paramRefs.head, creatorState = null) } From ea1815da45214e1b671da5a39855d80cfc929e63 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Tue, 2 Apr 2019 11:57:58 +0200 Subject: [PATCH 09/20] Improve GadtConstraint comment --- .../tools/dotc/core/GadtConstraint.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 488fcd5540c5..bfa265147974 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -102,14 +102,19 @@ final class ProperGadtConstraint private( else addLess(boundTvar.origin, symTvar.origin) case bound => val oldUpperBound = bounds(symTvar.origin) - // If we already have bounds `F >: [t] => List[t] <: [t] => Any` - // and we want to record that `F <: [+A] => List[A]`, we need to adapt - // type parameter variances of the bound. Consider that the following is valid: + // If we have bounds: + // F >: [t] => List[t] <: [t] => Any + // and we want to record that: + // F <: [+A] => List[A] + // we need to adapt the variance and instead record that: + // F <: [A] => List[A] + // We cannot record the original bound, since it is false that: + // [t] => List[t] <: [+A] => List[A] // - // class Foo[F[t] >: List[t]] - // type T = Foo[List] - // - // precisely because `Foo[List]` is desugared to `Foo[[A] => List[A]]`. + // Note that the following code is accepted: + // class Foo[F[t] >: List[t]] + // type T = Foo[List] + // precisely because Foo[List] is desugared to Foo[[A] => List[A]]. val bound1 = bound.adaptHkVariances(oldUpperBound) if (isUpper) addUpperBound(symTvar.origin, bound1) else addLowerBound(symTvar.origin, bound1) From 3cf8530add67de398377b34a8627ff55a64f91e7 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 11 Apr 2019 17:56:27 +0200 Subject: [PATCH 10/20] Simplify ConstraintHandling#fullBounds --- .../tools/dotc/core/ConstraintHandling.scala | 22 ++++----------- .../tools/dotc/core/GadtConstraint.scala | 28 +++++++++++++++---- .../dotty/tools/dotc/core/TypeComparer.scala | 2 -- .../tools/dotc/printing/PlainPrinter.scala | 2 +- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index fbb673a888c2..4afd55efdefb 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -34,8 +34,6 @@ trait ConstraintHandling[AbstractContext] { protected def constraint: Constraint protected def constraint_=(c: Constraint): Unit - protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type - private[this] var addConstraintInvocations = 0 /** If the constraint is frozen we cannot add new bounds to the constraint. */ @@ -71,28 +69,20 @@ trait ConstraintHandling[AbstractContext] { case tp => tp } - def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = - constraint.nonParamBounds(param) match { - case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr)) - case tb => tb - } + def nonParamBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = constraint.nonParamBounds(param) - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).lo /: constraint.minLower(param)) { - (t, u) => t | externalize(u) - } + def fullLowerBound(param: TypeParamRef)(implicit actx: AbstractContext): Type = + (nonParamBounds(param).lo /: constraint.minLower(param))(_ | _) - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).hi /: constraint.minUpper(param)) { - (t, u) => t & externalize(u) - } + def fullUpperBound(param: TypeParamRef)(implicit actx: AbstractContext): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param))(_ & _) /** Full bounds of `param`, including other lower/upper params. * * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` * of some param when comparing types might lead to infinite recursion. Consider `bounds` instead. */ - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = + def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean = diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index bfa265147974..5eb51f1c27f3 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -180,17 +180,33 @@ final class ProperGadtConstraint private( override protected def constraint = myConstraint override protected def constraint_=(c: Constraint) = myConstraint = c - override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = - reverseMapping(param) match { - case sym: Symbol => sym.typeRef - case null => param - } - override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) + override def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = + constraint.nonParamBounds(param) match { + case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr)) + case tb => tb + } + + override def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).lo /: constraint.minLower(param)) { + (t, u) => t | externalize(u) + } + + override def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param)) { + (t, u) => t & externalize(u) + } + // ---- Private ---------------------------------------------------------- + private[this] def externalize(param: TypeParamRef)(implicit ctx: Context): Type = + reverseMapping(param) match { + case sym: Symbol => sym.typeRef + case null => param + } + private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { mapping(sym) match { case tv: TypeVar => diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 92878da97d76..64e6534be70e 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -33,8 +33,6 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { def constraint: Constraint = state.constraint def constraint_=(c: Constraint): Unit = state.constraint = c - override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = param - private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null private[this] var recCount = 0 private[this] var monitored = false diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index d3c170d50243..844533376725 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -210,7 +210,7 @@ class PlainPrinter(_ctx: Context) extends Printer { val bounds = if (constr.contains(tp)) { val ctx0 = ctx.addMode(Mode.Printing) - ctx0.typeComparer.fullBounds(tp.origin)(ctx0) + ctx0.typeComparer.fullBounds(tp.origin) } else TypeBounds.empty if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value) From 1171db44056ce67c1bafa3d0f0e9d4dd33c6e85e Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 11 Apr 2019 18:25:57 +0200 Subject: [PATCH 11/20] Document Constraint methods --- compiler/src/dotty/tools/dotc/core/Constraint.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index 89a22d0dd0a5..b40b806c85bb 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -45,10 +45,16 @@ abstract class Constraint extends Showable { /** The parameters that are known to be greater wrt <: than `param` */ def upper(param: TypeParamRef): List[TypeParamRef] - /** `lower`, except that `minLower.forall(tpr => !minLower.exists(_ <:< tpr))` */ + /** The lower dominator set. + * + * This is like `lower`, except that each parameter returned is no smaller than every other returned parameter. + */ def minLower(param: TypeParamRef): List[TypeParamRef] - /** `upper`, except that `minUpper.forall(tpr => !minUpper.exists(tpr <:< _))` */ + /** The upper dominator set. + * + * This is like `upper`, except that each parameter returned is no greater than every other returned parameter. + */ def minUpper(param: TypeParamRef): List[TypeParamRef] /** lower(param) \ lower(butNot) */ From 177bd3666601b6bd0d4fc5a4d134d2e0350e78c4 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Fri, 12 Apr 2019 14:37:10 +0200 Subject: [PATCH 12/20] Allow adding multiple symbols to GadtConstraint simultaneously The added symbols can have inter-dependencies in their bounds. --- .../tools/dotc/core/GadtConstraint.scala | 125 ++++++++++++------ .../src/dotty/tools/dotc/core/Symbols.scala | 12 +- .../src/dotty/tools/dotc/core/TypeOps.scala | 2 +- .../dotty/tools/dotc/typer/Inferencing.scala | 8 +- .../src/dotty/tools/dotc/typer/Inliner.scala | 6 +- .../src/dotty/tools/dotc/typer/Namer.scala | 9 +- .../src/dotty/tools/dotc/typer/Typer.scala | 48 +++---- 7 files changed, 117 insertions(+), 93 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 5eb51f1c27f3..f4a30c8391a1 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -27,7 +27,11 @@ sealed abstract class GadtConstraint extends Showable { /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean - def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit + /** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */ + def addToConstraint(syms: List[Symbol])(implicit ctx: Context): Boolean + def addToConstraint(sym: Symbol)(implicit ctx: Context): Boolean = addToConstraint(sym :: Nil) + + /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean /** Is the symbol registered in the constraint? @@ -72,7 +76,54 @@ final class ProperGadtConstraint private( subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) } - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) + override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = { + import NameKinds.DepParamName + + val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })( + pt => params.map { param => + // replace the symbols in bound type `tp` which are in dependent positions + // with their internal TypeParamRefs + def substDependentSyms(tp: Type, isUpper: Boolean)(implicit ctx: Context): Type = { + def loop(tp: Type) = substDependentSyms(tp, isUpper) + tp match { + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp: NamedType => + params.indexOf(tp.symbol) match { + case -1 => + mapping(tp.symbol) match { + case tv: TypeVar => tv.origin + case null => tp + } + case i => pt.paramRefs(i) + } + case tp => tp + } + } + + val tb = param.info.bounds + tb.derivedTypeBounds( + lo = substDependentSyms(tb.lo, isUpper = false), + hi = substDependentSyms(tb.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) => + val tv = new TypeVar(paramRef, creatorState = null) + mapping = mapping.updated(sym, tv) + reverseMapping = reverseMapping.updated(tv.origin, sym) + tv + } + + // the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings + addToConstraint(poly1, tvars).reporting({ _ => + i"added to constraint: $params%, %\n$debugBoundsDescription" + }, gadts) + } override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = { @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { @@ -82,7 +133,7 @@ final class ProperGadtConstraint private( case _ => tp } - val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { + val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { case tv: TypeVar => tv case inst => gadts.println(i"instantiated: $sym -> $inst") @@ -90,8 +141,9 @@ final class ProperGadtConstraint private( } val internalizedBound = bound match { - case nt: NamedType if contains(nt.symbol) => - stripInternalTypeVar(tvar(nt.symbol)) + case nt: NamedType => + val ntTvar = mapping(nt.symbol) + if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound case _ => bound } ( @@ -119,20 +171,22 @@ final class ProperGadtConstraint private( if (isUpper) addUpperBound(symTvar.origin, bound1) else addLowerBound(symTvar.origin, bound1) } - ).reporting({ res => + ).reporting({ res => val descr = if (isUpper) "upper" else "lower" val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" + i"adding $descr bound $sym $op $bound = $res" }, gadts) } override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = - constraint.isLess(tvar(sym1).origin, tvar(sym2).origin) + constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = mapping(sym) match { case null => null - case tv => fullBounds(tv.origin) + case tv => + fullBounds(tv.origin) + .ensuring(containsNoInternalTypes(_)) } override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { @@ -145,14 +199,16 @@ final class ProperGadtConstraint private( TypeAlias(reverseMapping(tpr).typeRef) case tb => tb } - retrieveBounds//.reporting({ res => i"gadt bounds $sym: $res" }, gadts) + retrieveBounds + //.reporting({ res => i"gadt bounds $sym: $res" }, gadts) + .ensuring(containsNoInternalTypes(_)) } } override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { - val res = approximation(tvar(sym).origin, fromBelow = fromBelow) + val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow) gadts.println(i"approximating $sym ~> $res") res } @@ -207,36 +263,21 @@ final class ProperGadtConstraint private( case null => param } - private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { - mapping(sym) match { - case tv: TypeVar => - tv - case null => - val res = { - import NameKinds.DepParamName - // For symbols standing for HK types, we need to preserve the kind information - // (see also usage of adaptHKvariances above) - // Ideally we'd always preserve the bounds, - // but first we need an equivalent of ConstraintHandling#addConstraint - // TODO: implement the above - val initialBounds = sym.info match { - case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb - case _ => TypeBounds.empty - } - // avoid registering the TypeVar with TyperState / TyperState#constraint - // - we don't want TyperState instantiating these TypeVars - // - we don't want TypeComparer constraining these TypeVars - val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( - pt => initialBounds :: Nil, - pt => defn.AnyType) - new TypeVar(poly.paramRefs.head, creatorState = null) - } - gadts.println(i"GADTMap: created tvar $sym -> $res") - constraint = constraint.add(res.origin.binder, res :: Nil) - mapping = mapping.updated(sym, res) - reverseMapping = reverseMapping.updated(res.origin, sym) - res - } + private[this] def tvarOrError(sym: Symbol)(implicit ctx: Context): TypeVar = + mapping(sym).ensuring(_ ne null, i"not a constrainable symbol: $sym") + + private[this] def containsNoInternalTypes( + tp: Type, + acc: TypeAccumulator[Boolean] = null + )(implicit ctx: Context): Boolean = tp match { + case tpr: TypeParamRef => !reverseMapping.contains(tpr) + case tv: TypeVar => !reverseMapping.contains(tv.origin) + case tp => + (if (acc ne null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp) + } + + private[this] class ContainsNoInternalTypesAccumulator(implicit ctx: Context) extends TypeAccumulator[Boolean] { + override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp) } // ---- Debug ------------------------------------------------------------ @@ -266,7 +307,7 @@ final class ProperGadtConstraint private( override def contains(sym: Symbol)(implicit ctx: Context) = false - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGadtConstraint.addEmptyBounds") + override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGadtConstraint.approximation") diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index ef81b8bb3bf9..b3d41754dc87 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -209,16 +209,10 @@ trait Symbols { this: Context => modFlags | PackageCreationFlags, clsFlags | PackageCreationFlags, Nil, decls) - /** Define a new symbol associated with a Bind or pattern wildcard and - * make it gadt narrowable. - */ - def newPatternBoundSymbol(name: Name, info: Type, span: Span): Symbol = { + /** Define a new symbol associated with a Bind or pattern wildcard and, by default, make it gadt narrowable. */ + def newPatternBoundSymbol(name: Name, info: Type, span: Span, addToGadt: Boolean = true): Symbol = { val sym = newSymbol(owner, name, Case, info, coord = span) - if (name.isTypeName) { - val bounds = info.bounds - gadt.addBound(sym, bounds.lo, isUpper = false) - gadt.addBound(sym, bounds.hi, isUpper = true) - } + if (addToGadt && name.isTypeName) gadt.addToConstraint(sym) sym } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index f41a710dcde5..664016b99f38 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -387,7 +387,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object. val bound1 = massage(bound) if (bound1 ne bound) { if (checkCtx eq ctx) checkCtx = ctx.fresh.setFreshGADTBounds - if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addEmptyBounds(sym) + if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addToConstraint(sym) checkCtx.gadt.addBound(sym, bound1, fromBelow) typr.println("install GADT bound $bound1 for when checking F-bounded $sym") } diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index f84a7d59e494..ae13037a2f52 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -284,13 +284,17 @@ object Inferencing { if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x) tvar.instantiate(fromBelow = false) else { - val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span) + // since the symbols we're creating may have inter-dependencies in their bounds, + // we add them to the GADT constraint later, simultaneously + val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span, addToGadt = false) tvar.instantiateWith(wildCard.typeRef) patternBound += wildCard } } } - patternBound.toList + val res = patternBound.toList + if (res.nonEmpty) ctx.gadt.addToConstraint(res) + res } type VarianceMap = SimpleIdentityMap[TypeVar, Integer] diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 09ab7f6d528a..f1ff0070469c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -821,11 +821,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } def registerAsGadtSyms(typeBinds: TypeBindsMap)(implicit ctx: Context): Unit = - typeBinds.foreachBinding { case (sym, _) => - val TypeBounds(lo, hi) = sym.info.bounds - ctx.gadt.addBound(sym, lo, isUpper = false) - ctx.gadt.addBound(sym, hi, isUpper = true) - } + if (typeBinds.size > 0) ctx.gadt.addToConstraint(typeBinds.keys) pat match { case Typed(pat1, tpt) => diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 65a9a0789d08..ba38ea9b75bc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1338,12 +1338,11 @@ class Namer { typer: Typer => var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) if (typeParams.nonEmpty) { + // we'll be typing an expression from a polymorphic definition's body, + // so we must allow constraining its type parameters + // compare with typedDefDef, see tests/pos/gadt-inference.scala rhsCtx.setFreshGADTBounds - typeParams.foreach { tdef => - val TypeBounds(lo, hi) = tdef.info.bounds - rhsCtx.gadt.addBound(tdef, lo, isUpper = false) - rhsCtx.gadt.addBound(tdef, hi, isUpper = true) - } + rhsCtx.gadt.addToConstraint(typeParams) } def rhsType = typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(rhsCtx).tpe diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index ca08f8eb7d89..bd4295c44f73 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1508,38 +1508,28 @@ class Typer extends Namer if (sym is ImplicitOrImplied) checkImplicitConversionDefOK(sym) val tpt1 = checkSimpleKinded(typedType(tpt)) - val rhsCtx: Context = { - var _result: FreshContext = null - def resultCtx(): FreshContext = { - if (_result == null) _result = ctx.fresh - _result - } - - if (tparams1.nonEmpty) { - resultCtx().setFreshGADTBounds - if (!sym.isConstructor) { - // if we're _not_ in a constructor, allow constraining type parameters - tparams1.foreach { tdef => - val tb @ TypeBounds(lo, hi) = tdef.symbol.info.bounds - resultCtx().gadt.addBound(tdef.symbol, lo, isUpper = false) - resultCtx().gadt.addBound(tdef.symbol, hi, isUpper = true) - } - } else if (!sym.isPrimaryConstructor) { - // otherwise, for secondary constructors we need a context that "knows" - // that their type parameters are aliases of the class type parameters. - // See pos/i941.scala - (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => - val tr = tparam.typeRef - resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = false) - resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = true) - } + val rhsCtx = ctx.fresh + if (tparams1.nonEmpty) { + rhsCtx.setFreshGADTBounds + if (!sym.isConstructor) { + // we're typing a polymorphic definition's body, + // so we allow constraining all of its type parameters + // constructors are an exception as we don't allow constraining type params of classes + rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol)) + } else if (!sym.isPrimaryConstructor) { + // otherwise, for secondary constructors we need a context that "knows" + // that their type parameters are aliases of the class type parameters. + // See pos/i941.scala + rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol)) + (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => + val tr = tparam.typeRef + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) } } - - if (sym.isInlineMethod) resultCtx().addMode(Mode.InlineableBody) - - if (_result ne null) _result else ctx } + + if (sym.isInlineMethod) rhsCtx.addMode(Mode.InlineableBody) val rhs1 = typedExpr(ddef.rhs, tpt1.tpe.widenExpr)(rhsCtx) if (sym.isInlineMethod) { From e32c2d6643dfe18fde07082ea378cc3b85e0571b Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Thu, 18 Apr 2019 14:50:27 +0200 Subject: [PATCH 13/20] Restore old constrainPatternType doc --- .../dotty/tools/dotc/typer/Inferencing.scala | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index ae13037a2f52..4ebeb399f425 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -153,20 +153,39 @@ object Inferencing { def isSkolemFree(tp: Type)(implicit ctx: Context): Boolean = !tp.existsPart(_.isInstanceOf[SkolemType]) - /** Infer constraints that should be in scope for a case body with given pattern and scrutinee types. + /** Derive information about a pattern type by comparing it with some variant of the + * static scrutinee type. We have the following situation in case of a (dynamic) pattern match: * - * If `termPattern`, infer constraints from knowing that there exists a value which of both scrutinee - * and pattern types (which is the case for normal pattern matching). If not `termPattern`, instead - * infer constraints from knowing that `tp <: pt`. + * StaticScrutineeType PatternType + * \ / + * DynamicScrutineeType * - * If a pattern matches during normal pattern matching, we can be certain that there exists a value - * which is of both scrutinee and pattern types (the value we're matching on). If this value - * was in a variable, say `x`, then we could simply infer constraints from `x.type <: pt`. Since we might - * be matching on an expression as well, we take a skolem of the scrutinee, which is essentially an existential - * singleton type (see [[dotty.tools.dotc.core.Types.SkolemType]]). + * If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained. + * Now let's say we can prove that `PatternType <: StaticScrutineeType`. * - * Note that we need to sometimes widen type parameters of the scrutinee type to avoid unsoundness - - * see i3989c.scala and related issue discussion on Github. + * StaticScrutineeType + * | \ + * | \ + * | \ + * | PatternType + * | / + * DynamicScrutineeType + * + * What can we say about the relationship of parameter types between `PatternType` and + * `DynamicScrutineeType`? + * + * - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType` + * in the same way as `PatternType` ("invariant refinement"), the subtype test + * `PatternType <:< StaticScrutineeType` tells us all we need to know. + * - Otherwise, if variant refinement is a possibility we can only make predictions + * about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test + * where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams` + * replaces all type argument of variant parameters with empty bounds. + * + * Invariant refinement can be assumed if `PatternType`'s class(es) are final or + * case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`). + * + * @param termPattern are we dealing with a term-level or a type-level pattern? */ def constrainPatternType(tp: Type, pt: Type, termPattern: Boolean)(implicit ctx: Context): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { From cb15213d7fdbfa3268a0d6c522874dcbea5bfd27 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Tue, 30 Apr 2019 10:54:02 +0200 Subject: [PATCH 14/20] Remove unnecessary calls to constrainPatternType constrainPatternType is specific to term patterns, whereas in match types there is a simple subtyping relationship between the pattern and the scrutinee. In the future, simply calling isSubType in GADTFlexible context would likely be sufficient. --- compiler/src/dotty/tools/dotc/typer/Applications.scala | 2 +- compiler/src/dotty/tools/dotc/typer/Inferencing.scala | 6 ++---- compiler/src/dotty/tools/dotc/typer/Typer.scala | 4 +--- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index c56aef84aebc..ca5d94dba23b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1090,7 +1090,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic => * - If a type proxy P is not a reference to a class, P's supertype is in G */ def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean = - if (constrainPatternType(subtp, tp, termPattern = true)) true + if (constrainPatternType(subtp, tp)) true else tp match { case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent) case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType) diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 4ebeb399f425..05d8a7568acc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -184,10 +184,8 @@ object Inferencing { * * Invariant refinement can be assumed if `PatternType`'s class(es) are final or * case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`). - * - * @param termPattern are we dealing with a term-level or a type-level pattern? */ - def constrainPatternType(tp: Type, pt: Type, termPattern: Boolean)(implicit ctx: Context): Boolean = { + def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case) case tp: TypeProxy => refinementIsInvariant(tp.underlying) @@ -209,7 +207,7 @@ object Inferencing { } val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt) - val narrowTp = if (termPattern) SkolemType(tp) else tp + val narrowTp = SkolemType(tp) trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { narrowTp <:< widePt } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index bd4295c44f73..6faac5b0614b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -604,7 +604,7 @@ class Typer extends Namer def handlePattern: Tree = { val tpt1 = typedTpt if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef) - constrainPatternType(tpt1.tpe, pt, termPattern = true)(ctx.addMode(Mode.GADTflexible)) + constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible)) // special case for an abstract type that comes with a class tag tryWithClassTag(ascription(tpt1, isWildcard = true), pt) } @@ -1103,8 +1103,6 @@ class Typer extends Namer def typedTypeCase(cdef: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { def caseRest(implicit ctx: Context) = { val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern))) - if (!ctx.isAfterTyper) - constrainPatternType(pat1.tpe, selType, termPattern = false)(ctx.addMode(Mode.GADTflexible)) val pat2 = indexPattern(cdef).transform(pat1) val body1 = typedType(cdef.body, pt) assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1) From af5010d8ad9ba974888c8058a0573d9cdb0b0a99 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 6 May 2019 17:36:31 +0200 Subject: [PATCH 15/20] Further improve GadtConstraint comment --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index f4a30c8391a1..e5fd3186e5f1 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -167,6 +167,9 @@ final class ProperGadtConstraint private( // class Foo[F[t] >: List[t]] // type T = Foo[List] // precisely because Foo[List] is desugared to Foo[[A] => List[A]]. + // + // Ideally we'd adapt the bound in ConstraintHandling#addOneBound, + // but doing it there actually interferes with type inference. val bound1 = bound.adaptHkVariances(oldUpperBound) if (isUpper) addUpperBound(symTvar.origin, bound1) else addLowerBound(symTvar.origin, bound1) From 65e2aa29cd8953674941dd874d207713df5c8c58 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 6 May 2019 17:36:57 +0200 Subject: [PATCH 16/20] Comment out accidentally committed sanity check --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index e5fd3186e5f1..43a810c04dfe 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -204,7 +204,7 @@ final class ProperGadtConstraint private( } retrieveBounds //.reporting({ res => i"gadt bounds $sym: $res" }, gadts) - .ensuring(containsNoInternalTypes(_)) + //.ensuring(containsNoInternalTypes(_)) } } From c9b093b49dec4769bb825fb5cadb9795719d093d Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 6 May 2019 17:47:30 +0200 Subject: [PATCH 17/20] Simplify NoMatchingImplicit#clarify --- compiler/src/dotty/tools/dotc/typer/Implicits.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index cdc68f2214b7..61647f5a1f8b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -397,13 +397,7 @@ object Implicits { * what was expected */ override def clarify(tp: Type)(implicit ctx: Context): Type = { - val ctx0 = ctx - locally { - implicit val ctx = ctx0.fresh.setTyperState { - val ts = ctx0.typerState.fresh() - ts.constraint_=(constraint)(ctx0) - ts - } + def replace(implicit ctx: Context): Type = { val map = new TypeMap { def apply(t: Type): Type = t match { case t: TypeParamRef => @@ -420,6 +414,10 @@ object Implicits { } map(tp) } + + val ctx1 = ctx.fresh.setExploreTyperState() + ctx1.typerState.constraint = constraint + replace(ctx1) } def explanation(implicit ctx: Context): String = From 9b96317680d53ba2696d8b232e8c0d5fa985cb91 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Tue, 7 May 2019 09:51:18 +0200 Subject: [PATCH 18/20] Further improve comments around GadtConstraint --- .../tools/dotc/core/GadtConstraint.scala | 19 +++++++++++-------- .../dotty/tools/dotc/typer/Inferencing.scala | 5 +++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 43a810c04dfe..11b869b8f995 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -19,15 +19,18 @@ sealed abstract class GadtConstraint extends Showable { /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. * - * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` - * of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead. + * @note this performs subtype checks between ordered symbols. + * Using this in isSubType can lead to infinite recursion. Consider `bounds` instead. */ def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean - /** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */ + /** Add symbols to constraint, correctly handling inter-dependencies. + * + * @see [[ConstraintHandling.addToConstraint]] + */ def addToConstraint(syms: List[Symbol])(implicit ctx: Context): Boolean def addToConstraint(sym: Symbol)(implicit ctx: Context): Boolean = addToConstraint(sym :: Nil) @@ -36,8 +39,7 @@ sealed abstract class GadtConstraint extends Showable { /** Is the symbol registered in the constraint? * - * Note that this is returns `true` even if `sym` is already instantiated to some type, - * unlike [[Constraint.contains]]. + * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. */ def contains(sym: Symbol)(implicit ctx: Context): Boolean @@ -81,8 +83,9 @@ final class ProperGadtConstraint private( val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })( pt => params.map { param => - // replace the symbols in bound type `tp` which are in dependent positions - // with their internal TypeParamRefs + // In bound type `tp`, replace the symbols in dependent positions with their internal TypeParamRefs. + // The replaced symbols will be later picked up in `ConstraintHandling#addToConstraint` + // and used as orderings. def substDependentSyms(tp: Type, isUpper: Boolean)(implicit ctx: Context): Type = { def loop(tp: Type) = substDependentSyms(tp, isUpper) tp match { @@ -119,7 +122,7 @@ final class ProperGadtConstraint private( tv } - // the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings + // The replaced symbols are picked up here. addToConstraint(poly1, tvars).reporting({ _ => i"added to constraint: $params%, %\n$debugBoundsDescription" }, gadts) diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 05d8a7568acc..0902964edea7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -301,8 +301,8 @@ object Inferencing { if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x) tvar.instantiate(fromBelow = false) else { - // since the symbols we're creating may have inter-dependencies in their bounds, - // we add them to the GADT constraint later, simultaneously + // We do not add the created symbols to GADT constraint immediately, since they may have inter-dependencies. + // Instead, we simultaneously add them later on. val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span, addToGadt = false) tvar.instantiateWith(wildCard.typeRef) patternBound += wildCard @@ -310,6 +310,7 @@ object Inferencing { } } val res = patternBound.toList + // We add the created symbols to GADT constraint here. if (res.nonEmpty) ctx.gadt.addToConstraint(res) res } From 12b8999d2983bf0a7259fa98ab3cd5422af20e79 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 13 May 2019 14:55:15 +0200 Subject: [PATCH 19/20] Inline a variable --- compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 844533376725..ffa4aa9a68a2 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -208,10 +208,7 @@ class PlainPrinter(_ctx: Context) extends Printer { else { val constr = ctx.typerState.constraint val bounds = - if (constr.contains(tp)) { - val ctx0 = ctx.addMode(Mode.Printing) - ctx0.typeComparer.fullBounds(tp.origin) - } + if (constr.contains(tp)) ctx.addMode(Mode.Printing).typeComparer.fullBounds(tp.origin) else TypeBounds.empty if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value) else if (ctx.settings.YshowVarBounds.value) "(" ~ toText(tp.origin) ~ "?" ~ toText(bounds) ~ ")" From 3c3b3db63be69936c391db3a965ebda1aff03391 Mon Sep 17 00:00:00 2001 From: Aleksander Boruch-Gruszecki Date: Mon, 13 May 2019 14:55:24 +0200 Subject: [PATCH 20/20] Split TypeComparer#either and document both cases --- .../dotty/tools/dotc/core/TypeComparer.scala | 95 ++++++++++++++++--- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 64e6534be70e..86776c4fda00 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1219,6 +1219,17 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { fix(tp) } + /** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints. + * + * If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints. + * If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all. + * + * @see [[sufficientEither]] for the normal case + * @see [[necessaryEither]] for the GADTFlexible case + */ + private def either(op1: => Boolean, op2: => Boolean): Boolean = + if (ctx.mode.is(Mode.GADTflexible)) necessaryEither(op1, op2) else sufficientEither(op1, op2) + /** Returns true iff the result of evaluating either `op1` or `op2` is true, * trying at the same time to keep the constraint as wide as possible. * E.g, if @@ -1247,13 +1258,79 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { * * Here, each precondition leads to a different constraint, and neither of * the two post-constraints subsumes the other. + * + * Note that to be complete when it comes to typechecking, we would instead need to backtrack + * and attempt to typecheck with the other constraint. + * + * Method name comes from the notion that we are keeping a constraint which is sufficient to satisfy + * one of subtyping relationships. */ - private def either(op1: => Boolean, op2: => Boolean): Boolean = { + private def sufficientEither(op1: => Boolean, op2: => Boolean): Boolean = { + val preConstraint = constraint + op1 && { + val leftConstraint = constraint + constraint = preConstraint + if (!(op2 && subsumes(leftConstraint, constraint, preConstraint))) { + if (constr != noPrinter && !subsumes(constraint, leftConstraint, preConstraint)) + constr.println(i"CUT - prefer $leftConstraint over $constraint") + constraint = leftConstraint + } + true + } || op2 + } + + /** Returns true iff the result of evaluating either `op1` or `op2` is true, keeping the smaller constraint if any. + * E.g., if + * + * tp11 <:< tp12 = true with constraint c1 and GADT constraint g1 + * tp12 <:< tp22 = true with constraint c2 and GADT constraint g2 + * + * We keep: + * - (c1, g1) if c2 subsumes c1 and g2 subsumes g1 + * - (c2, g2) if c1 subsumes c2 and g1 subsumes g2 + * - neither constraint pair otherwise. + * + * Like [[sufficientEither]], this method is used to approximate a solution in one of the following cases: + * + * T1 & T2 <:< T3 + * T1 <:< T2 | T3 + * + * Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT + * constraints that necessarily follow from the subtyping relationship. For instance, if we have + * + * enum Expr[T] { + * case IntExpr(i: Int) extends Expr[Int] + * case StrExpr(s: String) extends Expr[String] + * } + * + * and `A` is an abstract type and we know that + * + * Expr[A] <: IntExpr | StrExpr + * + * (the case with &-type is analogous) then this may follow either from + * + * Expr[A] <: IntExpr or Expr[A] <: StrExpr + * + * Since we don't know which branch is true, we need to give up and not keep either constraint. OTOH, if one + * constraint pair is subsumed by the other, we know that it is necessary for both cases and therefore we can + * keep it. + * + * Like [[sufficientEither]], this method is not complete because sometimes, the necessary constraint + * is neither of the pairs. For instance, if + * + * g1 = { A = Int, B = String } + * g2 = { A = Int, B = Int } + * + * then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive. + * + * Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both + * subtyping relationships. + */ + private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = { val preConstraint = constraint - if (ctx.mode.is(Mode.GADTflexible)) { val preGadt = ctx.gadt.fresh - // if GADTflexible mode is on, we always have a ProperGadtConstraint + // if GADTflexible mode is on, we expect to always have a ProperGadtConstraint val pre = preGadt.asInstanceOf[ProperGadtConstraint] if (op1) { val leftConstraint = constraint @@ -1284,18 +1361,6 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { true } } else op2 - } else { - op1 && { - val leftConstraint = constraint - constraint = preConstraint - if (!(op2 && subsumes(leftConstraint, constraint, preConstraint))) { - if (constr != noPrinter && !subsumes(constraint, leftConstraint, preConstraint)) - constr.println(i"CUT - prefer $leftConstraint over $constraint") - constraint = leftConstraint - } - true - } || op2 - } } /** Does type `tp1` have a member with name `name` whose normalized type is a subtype of