Skip to content

Commit e8a8428

Browse files
rochalaWojciechMazur
authored andcommitted
Support completions for extension definition parameter (#18331)
Extension methods are extended into normal definitions. Because of that typed trees don't include any information about the extension method definition parameter: ```scala extension (x: In@@) ``` In order to add completions, we check if there is an exact path to the untyped tree, and if not, we fall back to it. There may also be more possible cases like that, but I can't think of any at the moment. [Cherry-picked de4ad2b]
1 parent 507eb90 commit e8a8428

File tree

9 files changed

+421
-209
lines changed

9 files changed

+421
-209
lines changed

compiler/src/dotty/tools/dotc/interactive/Completion.scala

+83-45
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package dotty.tools.dotc.interactive
22

3-
import scala.language.unsafeNulls
4-
53
import dotty.tools.dotc.ast.untpd
4+
import dotty.tools.dotc.ast.NavigateAST
65
import dotty.tools.dotc.config.Printers.interactiv
76
import dotty.tools.dotc.core.Contexts._
87
import dotty.tools.dotc.core.Decorators._
@@ -25,6 +24,10 @@ import dotty.tools.dotc.util.SourcePosition
2524

2625
import scala.collection.mutable
2726
import scala.util.control.NonFatal
27+
import dotty.tools.dotc.core.ContextOps.localContext
28+
import dotty.tools.dotc.core.Names
29+
import dotty.tools.dotc.core.Types
30+
import dotty.tools.dotc.core.Symbols
2831

2932
/**
3033
* One of the results of a completion query.
@@ -37,18 +40,17 @@ import scala.util.control.NonFatal
3740
*/
3841
case class Completion(label: String, description: String, symbols: List[Symbol])
3942

40-
object Completion {
43+
object Completion:
4144

4245
import dotty.tools.dotc.ast.tpd._
4346

4447
/** Get possible completions from tree at `pos`
4548
*
4649
* @return offset and list of symbols for possible completions
4750
*/
48-
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) = {
49-
val path = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
51+
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) =
52+
val path: List[Tree] = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
5053
computeCompletions(pos, path)(using Interactive.contextOfPath(path).withPhase(Phases.typerPhase))
51-
}
5254

