Skip to content

Commit a76470f

Browse files
committed
Implement basic version of desugaring context bounds for poly functions
1 parent a3786a5 commit a76470f

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+27
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,33 @@ object desugar {
12211221
case _ => body
12221222
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
12231223

1224+
/** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R
1225+
* Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R
1226+
*/
1227+
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
1228+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
1229+
val newTParams = tparams.map {
1230+
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
1231+
TypeDef(name, ContextBounds(bounds, List.empty))
1232+
}
1233+
var idx = -1
1234+
val collecedContextBounds = tparams.collect {
1235+
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
1236+
// TOOD(kπ) Should we handle non empty normal bounds here?
1237+
name -> ctxBounds
1238+
}.flatMap { case (name, ctxBounds) =>
1239+
ctxBounds.map { ctxBound =>
1240+
idx = idx + 1
1241+
makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given)
1242+
}
1243+
}
1244+
val contextFunctionResult =
1245+
if collecedContextBounds.isEmpty then
1246+
fun
1247+
else
1248+
Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span)
1249+
PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)
1250+
12241251
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12251252
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12261253
*/

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ object Parsers {
6868
def acceptsVariance =
6969
this == Class || this == CaseClass || this == Hk
7070
def acceptsCtxBounds =
71-
!(this == Type || this == Hk)
71+
!(this == Hk)
7272
def acceptsWildcard =
7373
this == Type || this == Hk
7474

@@ -3429,7 +3429,7 @@ object Parsers {
34293429
*
34303430
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
34313431
* TypTypeParam ::= {Annotation}
3432-
* (id | ‘_’) [HkTypeParamClause] TypeBounds
3432+
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
34333433
*
34343434
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
34353435
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]

compiler/src/dotty/tools/dotc/typer/Typer.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -1919,8 +1919,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19191919

19201920
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19211921
val tree1 = desugar.normalizePolyFunction(tree)
1922-
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1923-
else typedPolyFunctionValue(tree1, pt)
1922+
val tree2 = desugar.expandPolyFunctionContextBounds(tree1)
1923+
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt)
1924+
else typedPolyFunctionValue(tree2, pt)
19241925

19251926
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19261927
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
5+
trait Ord[X]:
6+
def compare(x: X, y: X): Int
7+
8+
val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
9+
10+
// type Comparer = [X: Ord] => (x: X, y: X) => Boolean
11+
// val less2: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
12+
13+
// type Cmp[X] = (x: X, y: X) => Boolean
14+
// type Comparer2 = [X: Ord] => Cmp[X]
15+
// val less3: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

0 commit comments

Comments
 (0)