Skip to content

Commit 1724d84

Browse files
authored
Merge pull request #15570 from dwijnand/gadt/poly
2 parents 9d07d52 + 34c2918 commit 1724d84

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+21-7
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ object desugar {
15041504
.withSpan(original.span.withPoint(named.span.start))
15051505

15061506
/** Main desugaring method */
1507-
def apply(tree: Tree)(using Context): Tree = {
1507+
def apply(tree: Tree, pt: Type = NoType)(using Context): Tree = {
15081508

15091509
/** Create tree for for-comprehension `<for (enums) do body>` or
15101510
* `<for (enums) yield body>` where mapName and flatMapName are chosen
@@ -1698,11 +1698,11 @@ object desugar {
16981698
}
16991699
}
17001700

1701-
def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match {
1701+
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
17021702
case Parens(body1) =>
1703-
makePolyFunction(targs, body1)
1703+
makePolyFunction(targs, body1, pt)
17041704
case Block(Nil, body1) =>
1705-
makePolyFunction(targs, body1)
1705+
makePolyFunction(targs, body1, pt)
17061706
case Function(vargs, res) =>
17071707
assert(targs.nonEmpty)
17081708
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
@@ -1726,12 +1726,26 @@ object desugar {
17261726
}
17271727
else {
17281728
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1729-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1729+
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1730+
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1731+
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1732+
1733+
def typeTree(tp: Type) = tp match
1734+
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1735+
var bail = false
1736+
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
1737+
case tp: TypeRef => ref(tp)
1738+
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
1739+
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
1740+
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
1741+
val mapped = mapper(mt.resultType, topLevel = true)
1742+
if bail then TypeTree() else mapped
1743+
case _ => TypeTree()
17301744

17311745
val applyVParams = vargs.asInstanceOf[List[ValDef]]
17321746
.map(varg => varg.withAddedFlags(mods.flags | Param))
17331747
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1734-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res))
1748+
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
17351749
))
17361750
}
17371751
case _ =>
@@ -1753,7 +1767,7 @@ object desugar {
17531767

17541768
val desugared = tree match {
17551769
case PolyFunction(targs, body) =>
1756-
makePolyFunction(targs, body) orElse tree
1770+
makePolyFunction(targs, body, pt) orElse tree
17571771
case SymbolLit(str) =>
17581772
Apply(
17591773
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/typer/Typer.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -2872,7 +2872,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28722872

28732873
typedTypeOrClassDef
28742874
case tree: untpd.Labeled => typedLabeled(tree)
2875-
case _ => typedUnadapted(desugar(tree), pt, locked)
2875+
case _ => typedUnadapted(desugar(tree, pt), pt, locked)
28762876
}
28772877
}
28782878

@@ -2925,7 +2925,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
29252925
case tree: untpd.Splice => typedSplice(tree, pt)
29262926
case tree: untpd.MacroTree => report.error("Unexpected macro", tree.srcPos); tpd.nullLiteral // ill-formed code may reach here
29272927
case tree: untpd.Hole => typedHole(tree, pt)
2928-
case _ => typedUnadapted(desugar(tree), pt, locked)
2928+
case _ => typedUnadapted(desugar(tree, pt), pt, locked)
29292929
}
29302930

29312931
try

tests/pos/i15554.scala

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
enum PingMessage[Response]:
2+
case Ping(from: String) extends PingMessage[String]
3+
4+
val pongBehavior: [O] => (Unit, PingMessage[O]) => (Unit, O) =
5+
[P] =>
6+
(state: Unit, msg: PingMessage[P]) =>
7+
msg match
8+
case PingMessage.Ping(from) => ((), s"Pong from $from")

0 commit comments

Comments
 (0)