Skip to content

Commit 2699715

Browse files
authored
Merge pull request #2015 from dotty-staging/add-pf-overloading
Add overloading support for case-closures
2 parents a0f47a0 + 47b6c6b commit 2699715

File tree

4 files changed

+70
-27
lines changed

4 files changed

+70
-27
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
287287
case ValDef(_, tpt, _) => tpt.isEmpty
288288
case _ => false
289289
}
290+
case Match(EmptyTree, _) =>
291+
true
290292
case _ => false
291293
}
292294

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
12271227
def typeShape(tree: untpd.Tree): Type = tree match {
12281228
case untpd.Function(args, body) =>
12291229
defn.FunctionOf(args map Function.const(defn.AnyType), typeShape(body))
1230+
case Match(EmptyTree, _) =>
1231+
defn.PartialFunctionType.appliedTo(defn.AnyType :: defn.NothingType :: Nil)
12301232
case _ =>
12311233
defn.NothingType
12321234
}
@@ -1271,7 +1273,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
12711273
alts filter (alt => sizeFits(alt, alt.widen))
12721274

12731275
def narrowByShapes(alts: List[TermRef]): List[TermRef] = {
1274-
if (normArgs exists (_.isInstanceOf[untpd.Function]))
1276+
if (normArgs exists untpd.isFunctionWithUnknownParamType)
12751277
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
12761278
else narrowByTypes(alts, normArgs map typeShape, resultType)
12771279
else
@@ -1351,33 +1353,31 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
13511353
case ValDef(_, tpt, _) => tpt.isEmpty
13521354
case _ => false
13531355
}
1354-
arg match {
1355-
case arg: untpd.Function if arg.args.exists(isUnknownParamType) =>
1356-
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1357-
val formalsForArg: List[Type] = altFormals.map(_.head)
1358-
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1359-
// (p_1_1, ..., p_m_1) => r_1
1360-
// ...
1361-
// (p_1_n, ..., p_m_n) => r_n
1362-
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
1363-
formalsForArg.map(defn.FunctionOf.unapply)
1364-
if (decomposedFormalsForArg.forall(_.isDefined)) {
1365-
val formalParamTypessForArg: List[List[Type]] =
1366-
decomposedFormalsForArg.map(_.get._1)
1367-
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1368-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1369-
// Given definitions above, for i = 1,...,m,
1370-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1371-
// If all p_i_k's are the same, assume the type as formal parameter
1372-
// type of the i'th parameter of the closure.
1373-
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1374-
else WildcardType)
1375-
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1376-
overload.println(i"pretype arg $arg with expected type $commonFormal")
1377-
pt.typedArg(arg, commonFormal)
1378-
}
1356+
if (untpd.isFunctionWithUnknownParamType(arg)) {
1357+
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1358+
val formalsForArg: List[Type] = altFormals.map(_.head)
1359+
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1360+
// (p_1_1, ..., p_m_1) => r_1
1361+
// ...
1362+
// (p_1_n, ..., p_m_n) => r_n
1363+
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
1364+
formalsForArg.map(defn.FunctionOf.unapply)
1365+
if (decomposedFormalsForArg.forall(_.isDefined)) {
1366+
val formalParamTypessForArg: List[List[Type]] =
1367+
decomposedFormalsForArg.map(_.get._1)
1368+
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1369+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1370+
// Given definitions above, for i = 1,...,m,
1371+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1372+
// If all p_i_k's are the same, assume the type as formal parameter
1373+
// type of the i'th parameter of the closure.
1374+
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1375+
else WildcardType)
1376+
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1377+
overload.println(i"pretype arg $arg with expected type $commonFormal")
1378+
pt.typedArg(arg, commonFormal)
13791379
}
1380-
case _ =>
1380+
}
13811381
}
13821382
recur(altFormals.map(_.tail), args1)
13831383
case _ =>

tests/pos/inferOverloaded.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
class MySeq[T] {
2+
def map1[U](f: T => U): MySeq[U] = new MySeq[U]
3+
def map2[U](f: T => U): MySeq[U] = new MySeq[U]
4+
}
5+
6+
class MyMap[A, B] extends MySeq[(A, B)] {
7+
def map1[C](f: (A, B) => C): MySeq[C] = new MySeq[C]
8+
def map1[C, D](f: (A, B) => (C, D)): MyMap[C, D] = new MyMap[C, D]
9+
def map1[C, D](f: ((A, B)) => (C, D)): MyMap[C, D] = new MyMap[C, D]
10+
11+
def foo(f: Function2[Int, Int, Int]): Unit = ()
12+
def foo[R](pf: PartialFunction[(A, B), R]): MySeq[R] = new MySeq[R]
13+
}
14+
15+
object Test {
16+
val m = new MyMap[Int, String]
17+
18+
// This one already worked because it is not overloaded:
19+
m.map2 { case (k, v) => k - 1 }
20+
21+
// These already worked because preSelectOverloaded eliminated the non-applicable overload:
22+
m.map1(t => t._1)
23+
m.map1((kInFunction, vInFunction) => kInFunction - 1)
24+
val r1 = m.map1(t => (t._1, 42.0))
25+
val r1t: MyMap[Int, Double] = r1
26+
27+
// These worked because the argument types are known for overload resolution:
28+
m.map1({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])
29+
m.map2({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])
30+
31+
// These ones did not work before:
32+
m.map1 { case (k, v) => k }
33+
val r = m.map1 { case (k, v) => (k, k*10) }
34+
val rt: MyMap[Int, Int] = r
35+
m.foo { case (k, v) => k - 1 }
36+
37+
// Used to be ambiguous but overload resolution now favors PartialFunction
38+
def h[R](pf: Function2[Int, String, R]): Unit = ()
39+
def h[R](pf: PartialFunction[(Double, Double), R]): Unit = ()
40+
h { case (a: Double, b: Double) => 42: Int }
41+
}

0 commit comments

Comments
 (0)