Skip to content

Commit 757a4be

Browse files
committed
Add defn.RefinedFunctionOf extractor
1 parent ea1d8f9 commit 757a4be

File tree

5 files changed

+27
-14
lines changed

5 files changed

+27
-14
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,9 +876,8 @@ object CaptureSet:
876876
empty
877877
case CapturingType(parent, refs) =>
878878
recur(parent) ++ refs
879-
case tpd @ RefinedType(parent, _, rinfo: MethodType)
880-
if followResult && defn.isFunctionNType(tpd) =>
881-
ofType(parent, followResult = false) // pick up capture set from parent type
879+
case tpd @ defn.RefinedFunctionOf(rinfo: MethodType) if followResult =>
880+
ofType(tpd.parent, followResult = false) // pick up capture set from parent type
882881
++ (recur(rinfo.resType) // add capture set of result
883882
-- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters
884883
case tpd @ AppliedType(tycon, args) =>

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer:
195195
capt.println(i"solving $t")
196196
refs.solve()
197197
traverse(parent)
198-
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) =>
198+
case t @ defn.RefinedFunctionOf(rinfo) =>
199199
traverse(rinfo)
200200
case tp: TypeVar =>
201201
case tp: TypeRef =>
@@ -302,8 +302,8 @@ class CheckCaptures extends Recheck, SymTransformer:
302302
t
303303
case _ =>
304304
val t1 = t match
305-
case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) =>
306-
t.derivedRefinedType(parent, rname, this(rinfo))
305+
case t @ defn.RefinedFunctionOf(rinfo: MethodType) =>
306+
t.derivedRefinedType(t.parent, t.refinedName, this(rinfo))
307307
case _ =>
308308
mapOver(t)
309309
if variance > 0 then t1
@@ -845,7 +845,7 @@ class CheckCaptures extends Recheck, SymTransformer:
845845
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
846846
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
847847
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
848-
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
848+
case actual @ defn.RefinedFunctionOf(rinfo: MethodType) =>
849849
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
850850
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
851851
(aargs1, ares1) =>
@@ -855,11 +855,11 @@ class CheckCaptures extends Recheck, SymTransformer:
855855
adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox,
856856
(aargs1, ares1) =>
857857
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
858-
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) =>
858+
case actual @ defn.RefinedFunctionOf(rinfo: PolyType) =>
859859
adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox,
860860
ares1 =>
861861
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
862-
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
862+
val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1)
863863
actual1
864864
)
865865
case _ =>
@@ -1080,7 +1080,7 @@ class CheckCaptures extends Recheck, SymTransformer:
10801080
case CapturingType(parent, refs) =>
10811081
healCaptureSet(refs)
10821082
traverse(parent)
1083-
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
1083+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
10841084
traverse(rinfo)
10851085
case tp: TermLambda =>
10861086
val saved = allowed

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
5454
val boxedRes = recur(res)
5555
if boxedRes eq res then tp
5656
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
57-
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) =>
57+
case tp1 @ defn.RefinedFunctionOf(rinfo: MethodType) =>
5858
val boxedRinfo = recur(rinfo)
5959
if boxedRinfo eq rinfo then tp
6060
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
@@ -149,7 +149,7 @@ extends tpd.TreeTraverser:
149149
tp.derivedAppliedType(tycon1, args1 :+ res1)
150150
else
151151
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
152-
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
152+
case defn.RefinedFunctionOf(rinfo: MethodType) =>
153153
val rinfo1 = apply(rinfo)
154154
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
155155
else tp

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,20 @@ class Definitions {
11291129
}
11301130
}
11311131

1132+
object RefinedFunctionOf {
1133+
/** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
1134+
* Extracts the method type type and apply info.
1135+
*/
1136+
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = {
1137+
tpe.refinedInfo match
1138+
case mt: MethodOrPoly
1139+
if tpe.refinedName == nme.apply
1140+
&& (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) =>
1141+
Some(mt)
1142+
case _ => None
1143+
}
1144+
}
1145+
11321146
object PolyFunctionOf {
11331147

11341148
/** Creates a refined `PolyFunction` with an `apply` method with the given info. */

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4097,8 +4097,8 @@ object Types {
40974097
tp.derivedAppliedType(tycon, addInto(args.head) :: Nil)
40984098
case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) =>
40994099
wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last)))
4100-
case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) =>
4101-
wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo)))
4100+
case tp @ defn.RefinedFunctionOf(rinfo) =>
4101+
wrapConvertible(tp.derivedRefinedType(tp.parent, tp.refinedName, addInto(rinfo)))
41024102
case tp: MethodOrPoly =>
41034103
tp.derivedLambdaType(resType = addInto(tp.resType))
41044104
case ExprType(resType) =>

0 commit comments

Comments
 (0)