Skip to content

Commit 8915fd7

Browse files
kasiaMarekKordyjan
authored andcommitted
bugfix: suggest correct arg name completions for lambda expressions
[Cherry-picked 77dcdb7]
1 parent 82f8d88 commit 8915fd7

File tree

3 files changed

+107
-8
lines changed

3 files changed

+107
-8
lines changed

presentation-compiler/src/main/dotty/tools/pc/completions/CompletionValue.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,10 @@ object CompletionValue:
247247
description
248248
override def insertMode: Option[InsertTextMode] = Some(InsertTextMode.AsIs)
249249

250-
def namedArg(label: String, sym: Symbol)(using
250+
def namedArg(label: String, sym: ParamSymbol)(using
251251
Context
252252
): CompletionValue =
253-
NamedArg(label, sym.info.widenTermRefExpr, sym)
253+
NamedArg(label, sym.info.widenTermRefExpr, sym.symbol)
254254

255255
def keyword(label: String, insertText: String): CompletionValue =
256256
Keyword(label, Some(insertText))

presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala

+53-6
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@ import dotty.tools.dotc.core.Symbols
1616
import dotty.tools.dotc.core.Symbols.Symbol
1717
import dotty.tools.dotc.core.Types.AndType
1818
import dotty.tools.dotc.core.Types.AppliedType
19+
import dotty.tools.dotc.core.Types.MethodType
1920
import dotty.tools.dotc.core.Types.OrType
21+
import dotty.tools.dotc.core.Types.RefinedType
2022
import dotty.tools.dotc.core.Types.TermRef
2123
import dotty.tools.dotc.core.Types.Type
2224
import dotty.tools.dotc.core.Types.TypeBounds
2325
import dotty.tools.dotc.core.Types.WildcardType
2426
import dotty.tools.dotc.util.SourcePosition
2527
import dotty.tools.pc.IndexedContext
2628
import dotty.tools.pc.utils.MtagsEnrichments.*
29+
import scala.annotation.tailrec
2730

2831
object NamedArgCompletions:
2932

@@ -195,9 +198,40 @@ object NamedArgCompletions:
195198
// def curry(x: Int)(apple: String, banana: String) = ???
196199
// curry(1)(apple = "test", b@@)
197200
// ```
198-
val (baseParams, baseArgs) =
201+
val (baseParams0, baseArgs) =
199202
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))
200203

204+
val baseParams: List[ParamSymbol] =
205+
def defaultBaseParams = baseParams0.map(JustSymbol(_))
206+
@tailrec
207+
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
208+
if level > 0 then
209+
val resultTypeOpt =
210+
refinedType match
211+
case RefinedType(AppliedType(_, args), _, _) => args.lastOption
212+
case AppliedType(_, args) => args.lastOption
213+
case _ => None
214+
resultTypeOpt match
215+
case Some(resultType) => getRefinedParams(resultType, level - 1)
216+
case _ => defaultBaseParams
217+
else
218+
refinedType match
219+
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
220+
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
221+
RefinedSymbol(sym, name, arg)
222+
}
223+
case _ => defaultBaseParams
224+
// finds param refinements for lambda expressions
225+
// val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
226+
@tailrec
227+
def refineParams(method: Tree, level: Int): List[ParamSymbol] =
228+
method match
229+
case Select(Apply(f, _), _) => refineParams(f, level + 1)
230+
case Select(h, v) => getRefinedParams(h.symbol.info, level)
231+
case _ => defaultBaseParams
232+
refineParams(method, 0)
233+
end baseParams
234+
201235
val args = ident
202236
.map(i => baseArgs.filterNot(_ == i))
203237
.getOrElse(baseArgs)
@@ -221,7 +255,7 @@ object NamedArgCompletions:
221255