5355
/**
5456
* Inspect `path` to determine what kinds of symbols should be considered.
@@ -60,10 +62,11 @@ object Completion {
6062
*
6163
* Otherwise, provide no completion suggestion.
6264
*/
63-
def completionMode(path: List[Tree], pos: SourcePosition): Mode =
64-
path match {
65-
case Ident(_) :: Import(_, _) :: _ => Mode.ImportOrExport
66-
case (ref: RefTree) :: _ =>
65+
def completionMode(path: List[untpd.Tree], pos: SourcePosition): Mode =
66+
path match
67+
case untpd.Ident(_) :: untpd.Import(_, _) :: _ => Mode.ImportOrExport
68+
case untpd.Ident(_) :: (_: untpd.ImportSelector) :: _ => Mode.ImportOrExport
69+
case (ref: untpd.RefTree) :: _ =>
6770
if (ref.name.isTermName) Mode.Term
6871
else if (ref.name.isTypeName) Mode.Type
6972
else Mode.None
@@ -72,9 +75,8 @@ object Completion {
7275
if sel.imported.span.contains(pos.span) then Mode.ImportOrExport
7376
else Mode.None // Can't help completing the renaming
7477

75-
case (_: ImportOrExport) :: _ => Mode.ImportOrExport
78+
case (_: untpd.ImportOrExport) :: _ => Mode.ImportOrExport
7679
case _ => Mode.None
77-
}
7880

7981
/** When dealing with <errors> in varios palces we check to see if they are
8082
* due to incomplete backticks. If so, we ensure we get the full prefix
@@ -101,10 +103,13 @@ object Completion {
101103
case (sel: untpd.ImportSelector) :: _ =>
102104
completionPrefix(sel.imported :: Nil, pos)
103105

106+
case untpd.Ident(_) :: (sel: untpd.ImportSelector) :: _ if !sel.isGiven =>
107+
completionPrefix(sel.imported :: Nil, pos)
108+
104109
case (tree: untpd.ImportOrExport) :: _ =>
105-
tree.selectors.find(_.span.contains(pos.span)).map { selector =>
110+
tree.selectors.find(_.span.contains(pos.span)).map: selector =>
106111
completionPrefix(selector :: Nil, pos)
107-
}.getOrElse("")
112+
.getOrElse("")
108113

109114
// Foo.`se<TAB> will result in Select(Ident(Foo), <error>)
110115
case (select: untpd.Select) :: _ if select.name == nme.ERROR =>
@@ -118,27 +123,65 @@ object Completion {
118123
if (ref.name == nme.ERROR) ""
119124
else ref.name.toString.take(pos.span.point - ref.span.point)
120125

121-
case _ =>
122-
""
126+
case _ => ""
127+
123128
end completionPrefix
124129

125130
/** Inspect `path` to determine the offset where the completion result should be inserted. */
126-
def completionOffset(path: List[Tree]): Int =
127-
path match {
128-
case (ref: RefTree) :: _ => ref.span.point
131+
def completionOffset(untpdPath: List[untpd.Tree]): Int =
132+
untpdPath match {
133+
case (ref: untpd.RefTree) :: _ => ref.span.point
129134
case _ => 0
130135
}
131136

132-
private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
133-
val mode = completionMode(path, pos)
134-
val rawPrefix = completionPrefix(path, pos)
137+
/** Some information about the trees is lost after Typer such as Extension method construct
138+
* is expanded into methods. In order to support completions in those cases
139+
* we have to rely on untyped trees and only when types are necessary use typed trees.
140+
*/
141+
def resolveTypedOrUntypedPath(tpdPath: List[Tree], pos: SourcePosition)(using Context): List[untpd.Tree] =
142+
lazy val untpdPath: List[untpd.Tree] = NavigateAST
143+
.pathTo(pos.span, List(ctx.compilationUnit.untpdTree), true).collect:
144+
case untpdTree: untpd.Tree => untpdTree
145+
146+
tpdPath match
147+
case (_: Bind) :: _ => tpdPath
148+
case (_: untpd.TypTree) :: _ => tpdPath
149+
case _ => untpdPath
150+
151+
/** Handle case when cursor position is inside extension method construct.
152+
* The extension method construct is then desugared into methods, and consturct parameters
153+
* are no longer a part of a typed tree, but instead are prepended to method parameters.
154+
*
155+
* @param untpdPath The typed or untyped path to the tree that is being completed
156+
* @param tpdPath The typed path that will be returned if no extension method construct is found
157+
* @param pos The cursor position
158+
*
159+
* @return Typed path to the parameter of the extension construct if found or tpdPath
160+
*/
161+
private def typeCheckExtensionConstructPath(
162+
untpdPath: List[untpd.Tree], tpdPath: List[Tree], pos: SourcePosition
163+
)(using Context): List[Tree] =
164+
untpdPath.collectFirst:
165+
case untpd.ExtMethods(paramss, _) =>
166+
val enclosingParam = paramss.flatten.find(_.span.contains(pos.span))
167+
enclosingParam.map: param =>
168+
ctx.typer.index(paramss.flatten)
169+
val typedEnclosingParam = ctx.typer.typed(param)
170+
Interactive.pathTo(typedEnclosingParam, pos.span)
171+
.flatten.getOrElse(tpdPath)
172+
173+
private def computeCompletions(pos: SourcePosition, tpdPath: List[Tree])(using Context): (Int, List[Completion]) =
174+
val path0 = resolveTypedOrUntypedPath(tpdPath, pos)
175+
val mode = completionMode(path0, pos)
176+
val rawPrefix = completionPrefix(path0, pos)
135177

136178
val hasBackTick = rawPrefix.headOption.contains('`')
137179
val prefix = if hasBackTick then rawPrefix.drop(1) else rawPrefix
138180

139181
val completer = new Completer(mode, prefix, pos)
140182

141-
val completions = path match {
183+
val adjustedPath = typeCheckExtensionConstructPath(path0, tpdPath, pos)
184+
val completions = adjustedPath match
142185
// Ignore synthetic select from `This` because in code it was `Ident`
143186
// See example in dotty.tools.languageserver.CompletionTest.syntheticThis
144187
case Select(qual @ This(_), _) :: _ if qual.span.isSynthetic => completer.scopeCompletions
@@ -147,21 +190,19 @@ object Completion {
147190
case (tree: ImportOrExport) :: _ => completer.directMemberCompletions(tree.expr)
148191
case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => completer.directMemberCompletions(expr)
149192
case _ => completer.scopeCompletions
150-
}
151193

152194
val describedCompletions = describeCompletions(completions)
153195
val backtickedCompletions =
154196
describedCompletions.map(completion => backtickCompletions(completion, hasBackTick))
155197

156-
val offset = completionOffset(path)
198+
val offset = completionOffset(path0)
157199

158200
interactiv.println(i"""completion with pos = $pos,
159201
| prefix = ${completer.prefix},
160202
| term = ${completer.mode.is(Mode.Term)},
161203
| type = ${completer.mode.is(Mode.Type)}
162204
| results = $backtickedCompletions%, %""")
163205
(offset, backtickedCompletions)
164-
}
165206

166207
def backtickCompletions(completion: Completion, hasBackTick: Boolean) =
167208
if hasBackTick || needsBacktick(completion.label) then
@@ -174,17 +215,17 @@ object Completion {
174215
// https://github.com/scalameta/metals/blob/main/mtags/src/main/scala/scala/meta/internal/mtags/KeywordWrapper.scala
175216
// https://github.com/com-lihaoyi/Ammonite/blob/73a874173cd337f953a3edc9fb8cb96556638fdd/amm/util/src/main/scala/ammonite/util/Model.scala
176217
private def needsBacktick(s: String) =
177-
val chunks = s.split("_", -1)
218+
val chunks = s.split("_", -1).nn
178219

179220
val validChunks = chunks.zipWithIndex.forall { case (chunk, index) =>
180-
chunk.forall(Chars.isIdentifierPart) ||
181-
(chunk.forall(Chars.isOperatorPart) &&
221+
chunk.nn.forall(Chars.isIdentifierPart) ||
222+
(chunk.nn.forall(Chars.isOperatorPart) &&
182223
index == chunks.length - 1 &&
183224
!(chunks.lift(index - 1).contains("") && index - 1 == 0))
184225
}
185226

186227
val validStart =
187-
Chars.isIdentifierStart(s(0)) || chunks(0).forall(Chars.isOperatorPart)
228+
Chars.isIdentifierStart(s(0)) || chunks(0).nn.forall(Chars.isOperatorPart)
188229

189230
val valid = validChunks && validStart && !keywords.contains(s)
190231

@@ -216,7 +257,7 @@ object Completion {
216257
* For the results of all `xyzCompletions` methods term names and type names are always treated as different keys in the same map
217258
* and they never conflict with each other.
218259
*/
219-
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition) {
260+
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition):
220261
/** Completions for terms and types that are currently in scope:
221262
* the members of the current class, local definitions and the symbols that have been imported,
222263
* recursively adding completions from outer scopes.
@@ -230,7 +271,7 @@ object Completion {
230271
* (even if the import follows it syntactically)
231272
* - a more deeply nested import shadowing a member or a local definition causes an ambiguity
232273
*/
233-
def scopeCompletions(using context: Context): CompletionMap = {
274+
def scopeCompletions(using context: Context): CompletionMap =
234275
val mappings = collection.mutable.Map.empty[Name, List[ScopedDenotations]].withDefaultValue(List.empty)
235276
def addMapping(name: Name, denots: ScopedDenotations) =
236277
mappings(name) = mappings(name) :+ denots
@@ -302,7 +343,7 @@ object Completion {
302343
}
303344

304345
resultMappings
305-
}
346+
end scopeCompletions
306347

307348
/** Widen only those types which are applied or are exactly nothing
308349
*/
@@ -335,16 +376,16 @@ object Completion {
335376
/** Completions introduced by imports directly in this context.
336377
* Completions from outer contexts are not included.
337378
*/
338-
private def importedCompletions(using Context): CompletionMap = {
379+
private def importedCompletions(using Context): CompletionMap =
339380
val imp = ctx.importInfo
340381

341-
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
342-
imp.site.member(name).alternatives
343-
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }
344-
345382
if imp == null then
346383
Map.empty
347384
else
385+
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
386+
imp.site.member(name).alternatives
387+
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }
388+
348389
val givenImports = imp.importedImplicits
349390
.map { ref => (ref.implicitName: Name, ref.underlyingRef.denot.asSingleDenotation) }
350391
.filter((name, denot) => include(denot, name))
@@ -370,7 +411,7 @@ object Completion {
370411
}.toSeq.groupByName
371412

372413
givenImports ++ wildcardMembers ++ explicitMembers
373-
}
414+
end importedCompletions
374415

375416
/** Completions from implicit conversions including old style extensions using implicit classes */
376417
private def implicitConversionMemberCompletions(qual: Tree)(using Context): CompletionMap =
@@ -532,7 +573,6 @@ object Completion {
532573
extension [N <: Name](namedDenotations: Seq[(N, SingleDenotation)])
533574
@annotation.targetName("groupByNameTupled")
534575
def groupByName: CompletionMap = namedDenotations.groupMap((name, denot) => name)((name, denot) => denot)
535-
}
536576

537577
private type CompletionMap = Map[Name, Seq[SingleDenotation]]
538578

@@ -545,11 +585,11 @@ object Completion {
545585
* The completion mode: defines what kinds of symbols should be included in the completion
546586
* results.
547587
*/
548-
class Mode(val bits: Int) extends AnyVal {
588+
class Mode(val bits: Int) extends AnyVal:
549589
def is(other: Mode): Boolean = (bits & other.bits) == other.bits
550590
def |(other: Mode): Mode = new Mode(bits | other.bits)
551-
}
552-
object Mode {
591+
592+
object Mode:
553593
/** No symbol should be included */
554594
val None: Mode = new Mode(0)
555595

@@ -561,6 +601,4 @@ object Completion {
561601

562602
/** Both term and type symbols are allowed */
563603
val ImportOrExport: Mode = new Mode(4) | Term | Type
564-
}
565-
}
566604

compiler/src/dotty/tools/repl/ReplCompiler.scala

+24-11
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
9393
end compile
9494

9595
final def typeOf(expr: String)(using state: State): Result[String] =
96-
typeCheck(expr).map { tree =>
96+
typeCheck(expr).map { (_, tpdTree) =>
9797
given Context = state.context
98-
tree.rhs match {
98+
tpdTree.rhs match {
9999
case Block(xs, _) => xs.last.tpe.widen.show
100100
case _ =>
101101
"""Couldn't compute the type of your expression, so sorry :(
@@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
129129
Iterator(sym) ++ sym.allOverriddenSymbols
130130
}
131131

132-
typeCheck(expr).map {
132+
typeCheck(expr).map { (_, tpdTree) => tpdTree match
133133
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
134134
val stat = stats.last.asInstanceOf[tpd.Tree]
135135
if (stat.tpe.isError) stat.tpe.show
@@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
152152
}
153153
}
154154

155-
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
155+
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {
156156

157157
def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
158158
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
@@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
181181
}
182182
}
183183

184-
def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
185-
def error: Result[tpd.ValDef] =
186-
List(new Diagnostic.Error(s"Invalid scala expression",
187-
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
184+
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
185+
List(new Diagnostic.Error(s"Invalid scala expression",
186+
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
188187

188+
def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
189189
import tpd._
190190
tree match {
191191
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
192192
tmpl.body
193193
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
194-
.getOrElse(error)
194+
.getOrElse(error[tpd.ValDef](sourceFile0))
195195
case _ =>
196-
error
196+
error[tpd.ValDef](sourceFile0)
197197
}
198198
}
199199

200+
def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
201+
import untpd._
202+
tree match {
203+
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
204+
tmpl.body
205+
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
206+
.getOrElse(error[untpd.ValDef](sourceFile0))
207+
case _ =>
208+
error[untpd.ValDef](sourceFile0)
209+
}
200210

201211
val src = SourceFile.virtual("<typecheck>", expr)
202212
inContext(state.context.fresh
@@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
209219
ctx.run.nn.compileUnits(unit :: Nil, ctx)
210220

211221
if (errorsAllowed || !ctx.reporter.hasErrors)
212-
unwrapped(unit.tpdTree, src)
222+
for
223+
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
224+
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
225+
yield untpdTree -> tpdTree
213226
else
214227
ctx.reporter.removeBufferedMessages.errors
215228
}

compiler/src/dotty/tools/repl/ReplDriver.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
251251
given state: State = newRun(state0)
252252
compiler
253253
.typeCheck(expr, errorsAllowed = true)
254-
.map { tree =>
254+
.map { (untpdTree, tpdTree) =>
255255
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
256256
val unit = CompilationUnit(file)(using state.context)
257-
unit.tpdTree = tree
257+
unit.untpdTree = untpdTree
258+
unit.tpdTree = tpdTree
258259
given Context = state.context.fresh.setCompilationUnit(unit)
259260
val srcPos = SourcePosition(file, Span(cursor))
260261
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil

0 commit comments

Comments
 (0)