Skip to content

SI-6187 Make partial functions re-typable #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/compiler/scala/tools/nsc/typechecker/PatternMatching.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ trait PatternMatching extends Transform with TypingTransformers with ast.TreeDSL
import definitions._
import analyzer._ //Typer


case class DefaultOverrideMatchAttachment(default: Tree)

object vpmName {
val one = newTermName("one")
val drop = newTermName("drop")
Expand Down Expand Up @@ -222,11 +219,11 @@ trait PatternMatching extends Transform with TypingTransformers with ast.TreeDSL
// However this is a pain (at least the way I'm going about it)
// and I have to think these detailed errors are primarily useful
// for beginners, not people writing nested pattern matches.
def checkMatchVariablePatterns(m: Match) {
def checkMatchVariablePatterns(cases: List[CaseDef]) {
// A string describing the first variable pattern
var vpat: String = null
// Using an iterator so we can recognize the last case
val it = m.cases.iterator
val it = cases.iterator

def addendum(pat: Tree) = {
matchingSymbolInScope(pat) match {
Expand Down Expand Up @@ -269,7 +266,15 @@ trait PatternMatching extends Transform with TypingTransformers with ast.TreeDSL
*/
def translateMatch(match_ : Match): Tree = {
val Match(selector, cases) = match_
checkMatchVariablePatterns(match_)

val (nonSyntheticCases, defaultOverride) = cases match {
case init :+ last if treeInfo isSyntheticDefaultCase last =>
(init, Some(((scrut: Tree) => last.body)))
case _ =>
(cases, None)
}

checkMatchVariablePatterns(nonSyntheticCases)

// we don't transform after uncurry
// (that would require more sophistication when generating trees,
Expand All @@ -296,14 +301,10 @@ trait PatternMatching extends Transform with TypingTransformers with ast.TreeDSL

// val packedPt = repeatedToSeq(typer.packedType(match_, context.owner))

// the alternative to attaching the default case override would be to simply
// append the default to the list of cases and suppress the unreachable case error that may arise (once we detect that...)
val matchFailGenOverride = match_.attachments.get[DefaultOverrideMatchAttachment].map{case DefaultOverrideMatchAttachment(default) => ((scrut: Tree) => default)}

val selectorSym = freshSym(selector.pos, pureType(selectorTp)) setFlag treeInfo.SYNTH_CASE_FLAGS

// pt = Any* occurs when compiling test/files/pos/annotDepMethType.scala with -Xexperimental
val combined = combineCases(selector, selectorSym, cases map translateCase(selectorSym, pt), pt, matchOwner, matchFailGenOverride)
val combined = combineCases(selector, selectorSym, nonSyntheticCases map translateCase(selectorSym, pt), pt, matchOwner, defaultOverride)

if (Statistics.canEnable) Statistics.stopTimer(patmatNanos, start)
combined
Expand Down
87 changes: 73 additions & 14 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ trait Typers extends Modes with Adaptations with Tags {
import global._
import definitions._
import TypersStats._
import patmat.DefaultOverrideMatchAttachment

final def forArgMode(fun: Tree, mode: Int) =
if (treeInfo.isSelfOrSuperConstrCall(fun)) mode | SCCmode
Expand Down Expand Up @@ -2674,7 +2673,6 @@ trait Typers extends Modes with Adaptations with Tags {
import CODE._

val Match(sel, cases) = tree

// need to duplicate the cases before typing them to generate the apply method, or the symbols will be all messed up
val casesTrue = cases map (c => deriveCaseDef(c)(x => atPos(x.pos.focus)(TRUE_typed)).duplicate.asInstanceOf[CaseDef])

Expand All @@ -2696,8 +2694,13 @@ trait Typers extends Modes with Adaptations with Tags {
def mkParam(methodSym: Symbol, tp: Type = argTp) =
methodSym.newValueParameter(paramName, paramPos.focus, SYNTHETIC) setInfo tp

def mkDefaultCase(body: Tree) =
atPos(tree.pos.makeTransparent) {
CaseDef(Bind(nme.DEFAULT_CASE, Ident(nme.WILDCARD)), body)
}

// `def applyOrElse[A1 <: $argTp, B1 >: $matchResTp](x: A1, default: A1 => B1): B1 =
// ${`$selector match { $cases }` updateAttachment DefaultOverrideMatchAttachment(REF(default) APPLY (REF(x)))}`
// ${`$selector match { $cases; case default$ => default(x) }`
def applyOrElseMethodDef = {
val methodSym = anonClass.newMethod(nme.applyOrElse, tree.pos, FINAL | OVERRIDE)

Expand All @@ -2706,7 +2709,7 @@ trait Typers extends Modes with Adaptations with Tags {
val x = mkParam(methodSym, A1.tpe)

// applyOrElse's default parameter:
val B1 = methodSym newTypeParameter (newTypeName("B1")) setInfo TypeBounds.empty //lower(resTp)
val B1 = methodSym newTypeParameter (newTypeName("B1")) setInfo TypeBounds.empty
val default = methodSym newValueParameter (newTermName("default"), tree.pos.focus, SYNTHETIC) setInfo functionType(List(A1.tpe), B1.tpe)

val paramSyms = List(x, default)
Expand All @@ -2716,19 +2719,72 @@ trait Typers extends Modes with Adaptations with Tags {
// should use the DefDef for the context's tree, but it doesn't exist yet (we need the typer we're creating to create it)
paramSyms foreach (methodBodyTyper.context.scope enter _)

val match_ = methodBodyTyper.typedMatch(selector, cases, mode, resTp)
// First, type without the default case; only the cases provided
// by the user are typed. The LUB of these becomes `B`, the lower
// bound of `B1`, which in turn is the result type of the default
// case
val match0 = methodBodyTyper.typedMatch(selector, cases, mode, resTp)
val matchResTp = match0.tpe

val matchResTp = match_.tpe
B1 setInfo TypeBounds.lower(matchResTp) // patch info

// the default uses applyOrElse's first parameter since the scrut's type has been widened
val match_ = {
val defaultCase = methodBodyTyper.typedCase(
mkDefaultCase(methodBodyTyper.typed1(REF(default) APPLY (REF(x)), mode, B1.tpe).setType(B1.tpe)), argTp, B1.tpe)
treeCopy.Match(match0, match0.selector, match0.cases :+ defaultCase)
}
match_ setType B1.tpe

// the default uses applyOrElse's first parameter since the scrut's type has been widened
val matchWithDefault = match_ updateAttachment DefaultOverrideMatchAttachment(REF(default) APPLY (REF(x)))
(DefDef(methodSym, methodBodyTyper.virtualizedMatch(matchWithDefault, mode, B1.tpe)), matchResTp)
// SI-6187 Do you really want to know? Okay, here's what's going on here.
//
// Well behaved trees satisfy the property:
//
// typed(tree) == typed(resetLocalAttrs(typed(tree))
//
// Trees constructed without low-level symbol manipulation get this for free;
// references to local symbols are cleared by `ResetAttrs`, but bind to the
// corresponding symbol in the re-typechecked tree. But PartialFunction synthesis
// doesn't play by these rules.
//
// During typechecking of method bodies, references to method type parameter from
// the declared types of the value parameters should bind to a fresh set of skolems,
// which have been entered into scope by `Namer#methodSig`. A comment therein:
//
// "since the skolemized tparams are in scope, the TypeRefs in vparamSymss refer to skolemized tparams"
//
// But, if we retypecheck the reset `applyOrElse`, the TypeTree of the `default`
// parameter contains no type. Somehow (where?!) it recovers a type that is _almost_ okay:
// `A1 => B1`. But it should really be `A1&0 => B1&0`. In the test, run/t6187.scala, this
// difference results in a type error, as `default.apply(x)` types as `B1`, which doesn't
// conform to the required `B1&0`
//
// I see three courses of action.
//
// 1) synthesize a `asInstanceOf[B1]` below (I tried this first. But... ewwww.)
// 2) install an 'original' TypeTree that will used after ResetAttrs (the solution below)
// 3) Figure out how the almost-correct type is recovered on re-typechecking, and
// substitute in the skolems.
//
// For 2.11, we'll probably shift this transformation back a phase or two, so macros
// won't be affected. But in any case, we should satisfy retypecheckability.
//
val originals: Map[Symbol, Tree] = {
def typedIdent(sym: Symbol) = methodBodyTyper.typedType(Ident(sym), mode)
val A1Tpt = typedIdent(A1)
val B1Tpt = typedIdent(B1)
Map(
x -> A1Tpt,
default -> gen.scalaFunctionConstr(List(A1Tpt), B1Tpt)
)
}
val rhs = methodBodyTyper.virtualizedMatch(match_, mode, B1.tpe)
val defdef = DefDef(methodSym, Modifiers(methodSym.flags), originals, rhs)

(defdef, matchResTp)
}

// `def isDefinedAt(x: $argTp): Boolean = ${`$selector match { $casesTrue ` updateAttachment DefaultOverrideMatchAttachment(FALSE_typed)}`
// `def isDefinedAt(x: $argTp): Boolean = ${`$selector match { $casesTrue; case default$ => false } }`
def isDefinedAtMethod = {
val methodSym = anonClass.newMethod(nme.isDefinedAt, tree.pos.makeTransparent, FINAL)
val paramSym = mkParam(methodSym)
Expand All @@ -2737,10 +2793,10 @@ trait Typers extends Modes with Adaptations with Tags {
methodBodyTyper.context.scope enter paramSym
methodSym setInfo MethodType(List(paramSym), BooleanClass.tpe)

val match_ = methodBodyTyper.typedMatch(selector, casesTrue, mode, BooleanClass.tpe)
val defaultCase = mkDefaultCase(FALSE_typed)
val match_ = methodBodyTyper.typedMatch(selector, casesTrue :+ defaultCase, mode, BooleanClass.tpe)

val matchWithDefault = match_ updateAttachment DefaultOverrideMatchAttachment(FALSE_typed)
DefDef(methodSym, methodBodyTyper.virtualizedMatch(matchWithDefault, mode, BooleanClass.tpe))
DefDef(methodSym, methodBodyTyper.virtualizedMatch(match_, mode, BooleanClass.tpe))
}

// only used for @cps annotated partial functions
Expand Down Expand Up @@ -2785,7 +2841,9 @@ trait Typers extends Modes with Adaptations with Tags {
members foreach (m => anonClass.info.decls enter m.symbol)

val typedBlock = typedPos(tree.pos, mode, pt) {
Block(ClassDef(anonClass, NoMods, ListOfNil, members, tree.pos.focus), atPos(tree.pos.focus)(New(anonClass.tpe)))
Block(ClassDef(anonClass, NoMods, ListOfNil, members, tree.pos.focus), atPos(tree.pos.focus)(
Apply(Select(New(Ident(anonClass.name).setSymbol(anonClass)), nme.CONSTRUCTOR), List())
))
}

if (typedBlock.isErrorTyped) typedBlock
Expand Down Expand Up @@ -5877,4 +5935,5 @@ object TypersStats {
val visitsByType = Statistics.newByClass("#visits by tree node", "typer")(Statistics.newCounter(""))
val byTypeNanos = Statistics.newByClass("time spent by tree node", "typer")(Statistics.newStackableTimer("", typerNanos))
val byTypeStack = Statistics.newTimerStack()

}
1 change: 1 addition & 0 deletions src/reflect/scala/reflect/internal/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ trait StdNames {
// Compiler internal names
val ANYname: NameType = "<anyname>"
val CONSTRUCTOR: NameType = "<init>"
val DEFAULT_CASE: NameType = "defaultCase$"
val EQEQ_LOCAL_VAR: NameType = "eqEqTemp$"
val FAKE_LOCAL_THIS: NameType = "this$"
val INITIALIZER: NameType = CONSTRUCTOR // Is this buying us something?
Expand Down
7 changes: 7 additions & 0 deletions src/reflect/scala/reflect/internal/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ abstract class TreeInfo {
case _ => false
}

/** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know
* whether the user provided cases are exhaustive. */
def isSyntheticDefaultCase(cdef: CaseDef) = cdef match {
case CaseDef(Bind(nme.DEFAULT_CASE, _), EmptyTree, _) => true
case _ => false
}

/** Does this CaseDef catch Throwable? */
def catchesThrowable(cdef: CaseDef) = catchesAllOf(cdef, ThrowableClass.tpe)

Expand Down
18 changes: 17 additions & 1 deletion src/reflect/scala/reflect/internal/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,11 @@ trait Trees extends api.Trees { self: SymbolTable =>
override private[scala] def copyAttrs(tree: Tree) = {
super.copyAttrs(tree)
tree match {
case other: TypeTree => wasEmpty = other.wasEmpty // SI-6648 Critical for correct operation of `resetAttrs`.
case other: TypeTree =>
// SI-6648 Critical for correct operation of `resetAttrs`.
wasEmpty = other.wasEmpty
if (other.orig != null)
orig = other.orig.duplicate
case _ =>
}
this
Expand Down Expand Up @@ -989,6 +993,18 @@ trait Trees extends api.Trees { self: SymbolTable =>
def DefDef(sym: Symbol, mods: Modifiers, rhs: Tree): DefDef =
DefDef(sym, mods, mapParamss(sym)(ValDef), rhs)

/** A DefDef with original trees attached to the TypeTree of each parameter */
def DefDef(sym: Symbol, mods: Modifiers, originalParamTpts: Symbol => Tree, rhs: Tree): DefDef = {
val paramms = mapParamss(sym){ sym =>
val vd = ValDef(sym, EmptyTree)
(vd.tpt : @unchecked) match {
case tt: TypeTree => tt setOriginal (originalParamTpts(sym) setPos sym.pos.focus)
}
vd
}
DefDef(sym, mods, paramms, rhs)
}

def DefDef(sym: Symbol, rhs: Tree): DefDef =
DefDef(sym, Modifiers(sym.flags), rhs)

Expand Down
25 changes: 0 additions & 25 deletions test/files/run/idempotency-partial-functions.scala

This file was deleted.

32 changes: 32 additions & 0 deletions test/files/run/t6187.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Type in expressions to have them evaluated.
Type :help for more information.

scala> import language.experimental.macros, reflect.macros.Context
import language.experimental.macros
import reflect.macros.Context

scala> def macroImpl[T: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[List[T]] = {
val r = c.universe.reify { List(t.splice) }
c.Expr[List[T]]( c.resetLocalAttrs(r.tree) )
}
macroImpl: [T](c: scala.reflect.macros.Context)(t: c.Expr[T])(implicit evidence$1: c.WeakTypeTag[T])c.Expr[List[T]]

scala> def demo[T](t: T): List[T] = macro macroImpl[T]
demo: [T](t: T)List[T]

scala> def m[T](t: T): List[List[T]] =
demo( List((t,true)) collect { case (x,true) => x } )
m: [T](t: T)List[List[T]]

scala> m(List(1))
res0: List[List[List[Int]]] = List(List(List(1)))

scala> // Showing we haven't added unreachable warnings

scala> List(1) collect { case x => x }
res1: List[Int] = List(1)

scala> List("") collect { case x => x }
res2: List[String] = List("")

scala>
18 changes: 18 additions & 0 deletions test/files/run/t6187.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.tools.partest.ReplTest

object Test extends ReplTest {
override def code = """
import language.experimental.macros, reflect.macros.Context
def macroImpl[T: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[List[T]] = {
val r = c.universe.reify { List(t.splice) }
c.Expr[List[T]]( c.resetLocalAttrs(r.tree) )
}
def demo[T](t: T): List[T] = macro macroImpl[T]
def m[T](t: T): List[List[T]] =
demo( List((t,true)) collect { case (x,true) => x } )
m(List(1))
// Showing we haven't added unreachable warnings
List(1) collect { case x => x }
List("") collect { case x => x }
""".trim
}
5 changes: 5 additions & 0 deletions test/files/run/t6187b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
object Test extends App {
val x: PartialFunction[Int, Int] = { case 1 => 1 }
val o: Any = ""
assert(x.applyOrElse(0, (_: Int) => o) == "")
}
28 changes: 28 additions & 0 deletions test/pending/run/idempotency-partial-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import scala.reflect.runtime.universe._
import scala.reflect.runtime.{currentMirror => cm}
import scala.tools.reflect.{ToolBox, ToolBoxError}
import scala.tools.reflect.Eval

// Related to SI-6187
//
// Moved to pending as we are currently blocked by the inability
// to reify the parent types of the anoymous function class,
// which are not part of the tree, but rather only part of the
// ClassInfoType.
object Test extends App {
val partials = reify {
List((false,true)) collect { case (x,true) => x }
}
println(Seq(show(partials), showRaw(partials)).mkString("\n\n"))
try {
println(partials.eval)
} catch {
case e: ToolBoxError => println(e)
}
val tb = cm.mkToolBox()
val tpartials = tb.typeCheck(partials.tree)
println(tpartials)
val rtpartials = tb.resetAllAttrs(tpartials)
println(tb.eval(rtpartials))
}
Test.main(null)