Skip to content

Commit 1a61d8d

Browse files
committed
Implement returns in methods with context function result types
1 parent 2ccba7c commit 1a61d8d

File tree

4 files changed

+62
-12
lines changed

4 files changed

+62
-12
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

+34-11
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ class Typer extends Namer
11291129
case _ =>
11301130
}
11311131

1132-
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
1132+
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree)
11331133

11341134
/** The inferred parameter type for a parameter in a lambda that does
11351135
* not have an explicit type given.
@@ -1445,17 +1445,40 @@ class Typer extends Namer
14451445
}
14461446

14471447
def typedReturn(tree: untpd.Return)(using Context): Return = {
1448+
1449+
/** If `pt` is a context function type, its return type. If the CFT
1450+
* is dependent, instantiate with the parameters of the associated
1451+
* anonymous function.
1452+
* @param paramss the parameters of the anonymous functions
1453+
* enclosing the return expression
1454+
*/
1455+
def instantiateCFT(pt: Type, paramss: List[List[Symbol]]): Type =
1456+
val ift = defn.asContextFunctionType(pt)
1457+
if ift.exists then
1458+
ift.nonPrivateMember(nme.apply).info match
1459+
case appType: MethodType =>
1460+
instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail)
1461+
else pt
1462+
14481463
def returnProto(owner: Symbol, locals: Scope): Type =
14491464
if (owner.isConstructor) defn.UnitType
1450-
else owner.info match {
1451-
case info: PolyType =>
1452-
val tparams = locals.toList.takeWhile(_ is TypeParam)
1453-
assert(info.paramNames.length == tparams.length,
1454-
i"return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, %")
1455-
info.instantiate(tparams.map(_.typeRef)).finalResultType
1456-
case info =>
1457-
info.finalResultType
1458-
}
1465+
else
1466+
val rt = owner.info match
1467+
case info: PolyType =>
1468+
val tparams = locals.toList.takeWhile(_ is TypeParam)
1469+
assert(info.paramNames.length == tparams.length,
1470+
i"return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, %")
1471+
info.instantiate(tparams.map(_.typeRef)).finalResultType
1472+
case info =>
1473+
info.finalResultType
1474+
def iftParamss = ctx.owner.ownersIterator
1475+
.filter(_.is(Method, butNot = Accessor))
1476+
.takeWhile(_.isAnonymousFunction)
1477+
.toList
1478+
.reverse
1479+
.map(_.paramSymss.head)
1480+
instantiateCFT(rt, iftParamss)
1481+
14591482
def enclMethInfo(cx: Context): (Tree, Type) = {
14601483
val owner = cx.owner
14611484
if (owner.isType) {
@@ -3147,7 +3170,7 @@ class Typer extends Namer
31473170

31483171
def isContextFunctionRef(wtp: Type): Boolean = wtp match {
31493172
case RefinedType(parent, nme.apply, _) =>
3150-
isContextFunctionRef(parent) // apply refinements indicate a dependent IFT
3173+
isContextFunctionRef(parent) // apply refinements indicate a dependent CFT
31513174
case _ =>
31523175
val underlying = wtp.underlyingClassRef(refinementOK = false) // other refinements are not OK
31533176
defn.isContextFunctionClass(underlying.classSymbol)

tests/neg/curried-dependent-ift.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ trait B
1414
def h(x: Boolean): A ?=> B ?=> (A, B) =
1515
(summon[A], summon[B]) // OK
1616

17-
def g(x: Boolean): (c1: Ctx1) ?=> Ctx2 ?=> (c1.T, Ctx2) = ??? // error
17+
def g(x: Boolean): (c1: Ctx1) ?=> Ctx2 ?=> (c1.T, Ctx2) =
18+
return ??? // error
19+
???

tests/run/ift-return.check

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
(22,abc)
2+
(22,def)

tests/run/ift-return.scala

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
trait A:
2+
val x: Int
3+
4+
trait Ctx:
5+
type T
6+
val x: T
7+
val y: T
8+
9+
def f(x: Boolean): A ?=> (c: Ctx) ?=> (Int, c.T) =
10+
if x then return (summon[A].x, summon[Ctx].x)
11+
(summon[A].x, summon[Ctx].y)
12+
13+
@main def Test =
14+
given A:
15+
val x = 22
16+
given Ctx:
17+
type T = String
18+
val x = "abc"
19+
val y = "def"
20+
21+
println(f(true))
22+
println(f(false))
23+

0 commit comments

Comments
 (0)