Skip to content

Fix #9562: Fix handling of identifiers in extension methods #9576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ object desugar {
name = mdef.name.toExtensionName,
tparams = ext.tparams ++ mdef.tparams,
vparamss = mdef.vparamss match
case vparams1 :: vparamss1 if !isLeftAssoc(mdef.name) =>
case vparams1 :: vparamss1 if mdef.name.isRightAssocOperatorName =>
vparams1 :: ext.vparamss ::: vparamss1
case _ =>
ext.vparamss ++ mdef.vparamss
Expand Down Expand Up @@ -1204,10 +1204,10 @@ object desugar {
case _ =>
Apply(sel, arg :: Nil)

if (isLeftAssoc(op.name))
makeOp(left, right, Span(left.span.start, op.span.end, op.span.start))
else
if op.name.isRightAssocOperatorName then
makeOp(right, left, Span(op.span.start, right.span.end))
else
makeOp(left, right, Span(left.span.start, op.span.end, op.span.start))
}

/** Translate tuple expressions of arity <= 22
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Positioned.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import util.Spans._
import util.{SourceFile, NoSource, SourcePosition, SrcPos}
import core.Contexts._
import core.Decorators._
import core.NameOps._
import core.Flags.{JavaDefined, ExtensionMethod}
import core.StdNames.nme
import ast.Trees.mods
Expand Down Expand Up @@ -208,7 +209,7 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
check(tree.vparamss)
case tree: DefDef if tree.mods.is(ExtensionMethod) =>
tree.vparamss match {
case vparams1 :: vparams2 :: rest if !isLeftAssoc(tree.name) =>
case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName =>
check(tree.tparams)
check(vparams2)
check(vparams1)
Expand Down
3 changes: 0 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
case _ => false
}

/** Is name a left-associative operator? */
def isLeftAssoc(operator: Name): Boolean = !operator.isEmpty && (operator.toSimpleName.last != ':')

/** Is this argument node of the form <expr> : _*, or is it a reference to
* such an argument ? The latter case can happen when an argument is lifted.
*/
Expand Down
26 changes: 13 additions & 13 deletions compiler/src/dotty/tools/dotc/core/ContextOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,48 @@ object ContextOps:
def enter(sym: Symbol): Symbol = inContext(ctx) {
ctx.owner match
case cls: ClassSymbol => cls.classDenot.enter(sym)
case _ => scope.openForMutations.enter(sym)
case _ => ctx.scope.openForMutations.enter(sym)
sym
}

