Skip to content

Commit e422066

Browse files
authored
Implement individual erased parameters (#16507)
### Syntax Erased parameters in a method / lambda comes with an `erased` modifier before its name: ```scala def erasedSecondParam(x: Int, erased y: Int): Int = x type EraseSecondParam[T, U] = (T, erased U) => T val esp: EraseSecondParam[Int, Int] = (x, erased y) => erasedSecondParam(x, y) ``` This is a breaking change, as previously erased methods / functions with multiple parameters now only have its first parameter erased. ### Semantics `[Impure][Contextual]ErasedFunctionN` traits are no longer available. Instead, erased function values are denoted by refining the `scala.runtime.ErasedFunction` trait: ```scala type Int_EInt = (Int, erased Int) => Int // is equivalent to type Int_EInt2 = scala.runtime.ErasedFunction { def apply(x$0: Int, erased x$1: Int): Int } ``` They are subsequently compiled (during Erasure) into `[Contextual]FunctionM` where `M` is the number of non-erased parameters. ### Erased Classes Any parameter that is an instance of an erased class is automatically erased. This is different from before, where the parameters are erased only if all parameters are instances of erased classes.
2 parents 6ea3ea6 + 0f7c3ab commit e422066

File tree

71 files changed

+879
-403
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+879
-403
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+26-3
Original file line numberDiff line numberDiff line change
@@ -1498,10 +1498,10 @@ object desugar {
14981498
case vd: ValDef => vd
14991499
}
15001500

1501-
def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(using Context): Function = {
1502-
val mods = if (isErased) Given | Erased else Given
1501+
def makeContextualFunction(formals: List[Tree], body: Tree, erasedParams: List[Boolean])(using Context): Function = {
1502+
val mods = Given
15031503
val params = makeImplicitParameters(formals, mods)
1504-
FunctionWithMods(params, body, Modifiers(mods))
1504+
FunctionWithMods(params, body, Modifiers(mods), erasedParams)
15051505
}
15061506

15071507
private def derivedValDef(original: Tree, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) = {
@@ -1834,6 +1834,7 @@ object desugar {
18341834
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
18351835
case _ =>
18361836
annotate(tpnme.retains, parent)
1837+
case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt)
18371838
}
18381839
desugared.withSpan(tree.span)
18391840
}
@@ -1909,6 +1910,28 @@ object desugar {
19091910
TypeDef(tpnme.REFINE_CLASS, impl).withFlags(Trait)
19101911
}
19111912

1913+
/** Ensure the given function tree use only ValDefs for parameters.
1914+
* For example,
1915+
* FunctionWithMods(List(TypeTree(A), TypeTree(B)), body, mods, erasedParams)
1916+
* gets converted to
1917+
* FunctionWithMods(List(ValDef(x$1, A), ValDef(x$2, B)), body, mods, erasedParams)
1918+
*/
1919+
def makeFunctionWithValDefs(tree: Function, pt: Type)(using Context): Function = {
1920+
val Function(args, result) = tree
1921+
args match {
1922+
case (_ : ValDef) :: _ => tree // ValDef case can be easily handled
1923+
case _ if !ctx.mode.is(Mode.Type) => tree
1924+
case _ =>
1925+
val applyVParams = args.zipWithIndex.map {
1926+
case (p, n) => makeSyntheticParameter(n + 1, p)
1927+
}
1928+
tree match
1929+
case tree: FunctionWithMods =>
1930+
untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, tree.erasedParams)
1931+
case _ => untpd.Function(applyVParams, result)
1932+
}
1933+
}
1934+
19121935
/** Returns list of all pattern variables, possibly with their types,
19131936
* without duplicates
19141937
*/

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
960960
&& tree.isTerm
961961
&& {
962962
val qualType = tree.qualifier.tpe
963-
hasRefinement(qualType) && !qualType.derivesFrom(defn.PolyFunctionClass)
963+
hasRefinement(qualType) && !defn.isRefinedFunctionType(qualType)
964964
}
965965
def loop(tree: Tree): Boolean = tree match
966966
case TypeApply(fun, _) =>