222256
baseParams.filterNot(param =>
223257
isNamed(param.name) ||
224-
param.denot.is(
258+
param.symbol.denot.is(
225259
Flags.Synthetic
226260
) // filter out synthesized param, like evidence
227261
)
@@ -232,7 +266,7 @@ object NamedArgCompletions:
232266
.map(_.name.toString)
233267
.getOrElse("")
234268
.replace(Cursor.value, "")
235-
val params: List[Symbol] =
269+
val params: List[ParamSymbol] =
236270
allParams
237271
.filter(param => param.name.startsWith(prefix))
238272
.distinctBy(sym => (sym.name, sym.info))
@@ -249,7 +283,7 @@ object NamedArgCompletions:
249283
.filter(name => name != "Nil" && name != "None")
250284
.sorted
251285

252-
def findDefaultValue(param: Symbol): String =
286+
def findDefaultValue(param: ParamSymbol): String =
253287
val matchingType = matchingTypesInScope(param.info)
254288
if matchingType.size == 1 then s":${matchingType.head}"
255289
else if matchingType.size > 1 then s"|???,${matchingType.mkString(",")}|"
@@ -260,12 +294,12 @@ object NamedArgCompletions:
260294
def shouldShow =
261295
allParams.exists(param => param.name.startsWith(prefix))
262296
def isExplicitlyCalled = suffix.startsWith(prefix)
263-
def hasParamsToFill = allParams.count(!_.is(Flags.HasDefault)) > 1
297+
def hasParamsToFill = allParams.count(!_.symbol.is(Flags.HasDefault)) > 1
264298
if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill
265299
then
266300
val editText = allParams.zipWithIndex
267301
.collect {
268-
case (param, index) if !param.is(Flags.HasDefault) =>
302+
case (param, index) if !param.symbol.is(Flags.HasDefault) =>
269303
s"${param.nameBackticked.replace("$", "$$")} = $${${index + 1}${findDefaultValue(param)}}"
270304
}
271305
.mkString(", ")
@@ -355,3 +389,16 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context):
355389
case _ => t
356390

357391
end FuzzyArgMatcher
392+
393+
sealed trait ParamSymbol:
394+
def name: Name
395+
def info: Type
396+
def symbol: Symbol
397+
def nameBackticked(using Context) = name.decoded.backticked
398+
399+
case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol:
400+
def name: Name = symbol.name
401+
def info: Type = symbol.info
402+
403+
case class RefinedSymbol(symbol: Symbol, name: Name, info: Type)
404+
extends ParamSymbol

presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala

+52
Original file line numberDiff line numberDiff line change
@@ -881,3 +881,55 @@ class CompletionArgSuite extends BaseCompletionSuite:
881881
|""".stripMargin,
882882
topLines = Some(1),
883883
)
884+
885+
886+
@Test def `lambda` =
887+
check(
888+
"""|val hello: (x: Int) => Unit = x => println(x)
889+
|val k = hello(@@)
890+
|""".stripMargin,
891+
"""|x = : Int
892+
|""".stripMargin,
893+
topLines = Some(1),
894+
)
895+
896+
@Test def `lambda2` =
897+
check(
898+
"""|object O:
899+
| val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
900+
|val k = O.hello(x = 1, @@)
901+
|""".stripMargin,
902+
"""|y = : Int
903+
|""".stripMargin,
904+
topLines = Some(1),
905+
)
906+
907+
@Test def `lambda3` =
908+
check(
909+
"""|val hello: (x: Int) => (j: Int) => Unit = x => j => println(x)
910+
|val k = hello(@@)
911+
|""".stripMargin,
912+
"""|x = : Int
913+
|""".stripMargin,
914+
topLines = Some(1),
915+
)
916+
917+
@Test def `lambda4` =
918+
check(
919+
"""|val hello: (x: Int) => (j: Int) => (str: String) => Unit = x => j => str => println(str)
920+
|val k = hello(x = 1)(2)(@@)
921+
|""".stripMargin,
922+
"""|str = : String
923+
|""".stripMargin,
924+
topLines = Some(1),
925+
)
926+
927+
@Test def `lambda5` =
928+
check(
929+
"""|val hello: (x: Int) => Int => (str: String) => Unit = x => j => str => println(str)
930+
|val k = hello(x = 1)(2)(@@)
931+
|""".stripMargin,
932+
"""|str = : String
933+
|""".stripMargin,
934+
topLines = Some(1),
935+
)

0 commit comments

Comments
 (0)