Skip to content

Commit b0a3dd9

Browse files
authored
Improve defn.PolyFunctionOf extractor (#18442)
* Only match `RefinedType` representing the `PolyFunction`. This will allow us to use `derivedRefinedType` on the function type. * Only match the refinement if it is a `MethodOrPoly`. `ExprType` is not a valid `PolyFunction` refinement. * Remove `dealias` in `PolyFunctionOf` extractor. There was only one case where this was necessary and it added unnecessary overhead.
2 parents 510aac8 + 60b2d02 commit b0a3dd9

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ class CheckCaptures extends Recheck, SymTransformer:
808808

809809
try
810810
val eres = expected.dealias.stripCapturing match
811-
case RefinedType(_, _, rinfo: PolyType) => rinfo.resType
811+
case defn.PolyFunctionOf(rinfo: PolyType) => rinfo.resType
812812
case expected: PolyType => expected.resType
813813
case _ => WildcardType
814814

compiler/src/dotty/tools/dotc/core/Definitions.scala

+6-5
Original file line numberDiff line numberDiff line change
@@ -1154,11 +1154,12 @@ class Definitions {
11541154
*
11551155
* Pattern: `PolyFunction { def apply: $mt }`
11561156
*/
1157-
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
1158-
case RefinedType(parent, nme.apply, mt: MethodicType)
1159-
if parent.derivesFrom(defn.PolyFunctionClass) =>
1160-
Some(mt)
1161-
case _ => None
1157+
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] =
1158+
tpe.refinedInfo match
1159+
case mt: MethodOrPoly
1160+
if tpe.refinedName == nme.apply && tpe.parent.derivesFrom(defn.PolyFunctionClass) =>
1161+
Some(mt)
1162+
case _ => None
11621163

11631164
private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
11641165
def isValidMethodType(info: Type) = info match

compiler/src/dotty/tools/dotc/typer/Typer.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -1651,10 +1651,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16511651
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
16521652
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
16531653
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
1654+
val dpt = pt.dealias
16541655

16551656
// If the expected type is a polymorphic function with the same number of
16561657
// type and value parameters, then infer the types of value parameters from the expected type.
1657-
val inferredVParams = pt match
1658+
val inferredVParams = dpt match
16581659
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
16591660
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
16601661
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
@@ -1670,7 +1671,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16701671
case _ =>
16711672
vparams
16721673

1673-
val resultTpt = pt.dealias match
1674+
val resultTpt = dpt match
16741675
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
16751676
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
16761677
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))

0 commit comments

Comments
 (0)