diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index b6b5d569677c..a950c81d9028 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -808,7 +808,7 @@ class CheckCaptures extends Recheck, SymTransformer: try val eres = expected.dealias.stripCapturing match - case RefinedType(_, _, rinfo: PolyType) => rinfo.resType + case defn.PolyFunctionOf(rinfo: PolyType) => rinfo.resType case expected: PolyType => expected.resType case _ => WildcardType diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ea48dd2b56fa..cb90f3409e33 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1140,11 +1140,12 @@ class Definitions { * * Pattern: `PolyFunction { def apply: $mt }` */ - def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match - case RefinedType(parent, nme.apply, mt: MethodicType) - if parent.derivesFrom(defn.PolyFunctionClass) => - Some(mt) - case _ => None + def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = + tpe.refinedInfo match + case mt: MethodOrPoly + if tpe.refinedName == nme.apply && tpe.parent.derivesFrom(defn.PolyFunctionClass) => + Some(mt) + case _ => None private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean = def isValidMethodType(info: Type) = info match diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 09df6614d496..4d3b695855c5 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1648,10 +1648,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked + val dpt = pt.dealias // If the expected type is a polymorphic function with the same number of // type and value parameters, then infer the types of value parameters from the expected type. - val inferredVParams = pt match + val inferredVParams = dpt match case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 => vparams.zipWithConserve(mt.paramInfos): (vparam, formal) => @@ -1667,7 +1668,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => vparams - val resultTpt = pt.dealias match + val resultTpt = dpt match case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) => untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))