diff --git a/src/dotty/tools/dotc/ast/Trees.scala b/src/dotty/tools/dotc/ast/Trees.scala index 8a36fee3a5ae..e19221841dd0 100644 --- a/src/dotty/tools/dotc/ast/Trees.scala +++ b/src/dotty/tools/dotc/ast/Trees.scala @@ -1109,36 +1109,20 @@ object Trees { cpy.Apply(tree, transform(fun), transform(args)) case TypeApply(fun, args) => cpy.TypeApply(tree, transform(fun), transform(args)) - case Literal(const) => - tree case New(tpt) => cpy.New(tree, transform(tpt)) - case Pair(left, right) => - cpy.Pair(tree, transform(left), transform(right)) case Typed(expr, tpt) => cpy.Typed(tree, transform(expr), transform(tpt)) case NamedArg(name, arg) => cpy.NamedArg(tree, name, transform(arg)) case Assign(lhs, rhs) => cpy.Assign(tree, transform(lhs), transform(rhs)) - case Block(stats, expr) => - cpy.Block(tree, transformStats(stats), transform(expr)) - case If(cond, thenp, elsep) => - cpy.If(tree, transform(cond), transform(thenp), transform(elsep)) case Closure(env, meth, tpt) => cpy.Closure(tree, transform(env), transform(meth), transform(tpt)) - case Match(selector, cases) => - cpy.Match(tree, transform(selector), transformSub(cases)) - case CaseDef(pat, guard, body) => - cpy.CaseDef(tree, transform(pat), transform(guard), transform(body)) case Return(expr, from) => cpy.Return(tree, transform(expr), transformSub(from)) - case Try(block, handler, finalizer) => - cpy.Try(tree, transform(block), transform(handler), transform(finalizer)) case Throw(expr) => cpy.Throw(tree, transform(expr)) - case SeqLiteral(elems) => - cpy.SeqLiteral(tree, transform(elems)) case TypeTree(original) => tree case SingletonTypeTree(ref) => @@ -1177,12 +1161,29 @@ object Trees { cpy.Import(tree, transform(expr), selectors) case PackageDef(pid, stats) => cpy.PackageDef(tree, transformSub(pid), transformStats(stats)) - case Annotated(annot, arg) => - cpy.Annotated(tree, transform(annot), transform(arg)) case Thicket(trees) => val trees1 = transform(trees) if (trees1 eq trees) tree else Thicket(trees1) + case Literal(const) => + tree + case Pair(left, right) => + cpy.Pair(tree, transform(left), transform(right)) + case Block(stats, expr) => + cpy.Block(tree, transformStats(stats), transform(expr)) + case If(cond, thenp, elsep) => + cpy.If(tree, transform(cond), transform(thenp), transform(elsep)) + case Match(selector, cases) => + cpy.Match(tree, transform(selector), transformSub(cases)) + case CaseDef(pat, guard, body) => + cpy.CaseDef(tree, transform(pat), transform(guard), transform(body)) + case Try(block, handler, finalizer) => + cpy.Try(tree, transform(block), transform(handler), transform(finalizer)) + case SeqLiteral(elems) => + cpy.SeqLiteral(tree, transform(elems)) + case Annotated(annot, arg) => + cpy.Annotated(tree, transform(annot), transform(arg)) } + def transformStats(trees: List[Tree])(implicit ctx: Context): List[Tree] = transform(trees) def transform(trees: List[Tree])(implicit ctx: Context): List[Tree] = diff --git a/src/dotty/tools/dotc/ast/tpd.scala b/src/dotty/tools/dotc/ast/tpd.scala index fecfefd37ac4..0ef855de261e 100644 --- a/src/dotty/tools/dotc/ast/tpd.scala +++ b/src/dotty/tools/dotc/ast/tpd.scala @@ -3,12 +3,15 @@ package dotc package ast import core._ +import dotty.tools.dotc.transform.TypeUtils import util.Positions._, Types._, Contexts._, Constants._, Names._, Flags._ import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Symbols._ import CheckTrees._, Denotations._, Decorators._ import config.Printers._ import typer.ErrorReporting._ +import scala.annotation.tailrec + /** Some creators for typed trees */ object tpd extends Trees.Instance[Type] with TypedTreeInfo { @@ -413,6 +416,68 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def tpes: List[Type] = xs map (_.tpe) } + /** RetypingTreeMap is a TreeMap that is able to propagate type changes. + * + * This is required when types can change during transformation, + * for example if `Block(stats, expr)` is being transformed + * and type of `expr` changes from `TypeRef(prefix, name)` to `TypeRef(newPrefix, name)` with different prefix, t + * type of enclosing Block should also change, otherwise the whole tree would not be type-correct anymore. + * see `propagateType` methods for propagation rulles. + * + * TreeMap does not include such logic as it assumes that types of threes do not change during transformation. + */ + class RetypingTreeMap extends TreeMap { + + override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match { + case tree@Select(qualifier, name) => + val tree1 = cpy.Select(tree, transform(qualifier), name) + propagateType(tree, tree1) + case tree@Pair(left, right) => + val left1 = transform(left) + val right1 = transform(right) + val tree1 = cpy.Pair(tree, left1, right1) + propagateType(tree, tree1) + case tree@Block(stats, expr) => + val stats1 = transform(stats) + val expr1 = transform(expr) + val tree1 = cpy.Block(tree, stats1, expr1) + propagateType(tree, tree1) + case tree@If(cond, thenp, elsep) => + val cond1 = transform(cond) + val thenp1 = transform(thenp) + val elsep1 = transform(elsep) + val tree1 = cpy.If(tree, cond1, thenp1, elsep1) + propagateType(tree, tree1) + case tree@Match(selector, cases) => + val selector1 = transform(selector) + val cases1 = transformSub(cases) + val tree1 = cpy.Match(tree, selector1, cases1) + propagateType(tree, tree1) + case tree@CaseDef(pat, guard, body) => + val pat1 = transform(pat) + val guard1 = transform(guard) + val body1 = transform(body) + val tree1 = cpy.CaseDef(tree, pat1, guard1, body1) + propagateType(tree, tree1) + case tree@Try(block, handler, finalizer) => + val expr1 = transform(block) + val handler1 = transform(handler) + val finalizer1 = transform(finalizer) + val tree1 = cpy.Try(tree, expr1, handler1, finalizer1) + propagateType(tree, tree1) + case tree@SeqLiteral(elems) => + val elems1 = transform(elems) + val tree1 = cpy.SeqLiteral(tree, elems1) + propagateType(tree, tree1) + case tree@Annotated(annot, arg) => + val annot1 = transform(annot) + val arg1 = transform(arg) + val tree1 = cpy.Annotated(tree, annot1, arg1) + propagateType(tree, tree1) + case _ => super.transform(tree) + } + } + /** A map that applies three functions together to a tree and makes sure * they are coordinated so that the result is well-typed. The functions are * @param typeMap A function from Type to type that gets applied to the @@ -425,7 +490,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { final class TreeTypeMap( val typeMap: Type => Type = IdentityTypeMap, val ownerMap: Symbol => Symbol = identity _, - val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends TreeMap { + val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends RetypingTreeMap { override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = { val tree1 = treeMap(tree) @@ -436,10 +501,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { cpy.DefDef(ddef, mods, name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(rhs)) case blk @ Block(stats, expr) => val (tmap1, stats1) = transformDefs(stats) - cpy.Block(blk, stats1, tmap1.transform(expr)) + val expr1 = tmap1.transform(expr) + val tree1 = cpy.Block(blk, stats1, expr1) + propagateType(blk, tree1) case cdef @ CaseDef(pat, guard, rhs) => val tmap = withMappedSyms(patVars(pat)) - cpy.CaseDef(cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs)) + val pat1 = tmap.transform(pat) + val guard1 = tmap.transform(guard) + val rhs1 = tmap.transform(rhs) + val tree1 = cpy.CaseDef(tree, pat1, guard1, rhs1) + propagateType(cdef, tree1) case tree1 => super.transform(tree1) } @@ -501,6 +572,56 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { acc(Nil, tree) } + def propagateType(origTree: Pair, newTree: Pair)(implicit ctx: Context) = { + if ((newTree eq origTree) || + ((newTree.left.tpe eq origTree.left.tpe) && (newTree.right.tpe eq origTree.right.tpe))) newTree + else ta.assignType(newTree, newTree.left, newTree.right) + } + + def propagateType(origTree: Block, newTree: Block)(implicit ctx: Context) = { + if ((newTree eq origTree) || (newTree.expr.tpe eq origTree.expr.tpe)) newTree + else ta.assignType(newTree, newTree.stats, newTree.expr) + } + + def propagateType(origTree: If, newTree: If)(implicit ctx: Context) = { + if ((newTree eq origTree) || + ((newTree.thenp.tpe eq origTree.thenp.tpe) && (newTree.elsep.tpe eq origTree.elsep.tpe))) newTree + else ta.assignType(newTree, newTree.thenp, newTree.elsep) + } + + def propagateType(origTree: Match, newTree: Match)(implicit ctx: Context) = { + if ((newTree eq origTree) || sameTypes(newTree.cases, origTree.cases)) newTree + else ta.assignType(newTree, newTree.cases) + } + + def propagateType(origTree: CaseDef, newTree: CaseDef)(implicit ctx: Context) = { + if ((newTree eq newTree) || (newTree.body.tpe eq origTree.body.tpe)) newTree + else ta.assignType(newTree, newTree.body) + } + + def propagateType(origTree: Try, newTree: Try)(implicit ctx: Context) = { + if ((newTree eq origTree) || + ((newTree.expr.tpe eq origTree.expr.tpe) && (newTree.handler.tpe eq origTree.handler.tpe))) newTree + else ta.assignType(newTree, newTree.expr, newTree.handler) + } + + def propagateType(origTree: SeqLiteral, newTree: SeqLiteral)(implicit ctx: Context) = { + if ((newTree eq origTree) || sameTypes(newTree.elems, origTree.elems)) newTree + else ta.assignType(newTree, newTree.elems) + } + + def propagateType(origTree: Annotated, newTree: Annotated)(implicit ctx: Context) = { + if ((newTree eq origTree) || ((newTree.arg.tpe eq origTree.arg.tpe) && (newTree.annot eq origTree.annot))) newTree + else ta.assignType(newTree, newTree.annot, newTree.arg) + } + + def propagateType(origTree: Select, newTree: Select)(implicit ctx: Context) = { + if ((origTree eq newTree) || (origTree.qualifier.tpe eq newTree.qualifier.tpe)) newTree + else newTree.tpe match { + case tpe: NamedType => newTree.withType(tpe.derivedSelect(newTree.qualifier.tpe)) + case _ => newTree + } + } // convert a numeric with a toXXX method def primitiveConversion(tree: Tree, numericCls: Symbol)(implicit ctx: Context): Tree = { val mname = ("to" + numericCls.name).toTermName @@ -515,6 +636,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } } + @tailrec + def sameTypes(trees: List[tpd.Tree], trees1: List[tpd.Tree]): Boolean = { + if (trees.isEmpty) trees.isEmpty + else if (trees1.isEmpty) trees.isEmpty + else (trees.head.tpe eq trees1.head.tpe) && sameTypes(trees.tail, trees1.tail) + } + def evalOnce(tree: Tree)(within: Tree => Tree)(implicit ctx: Context) = { if (isIdempotentExpr(tree)) within(tree) else { diff --git a/src/dotty/tools/dotc/core/Substituters.scala b/src/dotty/tools/dotc/core/Substituters.scala index 3d14317cb76a..8beb6995eb1a 100644 --- a/src/dotty/tools/dotc/core/Substituters.scala +++ b/src/dotty/tools/dotc/core/Substituters.scala @@ -89,6 +89,40 @@ trait Substituters { this: Context => } } + final def substDealias(tp: Type, from: List[Symbol], to: List[Type], theMap: SubstDealiasMap): Type = { + tp match { + case tp: NamedType => + val sym = tp.symbol + var fs = from + var ts = to + while (fs.nonEmpty) { + if (fs.head eq sym) return ts.head + fs = fs.tail + ts = ts.tail + } + if (sym.isStatic && !existsStatic(from)) tp + else { + val prefix1 = substDealias(tp.prefix, from, to, theMap) + if (prefix1 ne tp.prefix) tp.derivedSelect(prefix1) + else if (sym.isAliasType) { + val hi = sym.info.bounds.hi + val hi1 = substDealias(hi, from, to, theMap) + if (hi1 eq hi) tp else hi1 + } + else tp + } + case _: ThisType | _: BoundType | NoPrefix => + tp + case tp: RefinedType => + tp.derivedRefinedType(substDealias(tp.parent, from, to, theMap), tp.refinedName, substDealias(tp.refinedInfo, from, to, theMap)) + case tp: TypeBounds if tp.lo eq tp.hi => + tp.derivedTypeAlias(substDealias(tp.lo, from, to, theMap)) + case _ => + (if (theMap != null) theMap else new SubstDealiasMap(from, to)) + .mapOver(tp) + } + } + final def substSym(tp: Type, from: List[Symbol], to: List[Symbol], theMap: SubstSymMap): Type = tp match { case tp: NamedType => @@ -205,11 +239,11 @@ trait Substituters { this: Context => final class SubstMap(from: List[Symbol], to: List[Type]) extends DeepTypeMap { def apply(tp: Type): Type = subst(tp, from, to, this) } -/* not needed yet - final class SubstDealiasMap(from: List[Symbol], to: List[Type]) extends SubstMap(from, to) { - override def apply(tp: Type): Type = subst(tp.dealias, from, to, this) + + final class SubstDealiasMap(from: List[Symbol], to: List[Type]) extends DeepTypeMap { + override def apply(tp: Type): Type = substDealias(tp, from, to, this) } -*/ + final class SubstSymMap(from: List[Symbol], to: List[Symbol]) extends DeepTypeMap { def apply(tp: Type): Type = substSym(tp, from, to, this) } diff --git a/src/dotty/tools/dotc/core/Types.scala b/src/dotty/tools/dotc/core/Types.scala index 1d1c326a02d2..8149cce78522 100644 --- a/src/dotty/tools/dotc/core/Types.scala +++ b/src/dotty/tools/dotc/core/Types.scala @@ -818,10 +818,17 @@ object Types { } } -/* Not needed yet: + /** Same as `subst` but follows aliases as a fallback. When faced with a reference + * to an alias type, where normal substiution does not yield a new type, the + * substitution is instead applied to the alias. If that yields a new type, + * this type is returned, outherwise the original type (not the alias) is returned. + * A use case for this method is if one wants to substitute the type parameters + * of a class and also wants to substitute any parameter accessors that alias + * the type parameters. + */ final def substDealias(from: List[Symbol], to: List[Type])(implicit ctx: Context): Type = - new ctx.SubstDealiasMap(from, to).apply(this) -*/ + ctx.substDealias(this, from, to, null) + /** Substitute all types of the form `PolyParam(from, N)` by * `PolyParam(to, N)`. */ diff --git a/src/dotty/tools/dotc/transform/FullParameterization.scala b/src/dotty/tools/dotc/transform/FullParameterization.scala index e21a10492f4f..cdea5754c3c5 100644 --- a/src/dotty/tools/dotc/transform/FullParameterization.scala +++ b/src/dotty/tools/dotc/transform/FullParameterization.scala @@ -104,7 +104,7 @@ trait FullParameterization { /** Replace class type parameters by the added type parameters of the polytype `pt` */ def mapClassParams(tp: Type, pt: PolyType): Type = { val classParamsRange = (mtparamCount until mtparamCount + ctparams.length).toList - tp.subst(clazz.typeParams, classParamsRange map (PolyParam(pt, _))) + tp.substDealias(clazz.typeParams, classParamsRange map (PolyParam(pt, _))) } /** The bounds for the added type paraneters of the polytype `pt` */ @@ -142,7 +142,7 @@ trait FullParameterization { * followed by the class parameters of its enclosing class. */ private def allInstanceTypeParams(originalDef: DefDef)(implicit ctx: Context): List[Symbol] = - originalDef.tparams.map(_.symbol) ::: originalDef.symbol.owner.typeParams + originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams /** Given an instance method definition `originalDef`, return a * fully parameterized method definition derived from `originalDef`, which @@ -152,7 +152,7 @@ trait FullParameterization { def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree = polyDefDef(derived, trefs => vrefss => { val origMeth = originalDef.symbol - val origClass = origMeth.owner.asClass + val origClass = origMeth.enclosingClass.asClass val origTParams = allInstanceTypeParams(originalDef) val origVParams = originalDef.vparamss.flatten map (_.symbol) val thisRef :: argRefs = vrefss.flatten @@ -201,7 +201,7 @@ trait FullParameterization { new TreeTypeMap( typeMap = rewireType(_) - .subst(origTParams, trefs) + .substDealias(origTParams, trefs) .subst(origVParams, argRefs.map(_.tpe)) .substThisUnlessStatic(origClass, thisRef.tpe), ownerMap = (sym => if (sym eq origMeth) derived else sym), @@ -219,7 +219,7 @@ trait FullParameterization { def forwarder(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree = ref(derived.termRef) .appliedToTypes(allInstanceTypeParams(originalDef).map(_.typeRef)) - .appliedTo(This(originalDef.symbol.owner.asClass)) + .appliedTo(This(originalDef.symbol.enclosingClass.asClass)) .appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol))) .withPos(originalDef.rhs.pos) } \ No newline at end of file diff --git a/src/dotty/tools/dotc/transform/LazyVals.scala b/src/dotty/tools/dotc/transform/LazyVals.scala index 75dc10ce46b6..02e5ed5a7d3b 100644 --- a/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/src/dotty/tools/dotc/transform/LazyVals.scala @@ -144,7 +144,7 @@ class LazyValTranformContext { * flag = true * target * } - * }` + * } */ def mkNonThreadSafeDef(target: Symbol, flag: Symbol, rhs: Tree)(implicit ctx: Context) = { diff --git a/src/dotty/tools/dotc/transform/TailRec.scala b/src/dotty/tools/dotc/transform/TailRec.scala index a2278e72f01d..d3bec6f902c2 100644 --- a/src/dotty/tools/dotc/transform/TailRec.scala +++ b/src/dotty/tools/dotc/transform/TailRec.scala @@ -1,28 +1,16 @@ package dotty.tools.dotc.transform -import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, TreeTransformer} -import dotty.tools.dotc.ast.{Trees, tpd} +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.core.Contexts.Context -import scala.collection.mutable.ListBuffer -import dotty.tools.dotc.core._ -import dotty.tools.dotc.core.Symbols.NoSymbol -import scala.annotation.tailrec -import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._ -import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._ -import Decorators._ -import Symbols._ -import scala.Some -import dotty.tools.dotc.transform.TreeTransforms.{NXTransformations, TransformerInfo, TreeTransform, TreeTransformer} -import dotty.tools.dotc.core.Contexts.Context -import scala.collection.mutable -import dotty.tools.dotc.core.Names.Name -import NameOps._ -import dotty.tools.dotc.CompilationUnit -import dotty.tools.dotc.util.Positions.{Position, Coord} -import dotty.tools.dotc.util.Positions.NoPosition +import dotty.tools.dotc.core.Decorators._ import dotty.tools.dotc.core.DenotTransformers.DenotTransformer import dotty.tools.dotc.core.Denotations.SingleDenotation +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core._ import dotty.tools.dotc.transform.TailRec._ +import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform} /** * A Tail Rec Transformer @@ -74,9 +62,9 @@ import dotty.tools.dotc.transform.TailRec._ * self recursive functions, that's why it's renamed to tailrec *

*/ -class TailRec extends TreeTransform with DenotTransformer { +class TailRec extends TreeTransform with DenotTransformer with FullParameterization { - import tpd._ + import dotty.tools.dotc.ast.tpd._ override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref @@ -85,54 +73,44 @@ class TailRec extends TreeTransform with DenotTransformer { final val labelPrefix = "tailLabel" final val labelFlags = Flags.Synthetic | Flags.Label - private def mkLabel(method: Symbol, tp: Type)(implicit c: Context): TermSymbol = { + private def mkLabel(method: Symbol)(implicit c: Context): TermSymbol = { val name = c.freshName(labelPrefix) - c.newSymbol(method, name.toTermName, labelFlags , tp) + + c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass)) } override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { tree match { case dd@DefDef(mods, name, tparams, vparamss0, tpt, rhs0) - if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree)) => + if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree) || (dd.symbol is Flags.Label)) => val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass) cpy.DefDef(tree, mods, name, tparams, vparamss0, tpt, rhs = { - val owner = ctx.owner.enclosingClass.asClass - val thisTpe = owner.thisType - - val newType: Type = dd.tpe.widen match { - case t: PolyType => PolyType(t.paramNames)(x => t.paramBounds, - x => MethodType(List(nme.THIS), List(thisTpe), t.resultType)) - case t => MethodType(List(nme.THIS), List(thisTpe), t) - } + val origMeth = tree.symbol + val label = mkLabel(dd.symbol) + val owner = ctx.owner.enclosingClass.asClass + val thisTpe = owner.thisType.widen - val label = mkLabel(dd.symbol, newType) var rewrote = false // Note: this can be split in two separate transforms(in different groups), // than first one will collect info about which transformations and rewritings should be applied // and second one will actually apply, // now this speculatively transforms tree and throws away result in many cases - val res = tpd.DefDef(label, args => { - val thiz = args.head.head - val argMapping: Map[Symbol, Tree] = (vparamss0.flatten.map(_.symbol) zip args.tail.flatten).toMap - val transformer = new TailRecElimination(dd.symbol, thiz, argMapping, owner, mandatory, label) + val rhsSemiTransformed = { + val transformer = new TailRecElimination(dd.symbol, owner, thisTpe, mandatory, label) val rhs = transformer.transform(rhs0)(ctx.withPhase(ctx.phase.next)) rewrote = transformer.rewrote rhs - }) + } if (rewrote) { - val call = - if (tparams.isEmpty) Ident(label.termRef) - else TypeApply(Ident(label.termRef), tparams) - Block( - List(res), - vparamss0.foldLeft(Apply(call, List(This(owner)))) - {(call, args) => Apply(call, args.map(x => Ident(x.symbol.termRef)))} - ) - } - else { + val dummyDefDef = cpy.DefDef(tree, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, + rhsSemiTransformed) + val res = fullyParameterizedDef(label, dummyDefDef) + val call = forwarder(label, dd) + Block(List(res), call) + } else { if (mandatory) ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos) rhs0 @@ -149,11 +127,9 @@ class TailRec extends TreeTransform with DenotTransformer { } - class TailRecElimination(method: Symbol, thiz: Tree, argMapping: Map[Symbol, Tree], - enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap { - - import tpd._ + class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol) extends tpd.RetypingTreeMap { + import dotty.tools.dotc.ast.tpd._ var rewrote = false @@ -182,7 +158,6 @@ class TailRec extends TreeTransform with DenotTransformer { def noTailTransforms(trees: List[Tree])(implicit c: Context) = trees map (noTailTransform) - override def transform(tree: Tree)(implicit c: Context): Tree = { /* A possibly polymorphic apply to be considered for tail call transformation. */ def rewriteApply(tree: Tree, sym: Symbol): Tree = { @@ -204,7 +179,7 @@ class TailRec extends TreeTransform with DenotTransformer { val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen - val receiverIsThis = recv.tpe.widen =:= thiz.tpe.widen + val receiverIsThis = recv.tpe.widen =:= thisType val isRecursiveCall = (method eq sym) @@ -226,17 +201,24 @@ class TailRec extends TreeTransform with DenotTransformer { def rewriteTailCall(recv: Tree): Tree = { c.debuglog("Rewriting tail recursive call: " + tree.pos) rewrote = true - val method = if (targs.nonEmpty) TypeApply(Ident(label.termRef), targs) else Ident(label.termRef) - val recv = noTailTransform(reciever) - if (recv.tpe.widen.isParameterless) method - else argumentss.foldLeft(Apply(method, List(recv))) { - (method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet. + val reciever = noTailTransform(recv) + val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos + val trz = classTypeArgs.map(x => ref(x.typeSymbol)) + val callTargs: List[tpd.Tree] = targs ::: trz + val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef), + List(reciever)) + + val res = + if (method.tpe.widen.isParameterless) method + else argumentss.foldLeft(method) { + (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. } + res } if (isRecursiveCall) { if (ctx.tailPos) { - if (recv eq EmptyTree) rewriteTailCall(thiz) + if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) else fail("it changes type of 'this' on a polymorphic recursive call") } @@ -247,7 +229,7 @@ class TailRec extends TreeTransform with DenotTransformer { } } - def rewriteTry(tree: Try): Tree = { + def rewriteTry(tree: Try): Try = { def transformHandlers(t: Tree): Tree = { t match { case Block(List((d: DefDef)), cl@Closure(Nil, _, EmptyTree)) => @@ -274,65 +256,61 @@ class TailRec extends TreeTransform with DenotTransformer { } val res: Tree = tree match { - case Block(stats, expr) => - tpd.cpy.Block(tree, + + case tree@Block(stats, expr) => + val tree1 = tpd.cpy.Block(tree, noTailTransforms(stats), transform(expr) ) + propagateType(tree, tree1) - case t@CaseDef(pat, guard, body) => - cpy.CaseDef(t, pat, guard, transform(body)) + case tree@CaseDef(pat, guard, body) => + val tree1 = cpy.CaseDef(tree, pat, guard, transform(body)) + propagateType(tree, tree1) - case If(cond, thenp, elsep) => - tpd.cpy.If(tree, - transform(cond), + case tree@If(cond, thenp, elsep) => + val tree1 = tpd.cpy.If(tree, + noTailTransform(cond), transform(thenp), transform(elsep) ) + propagateType(tree, tree1) - case Match(selector, cases) => - tpd.cpy.Match(tree, + case tree@Match(selector, cases) => + val tree1 = tpd.cpy.Match(tree, noTailTransform(selector), transformSub(cases) ) + propagateType(tree, tree1) - case t: Try => - rewriteTry(t) + case tree: Try => + val tree1 = rewriteTry(tree) + propagateType(tree, tree1) case Apply(fun, args) if fun.symbol == defn.Boolean_or || fun.symbol == defn.Boolean_and => tpd.cpy.Apply(tree, fun, transform(args)) case Apply(fun, args) => rewriteApply(tree, fun.symbol) + case Alternative(_) | Bind(_, _) => assert(false, "We should've never gotten inside a pattern") tree - case This(cls) if cls eq enclosingClass => - thiz - case Select(qual, name) => + + case tree: Select => val sym = tree.symbol if (sym == method && ctx.tailPos) rewriteApply(tree, sym) - else tpd.cpy.Select(tree, noTailTransform(qual), name) + else propagateType(tree, tpd.cpy.Select(tree, noTailTransform(tree.qualifier), tree.name)) + case ValDef(_, _, _, _) | EmptyTree | Super(_, _) | This(_) | Literal(_) | TypeTree(_) | DefDef(_, _, _, _, _, _) | TypeDef(_, _, _) => tree + case Ident(qual) => val sym = tree.symbol if (sym == method && ctx.tailPos) rewriteApply(tree, sym) - else argMapping.get(sym) match { - case Some(rewrite) => rewrite - case None => tree.tpe match { - case TermRef(ThisType(`enclosingClass`), _) => - if (sym.flags is Flags.Local) { - // trying to access private[this] member. toggle flag in order to access. - val d = sym.denot - val newDenot = d.copySymDenotation(initFlags = sym.flags &~ Flags.Local) - newDenot.installAfter(TailRec.this) - } - thiz.select(sym) - case _ => tree - } - } + else tree + case _ => super.transform(tree) } @@ -341,6 +319,11 @@ class TailRec extends TreeTransform with DenotTransformer { } } + /** If references to original `target` from fully parameterized method `derived` should be + * rewired to some fully parameterized method, that method symbol, + * otherwise NoSymbol. + */ + override protected def rewiredTarget(target: Symbol, derived: Symbol)(implicit ctx: Context): Symbol = NoSymbol } object TailRec { diff --git a/src/dotty/tools/dotc/transform/TypeUtils.scala b/src/dotty/tools/dotc/transform/TypeUtils.scala index f11bb980acfc..a266600929e8 100644 --- a/src/dotty/tools/dotc/transform/TypeUtils.scala +++ b/src/dotty/tools/dotc/transform/TypeUtils.scala @@ -1,23 +1,18 @@ package dotty.tools.dotc package transform -import core._ -import Types._ -import Contexts._ -import Symbols._ -import Decorators._ -import StdNames.nme -import NameOps._ -import language.implicitConversions +import dotty.tools.dotc.core.Types._ + +import scala.language.implicitConversions object TypeUtils { implicit def decorateTypeUtils(tpe: Type): TypeUtils = new TypeUtils(tpe) + } /** A decorator that provides methods for type transformations * that are needed in the transofmer pipeline (not needed right now) */ class TypeUtils(val self: Type) extends AnyVal { - import TypeUtils._ } \ No newline at end of file diff --git a/test/dotc/tests.scala b/test/dotc/tests.scala index a04d076942a3..2d06e73a8c22 100644 --- a/test/dotc/tests.scala +++ b/test/dotc/tests.scala @@ -14,7 +14,7 @@ class tests extends CompilerTest { "-pagewidth", "160") implicit val defaultOptions = noCheckOptions ++ List( - "-Ycheck:extmethods"//, "-Ystop-before:terminal" + "-Ycheck:tailrec" ) val twice = List("#runs", "2", "-YnoDoubleBindings") @@ -25,6 +25,7 @@ class tests extends CompilerTest { val newDir = "./tests/new/" val dotcDir = "./src/dotty/" + @Test def pos_erasure = compileFile(posDir, "erasure", doErase) @Test def pos_Coder() = compileFile(posDir, "Coder", doErase) @Test def pos_blockescapes() = compileFile(posDir, "blockescapes", doErase) diff --git a/tests/pos/tailcall/tailcall.scala b/tests/pos/tailcall/tailcall.scala index 9cf373cf0ba4..1e05840ea3ee 100644 --- a/tests/pos/tailcall/tailcall.scala +++ b/tests/pos/tailcall/tailcall.scala @@ -3,3 +3,7 @@ class tailcall { final def fact(x: Int, acc: Int = 1): Int = if (x == 0) acc else fact(x - shift, acc * x) def id[T <: AnyRef](x: T): T = if (x eq null) x else id(x) } + +class TypedApply[T2]{ + private def firstDiff[T <: TypedApply[T2]](xs: List[T]): Int = firstDiff(xs) +} \ No newline at end of file