Skip to content

Topic/tree checking #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Run(comp: Compiler)(implicit ctx: Context) {

private def printTree(ctx: Context) = {
val unit = ctx.compilationUnit
println(s"result of $unit after ${ctx.phase}:")
println(s"result of $unit after ${ctx.phase.prev}:")
println(unit.tpdTree.show(ctx))
}

Expand Down
6 changes: 4 additions & 2 deletions src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ object Contexts {
protected def searchHistory_= (searchHistory: SearchHistory) = _searchHistory = searchHistory
def searchHistory: SearchHistory = _searchHistory

/** Caches for withPhase */
private var phasedCtx: Context = _
private var phasedCtxs: Array[Context] = _


/** This context at given phase.
* This method will always return a phase period equal to phaseId, thus will never return squashed phases
*/
Expand All @@ -205,7 +205,8 @@ object Contexts {

final def withPhase(phase: Phase): Context =
withPhase(phase.id)
/** If -Ydebug is on, the top of the stack trace where this context

/** If -Ydebug is on, the top of the stack trace where this context
* was created, otherwise `null`.
*/
private var creationTrace: Array[StackTraceElement] = _
Expand Down Expand Up @@ -298,6 +299,7 @@ object Contexts {
setCreationTrace()
this
}

/** A fresh clone of this context. */
def fresh: FreshContext = clone.asInstanceOf[FreshContext].init(this)

Expand Down
51 changes: 25 additions & 26 deletions src/dotty/tools/dotc/core/Denotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,35 +273,34 @@ object Denotations {

def unionDenot(denot1: SingleDenotation, denot2: SingleDenotation): Denotation =
if (denot1.signature matches denot2.signature) {
val sym1 = denot1.symbol
val sym2 = denot2.symbol
val info1 = denot1.info
val info2 = denot2.info
val sym2 = denot2.symbol
def sym2Accessible = sym2.isAccessibleFrom(pre)
if (info1 <:< info2 && sym2Accessible) denot2
val sameSym = sym1 eq sym2
if (sameSym && info1 <:< info2) denot2
else if (sameSym && info2 <:< info1) denot1
else {
val sym1 = denot1.symbol
def sym1Accessible = sym1.isAccessibleFrom(pre)
if (info2 <:< info1 && sym1Accessible) denot1
else {
val owner2 = if (sym2 ne NoSymbol) sym2.owner else NoSymbol
/** Determine a symbol which is overridden by both sym1 and sym2.
* Preference is given to accessible symbols.
*/
def lubSym(overrides: Iterator[Symbol], previous: Symbol): Symbol =
if (!overrides.hasNext) previous
else {
val candidate = overrides.next
if (owner2 derivesFrom candidate.owner)
if (candidate isAccessibleFrom pre) candidate
else lubSym(overrides, previous orElse candidate)
else
lubSym(overrides, previous)
}
new JointRefDenotation(
lubSym(sym1.allOverriddenSymbols, NoSymbol),
info1 | info2,
denot1.validFor & denot2.validFor)
}
val jointSym =
if (sameSym) sym1
else {
val owner2 = if (sym2 ne NoSymbol) sym2.owner else NoSymbol
/** Determine a symbol which is overridden by both sym1 and sym2.
* Preference is given to accessible symbols.
*/
def lubSym(overrides: Iterator[Symbol], previous: Symbol): Symbol =
if (!overrides.hasNext) previous
else {
val candidate = overrides.next
if (owner2 derivesFrom candidate.owner)
if (candidate isAccessibleFrom pre) candidate
else lubSym(overrides, previous orElse candidate)
else
lubSym(overrides, previous)
}
lubSym(sym1.allOverriddenSymbols, NoSymbol)
}
new JointRefDenotation(jointSym, info1 | info2, denot1.validFor & denot2.validFor)
}
}
else NoDenotation
Expand Down
1 change: 0 additions & 1 deletion src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ object Phases {
override def lastPhaseId(implicit ctx: Context) = id
}


/** Use the following phases in the order they are given.
* The list should never contain NoPhase.
* if squashing is enabled, phases in same subgroup will be squashed to single phase.
Expand Down
9 changes: 7 additions & 2 deletions src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package printing

import core._
import Texts._, Types._, Flags._, Names._, Symbols._, NameOps._, Constants._
import Contexts.Context, Scopes.Scope, Denotations.Denotation, Annotations.Annotation
import Contexts.Context, Scopes.Scope, Denotations._, Annotations.Annotation
import StdNames.nme
import ast.{Trees, untpd}
import typer.Namer
Expand Down Expand Up @@ -475,7 +475,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
Text(flags.flagStrings.filterNot(_.startsWith("<")) map stringToText, " ")
}

override def toText(denot: Denotation): Text = toText(denot.symbol)
override def toText(denot: Denotation): Text = denot match {
case denot: MultiDenotation => denot.toString
case _ =>
if (denot.symbol.exists) toText(denot.symbol)
else "some " ~ toText(denot.info)
}

override def plain = new PlainPrinter(_ctx)
}
106 changes: 103 additions & 3 deletions src/dotty/tools/dotc/transform/Splitter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package transform

import TreeTransforms._
import ast.Trees._
import core.Contexts._
import core.Types._
import core._
import Contexts._, Types._, Decorators._, Denotations._, Symbols._, SymDenotations._, Names._

/** This transform makes usre every identifier and select node
/** This transform makes sure every identifier and select node
* carries a symbol. To do this, certain qualifiers with a union type
* have to be "splitted" with a type test.
*
Expand All @@ -24,4 +24,104 @@ class Splitter extends TreeTransform {
This(cls) withPos tree.pos
case _ => tree
}

/** If we select a name, make sure the node has a symbol.
* If necessary, split the qualifier with type tests.
* Example: Assume:
*
* class A { def f(x: S): T }
* class B { def f(x: S): T }
* def p(): A | B
*
* Then p().f(a) translates to
*
* val ev$1 = p()
* if (ev$1.isInstanceOf[A]) ev$1.asInstanceOf[A].f(a)
* else ev$1.asInstanceOf[B].f(a)
*/
override def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo) = {
val Select(qual, name) = tree

def memberDenot(tp: Type): SingleDenotation = {
val mbr = tp.member(name)
if (!mbr.isOverloaded) mbr.asSingleDenotation
else tree.tpe match {
case tref: TermRefWithSignature => mbr.atSignature(tref.sig)
case _ => ctx.error(s"cannot disambiguate overloaded member $mbr"); NoDenotation
}
}

def candidates(tp: Type): List[Symbol] = {
val mbr = memberDenot(tp)
if (mbr.symbol.exists) mbr.symbol :: Nil
else tp.widen match {
case tref: TypeRef =>
tref.info match {
case TypeBounds(_, hi) => candidates(hi)
case _ => Nil
}
case OrType(tp1, tp2) =>
candidates(tp1) | candidates(tp2)
case AndType(tp1, tp2) =>
candidates(tp1) & candidates(tp2)
case tpw =>
Nil
}
}

def isStructuralSelect(tp: Type): Boolean = tp.stripTypeVar match {
case tp: RefinedType => tp.refinedName == name || isStructuralSelect(tp)
case tp: TypeProxy => isStructuralSelect(tp.underlying)
case AndType(tp1, tp2) => isStructuralSelect(tp1) || isStructuralSelect(tp2)
case _ => false
}

if (tree.symbol.exists) tree
else {
def choose(qual: Tree, syms: List[Symbol]): Tree = {
def testOrCast(which: Symbol, mbr: Symbol) =
TypeApply(Select(qual, which), TypeTree(mbr.owner.typeRef) :: Nil)
def select(sym: Symbol) = {
val qual1 =
if (qual.tpe derivesFrom sym.owner) qual
else testOrCast(defn.Any_asInstanceOf, sym)
Select(qual1, sym) withPos tree.pos
}
syms match {
case Nil =>
def msg =
if (isStructuralSelect(qual.tpe))
s"cannot access member '$name' from structural type ${qual.tpe.widen.show}; use Dynamic instead"
else
s"no candidate symbols for ${tree.tpe.show} found in ${qual.tpe.show}"
ctx.error(msg, tree.pos)
tree
case sym :: Nil =>
select(sym)
case sym :: syms1 =>
If(testOrCast(defn.Any_isInstanceOf, sym), select(sym), choose(qual, syms1))
}
}
evalOnce(qual)(qual => choose(qual, candidates(qual.tpe)))
}
}

/** Distribute arguments among splitted branches */
def distribute(tree: GenericApply[Type], rebuild: (Tree, List[Tree]) => Context => Tree)(implicit ctx: Context) = {
def recur(fn: Tree): Tree = fn match {
case Block(stats, expr) => Block(stats, recur(expr))
case If(cond, thenp, elsep) => If(cond, recur(thenp), recur(elsep))
case _ => rebuild(fn, tree.args)(ctx) withPos tree.pos
}
recur(tree.fun)
}

override def transformTypeApply(tree: TypeApply)(implicit ctx: Context, info: TransformerInfo) =
distribute(tree, typeApply)

override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) =
distribute(tree, apply)

private val typeApply = (fn: Tree, args: List[Tree]) => (ctx: Context) => TypeApply(fn, args)(ctx)
private val apply = (fn: Tree, args: List[Tree]) => (ctx: Context) => Apply(fn, args)(ctx)
}
37 changes: 31 additions & 6 deletions src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import core.Types._
import core.Constants._
import core.StdNames._
import core.transform.Erasure.isUnboundedGeneric
import typer._
import typer.ErrorReporting._
import ast.Trees._
import ast.{tpd, untpd}

