Skip to content

Commit ac806b8

Browse files
authored
Merge pull request #10327 from dotty-staging/more-pretyping
Better expected type for arguments of overloaded methods
2 parents 0c9b2ca + 5e6bfa2 commit ac806b8

File tree

2 files changed

+50
-32
lines changed

2 files changed

+50
-32
lines changed

compiler/src/dotty/tools/dotc/typer/Applications.scala

+31-32
Original file line numberDiff line numberDiff line change
@@ -2028,38 +2028,37 @@ trait Applications extends Compatibility {
20282028
private def pretypeArgs(alts: List[TermRef], pt: FunProto)(using Context): Unit = {
20292029
def recur(altFormals: List[List[Type]], args: List[untpd.Tree]): Unit = args match {
20302030
case arg :: args1 if !altFormals.exists(_.isEmpty) =>
2031-
untpd.functionWithUnknownParamType(arg) match {
2032-
case Some(fn) =>
2033-
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
2034-
val formalsForArg: List[Type] = altFormals.map(_.head)
2035-
def argTypesOfFormal(formal: Type): List[Type] =
2036-
formal match {
2037-
case defn.FunctionOf(args, result, isImplicit, isErased) => args
2038-
case defn.PartialFunctionOf(arg, result) => arg :: Nil
2039-
case _ => Nil
2040-
}
2041-
val formalParamTypessForArg: List[List[Type]] =
2042-
formalsForArg.map(argTypesOfFormal)
2043-
if (formalParamTypessForArg.forall(_.nonEmpty) &&
2044-
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
2045-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
2046-
// Given definitions above, for i = 1,...,m,
2047-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
2048-
// If all p_i_k's are the same, assume the type as formal parameter
2049-
// type of the i'th parameter of the closure.
2050-
if (isUniform(ps)(_ frozen_=:= _)) ps.head
2051-
else WildcardType)
2052-
def isPartial = // we should generate a partial function for the arg
2053-
fn.isInstanceOf[untpd.Match] &&
2054-
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
2055-
val commonFormal =
2056-
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
2057-
else defn.FunctionOf(commonParamTypes, WildcardType)
2058-
overload.println(i"pretype arg $arg with expected type $commonFormal")
2059-
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom)))
2060-
withMode(Mode.ImplicitsEnabled)(pt.typedArg(arg, commonFormal))
2061-
}
2062-
case None =>
2031+
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
2032+
val formalsForArg: List[Type] = altFormals.map(_.head)
2033+
def argTypesOfFormal(formal: Type): List[Type] =
2034+
formal match {
2035+
case defn.FunctionOf(args, result, isImplicit, isErased) => args
2036+
case defn.PartialFunctionOf(arg, result) => arg :: Nil
2037+
case _ => Nil
2038+
}
2039+
val formalParamTypessForArg: List[List[Type]] =
2040+
formalsForArg.map(argTypesOfFormal)
2041+
if (formalParamTypessForArg.forall(_.nonEmpty) &&
2042+
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
2043+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
2044+
// Given definitions above, for i = 1,...,m,
2045+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
2046+
// If all p_i_k's are the same, assume the type as formal parameter
2047+
// type of the i'th parameter of the closure.
2048+
if (isUniform(ps)(_ frozen_=:= _)) ps.head
2049+
else WildcardType)
2050+
/** Should we generate a partial function for the arg ? */
2051+
def isPartial = untpd.functionWithUnknownParamType(arg) match
2052+
case Some(_: untpd.Match) =>
2053+
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
2054+
case _ =>
2055+
false
2056+
val commonFormal =
2057+
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
2058+
else defn.FunctionOf(commonParamTypes, WildcardType)
2059+
overload.println(i"pretype arg $arg with expected type $commonFormal")
2060+
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom)))
2061+
withMode(Mode.ImplicitsEnabled)(pt.typedArg(arg, commonFormal))
20632062
}
20642063
recur(altFormals.map(_.tail), args1)
20652064
case _ =>

tests/pos/i10325.scala

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
object Test {
2+
def nullToNone[K, V](tuple: (K, V)): (K, Option[V]) = {
3+
val (k, v) = tuple
4+
(k, Option(v))
5+
}
6+
7+
def test: Unit = {
8+
val scalaMap: Map[String, String] = Map()
9+
10+
val a = scalaMap.map(nullToNone)
11+
val a1: Map[String, Option[String]] = a
12+
13+
val b = scalaMap.map(nullToNone(_))
14+
val b1: Map[String, Option[String]] = b
15+
16+
val c = scalaMap.map(x => nullToNone(x))
17+
val c1: Map[String, Option[String]] = c
18+
}
19+
}

0 commit comments

Comments
 (0)