Skip to content

Commit cedef0d

Browse files
authored
Merge pull request #5214 from dotty-staging/fix-#5202
Fix #5202: Refine argForParam
2 parents e7967c2 + 948c085 commit cedef0d

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

+30-12
Original file line numberDiff line numberDiff line change
@@ -1988,7 +1988,8 @@ object Types {
19881988
}
19891989

19901990
/** The argument corresponding to class type parameter `tparam` as seen from
1991-
* prefix `pre`.
1991+
* prefix `pre`. Can produce a TypeBounds type in case prefix is an & or | type
1992+
* and parameter is non-variant.
19921993
*/
19931994
def argForParam(pre: Type)(implicit ctx: Context): Type = {
19941995
val tparam = symbol
@@ -2010,8 +2011,16 @@ object Types {
20102011
idx += 1
20112012
}
20122013
NoType
2013-
case OrType(base1, base2) => argForParam(base1) | argForParam(base2)
2014-
case AndType(base1, base2) => argForParam(base1) & argForParam(base2)
2014+
case base: AndOrType =>
2015+
var tp1 = argForParam(base.tp1)
2016+
var tp2 = argForParam(base.tp2)
2017+
val variance = tparam.paramVariance
2018+
if (tp1.isInstanceOf[TypeBounds] || tp2.isInstanceOf[TypeBounds] || variance == 0) {
2019+
// compute argument as a type bounds instead of a point type
2020+
tp1 = tp1.bounds
2021+
tp2 = tp2.bounds
2022+
}
2023+
if (base.isAnd == variance >= 0) tp1 & tp2 else tp1 | tp2
20152024
case _ =>
20162025
if (pre.termSymbol is Package) argForParam(pre.select(nme.PACKAGE))
20172026
else if (pre.isBottomType) pre
@@ -4414,6 +4423,7 @@ object Types {
44144423
protected def range(lo: Type, hi: Type): Type =
44154424
if (variance > 0) hi
44164425
else if (variance < 0) lo
4426+
else if (lo `eq` hi) lo
44174427
else Range(lower(lo), upper(hi))
44184428

44194429
protected def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]
@@ -4462,13 +4472,18 @@ object Types {
44624472
* If the expansion is a wildcard parameter reference, convert its
44634473
* underlying bounds to a range, otherwise return the expansion.
44644474
*/
4465-
def expandParam(tp: NamedType, pre: Type): Type = tp.argForParam(pre) match {
4466-
case arg @ TypeRef(pre, _) if pre.isArgPrefixOf(arg.symbol) =>
4467-
arg.info match {
4468-
case TypeBounds(lo, hi) => range(atVariance(-variance)(reapply(lo)), reapply(hi))
4469-
case arg => reapply(arg)
4470-
}
4471-
case arg => reapply(arg)
4475+
def expandParam(tp: NamedType, pre: Type): Type = {
4476+
def expandBounds(tp: TypeBounds) =
4477+
range(atVariance(-variance)(reapply(tp.lo)), reapply(tp.hi))
4478+
tp.argForParam(pre) match {
4479+
case arg @ TypeRef(pre, _) if pre.isArgPrefixOf(arg.symbol) =>
4480+
arg.info match {
4481+
case argInfo: TypeBounds => expandBounds(argInfo)
4482+
case argInfo => reapply(arg)
4483+
}
4484+
case arg: TypeBounds => expandBounds(arg)
4485+
case arg => reapply(arg)
4486+
}
44724487
}
44734488

44744489
/** Derived selection.
@@ -4482,9 +4497,12 @@ object Types {
44824497
if (tp.symbol.is(ClassTypeParam)) expandParam(tp, preHi)
44834498
else tryWiden(tp, preHi)
44844499
forwarded.orElse(
4485-
range(super.derivedSelect(tp, preLo), super.derivedSelect(tp, preHi)))
4500+
range(super.derivedSelect(tp, preLo).loBound, super.derivedSelect(tp, preHi).hiBound))
44864501
case _ =>
4487-
super.derivedSelect(tp, pre)
4502+
super.derivedSelect(tp, pre) match {
4503+
case TypeBounds(lo, hi) => range(lo, hi)
4504+
case tp => tp
4505+
}
44884506
}
44894507

44904508
override protected def derivedRefinedType(tp: RefinedType, parent: Type, info: Type): Type =

tests/neg/i5202.scala

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
object Test {
2+
val f: (Int => Int) | (String => Int) = (a: Int) => a + 3
3+
4+
f.apply(5) // error - found: Int expected: Int & String
5+
f("c") // error - found: String expected: Int & String
6+
}
7+
8+
class Foo[A] {
9+
def foo(a: A): Unit = {}
10+
}
11+
class Co[+A] {
12+
def foo(a: A): Unit = {} // error: contravariant occurs in covariant position
13+
def bar: A = ???
14+
}
15+
class Contra[-A] {
16+
def foo(a: A): Unit = {}
17+
}
18+
19+
object Test2 {
20+
def main(args: Array[String]): Unit = {
21+
val x: Foo[Int] | Foo[String] = new Foo[Int]
22+
x.foo("") // error, found: String, required: Int & String
23+
val y: Contra[Int] | Contra[String] = new Contra[Int]
24+
y.foo("") // error, found: String, required: Int & String
25+
val z: Co[Int] | Co[String] = new Co[Int]
26+
z.foo("") // OK
27+
val s: String = z.bar // error: found Int | String, required: String
28+
}
29+
}

0 commit comments

Comments
 (0)