compiler/src/dotty/tools/dotc/ast/tpd.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
260260
// If `isParamDependent == false`, the value of `previousParamRefs` is not used.
261261
if isParamDependent then mutable.ListBuffer[TermRef]() else (null: ListBuffer[TermRef] | Null).uncheckedNN
262262

263-
def valueParam(name: TermName, origInfo: Type): TermSymbol =
263+
def valueParam(name: TermName, origInfo: Type, isErased: Boolean): TermSymbol =
264264
val maybeImplicit =
265265
if tp.isContextualMethod then Given
266266
else if tp.isImplicitMethod then Implicit
267267
else EmptyFlags
268-
val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags
268+
val maybeErased = if isErased then Erased else EmptyFlags
269269

270270
def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)
271271

@@ -283,7 +283,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
283283
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
284284
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
285285
case nil =>
286-
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
286+
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.erasedParams).map(valueParam), Nil)
287287
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
288288
(rtp, vparams :: paramss)
289289
case _ =>
@@ -1140,10 +1140,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11401140

11411141
def etaExpandCFT(using Context): Tree =
11421142
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
1143-
case defn.ContextFunctionType(argTypes, resType, isErased) =>
1143+
case defn.ContextFunctionType(argTypes, resType, _) =>
11441144
val anonFun = newAnonFun(
11451145
ctx.owner,
1146-
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
1146+
MethodType.companion(isContextual = true)(argTypes, resType),
11471147
coord = ctx.owner.coord)
11481148
def lambdaBody(refss: List[List[Tree]]) =
11491149
expand(target.select(nme.apply).appliedToArgss(refss), resType)(

compiler/src/dotty/tools/dotc/ast/untpd.scala

+7-3
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
7676
override def isType: Boolean = body.isType
7777
}
7878

79-
/** A function type or closure with `implicit`, `erased`, or `given` modifiers */
80-
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
81-
extends Function(args, body)
79+
/** A function type or closure with `implicit` or `given` modifiers and information on which parameters are `erased` */
80+
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers, val erasedParams: List[Boolean])(implicit @constructorOnly src: SourceFile)
81+
extends Function(args, body) {
82+
assert(args.length == erasedParams.length)
83+
84+
def hasErasedParams = erasedParams.contains(true)
85+
}
8286

8387
/** A polymorphic function type */
8488
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ extension (tp: Type)
146146
defn.FunctionType(
147147
fname.functionArity,
148148
isContextual = fname.isContextFunction,
149-
isErased = fname.isErasedFunction,
150149
isImpure = true).appliedTo(args)
151150
case _ =>
152151
tp

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+10-10
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ class CheckCaptures extends Recheck, SymTransformer:
336336
mapArgUsing(_.forceBoxStatus(false))
337337
else if meth == defn.Caps_unsafeBoxFunArg then
338338
mapArgUsing {
339-
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual, isErased) =>
340-
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual, isErased)
339+
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual) =>
340+
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual)
341341
}
342342
else
343343
super.recheckApply(tree, pt) match
@@ -430,7 +430,7 @@ class CheckCaptures extends Recheck, SymTransformer:
430430
block match
431431
case closureDef(mdef) =>
432432
pt.dealias match
433-
case defn.FunctionOf(ptformals, _, _, _)
433+
case defn.FunctionOf(ptformals, _, _)
434434
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
435435
// Redo setup of the anonymous function so that formal parameters don't
436436
// get capture sets. This is important to avoid false widenings to `*`
@@ -598,18 +598,18 @@ class CheckCaptures extends Recheck, SymTransformer:
598598
//println(i"check conforms $actual1 <<< $expected1")
599599
super.checkConformsExpr(actual1, expected1, tree)
600600

601-
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: Boolean)(using Context): Type =
602-
MethodType.companion(isContextual = isContextual, isErased = isErased)(args, resultType)
601+
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type =
602+
MethodType.companion(isContextual = isContextual)(args, resultType)
603603
.toFunctionType(isJava = false, alwaysDependent = true)
604604