/** The denotation with the given `name` and all `required` flags in current context
*/
def denotNamed(name: Name, required: FlagSet = EmptyFlags): Denotation = inContext(ctx) {
if (owner.isClass)
if (outer.owner == owner) { // inner class scope; check whether we are referring to self
if (scope.size == 1) {
val elem = scope.lastEntry
if (ctx.owner.isClass)
if (ctx.outer.owner == ctx.owner) { // inner class scope; check whether we are referring to self
if (ctx.scope.size == 1) {
val elem = ctx.scope.lastEntry
if (elem.name == name) return elem.sym.denot // return self
}
val pre = owner.thisType
val pre = ctx.owner.thisType
pre.findMember(name, pre, required, EmptyFlags)
}
else // we are in the outermost context belonging to a class; self is invisible here. See inClassContext.
owner.findMember(name, owner.thisType, required, EmptyFlags)
ctx.owner.findMember(name, ctx.owner.thisType, required, EmptyFlags)
else
scope.denotsNamed(name).filterWithFlags(required, EmptyFlags).toDenot(NoPrefix)
ctx.scope.denotsNamed(name).filterWithFlags(required, EmptyFlags).toDenot(NoPrefix)
}

/** A fresh local context with given tree and owner.
* Owner might not exist (can happen for self valdefs), in which case
* no owner is set in result context
*/
def localContext(tree: untpd.Tree, owner: Symbol): FreshContext = inContext(ctx) {
val freshCtx = fresh.setTree(tree)
if (owner.exists) freshCtx.setOwner(owner) else freshCtx
val freshCtx = ctx.fresh.setTree(tree)
if owner.exists then freshCtx.setOwner(owner) else freshCtx
}

/** Context where `sym` is defined, assuming we are in a nested context. */
def defContext(sym: Symbol): Context = inContext(ctx) {
outersIterator
ctx.outersIterator
.dropWhile(_.owner != sym)
.dropWhile(_.owner == sym)
.next()
}

/** A new context for the interior of a class */
def inClassContext(selfInfo: TypeOrSymbol): Context = inContext(ctx) {
val localCtx: Context = fresh.setNewScope
val localCtx: Context = ctx.fresh.setNewScope
selfInfo match {
case sym: Symbol if sym.exists && sym.name != nme.WILDCARD => localCtx.scope.openForMutations.enter(sym)
case _ =>
Expand All @@ -69,7 +69,7 @@ object ContextOps:
}

def packageContext(tree: untpd.PackageDef, pkg: Symbol): Context = inContext(ctx) {
if (pkg.is(Package)) fresh.setOwner(pkg.moduleClass).setTree(tree)
if (pkg.is(Package)) ctx.fresh.setOwner(pkg.moduleClass).setTree(tree)
else ctx
}
end ContextOps
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ object NameOps {
def isAnonymousClassName: Boolean = name.startsWith(str.ANON_CLASS)
def isAnonymousFunctionName: Boolean = name.startsWith(str.ANON_FUN)
def isUnapplyName: Boolean = name == nme.unapply || name == nme.unapplySeq
def isRightAssocOperatorName: Boolean = name.lastPart.last == ':'

def isOperatorName: Boolean = name match
case name: SimpleName => name.exists(isOperatorPart)
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ object SymDenotations {
else recurWithParamss(info, rawParamss)
end paramSymss

/** The extension parameter of this extension method
* @pre this symbol is an extension method
*/
final def extensionParam(using Context): Symbol =
def leadParam(paramss: List[List[Symbol]]): Symbol = paramss match
case (param :: _) :: paramss1 if param.isType => leadParam(paramss1)
case _ :: (snd :: Nil) :: _ if name.isRightAssocOperatorName => snd
case (fst :: Nil) :: _ => fst
case _ => NoSymbol
assert(isAllOf(ExtensionMethod))
leadParam(rawParamss)

/** The denotation is completed: info is not a lazy type and attributes have defined values */
final def isCompleted: Boolean = !myInfo.isInstanceOf[LazyType]

Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ object Parsers {
var opStack: List[OpInfo] = Nil

def checkAssoc(offset: Token, op1: Name, op2: Name, op2LeftAssoc: Boolean): Unit =
if (isLeftAssoc(op1) != op2LeftAssoc)
if (op1.isRightAssocOperatorName == op2LeftAssoc)
syntaxError(MixedLeftAndRightAssociativeOps(op1, op2, op2LeftAssoc), offset)

def reduceStack(base: List[OpInfo], top: Tree, prec: Int, leftAssoc: Boolean, op2: Name, isType: Boolean): Tree = {
Expand Down Expand Up @@ -967,7 +967,7 @@ object Parsers {
def recur(top: Tree): Tree =
if (isIdent && isOperator) {
val op = if (isType) typeIdent() else termIdent()
val top1 = reduceStack(base, top, precedence(op.name), isLeftAssoc(op.name), op.name, isType)
val top1 = reduceStack(base, top, precedence(op.name), !op.name.isRightAssocOperatorName, op.name, isType)
opStack = OpInfo(top1, op, in.offset) :: opStack
colonAtEOLOpt()
newLineOptWhenFollowing(canStartOperand)
Expand Down Expand Up @@ -3316,7 +3316,7 @@ object Parsers {
typeParamClause(ParamOwner.Def)
else leadingTparams
val vparamss = paramClauses() match
case rparams :: rparamss if leadingVparamss.nonEmpty && !isLeftAssoc(ident.name) =>
case rparams :: rparamss if leadingVparamss.nonEmpty && ident.name.isRightAssocOperatorName =>
rparams :: leadingVparamss ::: rparamss
case rparamss =>
leadingVparamss ::: rparamss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
val (prefix, vparamss) =
if isExtension then
val (leadingParams, otherParamss) = (tree.vparamss: @unchecked) match
case vparams1 :: vparams2 :: rest if !isLeftAssoc(tree.name) =>
case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName =>
(vparams2, vparams1 :: rest)
case vparams1 :: rest =>
(vparams1, rest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {

/** Is this an anonymous class deriving from an enum definition? */
extension (cls: ClassSymbol) private def isEnumValueImplementation(using Context): Boolean =
isAnonymousClass && classParents.head.typeSymbol.is(Enum) // asserted in Typer
cls.isAnonymousClass && cls.classParents.head.typeSymbol.is(Enum) // asserted in Typer

/** If this is the class backing a serializable singleton enum value with base class `MyEnum`,
* and not deriving from `java.lang.Enum` add the method:
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,10 @@ object Applications {
* I.e., if the expected type is a PolyProto, then `app` will be a `TypeApply(_, args)` where
* `args` are the type arguments of the expected type.
*/
class IntegratedTypeArgs(val app: Tree)(implicit @constructorOnly src: SourceFile) extends tpd.Tree {
class IntegratedTypeArgs(val app: Tree)(implicit @constructorOnly src: SourceFile) extends ProxyTree {
override def span = app.span

def forwardTo = app
def canEqual(that: Any): Boolean = app.canEqual(that)
def productArity: Int = app.productArity
def productElement(n: Int): Any = app.productElement(n)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ trait Checking {
def checkEnumParent(cls: Symbol, firstParent: Symbol)(using Context): Unit =

extension (sym: Symbol) def typeRefApplied(using Context): Type =
typeRef.appliedTo(typeParams.map(_.info.loBound))
sym.typeRef.appliedTo(sym.typeParams.map(_.info.loBound))

def ensureParentDerivesFrom(enumCase: Symbol)(using Context) =
val enumCls = enumCase.owner.linkedClass
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ class RefChecks extends MiniPhase { thisPhase =>
private def checkByNameRightAssociativeDef(tree: DefDef) {
tree match {
case DefDef(_, name, _, params :: _, _, _) =>
if (settings.lint && !treeInfo.isLeftAssoc(name.decodedName) && params.exists(p => isByName(p.symbol)))
if (settings.lint && name.decodedName.isRightAssocOperatorName && params.exists(p => isByName(p.symbol)))
unit.warning(tree.pos,
"by-name parameters will be evaluated eagerly when called as a right-associative infix operator. For more details, see SI-1980.")
case _ =>
Expand Down
42 changes: 17 additions & 25 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -447,25 +447,6 @@ class Typer extends Namer
if (name == nme.ROOTPKG)
return tree.withType(defn.RootPackage.termRef)

/** Convert a reference `f` to an extension method select `p.f`, where
* `p` is the closest enclosing extension parameter, or else `this`.
*/
def extensionMethodSelect: untpd.Tree =
val xmethod = ctx.owner.enclosingExtensionMethod
val qualifier =
if xmethod.exists then // TODO: see whether we can use paramss for that
val leadParamName = xmethod.info.paramNamess.head.head
def isLeadParam(sym: Symbol) =
sym.is(Param) && sym.owner.owner == xmethod.owner && sym.name == leadParamName
def leadParam(ctx: Context): Symbol =
ctx.scope.lookupAll(leadParamName).find(isLeadParam) match
case Some(param) => param
case None => leadParam(ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next)
untpd.ref(leadParam(ctx).termRef)
else
untpd.This(untpd.EmptyTypeIdent)
untpd.cpy.Select(tree)(qualifier, name)

val rawType = {
val saved1 = unimported
val saved2 = foundUnderScala2
Expand Down Expand Up @@ -516,7 +497,20 @@ class Typer extends Namer
else if name.toTermName == nme.ERROR then
setType(UnspecifiedErrorType)
else if name.isTermName then
tryEither(typed(extensionMethodSelect, pt))((_, _) => fail)
// Convert a reference `f` to an extension method select `p.f`, where
// `p` is the closest enclosing extension parameter, or else convert to `this.f`.
val xmethod = ctx.owner.enclosingExtensionMethod
val qualifier =
if xmethod.exists then untpd.ref(xmethod.extensionParam.termRef)
else untpd.This(untpd.EmptyTypeIdent)
val selection = untpd.cpy.Select(tree)(qualifier, name)
val result = tryEither(typed(selection, pt))((_, _) => fail)
def canAccessUnqualified(sym: Symbol) =
sym.is(ExtensionMethod) && (sym.extensionParam.span == xmethod.extensionParam.span)
if !xmethod.exists || result.tpe.isError || canAccessUnqualified(result.symbol) then
result
else
fail
else
fail
end typedIdent
Expand Down Expand Up @@ -2348,20 +2342,18 @@ class Typer extends Namer
typedUnApply(cpy.Apply(tree)(op, l :: r :: Nil), pt)
else {
val app = typedApply(desugar.binop(l, op, r), pt)
if (untpd.isLeftAssoc(op.name)) app
else {
if op.name.isRightAssocOperatorName then
val defs = new mutable.ListBuffer[Tree]
def lift(app: Tree): Tree = (app: @unchecked) match {
def lift(app: Tree): Tree = (app: @unchecked) match
case Apply(fn, args) =>
if (app.tpe.isError) app
else tpd.cpy.Apply(app)(fn, LiftImpure.liftArgs(defs, fn.tpe, args))
case Assign(lhs, rhs) =>
tpd.cpy.Assign(app)(lhs, lift(rhs))
case Block(stats, expr) =>
tpd.cpy.Block(app)(stats, lift(expr))
}
wrapDefs(defs, lift(app))
}
else app
}
checkValidInfix(tree, result.symbol)
result
Expand Down
14 changes: 8 additions & 6 deletions docs/docs/reference/contextual/extension-methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ extension [T: Numeric](x: T)
```

If an extension method has type parameters, they come immediately after `extension` and are followed by the extended parameter.
When calling a generic extension method, any explicitly given type arguments follow the method name.
When calling a generic extension method, any explicitly given type arguments follow the method name.
So the `second` method could be instantiated as follows:

```scala
Expand All @@ -92,8 +92,8 @@ extension [T](x: T)(using n: Numeric[T])
def + (y: T): T = n.plus(x, y)
```

**Note**: Type parameters have to be given after the `extension` keyword; they cannot be given after the `def`.
This restriction might be lifted in the future once we support multiple type parameter clauses in a method.
**Note**: Type parameters have to be given after the `extension` keyword; they cannot be given after the `def`.
This restriction might be lifted in the future once we support multiple type parameter clauses in a method.
By contrast, using clauses can be defined for the `extension` as well as per `def`.

### Collective Extensions
Expand Down Expand Up @@ -215,7 +215,7 @@ List(1, 2) < List(3)

The precise rules for resolving a selection to an extension method are as follows.

Assume a selection `e.m[Ts]` where `m` is not a member of `e`, where the type arguments `[Ts]` are optional, and where `T` is the expected type.
Assume a selection `e.m[Ts]` where `m` is not a member of `e`, where the type arguments `[Ts]` are optional, and where `T` is the expected type.
The following two rewritings are tried in order:

1. The selection is rewritten to `extension_m[Ts](e)`.
Expand All @@ -233,9 +233,11 @@ An extension method can also be used as an identifier by itself. If an identifie
resolve, the identifier is rewritten to:

- `x.m` if the identifier appears in an extension with parameter `x`
and the method `m` resolves to an extension method in
a (possibly collective) extension that also contains the call,
- `this.m` otherwise

and the rewritten term is again tried as an application of an extension method. Example:
and the rewritten term is again tried as an application of an extension method. In

```scala
extension (s: String)
Expand Down Expand Up @@ -264,7 +266,7 @@ def extension_position(s: String)(ch: Char, n: Int): Int =
extension (x: Double) def ** (exponent: Int): Double =
require(exponent >= 0)
if exponent == 0 then 1 else x * (x ** (exponent - 1))

import DoubleOps.{**, extension_**}
assert(2.0 ** 3 == extension_**(2.0)(3))
```
Expand Down
12 changes: 12 additions & 0 deletions tests/neg/i9562.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class Foo:
def foo = 23

object Unrelated:
extension (f: Foo)
def g = f.foo // OK

extension (f: Foo)
def h1: Int = foo // error
def h2: Int = h1 + 1 // OK
def h3: Int = g // error
def ++: (x: Int): Int = h1 + x // OK
11 changes: 11 additions & 0 deletions tests/pos/i9562.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class Foo:
def foo = 23

object Unrelated:
extension (f: Foo)
def g = f.foo // OK

extension (f: Foo)
def h1: Int = 0
def h2: Int = h1 + 1 // OK
def ++: (x: Int): Int = h2 + x // OK