@@ -1504,7 +1504,7 @@ object desugar {
1504
1504
.withSpan(original.span.withPoint(named.span.start))
1505
1505
1506
1506
/** Main desugaring method */
1507
- def apply (tree : Tree )(using Context ): Tree = {
1507
+ def apply (tree : Tree , pt : Type = NoType )(using Context ): Tree = {
1508
1508
1509
1509
/** Create tree for for-comprehension `<for (enums) do body>` or
1510
1510
* `<for (enums) yield body>` where mapName and flatMapName are chosen
@@ -1698,11 +1698,11 @@ object desugar {
1698
1698
}
1699
1699
}
1700
1700
1701
- def makePolyFunction (targs : List [Tree ], body : Tree ): Tree = body match {
1701
+ def makePolyFunction (targs : List [Tree ], body : Tree , pt : Type ): Tree = body match {
1702
1702
case Parens (body1) =>
1703
- makePolyFunction(targs, body1)
1703
+ makePolyFunction(targs, body1, pt )
1704
1704
case Block (Nil , body1) =>
1705
- makePolyFunction(targs, body1)
1705
+ makePolyFunction(targs, body1, pt )
1706
1706
case Function (vargs, res) =>
1707
1707
assert(targs.nonEmpty)
1708
1708
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
@@ -1726,12 +1726,26 @@ object desugar {
1726
1726
}
1727
1727
else {
1728
1728
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1729
- // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1729
+ // with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1730
+ // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1731
+ // where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1732
+
1733
+ def typeTree (tp : Type ) = tp match
1734
+ case RefinedType (parent, nme.apply, PolyType (_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1735
+ var bail = false
1736
+ def mapper (tp : Type , topLevel : Boolean = false ): Tree = tp match
1737
+ case tp : TypeRef => ref(tp)
1738
+ case tp : TypeParamRef => Ident (applyTParams(tp.paramNum).name)
1739
+ case AppliedType (tycon, args) => AppliedTypeTree (mapper(tycon), args.map(mapper(_)))
1740
+ case _ => if topLevel then TypeTree () else { bail = true ; genericEmptyTree }
1741
+ val mapped = mapper(mt.resultType, topLevel = true )
1742
+ if bail then TypeTree () else mapped
1743
+ case _ => TypeTree ()
1730
1744
1731
1745
val applyVParams = vargs.asInstanceOf [List [ValDef ]]
1732
1746
.map(varg => varg.withAddedFlags(mods.flags | Param ))
1733
1747
New (Template (emptyConstructor, List (polyFunctionTpt), Nil , EmptyValDef ,
1734
- List (DefDef (nme.apply, applyTParams :: applyVParams :: Nil , TypeTree ( ), res))
1748
+ List (DefDef (nme.apply, applyTParams :: applyVParams :: Nil , typeTree(pt ), res))
1735
1749
))
1736
1750
}
1737
1751
case _ =>
@@ -1753,7 +1767,7 @@ object desugar {
1753
1767
1754
1768
val desugared = tree match {
1755
1769
case PolyFunction (targs, body) =>
1756
- makePolyFunction(targs, body) orElse tree
1770
+ makePolyFunction(targs, body, pt ) orElse tree
1757
1771
case SymbolLit (str) =>
1758
1772
Apply (
1759
1773
ref(defn.ScalaSymbolClass .companionModule.termRef),
0 commit comments