diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index a6c0646d7163..ca65e9f841d5 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -954,6 +954,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = { def isStructuralTermSelect(tree: Select) = def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match + case defn.PolyOrErasedFunctionOf(_) => + false case RefinedType(parent, rname, rinfo) => rname == tree.name || hasRefinement(parent) case tp: TypeProxy => @@ -966,10 +968,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => false !tree.symbol.exists && tree.isTerm - && { - val qualType = tree.qualifier.tpe - hasRefinement(qualType) && !defn.isPolyOrErasedFunctionType(qualType) - } + && hasRefinement(tree.qualifier.tpe) def loop(tree: Tree): Boolean = tree match case TypeApply(fun, _) => loop(fun) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index e2afe906a9c4..aff5c3ef57e4 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1114,7 +1114,7 @@ class Definitions { FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil) def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = { ft.dealias match - case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) => + case ErasedFunctionOf(mt) => Some(mt.paramInfos, mt.resType, mt.isContextualMethod) case _ => val tsym = ft.dealias.typeSymbol @@ -1126,6 +1126,42 @@ class Definitions { } } + object PolyOrErasedFunctionOf { + /** Matches a refined `PolyFunction` or `ErasedFunction` type and extracts the apply info. + * + * Pattern: `(PolyFunction | ErasedFunction) { def apply: $mt }` + */ + def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match + case RefinedType(parent, nme.apply, mt: MethodicType) + if parent.derivesFrom(defn.PolyFunctionClass) || parent.derivesFrom(defn.ErasedFunctionClass) => + Some(mt) + case _ => None + } + + object PolyFunctionOf { + /** Matches a refined `PolyFunction` type and extracts the apply info. + * + * Pattern: `PolyFunction { def apply: $pt }` + */ + def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match + case RefinedType(parent, nme.apply, pt: PolyType) + if parent.derivesFrom(defn.PolyFunctionClass) => + Some(pt) + case _ => None + } + + object ErasedFunctionOf { + /** Matches a refined `ErasedFunction` type and extracts the apply info. + * + * Pattern: `ErasedFunction { def apply: $mt }` + */ + def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match + case RefinedType(parent, nme.apply, mt: MethodType) + if parent.derivesFrom(defn.ErasedFunctionClass) => + Some(mt) + case _ => None + } + object PartialFunctionOf { def apply(arg: Type, result: Type)(using Context): Type = PartialFunctionClass.typeRef.appliedTo(arg :: result :: Nil) @@ -1713,18 +1749,6 @@ class Definitions { def isFunctionNType(tp: Type)(using Context): Boolean = isNonRefinedFunction(tp.dropDependentRefinement) - /** Does `tp` derive from `PolyFunction` or `ErasedFunction`? */ - def isPolyOrErasedFunctionType(tp: Type)(using Context): Boolean = - isPolyFunctionType(tp) || isErasedFunctionType(tp) - - /** Does `tp` derive from `PolyFunction`? */ - def isPolyFunctionType(tp: Type)(using Context): Boolean = - tp.derivesFrom(defn.PolyFunctionClass) - - /** Does `tp` derive from `ErasedFunction`? */ - def isErasedFunctionType(tp: Type)(using Context): Boolean = - tp.derivesFrom(defn.ErasedFunctionClass) - /** Returns whether `tp` is an instance or a refined instance of: * - scala.FunctionN * - scala.ContextFunctionN @@ -1732,7 +1756,9 @@ class Definitions { * - PolyFunction */ def isFunctionType(tp: Type)(using Context): Boolean = - isFunctionNType(tp) || isPolyOrErasedFunctionType(tp) + isFunctionNType(tp) + || tp.derivesFrom(defn.PolyFunctionClass) // TODO check for refinement? + || tp.derivesFrom(defn.ErasedFunctionClass) // TODO check for refinement? private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) = if !ctx.settings.Yscala2Stdlib.value then @@ -1836,7 +1862,7 @@ class Definitions { tp.stripTypeVar.dealias match case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) => asContextFunctionType(TypeComparer.bounds(tp1).hiBound) - case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod => + case tp1 @ ErasedFunctionOf(mt) if mt.isContextualMethod => tp1 case tp1 => if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1 @@ -1856,7 +1882,7 @@ class Definitions { atPhase(erasurePhase)(unapply(tp)) else asContextFunctionType(tp) match - case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) => + case ErasedFunctionOf(mt) => Some((mt.paramInfos, mt.resType, mt.erasedParams)) case tp1 if tp1.exists => val args = tp1.functionArgInfos @@ -1866,7 +1892,7 @@ class Definitions { /* Returns a list of erased booleans marking whether parameters are erased, for a function type. */ def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match { - case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams + case ErasedFunctionOf(mt) => mt.erasedParams case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false } case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala index c1b2541c460b..8ce0da9bc50f 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala @@ -509,7 +509,7 @@ class TypeApplications(val self: Type) extends AnyVal { * Handles `ErasedFunction`s and poly functions gracefully. */ final def functionArgInfos(using Context): List[Type] = self.dealias match - case RefinedType(parent, nme.apply, mt: MethodType) if defn.isPolyOrErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType) + case defn.ErasedFunctionOf(mt) => (mt.paramInfos :+ mt.resultType) case _ => self.dropDependentRefinement.dealias.argInfos /** Argument types where existential types in arguments are disallowed */ diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index db1bf85ade93..e83cefb570cb 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -666,7 +666,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling isSubType(info1, info2) if defn.isFunctionType(tp2) then - if defn.isPolyFunctionType(tp2) then + if tp2.derivesFrom(defn.PolyFunctionClass) then // TODO should we handle ErasedFunction is this same way? tp1.member(nme.apply).info match case info1: PolyType => diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index b29f88aaa010..59c8ba193850 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -654,8 +654,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst else SuperType(eThis, eSuper) case ExprType(rt) => defn.FunctionType(0) - case RefinedType(parent, nme.apply, refinedInfo) if defn.isPolyOrErasedFunctionType(parent) => - eraseRefinedFunctionApply(refinedInfo) + case defn.PolyOrErasedFunctionOf(mt) => + eraseRefinedFunctionApply(mt) case tp: TypeVar if !tp.isInstantiated => assert(inSigName, i"Cannot erase uninstantiated type variable $tp") WildcardType @@ -936,7 +936,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst sigName(defn.FunctionOf(Nil, rt)) case tp: TypeVar if !tp.isInstantiated => tpnme.Uninstantiated - case tp @ RefinedType(parent, nme.apply, _) if defn.isPolyOrErasedFunctionType(parent) => + case tp @ defn.PolyOrErasedFunctionOf(_) => // we need this case rather than falling through to the default // because RefinedTypes <: TypeProxy and it would be caught by // the case immediately below diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index f00693e535e3..504045246813 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1747,9 +1747,7 @@ object Types { if !tf1.exists then tf2 else if !tf2.exists then tf1 else NoType - case t if defn.isNonRefinedFunction(t) => - t - case t if defn.isErasedFunctionType(t) => + case t if defn.isFunctionType(t) => t case t @ SAMType(_) => t diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 3ed024429bb6..ca441fe9e799 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -679,7 +679,7 @@ object Erasure { // Instead, we manually lookup the type of `apply` in the qualifier. inContext(preErasureCtx) { val qualTp = tree.qualifier.typeOpt.widen - if defn.isPolyOrErasedFunctionType(qualTp) then + if qualTp.derivesFrom(defn.PolyFunctionClass) || qualTp.derivesFrom(defn.ErasedFunctionClass) then eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol else NoSymbol diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index b08c23db557a..b22104a462e0 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -447,7 +447,11 @@ object TreeChecker { val tpe = tree.typeOpt // PolyFunction and ErasedFunction apply methods stay structural until Erasure - val isRefinedFunctionApply = (tree.name eq nme.apply) && defn.isPolyOrErasedFunctionType(tree.qualifier.typeOpt) + val isRefinedFunctionApply = (tree.name eq nme.apply) && { + val qualTpe = tree.qualifier.typeOpt + qualTpe.derivesFrom(defn.PolyFunctionClass) || qualTpe.derivesFrom(defn.ErasedFunctionClass) + } + // Outer selects are pickled specially so don't require a symbol val isOuterSelect = tree.name.is(OuterSelectName) val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name) diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index e88546c0b71d..0a87a95120ae 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -105,7 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): expected =:= defn.FunctionOf(actualArgs, actualRet, defn.isContextFunctionType(baseFun)) val arity: Int = - if defn.isErasedFunctionType(fun) then -1 // TODO support? + if fun.derivesFrom(defn.ErasedFunctionClass) then -1 // TODO support? else if defn.isFunctionNType(fun) then // TupledFunction[(...) => R, ?] fun.functionArgInfos match diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f94c91ad1358..ee18b344a194 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1326,7 +1326,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer (pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound))) case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) - if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity => + if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => + (formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))) + case defn.ErasedFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity => (formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))) case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) => val restpe = mt.resultType match @@ -1648,11 +1650,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // If the expected type is a polymorphic function with the same number of // type and value parameters, then infer the types of value parameters from the expected type. val inferredVParams = pt match - case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) - if (parent.typeSymbol eq defn.PolyFunctionClass) - && tparams.lengthCompare(poly.paramNames) == 0 - && vparams.lengthCompare(mt.paramNames) == 0 - => + case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) + if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 => vparams.zipWithConserve(mt.paramInfos): (vparam, formal) => // Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since // it must be a valid method parameter type. @@ -1667,7 +1666,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer vparams val resultTpt = pt.dealias match - case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass => + case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) => untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) case _ => untpd.TypeTree() @@ -3234,8 +3233,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else formals.map(untpd.TypeTree) } - val erasedParams = pt.dealias match { - case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams + val erasedParams = pt match { + case defn.ErasedFunctionOf(mt: MethodType) => mt.erasedParams case _ => paramTypes.map(_ => false) } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 7df98d57b5ff..6e8d84f78abf 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -1788,7 +1788,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler def isContextFunctionType: Boolean = dotc.core.Symbols.defn.isContextFunctionType(self) def isErasedFunctionType: Boolean = - dotc.core.Symbols.defn.isErasedFunctionType(self) + self.derivesFrom(dotc.core.Symbols.defn.ErasedFunctionClass) def isDependentFunctionType: Boolean = val tpNoRefinement = self.dropDependentRefinement tpNoRefinement != self