Skip to content

Commit 721c36d

Browse files
committed
Pretty printing, fixes, and tests
1 parent b7a11de commit 721c36d

File tree

13 files changed

+454
-22
lines changed

13 files changed

+454
-22
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+12-8
Original file line numberDiff line numberDiff line change
@@ -786,34 +786,38 @@ object desugar {
786786
* <body2> = <body1> where each method definition gets <combined-params> as last parameter section.
787787
*/
788788
def extensionDef(tree: Extension)(implicit ctx: Context): Tree = {
789-
val Extension(name, extended, impl) = tree
789+
val Extension(name, constr, extended, impl) = tree
790790
val isSimpleExtension = impl.parents.isEmpty
791791

792792
val firstParams = ValDef(nme.SELF, extended, EmptyTree).withFlags(Private | Local | ParamAccessor) :: Nil
793-
val body1 = substThis.transform(impl.body)
793+
val importSelf = Import(Ident(nme.SELF), Ident(nme.WILDCARD) :: Nil)
794+
val body1 = importSelf :: substThis.transform(impl.body)
794795
val impl1 =
795796
if (isSimpleExtension) {
796797
val (typeParams, evidenceParams) =
797-
desugarTypeBindings(impl.constr.tparams, forPrimaryConstructor = false)
798+
desugarTypeBindings(constr.tparams, forPrimaryConstructor = false)
798799
cpy.Template(impl)(
799-
constr = cpy.DefDef(impl.constr)(tparams = typeParams, vparamss = firstParams :: Nil),
800+
constr = cpy.DefDef(constr)(tparams = typeParams, vparamss = firstParams :: Nil),
800801
parents = ref(defn.AnyValType) :: Nil,
801802
body = body1.map {
802803
case ddef: DefDef =>
803804
def resetFlags(vdef: ValDef) =
804805
vdef.withMods(vdef.mods &~ PrivateLocalParamAccessor | Param)
805-
val originalParams = impl.constr.vparamss.headOption.getOrElse(Nil).map(resetFlags)
806+
val originalParams = constr.vparamss.headOption.getOrElse(Nil).map(resetFlags)
806807
addEvidenceParams(addEvidenceParams(ddef, originalParams), evidenceParams)
807808
case other =>
808809
other
809810
})
810811
}
811812
else
812813
cpy.Template(impl)(
813-
constr = cpy.DefDef(impl.constr)(vparamss = firstParams :: impl.constr.vparamss),
814+
constr = cpy.DefDef(constr)(vparamss = firstParams :: constr.vparamss),
814815
body = body1)
815-
val icls = TypeDef(name, impl1).withMods(tree.mods.withAddedMod(Mod.Extension()) | Implicit)
816-
desugr.println(i"desugar $extended --> $icls")
816+
val mods1 =
817+
if (isSimpleExtension) tree.mods
818+
else tree.mods.withAddedMod(Mod.InstanceDcl())
819+
val icls = TypeDef(name, impl1).withMods(mods1 | Implicit)
820+
desugr.println(i"desugar $tree --> $icls")
817821
classDef(icls)
818822
}
819823

compiler/src/dotty/tools/dotc/ast/untpd.scala

+18-3
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
4040
def withName(name: Name)(implicit ctx: Context) = cpy.ModuleDef(this)(name.toTermName, impl)
4141
}
4242

43-
/** extend extended impl */
44-
case class Extension(name: TypeName, extended: Tree, impl: Template) extends MemberDef
43+
/** extension name tparams vparamss for tpt impl
44+
*
45+
* where `tparams` and `vparamss` are part of `constr`.
46+
*/
47+
case class Extension(name: TypeName, constr: DefDef, tpt: Tree, impl: Template)
48+
extends MemberDef{
49+
type ThisTree[-T >: Untyped] <: Trees.NameTree[T] with Trees.MemberDef[T] with Extension
50+
def withName(name: Name)(implicit ctx: Context) = cpy.Extension(this)(name.toTypeName, constr, tpt, impl)
51+
}
4552

4653
case class ParsedTry(expr: Tree, handler: Tree, finalizer: Tree) extends TermTree
4754

@@ -143,7 +150,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
143150

144151
case class EnumCase() extends Mod(Flags.EmptyFlags)
145152

146-
case class Extension() extends Mod(Flags.EmptyFlags)
153+
case class InstanceDcl() extends Mod(Flags.EmptyFlags)
147154
}
148155

