@@ -3590,14 +3590,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35903590
35913591 private def pushDownDeferredEvidenceParams (tpe : Type , params : List [untpd.ValDef ], span : Span )(using Context ): Type = tpe.dealias match {
35923592 case tpe : MethodType =>
3593- MethodType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3593+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
35943594 case tpe : PolyType =>
3595- PolyType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3595+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
35963596 case tpe : RefinedType =>
3597- // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3598- RefinedType (pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3597+ tpe.derivedRefinedType(
3598+ pushDownDeferredEvidenceParams(tpe.parent, params, span),
3599+ tpe.refinedName,
3600+ pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3601+ )
35993602 case tpe @ AppliedType (tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3600- AppliedType ( tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3603+ tpe.derivedAppliedType( tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
36013604 case tpe =>
36023605 val paramNames = params.map(_.name)
36033606 val paramTpts = params.map(_.tpt)
@@ -3606,18 +3609,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36063609 typed(ctxFunction).tpe
36073610 }
36083611
3609- private def addDownDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3612+ private def extractTopMethodTermParams (tpe : Type )(using Context ): (List [TermName ], List [Type ]) = tpe match {
3613+ case tpe : MethodType =>
3614+ tpe.paramNames -> tpe.paramInfos
3615+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3616+ extractTopMethodTermParams(tpe.refinedInfo)
3617+ case _ =>
3618+ Nil -> Nil
3619+ }
3620+
3621+ private def removeTopMethodTermParams (tpe : Type )(using Context ): Type = tpe match {
3622+ case tpe : MethodType =>
3623+ tpe.resultType
3624+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3625+ tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3626+ case tpe : AppliedType if defn.isFunctionType(tpe) =>
3627+ tpe.args.last
3628+ case _ =>
3629+ tpe
3630+ }
3631+
3632+ private def healToPolyFunctionType (tree : Tree )(using Context ): Tree = tree match {
3633+ case defdef : DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam ))) && defdef.paramss.size == 1 =>
3634+ val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3635+ val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3636+ val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef (name, TypeTree (tpe), flags = SyntheticTermParam ))
3637+ val newDefDef = cpy.DefDef (defdef)(paramss = defdef.paramss ++ List (newParams), tpt = untpd.TypeTree (newTpe))
3638+ val nestedCtx = ctx.fresh.setNewTyperState()
3639+ typed(newDefDef)(using nestedCtx)
3640+ case _ => tree
3641+ }
3642+
3643+ private def addDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
36103644 tree.getAttachment(desugar.PolyFunctionApply ) match
36113645 case Some (params) if params.nonEmpty =>
36123646 tree.removeAttachment(desugar.PolyFunctionApply )
36133647 val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
36143648 TypeTree (tpe).withSpan(tree.span) -> tpe
3649+ // case Some(params) if params.isEmpty =>
3650+ // println(s"tree: $tree")
3651+ // healToPolyFunctionType(tree) -> pt
36153652 case _ => tree -> pt
36163653 }
36173654
36183655 /** Interpolate and simplify the type of the given tree. */
36193656 protected def simplify (tree : Tree , pt : Type , locked : TypeVars )(using Context ): Tree =
3620- val (tree1, pt1) = addDownDeferredEvidenceParams (tree, pt)
3657+ val (tree1, pt1) = addDeferredEvidenceParams (tree, pt)
36213658 if ! tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
36223659 if ! tree1.tpe.widen.isInstanceOf [MethodOrPoly ] // wait with simplifying until method is fully applied
36233660 || tree1.isDef // ... unless tree is a definition
0 commit comments