Skip to content

Commit 1ba082a

Browse files
authored
Merge pull request #5497 from dotty-staging/inline-for-derive
Tests to Explore Typeclass Derivation
2 parents 11a9a78 + 6eace06 commit 1ba082a

10 files changed

+573
-35
lines changed

compiler/src/dotty/tools/dotc/printing/DecompilerPrinter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ import dotty.tools.dotc.core.StdNames.nme
99
import dotty.tools.dotc.core.Flags._
1010
import dotty.tools.dotc.core.Symbols._
1111
import dotty.tools.dotc.core.StdNames._
12-
12+
import dotty.tools.dotc.core.Annotations.Annotation
1313

1414
class DecompilerPrinter(_ctx: Context) extends RefinedPrinter(_ctx) {
1515

16-
override protected def filterModTextAnnots(annots: List[untpd.Tree]): List[untpd.Tree] =
17-
super.filterModTextAnnots(annots).filter(_.tpe != defn.SourceFileAnnotType)
16+
override protected def dropAnnotForModText(sym: Symbol): Boolean =
17+
super.dropAnnotForModText(sym) || sym == defn.SourceFileAnnot
1818

1919
override protected def blockToText[T >: Untyped](block: Block[T]): Text =
2020
block match {

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Symbols._
1010
import NameOps._
1111
import TypeErasure.ErasedValueType
1212
import Contexts.Context
13+
import Annotations.Annotation
1314
import Denotations._
1415
import SymDenotations._
1516
import StdNames.{nme, tpnme}
@@ -624,7 +625,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
624625
def Modifiers(sym: Symbol)(implicit ctx: Context): Modifiers = untpd.Modifiers(
625626
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags),
626627
if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY,
627-
sym.annotations map (_.tree))
628+
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree))
629+
630+
protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot
628631

629632
protected def optAscription[T >: Untyped](tpt: Tree[T]): Text = optText(tpt)(": " ~ _)
630633

@@ -748,14 +751,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
748751
if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= Implicit // drop implicit from classes
749752
val flags = (if (sym.exists) sym.flags else (mods.flags)) & flagMask
750753
val flagsText = if (flags.isEmpty) "" else keywordStr(flags.toString)
751-
val annotations = filterModTextAnnots(
752-
if (sym.exists) sym.annotations.filterNot(_.isInstanceOf[Annotations.BodyAnnotation]).map(_.tree)
753-
else mods.annotations)
754+
val annotations =
755+
if (sym.exists) sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)
756+
else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol))
754757
Text(annotations.map(annotText), " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
755758
}
756759

757-
protected def filterModTextAnnots(annots: List[untpd.Tree]): List[untpd.Tree] = annots
758-
759760
def optText(name: Name)(encl: Text => Text): Text =
760761
if (name.isEmpty) "" else encl(toText(name))
761762