149156
/** Modifiers and annotations for definitions
@@ -416,6 +423,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
416423
case tree: ModuleDef if (name eq tree.name) && (impl eq tree.impl) => tree
417424
case _ => finalize(tree, untpd.ModuleDef(name, impl))
418425
}
426+
def Extension(tree: Tree)(name: TypeName, constr: DefDef, tpt: Tree, impl: Template) = tree match {
427+
case tree: Extension if (name eq tree.name) && (constr eq tree.constr) && (tpt eq tree.tpt) && (impl eq tree.impl) => tree
428+
case _ => finalize(tree, untpd.Extension(name, constr, tpt, impl))
429+
}
419430
def ParsedTry(tree: Tree)(expr: Tree, handler: Tree, finalizer: Tree) = tree match {
420431
case tree: ParsedTry
421432
if (expr eq tree.expr) && (handler eq tree.handler) && (finalizer eq tree.finalizer) => tree
@@ -499,6 +510,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
499510
override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
500511
case ModuleDef(name, impl) =>
501512
cpy.ModuleDef(tree)(name, transformSub(impl))
513+
case Extension(name, constr, tpt, impl) =>
514+
cpy.Extension(tree)(name, transformSub(constr), transform(tpt), transformSub(impl))
502515
case ParsedTry(expr, handler, finalizer) =>
503516
cpy.ParsedTry(tree)(transform(expr), transform(handler), transform(finalizer))
504517
case SymbolLit(str) =>
@@ -548,6 +561,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
548561
override def foldOver(x: X, tree: Tree)(implicit ctx: Context): X = tree match {
549562
case ModuleDef(name, impl) =>
550563
this(x, impl)
564+
case Extension(name, constr, tpt, impl) =>
565+
this(this(this(x, constr), tpt), impl)
551566
case ParsedTry(expr, handler, finalizer) =>
552567
this(this(this(x, expr), handler), finalizer)
553568
case SymbolLit(str) =>

compiler/src/dotty/tools/dotc/config/Printers.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ object Printers {
1717
val checks: Printer = noPrinter
1818
val config: Printer = noPrinter
1919
val cyclicErrors: Printer = noPrinter
20-
val desugr: Printer = new Printer
20+
val desugr: Printer = noPrinter
2121
val dottydoc: Printer = noPrinter
2222
val exhaustivity: Printer = noPrinter
2323
val gadts: Printer = noPrinter

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -2291,16 +2291,16 @@ object Parsers {
22912291
val name = ident().toTypeName
22922292
val tparams = typeParamClauseOpt(ParamOwner.Class)
22932293
val vparamss = paramClauses(tpnme.EMPTY, ofExtension = true).take(1)
2294-
val constr = makeConstructor(Nil, vparamss)
2294+
val constr = makeConstructor(tparams, vparamss)
22952295
accept(FOR)
22962296
val extended = annotType()
22972297
val templ =
22982298
if (in.token == COLON) {
22992299
in.nextToken()
2300-
template(constr, bodyRequired = true)._1
2300+
template(emptyConstructor, bodyRequired = true)._1
23012301
}
23022302
else {
2303-
val templ = templateClauseOpt(constr, bodyRequired = true)
2303+
val templ = templateClauseOpt(emptyConstructor, bodyRequired = true)
23042304
def checkDef(tree: Tree) = tree match {
23052305
case _: DefDef | EmptyValDef => // ok
23062306
case _ => syntaxError("`def` expected", tree.pos.startPos.orElse(templ.pos.startPos))
@@ -2309,7 +2309,7 @@ object Parsers {
23092309
templ.body.foreach(checkDef)
23102310
templ
23112311
}
2312-
Extension(name, extended, templ)
2312+
Extension(name, constr, extended, templ)
23132313
}
23142314

23152315
/* -------- TEMPLATES ------------------------------------------- */

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+19-3
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,17 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
443443
withEnclosingDef(tree) {
444444
modText(tree.mods, keywordStr("object")) ~~ nameIdText(tree) ~ toTextTemplate(impl)
445445
}
446+
case tree @ Extension(name, constr, tpt, impl) =>
447+
withEnclosingDef(tree) {
448+
modText(tree.mods, keywordStr("extension")) ~~
449+
nameIdText(tree) ~
450+
{ withEnclosingDef(constr) {
451+
addVparamssText(tparamsText(constr.tparams), constr.vparamss.drop(1))
452+
}
453+
} ~
454+
" for " ~ atPrec(DotPrec) { toText(tpt) } ~~
455+
toTextTemplateBody(impl, Str(" :") `provided` impl.parents.nonEmpty)
456+
}
446457
case SymbolLit(str) =>
447458
"'" + str
448459
case InterpolatedString(id, segments) =>
@@ -624,7 +635,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
624635
}
625636

