diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index fb21695a1b07..b212e3cbfcfb 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -724,15 +724,18 @@ object Contexts { def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds def contains(sym: Symbol)(implicit ctx: Context): Boolean + def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type def debugBoundsDescription(implicit ctx: Context): String def fresh: GADTMap + def restore(other: GADTMap): Unit + def isEmpty: Boolean } final class SmartGADTMap private ( - private[this] var myConstraint: Constraint, - private[this] var mapping: SimpleIdentityMap[Symbol, TypeVar], - private[this] var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private[this] var boundCache: SimpleIdentityMap[Symbol, TypeBounds] + private var myConstraint: Constraint, + private var mapping: SimpleIdentityMap[Symbol, TypeVar], + private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], + private var boundCache: SimpleIdentityMap[Symbol, TypeBounds] ) extends GADTMap with ConstraintHandling[Context] { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -833,6 +836,12 @@ object Contexts { override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { + val res = removeTypeVars(approximation(tvar(sym).origin, fromBelow = fromBelow)) + gadts.println(i"approximating $sym ~> $res") + res + } + override def fresh: GADTMap = new SmartGADTMap( myConstraint, mapping, @@ -840,6 +849,17 @@ object Contexts { boundCache ) + def restore(other: GADTMap): Unit = other match { + case other: SmartGADTMap => + this.myConstraint = other.myConstraint + this.mapping = other.mapping + this.reverseMapping = other.reverseMapping + this.boundCache = other.boundCache + case _ => ; + } + + override def isEmpty: Boolean = mapping.size == 0 + // ---- Private ---------------------------------------------------------- private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { @@ -916,7 +936,12 @@ object Contexts { override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound") override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null override def contains(sym: Symbol)(implicit ctx: Context) = false + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation") override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" override def fresh = new SmartGADTMap + override def restore(other: GADTMap): Unit = { + if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") + } + override def isEmpty: Boolean = true } } diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index e9a18e09380a..706d29c9be4b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -346,7 +346,7 @@ object Implicits { * @param level The level where the reference was found * @param tstate The typer state to be committed if this alternative is chosen */ - case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState) extends SearchResult with Showable + case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GADTMap) extends SearchResult with Showable /** A failed search */ case class SearchFailure(tree: Tree) extends SearchResult { @@ -880,6 +880,7 @@ trait Implicits { self: Typer => result0 match { case result: SearchSuccess => result.tstate.commit() + ctx.gadt.restore(result.gstate) implicits.println(i"success: $result") implicits.println(i"committing ${result.tstate.constraint} yielding ${ctx.typerState.constraint} in ${ctx.typerState}") result @@ -1004,7 +1005,7 @@ trait Implicits { self: Typer => val generated2 = if (cand.isExtension) Applications.ExtMethodApply(generated1).withType(generated1.tpe) else generated1 - SearchSuccess(generated2, ref, cand.level)(ctx.typerState) + SearchSuccess(generated2, ref, cand.level)(ctx.typerState, ctx.gadt) } }} @@ -1016,7 +1017,8 @@ trait Implicits { self: Typer => SearchFailure(new DivergingImplicit(cand.ref, pt.widenExpr, argument)) else { val history = ctx.searchHistory.nest(cand, pt) - val result = typedImplicit(cand, contextual)(nestedContext().setNewTyperState().setSearchHistory(history)) + val result = + typedImplicit(cand, contextual)(nestedContext().setNewTyperState().setFreshGADTBounds.setSearchHistory(history)) result match { case res: SearchSuccess => ctx.searchHistory.defineBynameImplicit(pt.widenExpr, res) @@ -1118,7 +1120,9 @@ trait Implicits { self: Typer => result match { case _: SearchFailure => SearchSuccess(ref(defn.Not_value), defn.Not_value.termRef, 0)( - ctx.typerState.fresh().setCommittable(true)) + ctx.typerState.fresh().setCommittable(true), + ctx.gadt + ) case _: SearchSuccess => NoMatchingImplicitsFailure } @@ -1188,7 +1192,7 @@ trait Implicits { self: Typer => // other candidates need to be considered. ctx.searchHistory.recursiveRef(pt) match { case ref: TermRef => - SearchSuccess(tpd.ref(ref).withPos(pos.startPos), ref, 0)(ctx.typerState) + SearchSuccess(tpd.ref(ref).withPos(pos.startPos), ref, 0)(ctx.typerState, ctx.gadt) case _ => val eligible = if (contextual) ctx.implicits.eligible(wildProto) @@ -1428,7 +1432,7 @@ final class SearchRoot extends SearchHistory { implicitDictionary.get(tpe) match { case Some((ref, _)) => implicitDictionary.put(tpe, (ref, result.tree)) - SearchSuccess(tpd.ref(ref).withPos(result.tree.pos), result.ref, result.level)(result.tstate) + SearchSuccess(tpd.ref(ref).withPos(result.tree.pos), result.ref, result.level)(result.tstate, result.gstate) case None => result } } @@ -1529,7 +1533,7 @@ final class SearchRoot extends SearchHistory { val blk = Block(classDef :: inst :: Nil, res) - success.copy(tree = blk)(success.tstate) + success.copy(tree = blk)(success.tstate, success.gstate) } } } diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 11638c5dbd12..d27a6199f957 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -2,7 +2,7 @@ package dotty.tools package dotc package typer -import dotty.tools.dotc.ast.{Trees, untpd, tpd, TreeTypeMap} +import ast._ import Trees._ import core._ import Flags._ @@ -14,16 +14,17 @@ import StdNames._ import transform.SymUtils._ import Contexts.Context import Names.{Name, TermName} -import NameKinds.{InlineAccessorName, InlineScrutineeName, InlineBinderName} +import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName} import ProtoTypes.selectionProto import SymDenotations.SymDenotation import Inferencing.fullyDefinedType import config.Printers.inlining import ErrorReporting.errorTree +import dotty.tools.dotc.util.{SimpleIdentityMap, SimpleIdentitySet} + import collection.mutable import reporting.trace import util.Positions.Position -import ast.TreeInfo object Inliner { import tpd._ @@ -667,7 +668,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { * for the pattern-bound variables and the RHS of the selected case. * Returns `None` if no case was selected. */ - type MatchRedux = Option[(List[MemberDef], tpd.Tree)] + type MatchRedux = Option[(List[MemberDef], Tree)] /** Reduce an inline match * @param mtch the match tree @@ -687,7 +688,13 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { /** Try to match pattern `pat` against scrutinee reference `scrut`. If successful add * bindings for variables bound in this pattern to `bindingsBuf`. */ - def reducePattern(bindingsBuf: mutable.ListBuffer[MemberDef], scrut: TermRef, pat: Tree)(implicit ctx: Context): Boolean = { + def reducePattern( + bindingsBuf: mutable.ListBuffer[MemberDef], + fromBuf: mutable.ListBuffer[TypeSymbol], + toBuf: mutable.ListBuffer[TypeSymbol], + scrut: TermRef, + pat: Tree + )(implicit ctx: Context): Boolean = { /** Create a binding of a pattern bound variable with matching part of * scrutinee as RHS and type that corresponds to RHS. @@ -712,50 +719,82 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } } - pat match { - case Typed(pat1, tpt) => - val getBoundVars = new TreeAccumulator[List[TypeSymbol]] { - def apply(syms: List[TypeSymbol], t: Tree)(implicit ctx: Context) = { - val syms1 = t match { - case t: Bind if t.symbol.isType => - t.symbol.asType :: syms - case _ => - syms - } - foldOver(syms1, t) + type TypeBindsMap = SimpleIdentityMap[TypeSymbol, java.lang.Boolean] + + def getTypeBindsMap(pat: Tree, tpt: Tree): TypeBindsMap = { + val getBinds = new TreeAccumulator[Set[TypeSymbol]] { + def apply(syms: Set[TypeSymbol], t: Tree)(implicit ctx: Context): Set[TypeSymbol] = { + val syms1 = t match { + case t: Bind if t.symbol.isType => + syms + t.symbol.asType + case _ => syms } + foldOver(syms1, t) } - var boundVars = getBoundVars(Nil, tpt) - // UnApply nodes with pattern bound variables translate to something like this - // UnApply[t @ t](pats)(implicits): T[t] - // Need to traverse any binds in type arguments of the UnAppyl to get the set of - // all instantiable type variables. Test case is pos/inline-caseclass.scala. - pat1 match { - case UnApply(TypeApply(_, tpts), _, _) => - for (tpt <- tpts) boundVars = getBoundVars(boundVars, tpt) - case _ => - } - for (bv <- boundVars) { - val TypeBounds(lo, hi) = bv.info.bounds - ctx.gadt.addBound(bv, lo, isUpper = false) - ctx.gadt.addBound(bv, hi, isUpper = true) + } + + // Extractors contain Bind nodes in type parameter lists, the tree looks like this: + // UnApply[t @ t](pats)(implicits): T[t] + // Test case is pos/inline-caseclass.scala. + val binds: Set[TypeSymbol] = pat match { + case UnApply(TypeApply(_, tpts), _, _) => getBinds(Set.empty[TypeSymbol], tpts) + case _ => getBinds(Set.empty[TypeSymbol], tpt) + } + + val extractBindVariance = new TypeAccumulator[TypeBindsMap] { + def apply(syms: TypeBindsMap, t: Type) = { + val syms1 = t match { + // `binds` is used to check if the symbol was actually bound by the pattern we're processing + case tr: TypeRef if tr.symbol.is(Case) && binds.contains(tr.symbol.asType) => + val trSym = tr.symbol.asType + // Exact same logic as in IsFullyDefinedAccumulator: + // the binding is to be maximized iff it only occurs contravariantly in the type + val wasToBeMinimized: Boolean = { + val v = syms(trSym) + if (v ne null) v else false + } + syms.updated(trSym, wasToBeMinimized || variance >= 0 : java.lang.Boolean) + case _ => + syms + } + foldOver(syms1, t) } + } + + extractBindVariance(SimpleIdentityMap.Empty, tpt.tpe) + } + + def registerAsGadtSyms(typeBinds: TypeBindsMap)(implicit ctx: Context): Unit = + typeBinds.foreachBinding { case (sym, _) => + val TypeBounds(lo, hi) = sym.info.bounds + ctx.gadt.addBound(sym, lo, isUpper = false) + ctx.gadt.addBound(sym, hi, isUpper = true) + } + + def addTypeBindings(typeBinds: TypeBindsMap)(implicit ctx: Context): Unit = + typeBinds.foreachBinding { case (sym, shouldBeMinimized) => + val copied = sym.copy(info = TypeAlias(ctx.gadt.approximation(sym, fromBelow = shouldBeMinimized))).asType + fromBuf += sym + toBuf += copied + } + + pat match { + case Typed(pat1, tpt) => + val typeBinds = getTypeBindsMap(pat1, tpt) + registerAsGadtSyms(typeBinds) scrut <:< tpt.tpe && { - for (bv <- boundVars) { - bv.info = TypeAlias(ctx.gadt.bounds(bv).lo) - // FIXME: This is very crude. We should approximate with lower or higher bound depending - // on variance, and we should also take care of recursive bounds. Basically what - // ConstraintHandler#approximation does. However, this only works for constrained paramrefs - // not GADT-bound variables. Hopefully we will get some way to improve this when we - // re-implement GADTs in terms of constraints. - if (bv.name != nme.WILDCARD) bindingsBuf += TypeDef(bv) - } - reducePattern(bindingsBuf, scrut, pat1) + addTypeBindings(typeBinds) + reducePattern(bindingsBuf, fromBuf, toBuf, scrut, pat1) } case pat @ Bind(name: TermName, Typed(_, tpt)) if isImplicit => - searchImplicit(pat.symbol.asTerm, tpt) + val typeBinds = getTypeBindsMap(tpt, tpt) + registerAsGadtSyms(typeBinds) + searchImplicit(pat.symbol.asTerm, tpt) && { + addTypeBindings(typeBinds) + true + } case pat @ Bind(name: TermName, body) => - reducePattern(bindingsBuf, scrut, body) && { + reducePattern(bindingsBuf, fromBuf, toBuf, scrut, body) && { if (name != nme.WILDCARD) newBinding(pat.symbol.asTerm, ref(scrut)) true } @@ -780,7 +819,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { case (pat :: pats1, selector :: selectors1) => val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenTermRefExpr).asTerm newBinding(elem, selector) - reducePattern(bindingsBuf, elem.termRef, pat) && + reducePattern(bindingsBuf, fromBuf, toBuf, elem.termRef, pat) && reduceSubPatterns(pats1, selectors1) case _ => false } @@ -826,9 +865,19 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } if (!isImplicit) caseBindingsBuf += scrutineeBinding val gadtCtx = typer.gadtContext(gadtSyms).addMode(Mode.GADTflexible) - val pat1 = typer.typedPattern(cdef.pat, scrutType)(gadtCtx) - if (reducePattern(caseBindingsBuf, scrutineeSym.termRef, pat1)(gadtCtx) && guardOK) - Some((caseBindingsBuf.toList, cdef.body)) + val fromBuf = mutable.ListBuffer.empty[TypeSymbol] + val toBuf = mutable.ListBuffer.empty[TypeSymbol] + if (reducePattern(caseBindingsBuf, fromBuf, toBuf, scrutineeSym.termRef, cdef.pat)(gadtCtx) && guardOK) { + val caseBindings = caseBindingsBuf.toList + val from = fromBuf.toList + val to = toBuf.toList + if (from.isEmpty) Some((caseBindings, cdef.body)) + else { + val Block(stats, expr) = tpd.Block(caseBindings, cdef.body).subst(from, to) + val typeDefs = to.collect { case sym if sym.name != tpnme.WILDCARD => tpd.TypeDef(sym).withPos(sym.pos) } + Some((typeDefs ::: stats.asInstanceOf[List[MemberDef]], expr)) + } + } else None } @@ -906,8 +955,23 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { val selType = if (sel.isEmpty) wideSelType else sel.tpe reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match { case Some((caseBindings, rhs0)) => - val (usedBindings, rhs1) = dropUnusedDefs(caseBindings, rhs0) - val rhs = seq(usedBindings, rhs1) + // drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match) + // note that any actually necessary casts will be reinserted by the typing pass below + val rhs1 = rhs0 match { + case Block(stats, t) if t.pos.isSynthetic => + t match { + case Typed(expr, _) => + Block(stats, expr) + case TypeApply(Select(expr, n), _) if n == defn.Any_asInstanceOf.name => + Block(stats, expr) + case _ => + rhs0 + } + case _ => rhs0 + } + + val (usedBindings, rhs2) = dropUnusedDefs(caseBindings, rhs1) + val rhs = seq(usedBindings, rhs2) inlining.println(i"""--- reduce: |$tree |--- to: @@ -936,123 +1000,107 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { */ def dropUnusedDefs(bindings: List[MemberDef], tree: Tree)(implicit ctx: Context): (List[MemberDef], Tree) = { // inlining.println(i"drop unused $bindings%, % in $tree") - val refCount = newMutableSymbolMap[Int] - val bindingOfSym = newMutableSymbolMap[MemberDef] - val dealiased = new java.util.IdentityHashMap[Type, Type]() - - def isInlineable(binding: MemberDef) = binding match { - case DefDef(_, Nil, Nil, _, _) => true - case vdef @ ValDef(_, _, _) => isPureExpr(vdef.rhs) - case _ => false - } - for (binding <- bindings if isInlineable(binding)) { - refCount(binding.symbol) = 0 - bindingOfSym(binding.symbol) = binding - } - val countRefs = new TreeTraverser { - override def traverse(t: Tree)(implicit ctx: Context) = { - def updateRefCount(sym: Symbol, inc: Int) = - for (x <- refCount.get(sym)) refCount(sym) = x + inc - def updateTermRefCounts(t: Tree) = - t.typeOpt.foreachPart { - case ref: TermRef => updateRefCount(ref.symbol, 2) // can't be inlined, so make sure refCount is at least 2 - case _ => - } + def inlineTermBindings(termBindings: List[MemberDef], tree: Tree)(implicit ctx: Context): (List[MemberDef], Tree) = { + val refCount = newMutableSymbolMap[Int] + val bindingOfSym = newMutableSymbolMap[MemberDef] - t match { - case t: RefTree => - updateRefCount(t.symbol, 1) - updateTermRefCounts(t) - case _: New | _: TypeTree => - updateTermRefCounts(t) - case _ => - } - traverseChildren(t) + def isInlineable(binding: MemberDef) = binding match { + case DefDef(_, Nil, Nil, _, _) => true + case vdef @ ValDef(_, _, _) => isPureExpr(vdef.rhs) + case _ => false } - } - countRefs.traverse(tree) - for (binding <- bindings) countRefs.traverse(binding) - - def retain(boundSym: Symbol) = { - refCount.get(boundSym) match { - case Some(x) => x > 1 || x == 1 && !boundSym.is(Method) - case none => true + for (binding <- termBindings if isInlineable(binding)) { + refCount(binding.symbol) = 0 + bindingOfSym(binding.symbol) = binding } - } && !boundSym.is(ImplicitInlineMethod) - val (termBindings, typeBindings) = bindings.partition(_.symbol.isTerm) - - /** drop any referenced type symbols from the given set of type symbols */ - val dealiasTypeBindings = new TreeMap { - val boundTypes = typeBindings.map(_.symbol).toSet - - val dealias = new TypeMap { - override def apply(tp: Type) = dealiased.get(tp) match { - case null => - val tp1 = mapOver { - tp match { - case tp: TypeRef if boundTypes.contains(tp.symbol) => - val TypeAlias(alias) = tp.info - alias - case _ => tp - } + val countRefs = new TreeTraverser { + override def traverse(t: Tree)(implicit ctx: Context) = { + def updateRefCount(sym: Symbol, inc: Int) = + for (x <- refCount.get(sym)) refCount(sym) = x + inc + def updateTermRefCounts(t: Tree) = + t.typeOpt.foreachPart { + case ref: TermRef => updateRefCount(ref.symbol, 2) // can't be inlined, so make sure refCount is at least 2 + case _ => } - dealiased.put(tp, tp1) - tp1 - case tp1 => tp1 + + t match { + case t: RefTree => + updateRefCount(t.symbol, 1) + updateTermRefCounts(t) + case _: New | _: TypeTree => + updateTermRefCounts(t) + case _ => + } + traverseChildren(t) } } + countRefs.traverse(tree) + for (binding <- bindings) countRefs.traverse(binding) - override def transform(t: Tree)(implicit ctx: Context) = { - val dealiasedType = dealias(t.tpe) - val t1 = t match { + def retain(boundSym: Symbol) = { + refCount.get(boundSym) match { + case Some(x) => x > 1 || x == 1 && !boundSym.is(Method) + case none => true + } + } && !boundSym.is(ImplicitInlineMethod) + + val inlineBindings = new TreeMap { + override def transform(t: Tree)(implicit ctx: Context) = t match { case t: RefTree => - if (t.name != nme.WILDCARD && boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos) - else t.withType(dealiasedType) - case t: DefTree => - t.symbol.info = dealias(t.symbol.info) - t + val sym = t.symbol + val t1 = refCount.get(sym) match { + case Some(1) => + bindingOfSym(sym) match { + case binding: ValOrDefDef => integrate(binding.rhs, sym) + } + case none => t + } + super.transform(t1) + case t: Apply => + val t1 = super.transform(t) + if (t1 `eq` t) t else reducer.betaReduce(t1) + case Block(Nil, expr) => + super.transform(expr) case _ => - t.withType(dealiasedType) + super.transform(t) } - super.transform(t1) } - } - val inlineBindings = new TreeMap { - override def transform(t: Tree)(implicit ctx: Context) = t match { - case t: RefTree => - val sym = t.symbol - val t1 = refCount.get(sym) match { - case Some(1) => - bindingOfSym(sym) match { - case binding: ValOrDefDef => integrate(binding.rhs, sym) - } - case none => t - } - super.transform(t1) - case t: Apply => - val t1 = super.transform(t) - if (t1 `eq` t) t else reducer.betaReduce(t1) - case Block(Nil, expr) => - super.transform(expr) - case _ => - super.transform(t) + val retained = termBindings.filterConserve(binding => retain(binding.symbol)) + if (retained `eq` termBindings) { + (termBindings, tree) + } + else { + val expanded = inlineBindings.transform(tree) + dropUnusedDefs(retained, expanded) } } - val dealiasedTermBindings = - termBindings.mapconserve(dealiasTypeBindings.transform).asInstanceOf[List[MemberDef]] - val dealiasedTree = dealiasTypeBindings.transform(tree) - - val retained = dealiasedTermBindings.filterConserve(binding => retain(binding.symbol)) - if (retained `eq` dealiasedTermBindings) { - (dealiasedTermBindings, dealiasedTree) - } + val (termBindings, typeBindings) = bindings.partition(_.symbol.isTerm) + if (typeBindings.isEmpty) inlineTermBindings(termBindings, tree) else { - val expanded = inlineBindings.transform(dealiasedTree) - dropUnusedDefs(retained, expanded) + val typeBindingsSet = typeBindings.foldLeft[SimpleIdentitySet[Symbol]](SimpleIdentitySet.empty)(_ + _.symbol) + val inlineTypeBindings = new TreeTypeMap( + typeMap = new TypeMap() { + override def apply(tp: Type): Type = tp match { + case tr: TypeRef if tr.prefix.eq(NoPrefix) && typeBindingsSet.contains(tr.symbol) => + val TypeAlias(res) = tr.info + res + case tp => mapOver(tp) + } + }, + treeMap = { + case ident: Ident if ident.isType && typeBindingsSet.contains(ident.symbol) => + val TypeAlias(r) = ident.symbol.info + TypeTree(r).withPos(ident.pos) + case tree => tree + } + ) + + val Block(termBindings1, tree1) = inlineTypeBindings(Block(termBindings, tree)) + inlineTermBindings(termBindings1.asInstanceOf[List[MemberDef]], tree1) } } } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index f71d6351a1e1..9c3af1eef743 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -573,7 +573,9 @@ object ProtoTypes { */ private def wildApprox(tp: Type, theMap: WildApproxMap, seen: Set[TypeParamRef])(implicit ctx: Context): Type = tp match { case tp: NamedType => // default case, inlined for speed - if (tp.symbol.isStatic || (tp.prefix `eq` NoPrefix)) tp + val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.is(Flags.Case) && !tp.symbol.isClass + if (isPatternBoundTypeRef) WildcardType(tp.underlying.bounds) + else if (tp.symbol.isStatic || (tp.prefix `eq` NoPrefix)) tp else tp.derivedSelect(wildApprox(tp.prefix, theMap, seen)) case tp @ AppliedType(tycon, args) => wildApprox(tycon, theMap, seen) match { diff --git a/compiler/test/dotc/pos-from-tasty.blacklist b/compiler/test/dotc/pos-from-tasty.blacklist index 64aff439e0f8..c078f439062d 100644 --- a/compiler/test/dotc/pos-from-tasty.blacklist +++ b/compiler/test/dotc/pos-from-tasty.blacklist @@ -20,3 +20,8 @@ default-super.scala i3050.scala i4006b.scala i4006c.scala + +# Need to print empty tree for implicit match +implicit-match.scala +implicit-match-nested.scala +implicit-match-and-inline-match.scala diff --git a/tests/neg/implicit-match-ambiguous-bind.scala b/tests/neg/implicit-match-ambiguous-bind.scala new file mode 100644 index 000000000000..04cec2a79d5f --- /dev/null +++ b/tests/neg/implicit-match-ambiguous-bind.scala @@ -0,0 +1,9 @@ +object `implicit-match-ambiguous-bind` { + case class Box[T](value: T) + implicit val ibox: Box[Int] = Box(0) + implicit val sbox: Box[String] = Box("") + inline def unbox = implicit match { + case b: Box[t] => b.value // error + } + val unboxed = unbox +} diff --git a/tests/pos/i5574.scala b/tests/pos/i5574.scala new file mode 100644 index 000000000000..85ec4e9121b2 --- /dev/null +++ b/tests/pos/i5574.scala @@ -0,0 +1,14 @@ +import scala.typelevel._ + +object i5574 { + class Box[F[_]] + + inline def foo[T] <: Any = + inline erasedValue[T] match { + case _: Box[f] => + type t = f + 23 + } + + foo[Box[List]] +} diff --git a/tests/pos/implicit-match-and-inline-match.scala b/tests/pos/implicit-match-and-inline-match.scala new file mode 100644 index 000000000000..987b04ed81b0 --- /dev/null +++ b/tests/pos/implicit-match-and-inline-match.scala @@ -0,0 +1,24 @@ +object `implicit-match-and-inline-match` { + import scala.typelevel._ + + case class Box[T](value: T) + implicit val ibox: Box[Int] = Box(0) + + object a { + inline def isTheBoxInScopeAnInt = implicit match { + case _: Box[t] => inline erasedValue[t] match { + case _: Int => true + } + } + val wellIsIt = isTheBoxInScopeAnInt + } + + object b { + inline def isTheBoxInScopeAnInt = implicit match { + case _: Box[t] => inline 0 match { + case _: t => true + } + } + val wellIsIt = isTheBoxInScopeAnInt + } +} diff --git a/tests/pos/implicit-match-nested.scala b/tests/pos/implicit-match-nested.scala new file mode 100644 index 000000000000..f49ba65fe834 --- /dev/null +++ b/tests/pos/implicit-match-nested.scala @@ -0,0 +1,16 @@ +object `implicit-match-nested` { + case class A[T]() + case class B[T]() + + implicit val a: A[Int] = A[Int]() + implicit val b1: B[Int] = B[Int]() + implicit val b2: B[String] = B[String]() + + inline def locateB <: B[_] = implicit match { + case _: A[t] => implicit match { + case b: B[`t`] => b + } + } + + locateB +} diff --git a/tests/pos/implicit-match.scala b/tests/pos/implicit-match.scala new file mode 100644 index 000000000000..b3789a17d1c8 --- /dev/null +++ b/tests/pos/implicit-match.scala @@ -0,0 +1,34 @@ +object `implicit-match` { + object invariant { + case class Box[T](value: T) + implicit val box: Box[Int] = Box(0) + inline def unbox <: Any = implicit match { + case b: Box[t] => b.value + } + val i: Int = unbox + val i2 = unbox + val i3: Int = i2 + } + + object covariant { + case class Box[+T](value: T) + implicit val box: Box[Int] = Box(0) + inline def unbox <: Any = implicit match { + case b: Box[t] => b.value + } + val i: Int = unbox + val i2 = unbox + val i3: Int = i2 + } + + object contravariant { + case class TrashCan[-T](trash: T => Unit) + implicit val trashCan: TrashCan[Int] = TrashCan { i => ; } + inline def trash <: Nothing => Unit = implicit match { + case c: TrashCan[t] => c.trash + } + val t1: Int => Unit = trash + val t2 = trash + val t3: Int => Unit = t2 + } +} diff --git a/tests/pos/inline-match-gadt-nested.scala b/tests/pos/inline-match-gadt-nested.scala new file mode 100644 index 000000000000..2570e7f898fe --- /dev/null +++ b/tests/pos/inline-match-gadt-nested.scala @@ -0,0 +1,14 @@ +object `inline-match-gadt-nested` { + import scala.typelevel._ + + enum Gadt[A, B] { + case Nested(gadt: Gadt[A, Int]) extends Gadt[A, Int] + case Simple extends Gadt[String, Int] + } + import Gadt._ + + inline def foo[A, B](g: Gadt[A, B]): (A, B) = + inline g match { + case Nested(Simple) => ("", 0) + } +} diff --git a/tests/pos/inline-match-gadt.scala b/tests/pos/inline-match-gadt.scala new file mode 100644 index 000000000000..0f22f7b96c22 --- /dev/null +++ b/tests/pos/inline-match-gadt.scala @@ -0,0 +1,10 @@ +object `inline-match-gadt` { + class Exactly[T] + erased def exactType[T]: Exactly[T] = ??? + + inline def foo[T](t: T): T = + inline exactType[T] match { + case _: Exactly[Int] => 23 + case _ => t + } +} diff --git a/tests/pos/inline-match-specialize.scala b/tests/pos/inline-match-specialize.scala new file mode 100644 index 000000000000..10377a3cc079 --- /dev/null +++ b/tests/pos/inline-match-specialize.scala @@ -0,0 +1,8 @@ +object `inline-match-specialize` { + case class Box[+T](value: T) + inline def specialize[T](box: Box[T]) <: Box[T] = inline box match { + case box: Box[t] => box + } + + val ibox: Box[Int] = specialize[Any](Box(0)) +}