compiler/src/dotty/tools/dotc/typer/ConstFold.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ object ConstFold {
2020
def apply[T <: Tree](tree: T)(implicit ctx: Context): T = finish(tree) {
2121
tree match {
2222
case Apply(Select(xt, op), yt :: Nil) =>
23-
xt.tpe.widenTermRefExpr match {
23+
xt.tpe.widenTermRefExpr.normalized match {
2424
case ConstantType(x) =>
2525
yt.tpe.widenTermRefExpr match {
2626
case ConstantType(y) => foldBinop(op, x, y)
@@ -42,7 +42,7 @@ object ConstFold {
4242
*/
4343
def apply[T <: Tree](tree: T, pt: Type)(implicit ctx: Context): T =
4444
finish(apply(tree)) {
45-
tree.tpe.widenTermRefExpr match {
45+
tree.tpe.widenTermRefExpr.normalized match {
4646
case ConstantType(x) => x convertTo pt
4747
case _ => null
4848
}

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -461,24 +461,19 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
461461
}
462462

463463
// Drop unused bindings
464-
val matchBindings = reducer.matchBindingsBuf.toList
465-
val (finalBindings, finalExpansion) = dropUnusedDefs(bindingsBuf.toList ++ matchBindings, expansion1)
466-
val (finalMatchBindings, finalArgBindings) = finalBindings.partition(matchBindings.contains(_))
464+
val (finalBindings, finalExpansion) = dropUnusedDefs(bindingsBuf.toList, expansion1)
467465

468466
if (inlinedMethod == defn.Typelevel_error) issueError()
469467

470468
// Take care that only argument bindings go into `bindings`, since positions are
471469
// different for bindings from arguments and bindings from body.
472-
tpd.Inlined(call, finalArgBindings, seq(finalMatchBindings, finalExpansion))
470+
tpd.Inlined(call, finalBindings, finalExpansion)
473471
}
474472
}
475473

476474
/** A utility object offering methods for rewriting inlined code */
477475
object reducer {
478476

479-
/** Additional bindings established by reducing match expressions */
480-
val matchBindingsBuf = new mutable.ListBuffer[MemberDef]
481-
482477
/** An extractor for terms equivalent to `new C(args)`, returning the class `C`,
483478
* a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can
484479
* follow a reference to an inline value binding to its right hand side.
@@ -599,7 +594,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
599594
def unapply(tree: Trees.Ident[_])(implicit ctx: Context): Option[Tree] = {
600595
def search(buf: mutable.ListBuffer[MemberDef]) = buf.find(_.name == tree.name)
601596
if (paramProxies.contains(tree.typeOpt))
602-
search(bindingsBuf).orElse(search(matchBindingsBuf)) match {
597+
search(bindingsBuf) match {
603598
case Some(vdef: ValDef) if vdef.symbol.is(Inline) =>
604599
Some(integrate(vdef.rhs, vdef.symbol))
605600
case Some(ddef: DefDef) =>
@@ -611,7 +606,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
611606
}
612607

613608
object ConstantValue {
614-
def unapply(tree: Tree)(implicit ctx: Context): Option[Any] = tree.tpe.widenTermRefExpr match {
609+
def unapply(tree: Tree)(implicit ctx: Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match {
615610
case ConstantType(Constant(x)) => Some(x)
616611
case _ => None
617612
}
@@ -662,7 +657,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
662657
* for the pattern-bound variables and the RHS of the selected case.
663658
* Returns `None` if no case was selected.
664659
*/
665-
type MatchRedux = Option[(List[MemberDef], untpd.Tree)]
660+
type MatchRedux = Option[(List[MemberDef], tpd.Tree)]
666661

667662
/** Reduce an inline match
668663
* @param mtch the match tree
@@ -674,7 +669,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
674669
* @return optionally, if match can be reduced to a matching case: A pair of
675670
* bindings for all pattern-bound variables and the RHS of the case.
676671
*/
677-
def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[untpd.CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = {
672+
def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = {
678673

679674
val isImplicit = scrutinee.isEmpty
680675
val gadtSyms = typer.gadtSyms(scrutType)
@@ -712,7 +707,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
712707
val getBoundVars = new TreeAccumulator[List[TypeSymbol]] {
713708
def apply(syms: List[TypeSymbol], t: Tree)(implicit ctx: Context) = {
714709
val syms1 = t match {
715-
case t: Bind if t.symbol.isType && t.name != tpnme.WILDCARD =>
710+
case t: Bind if t.symbol.isType =>
716711
t.symbol.asType :: syms
717712
case _ =>
718713
syms
@@ -739,7 +734,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
739734
// ConstraintHandler#approximation does. However, this only works for constrained paramrefs
740735
// not GADT-bound variables. Hopefully we will get some way to improve this when we
741736
// re-implement GADTs in terms of constraints.
742-
bindingsBuf += TypeDef(bv)
737+
if (bv.name != nme.WILDCARD) bindingsBuf += TypeDef(bv)
743738
}
744739
reducePattern(bindingsBuf, scrut, pat1)
745740
}
@@ -805,7 +800,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
805800
val scrutineeSym = newSym(InlineScrutineeName.fresh(), Synthetic, scrutType).asTerm
806801
val scrutineeBinding = normalizeBinding(ValDef(scrutineeSym, scrutinee))
807802

808-
def reduceCase(cdef: untpd.CaseDef): MatchRedux = {
803+
def reduceCase(cdef: CaseDef): MatchRedux = {
809804
val caseBindingsBuf = new mutable.ListBuffer[MemberDef]()
810805
def guardOK(implicit ctx: Context) = cdef.guard.isEmpty || {
811806
val guardCtx = ctx.fresh.setNewScope
@@ -824,7 +819,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
824819
None
825820
}
826821

827-
def recur(cases: List[untpd.CaseDef]): MatchRedux = cases match {
822+
def recur(cases: List[CaseDef]): MatchRedux = cases match {
828823
case Nil => None
829824
case cdef :: cases1 => reduceCase(cdef) `orElse` recur(cases1)
830825
}
@@ -895,14 +890,15 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
895890
super.typedMatchFinish(tree, sel, wideSelType, cases, pt)
896891
else {
897892
val selType = if (sel.isEmpty) wideSelType else sel.tpe
898-
reduceInlineMatch(sel, selType, cases, this) match {
899-
case Some((caseBindings, rhs)) =>
900-
var rhsCtx = ctx.fresh.setNewScope
901-
for (binding <- caseBindings) {
902-
matchBindingsBuf += binding
903-
rhsCtx.enter(binding.symbol)
904-
}
905-
typedExpr(rhs, pt)(rhsCtx)
893+
reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match {
894+
case Some((caseBindings, rhs0)) =>
895+
val (usedBindings, rhs1) = dropUnusedDefs(caseBindings, rhs0)
896+
val rhs = seq(usedBindings, rhs1)
897+
inlining.println(i"""--- reduce:
898+
|$tree
899+
|--- to:
900+
|$rhs""")
901+
typedExpr(rhs, pt)
906902
case None =>
907903
def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard"
908904
def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}"
@@ -993,7 +989,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
993989
val dealiasedType = dealias(t.tpe)
994990
val t1 = t match {
995991
case t: RefTree =>
996-
if (boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos)
992+
if (t.name != nme.WILDCARD && boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos)
997993
else t.withType(dealiasedType)
998994
case t: DefTree =>
999995
t.symbol.info = dealias(t.symbol.info)

compiler/test/dotc/pos-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ i4125.scala
99
implicit-dep.scala
1010
inline-access-levels
1111
inline-rewrite.scala
12+
inline-caseclass.scala
1213
macro-with-array
1314
macro-with-type
1415
matchtype.scala

compiler/test/dotc/run-from-tasty.blacklist

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ puzzle.scala
66

77
# Need to print empty tree for implicit match
88
implicitMatch.scala
9+
typeclass-derivation1.scala
10+
typeclass-derivation2.scala
11+

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ t8133b
99
tuples1.scala
1010
tuples1a.scala
1111
implicitMatch.scala
12+
typeclass-derivation1.scala
13+
typeclass-derivation2.scala

tests/run/typeclass-derivation1.scala

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
object Deriving {
2+
import scala.typelevel._
3+
4+
sealed trait Shape
5+
6+
class HasSumShape[T, S <: Tuple]
7+
8+
abstract class HasProductShape[T, Xs <: Tuple] {
9+
def toProduct(x: T): Xs
10+
def fromProduct(x: Xs): T
11+
}
12+
13+
enum Lst[+T] {
14+
case Cons(hd: T, tl: Lst[T])
15+
case Nil
16+
}
17+
18+
object Lst {
19+
implicit def lstShape[T]: HasSumShape[Lst[T], (Cons[T], Nil.type)] = new HasSumShape
20+
21+
implicit def consShape[T]: HasProductShape[Lst.Cons[T], (T, Lst[T])] = new {
22+
def toProduct(xs: Lst.Cons[T]) = (xs.hd, xs.tl)
23+
def fromProduct(xs: (T, Lst[T])): Lst.Cons[T] = Lst.Cons(xs(0), xs(1)).asInstanceOf
24+
}
25+
26+
implicit def nilShape[T]: HasProductShape[Lst.Nil.type, Unit] = new {
27+
def toProduct(xs: Lst.Nil.type) = ()
28+
def fromProduct(xs: Unit) = Lst.Nil
29+
}
30+
31+
implicit def LstEq[T: Eq]: Eq[Lst[T]] = Eq.derivedForSum
32+
implicit def ConsEq[T: Eq]: Eq[Cons[T]] = Eq.derivedForProduct
33+
implicit def NilEq[T]: Eq[Nil.type] = Eq.derivedForProduct
34+
}
35+
36+
trait Eq[T] {
37+
def equals(x: T, y: T): Boolean
38+
}
39+
40+
object Eq {
41+
inline def tryEq[T](x: T, y: T) = implicit match {
42+
case eq: Eq[T] => eq.equals(x, y)
43+
}
44+
45+
inline def deriveForSum[Alts <: Tuple](x: Any, y: Any): Boolean = inline erasedValue[Alts] match {
46+
case _: (alt *: alts1) =>
47+
x match {
48+
case x: `alt` =>
49+
y match {
50+
case y: `alt` => tryEq[alt](x, y)
51+
case _ => false
52+
}
53+
case _ => deriveForSum[alts1](x, y)
54+
}
55+
case _: Unit =>
56+
false
57+
}
58+
59+
inline def deriveForProduct[Elems <: Tuple](xs: Elems, ys: Elems): Boolean = inline erasedValue[Elems] match {
60+
case _: (elem *: elems1) =>
61+
val xs1 = xs.asInstanceOf[elem *: elems1]
62+
val ys1 = ys.asInstanceOf[elem *: elems1]
63+
tryEq[elem](xs1.head, ys1.head) &&
64+
deriveForProduct[elems1](xs1.tail, ys1.tail)
65+
case _: Unit =>
66+
true
67+
}
68+
69+
inline def derivedForSum[T, Alts <: Tuple](implicit ev: HasSumShape[T, Alts]): Eq[T] = new {
70+
def equals(x: T, y: T): Boolean = deriveForSum[Alts](x, y)
71+
}
72+
73+
inline def derivedForProduct[T, Elems <: Tuple](implicit ev: HasProductShape[T, Elems]): Eq[T] = new {
74+
def equals(x: T, y: T): Boolean = deriveForProduct[Elems](ev.toProduct(x), ev.toProduct(y))
75+
}
76+
77+
implicit object eqInt extends Eq[Int] {
78+
def equals(x: Int, y: Int) = x == y
79+
}
80+
}
81+
}
82+
83+
object Test extends App {
84+
import Deriving._
85+
val eq = implicitly[Eq[Lst[Int]]]
86+
val xs = Lst.Cons(1, Lst.Cons(2, Lst.Cons(3, Lst.Nil)))
87+
val ys = Lst.Cons(1, Lst.Cons(2, Lst.Nil))
88+
assert(eq.equals(xs, xs))
89+
assert(!eq.equals(xs, ys))
90+
assert(!eq.equals(ys, xs))
91+
assert(eq.equals(ys, ys))
92+
93+
val eq2 = implicitly[Eq[Lst[Lst[Int]]]]
94+
val xss = Lst.Cons(xs, Lst.Cons(ys, Lst.Nil))
95+
val yss = Lst.Cons(xs, Lst.Nil)
96+
assert(eq2.equals(xss, xss))
97+
assert(!eq2.equals(xss, yss))
98+
assert(!eq2.equals(yss, xss))
99+
assert(eq2.equals(yss, yss))
100+
}

tests/run/typeclass-derivation2.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
ListBuffer(0, 11, 0, 22, 0, 33, 1)
2+
Cons(11,Cons(22,Cons(33,Nil)))
3+
ListBuffer(0, 0, 11, 0, 22, 0, 33, 1, 0, 0, 11, 0, 22, 1, 1)
4+
Cons(Cons(11,Cons(22,Cons(33,Nil))),Cons(Cons(11,Cons(22,Nil)),Nil))
5+
ListBuffer(1, 2)
6+
Pair(1,2)
7+
Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil())))
8+
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil())), tl = Nil()))

0 commit comments

Comments
 (0)