/** This transform eliminates patterns. Right now it's a dummy.
* Awaiting the real pattern matcher.
Expand All @@ -22,14 +24,37 @@ class TreeChecker {

def check(ctx: Context) = {
println(s"checking ${ctx.compilationUnit} after phase ${ctx.phase.prev}")
Checker.transform(ctx.compilationUnit.tpdTree)(ctx)
Checker.typedExpr(ctx.compilationUnit.tpdTree)(ctx)
}

object Checker extends TreeMap {
override def transform(tree: Tree)(implicit ctx: Context) = {
println(i"checking $tree")
assert(tree.isEmpty || tree.hasType, tree.show)
super.transform(tree)
object Checker extends ReTyper {
override def typed(tree: untpd.Tree, pt: Type)(implicit ctx: Context) =
if (tree.isEmpty) tree.asInstanceOf[Tree]
else {
assert(tree.hasType, tree.show)
val tree1 = super.typed(tree, pt)
def sameType(tp1: Type, tp2: Type) =
(tp1 eq tp2) || // accept NoType / NoType
(tp1 =:= tp2)
def divergenceMsg =
s"""Types differ
|Original type : ${tree.typeOpt.show}
|After checking: ${tree1.tpe.show}
|Original tree : ${tree.show}
|After checking: ${tree1.show}
""".stripMargin
assert(sameType(tree1.tpe, tree.typeOpt), divergenceMsg)
tree1
}

override def typedIdent(tree: untpd.Ident, pt: Type)(implicit ctx: Context): Tree = {
assert(tree.isTerm, tree.show)
super.typedIdent(tree, pt)
}

override def typedSelect(tree: untpd.Select, pt: Type)(implicit ctx: Context): Tree = {
assert(tree.isTerm, tree.show)
super.typedSelect(tree, pt)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -794,8 +794,8 @@ trait Applications extends Compatibility { self: Typer =>
tp
}

val owner1 = alt1.symbol.owner
val owner2 = alt2.symbol.owner
val owner1 = if (alt1.symbol.exists) alt1.symbol.owner else NoSymbol
val owner2 = if (alt2.symbol.exists) alt2.symbol.owner else NoSymbol
val tp1 = stripImplicit(alt1.widen)
val tp2 = stripImplicit(alt2.widen)

Expand Down
4 changes: 3 additions & 1 deletion src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package typer

import core.Contexts._
import core.Types._
import core.Symbols.Symbol
import core.Symbols._
import typer.ProtoTypes._
import ast.{tpd, untpd}
import ast.Trees._
Expand Down Expand Up @@ -48,6 +48,8 @@ class ReTyper extends Typer {
untpd.cpy.Bind(tree, tree.name, body1).withType(tree.typeOpt)
}

override def localDummy(cls: ClassSymbol, impl: untpd.Template)(implicit ctx: Context) = impl.symbol

override def retrieveSym(tree: untpd.Tree)(implicit ctx: Context): Symbol = tree.symbol

override def localTyper(sym: Symbol) = this
Expand Down
9 changes: 6 additions & 3 deletions src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,11 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
val parents1 = ensureConstrCall(ensureFirstIsClass(
parents mapconserve typedParent, cdef.pos.toSynthetic))
val self1 = typed(self)(ctx.outer).asInstanceOf[ValDef] // outer context where class memebers are not visible
val localDummy = ctx.newLocalDummy(cls, impl.pos)
val body1 = typedStats(body, localDummy)(inClassContext(self1.symbol))
val dummy = localDummy(cls, impl)
val body1 = typedStats(body, dummy)(inClassContext(self1.symbol))
checkNoDoubleDefs(cls)
val impl1 = cpy.Template(impl, constr1, parents1, self1, body1)
.withType(localDummy.termRef)
.withType(dummy.termRef)
assignType(cpy.TypeDef(cdef, mods1, name, impl1), cls)

// todo later: check that
Expand All @@ -825,6 +825,9 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
// 4. Polymorphic type defs override nothing.
}

def localDummy(cls: ClassSymbol, impl: untpd.Template)(implicit ctx: Context): Symbol =
ctx.newLocalDummy(cls, impl.pos)

def typedImport(imp: untpd.Import, sym: Symbol)(implicit ctx: Context): Import = track("typedImport") {
val expr1 = typedExpr(imp.expr, AnySelectionProto)
checkStable(expr1.tpe, imp.expr.pos)
Expand Down
2 changes: 2 additions & 0 deletions test/dotc/tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class tests extends CompilerTest {
@Test def neg_i39 = compileFile(negDir, "i39", xerrors = 1)
@Test def neg_i50_volatile = compileFile(negDir, "i50-volatile", xerrors = 4)
@Test def neg_t0273_doubledefs = compileFile(negDir, "t0273", xerrors = 1)
@Test def neg_t0586_structural = compileFile(negDir, "t0586", xerrors = 1)
@Test def neg_t0625_structural = compileFile(negDir, "t0625", xerrors = 1)
@Test def neg_t0654_polyalias = compileFile(negDir, "t0654", xerrors = 2)
@Test def neg_t1192_legalPrefix = compileFile(negDir, "t1192", xerrors = 1)

Expand Down
File renamed without changes.
File renamed without changes.
11 changes: 11 additions & 0 deletions tests/pos/unions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@ object unions {

class A {
def f: String = "abc"

def g(x: Int): Int = x
def g(x: Double): Double = x
}

class B {
def f: String = "bcd"

def g(x: Int) = -x
def g(x: Double): Double = -x
}

val x: A | B = if (true) new A else new B
def y: B | A = if (true) new A else new B
println(x.f)
println(x.g(2))
println(y.f)
println(y.g(1.0))


}