626637
protected def toTextTemplate(impl: Template, ofNew: Boolean = false): Text = {
627-
val Template(constr @ DefDef(_, tparams, vparamss, _, _), parents, self, _) = impl
638+
val constr @ DefDef(_, tparams, vparamss, _, _) = impl.constr
628639
val tparamsTxt = withEnclosingDef(constr) { tparamsText(tparams) }
629640
val primaryConstrs = if (constr.rhs.isEmpty) Nil else constr :: Nil
630641
val prefix: Text =
@@ -635,6 +646,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
635646
if (constr.mods.hasAnnotations && !constr.mods.hasFlags) modsText = modsText ~~ " this"
636647
withEnclosingDef(constr) { addVparamssText(tparamsTxt ~~ modsText, vparamss) }
637648
}
649+
prefix ~ toTextTemplateBody(impl, keywordText(" extends") `provided` !ofNew, primaryConstrs)
650+
}
651+
652+
protected def toTextTemplateBody(impl: Template, leading: Text, leadingStats: List[Tree[_]] = Nil): Text = {
653+
val Template(_, parents, self, _) = impl
638654
val parentsText = Text(parents map constrText, keywordStr(" with "))
639655
val selfText = {
640656
val selfName = if (self.name == nme.WILDCARD) keywordStr("this") else self.name.toString
@@ -652,9 +668,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
652668
params ::: rest
653669
} else impl.body
654670

655-
val bodyText = "{" ~~ selfText ~~ toTextGlobal(primaryConstrs ::: body, "\n") ~ "}"
671+
val bodyText = "{" ~~ selfText ~~ toTextGlobal(leadingStats ::: body, "\n") ~ "}"
656672

657-
prefix ~ (keywordText(" extends") provided !ofNew) ~~ parentsText ~~ bodyText
673+
leading ~~ parentsText ~~ bodyText
658674
}
659675

660676
protected def templateText(tree: TypeDef, impl: Template): Text = {

compiler/src/dotty/tools/dotc/typer/Namer.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,8 @@ class Namer { typer: Typer =>
943943
if (cls.isRefinementClass) ptype
944944
else {
945945
val pt = checkClassType(ptype, parent.pos,
946-
traitReq = parent ne parents.head, stablePrefixReq = true)
946+
traitReq = (parent `ne` parents.head) || original.mods.hasMod[Mod.InstanceDcl],
947+
stablePrefixReq = true)
947948
if (pt.derivesFrom(cls)) {
948949
val addendum = parent match {
949950
case Select(qual: Super, _) if ctx.scala2Mode =>

compiler/test/dotty/tools/dotc/transform/PatmatExhaustivityTest.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class PatmatExhaustivityTest {
7878
}
7979

8080
val failed = res.filter { case (_, expected, actual) => expected != actual }
81-
val ignored = Directory(testsDir).list.toList.filter(_.extension == "ignore")
81+
val ignored = Directory(testsDir).list.toList.filter(_.`extension` == "ignore")
8282

8383
failed.foreach { case (file, expected, actual) =>
8484
println(s"\n----------------- incorrect output for $file --------------\n" +

docs/docs/reference/extend/extension-methods.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ object PostConditions {
122122

123123
def result[T](implicit er: WrappedResult[T]): T = WrappedResult.unwrap(er)
124124

125-
extenson Ensuring[T] for T {
125+
extension Ensuring[T] for T {
126126
def ensuring(condition: implicit WrappedResult[T] => Boolean): T = {
127127
implicit val wrapped = WrappedResult.wrap(this)
128128
assert(condition)

tests/neg/extensions.scala

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import Predef.{any2stringadd => _, _}
2+
object extensions {
3+
4+
// Simple extension methods
5+
6+
case class Circle(x: Double, y: Double, radius: Double)
7+
8+
extension CircleOps for Circle {
9+
def circumference = this.radius * math.Pi * 2
10+
private val p = math.Pi // error: `def` expected
11+
}
12+
13+
// Trait implementations
14+
15+
trait HasArea {
16+
def area: Double
17+
}
18+
19+
abstract class HasAreaClass extends HasArea
20+
21+
extension Ops2 : HasArea {} // error: `def` expected
22+
extension Ops for Circle extends HasArea {} // error: `def` expected
23+
24+
extension Circle2 : HasArea { // error: `for` expected
25+
def area = this.radius * this.radius * math.Pi
26+
}
27+
28+
extension Ops3 for Circle : HasAreaClass { // error: class HasAreaClass is not a trait
29+
def area = this.radius * this.radius * math.Pi
30+
}
31+
32+
// Generic trait implementations
33+
34+
extension ListOps[T] for List[T] {
35+
type I = Int // error: `def` expected
36+
def second = this.tail.head
37+
}
38+
39+
// Specific trait implementations
40+
41+
extension ListOps2 for List[Int] { self => // error: `def` expected
42+
import java.lang._ // error: `def` expected
43+
def maxx = (0 /: this)(_ `max` _)
44+
}
45+
}

0 commit comments

Comments
 (0)