@@ -16,14 +16,17 @@ import dotty.tools.dotc.core.Symbols
16
16
import dotty .tools .dotc .core .Symbols .Symbol
17
17
import dotty .tools .dotc .core .Types .AndType
18
18
import dotty .tools .dotc .core .Types .AppliedType
19
+ import dotty .tools .dotc .core .Types .MethodType
19
20
import dotty .tools .dotc .core .Types .OrType
21
+ import dotty .tools .dotc .core .Types .RefinedType
20
22
import dotty .tools .dotc .core .Types .TermRef
21
23
import dotty .tools .dotc .core .Types .Type
22
24
import dotty .tools .dotc .core .Types .TypeBounds
23
25
import dotty .tools .dotc .core .Types .WildcardType
24
26
import dotty .tools .dotc .util .SourcePosition
25
27
import dotty .tools .pc .IndexedContext
26
28
import dotty .tools .pc .utils .MtagsEnrichments .*
29
+ import scala .annotation .tailrec
27
30
28
31
object NamedArgCompletions :
29
32
@@ -195,9 +198,40 @@ object NamedArgCompletions:
195
198
// def curry(x: Int)(apple: String, banana: String) = ???
196
199
// curry(1)(apple = "test", b@@)
197
200
// ```
198
- val (baseParams , baseArgs) =
201
+ val (baseParams0 , baseArgs) =
199
202
vparamss.zip(argss).lastOption.getOrElse((Nil , Nil ))
200
203
204
+ val baseParams : List [ParamSymbol ] =
205
+ def defaultBaseParams = baseParams0.map(JustSymbol (_))
206
+ @ tailrec
207
+ def getRefinedParams (refinedType : Type , level : Int ): List [ParamSymbol ] =
208
+ if level > 0 then
209
+ val resultTypeOpt =
210
+ refinedType match
211
+ case RefinedType (AppliedType (_, args), _, _) => args.lastOption
212
+ case AppliedType (_, args) => args.lastOption
213
+ case _ => None
214
+ resultTypeOpt match
215
+ case Some (resultType) => getRefinedParams(resultType, level - 1 )
216
+ case _ => defaultBaseParams
217
+ else
218
+ refinedType match
219
+ case RefinedType (AppliedType (_, args), _, MethodType (ri)) =>
220
+ baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
221
+ RefinedSymbol (sym, name, arg)
222
+ }
223
+ case _ => defaultBaseParams
224
+ // finds param refinements for lambda expressions
225
+ // val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
226
+ @ tailrec
227
+ def refineParams (method : Tree , level : Int ): List [ParamSymbol ] =
228
+ method match
229
+ case Select (Apply (f, _), _) => refineParams(f, level + 1 )
230
+ case Select (h, v) => getRefinedParams(h.symbol.info, level)
231
+ case _ => defaultBaseParams
232
+ refineParams(method, 0 )
233
+ end baseParams
234
+
201
235
val args = ident
202
236
.map(i => baseArgs.filterNot(_ == i))
203
237
.getOrElse(baseArgs)
@@ -221,7 +255,7 @@ object NamedArgCompletions:
221
255
222
256
baseParams.filterNot(param =>
223
257
isNamed(param.name) ||
224
- param.denot.is(
258
+ param.symbol. denot.is(
225
259
Flags .Synthetic
226
260
) // filter out synthesized param, like evidence
227
261
)
@@ -232,7 +266,7 @@ object NamedArgCompletions:
232
266
.map(_.name.toString)
233
267
.getOrElse(" " )
234
268
.replace(Cursor .value, " " )
235
- val params : List [Symbol ] =
269
+ val params : List [ParamSymbol ] =
236
270
allParams
237
271
.filter(param => param.name.startsWith(prefix))
238
272
.distinctBy(sym => (sym.name, sym.info))
@@ -249,7 +283,7 @@ object NamedArgCompletions:
249
283
.filter(name => name != " Nil" && name != " None" )
250
284
.sorted
251
285
252
- def findDefaultValue (param : Symbol ): String =
286
+ def findDefaultValue (param : ParamSymbol ): String =
253
287
val matchingType = matchingTypesInScope(param.info)
254
288
if matchingType.size == 1 then s " : ${matchingType.head}"
255
289
else if matchingType.size > 1 then s " |???, ${matchingType.mkString(" ," )}| "
@@ -260,12 +294,12 @@ object NamedArgCompletions:
260
294
def shouldShow =
261
295
allParams.exists(param => param.name.startsWith(prefix))
262
296
def isExplicitlyCalled = suffix.startsWith(prefix)
263
- def hasParamsToFill = allParams.count(! _.is(Flags .HasDefault )) > 1
297
+ def hasParamsToFill = allParams.count(! _.symbol. is(Flags .HasDefault )) > 1
264
298
if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill
265
299
then
266
300
val editText = allParams.zipWithIndex
267
301
.collect {
268
- case (param, index) if ! param.is(Flags .HasDefault ) =>
302
+ case (param, index) if ! param.symbol. is(Flags .HasDefault ) =>
269
303
s " ${param.nameBackticked.replace(" $" , " $$" )} = $$ { ${index + 1 }${findDefaultValue(param)}} "
270
304
}
271
305
.mkString(" , " )
@@ -355,3 +389,16 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context):
355
389
case _ => t
356
390
357
391
end FuzzyArgMatcher
392
+
393
+ sealed trait ParamSymbol :
394
+ def name : Name
395
+ def info : Type
396
+ def symbol : Symbol
397
+ def nameBackticked (using Context ) = name.decoded.backticked
398
+
399
+ case class JustSymbol (symbol : Symbol )(using Context ) extends ParamSymbol :
400
+ def name : Name = symbol.name
401
+ def info : Type = symbol.info
402
+
403
+ case class RefinedSymbol (symbol : Symbol , name : Name , info : Type )
404
+ extends ParamSymbol
0 commit comments