605605
/** Turn `expected` into a dependent function when `actual` is dependent. */
606606
private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type =
607607
def recur(expected: Type): Type = expected.dealias match
608608
case expected @ CapturingType(eparent, refs) =>
609609
CapturingType(recur(eparent), refs, boxed = expected.isBoxed)
610-
case expected @ defn.FunctionOf(args, resultType, isContextual, isErased)
610+
case expected @ defn.FunctionOf(args, resultType, isContextual)
611611
if defn.isNonRefinedFunction(expected) && defn.isFunctionType(actual) && !defn.isNonRefinedFunction(actual) =>
612-
val expected1 = toDepFun(args, resultType, isContextual, isErased)
612+
val expected1 = toDepFun(args, resultType, isContextual)
613613
expected1
614614
case _ =>
615615
expected
@@ -675,7 +675,7 @@ class CheckCaptures extends Recheck, SymTransformer:
675675

676676
try
677677
val (eargs, eres) = expected.dealias.stripCapturing match
678-
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
678+
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
679679
case expected: MethodType => (expected.paramInfos, expected.resType)
680680
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
681681
case _ => (aargs.map(_ => WildcardType), WildcardType)
@@ -739,7 +739,7 @@ class CheckCaptures extends Recheck, SymTransformer:
739739
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
740740
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
741741
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
742-
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
742+
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(actual) =>
743743
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
744744
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
745745
(aargs1, ares1) =>
@@ -962,7 +962,7 @@ class CheckCaptures extends Recheck, SymTransformer:
962962
case CapturingType(parent, refs) =>
963963
healCaptureSet(refs)
964964
traverse(parent)
965-
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
965+
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
966966
traverse(rinfo)
967967
case tp: TermLambda =>
968968
val saved = allowed

compiler/src/dotty/tools/dotc/cc/Setup.scala

+11-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import transform.Recheck.*
1212
import CaptureSet.IdentityCaptRefMap
1313
import Synthetics.isExcluded
1414
import util.Property
15+
import dotty.tools.dotc.core.Annotations.Annotation
1516

1617
/** A tree traverser that prepares a compilation unit to be capture checked.
1718
* It does the following:
@@ -38,7 +39,6 @@ extends tpd.TreeTraverser:
3839
private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type =
3940
MethodType.companion(
4041
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
41-
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
4242
)(argTypes, resType)
4343
.toFunctionType(isJava = false, alwaysDependent = true)
4444

@@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
5454
val boxedRes = recur(res)
5555
if boxedRes eq res then tp
5656
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
57-
case tp1 @ RefinedType(_, _, rinfo) if defn.isFunctionType(tp1) =>
57+
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(tp1) =>
5858
val boxedRinfo = recur(rinfo)
5959
if boxedRinfo eq rinfo then tp
6060
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
@@ -231,7 +231,7 @@ extends tpd.TreeTraverser:
231231
tp.derivedAppliedType(tycon1, args1 :+ res1)
232232
else
233233
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
234-
case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) =>
234+
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
235235
val rinfo1 = apply(rinfo)
236236
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
237237
else tp
@@ -260,7 +260,13 @@ extends tpd.TreeTraverser:
260260
private def expandThrowsAlias(tp: Type)(using Context) = tp match
261261
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
262262
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
263-
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true)
263+
defn.FunctionOf(
264+
AnnotatedType(
265+
defn.CanThrowClass.typeRef.appliedTo(exc),
266+
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil,
267+
res,
268+
isContextual = true
269+
)
264270
case _ => tp
265271

266272
private def expandThrowsAliases(using Context) = new TypeMap:
@@ -323,7 +329,7 @@ extends tpd.TreeTraverser:
323329
args.last, CaptureSet.empty, currentCs ++ outerCs)
324330
tp.derivedAppliedType(tycon1, args1 :+ resType1)
325331
tp1.capturing(outerCs)
326-
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) =>
332+
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
327333
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
328334
.capturing(outerCs)
329335
case _ =>

0 commit comments

Comments
 (0)