Skip to content

Commit d5efc05

Browse files
authored
Merge pull request #9523 from dotty-staging/partial-hk1
Improve partial unification handling
2 parents 87cb1d8 + 64e8797 commit d5efc05

File tree

10 files changed

+211
-56
lines changed

10 files changed

+211
-56
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 118 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,85 @@ class TypeComparer(using val comparerCtx: Context) extends ConstraintHandling wi
849849
case _ =>
850850
isSubType(pre1, pre2)
851851

852+
/** Compare `tycon[args]` with `other := otherTycon[otherArgs]`, via `>:>` if fromBelow is true, `<:<` otherwise
853+
* (we call this relationship `~:~` in the rest of this comment).
854+
*
855+
* This method works by:
856+
*
857+
* 1. Choosing an appropriate type constructor `adaptedTycon`
858+
* 2. Constraining `tycon` such that `tycon ~:~ adaptedTycon`
859+
* 3. Recursing on `adaptedTycon[args] ~:~ other`
860+
*
861+
* So, how do we pick `adaptedTycon`? When `args` and `otherArgs` have the
862+
* same length the answer is simply:
863+
*
864+
* adaptedTycon := otherTycon
865+
*
866+
* But we also handle having `args.length < otherArgs.length`, in which
867+
* case we need to make up a type constructor of the right kind. For
868+
* example, if `fromBelow = false` and we're comparing:
869+
*
870+
* ?F[A] <:< Either[String, B] where `?F <: [X] =>> Any`
871+
*
872+
* we will choose:
873+
*
874+
* adaptedTycon := [X] =>> Either[String, X]
875+
*
876+
* this allows us to constrain:
877+
*
878+
* ?F <: adaptedTycon
879+
*
880+
* and then recurse on:
881+
*
882+
* adaptedTycon[A] <:< Either[String, B]
883+
*
884+
* In general, given:
885+
*
886+
* - k := args.length
887+
* - d := otherArgs.length - k
888+
*
889+
* `adaptedTycon` will be:
890+
*
891+
* [T_0, ..., T_k-1] =>> otherTycon[otherArgs(0), ..., otherArgs(d-1), T_0, ..., T_k-1]
892+
*
893+
* where `T_n` has the same bounds as `otherTycon.typeParams(d+n)`
894+
*
895+
* Historical note: this strategy is known in Scala as "partial unification"
896+
* (even though the type constructor variable isn't actually unified but only
897+
* has one of its bounds constrained), for background see:
898+
* - The infamous SI-2712: https://github.com/scala/bug/issues/2712
899+
* - The PR against Scala 2.12 implementing -Ypartial-unification: https://github.com/scala/scala/pull/5102
900+
* - Some explanations on how this impacts API design: https://gist.github.com/djspiewak/7a81a395c461fd3a09a6941d4cd040f2
901+
*/
902+
def compareAppliedTypeParamRef(tycon: TypeParamRef, args: List[Type], other: AppliedType, fromBelow: Boolean): Boolean =
903+
def directionalIsSubType(tp1: Type, tp2: Type): Boolean =
904+
if fromBelow then isSubType(tp2, tp1) else isSubType(tp1, tp2)
905+
def directionalRecur(tp1: Type, tp2: Type): Boolean =
906+
if fromBelow then recur(tp2, tp1) else recur(tp1, tp2)
907+
908+
val otherTycon = other.tycon
909+
val otherArgs = other.args
910+
911+
val d = otherArgs.length - args.length
912+
d >= 0 && {
913+
val tparams = tycon.typeParams
914+
val remainingTparams = otherTycon.typeParams.drop(d)
915+
variancesConform(remainingTparams, tparams) && {
916+
val adaptedTycon =
917+
if d > 0 then
918+
HKTypeLambda(remainingTparams.map(_.paramName))(
919+
tl => remainingTparams.map(remainingTparam =>
920+
tl.integrate(remainingTparams, remainingTparam.paramInfo).bounds),
921+
tl => otherTycon.appliedTo(
922+
otherArgs.take(d) ++ tl.paramRefs))
923+
else
924+
otherTycon
925+
(assumedTrue(tycon) || directionalIsSubType(tycon, adaptedTycon.ensureLambdaSub)) &&
926+
directionalRecur(adaptedTycon.appliedTo(args), other)
927+
}
928+
}
929+
end compareAppliedTypeParamRef
930+
852931
/** Subtype test for the hk application `tp2 = tycon2[args2]`.
853932
*/
854933
def compareAppliedType2(tp2: AppliedType, tycon2: Type, args2: List[Type]): Boolean = {
@@ -860,13 +939,35 @@ class TypeComparer(using val comparerCtx: Context) extends ConstraintHandling wi
860939
*/
861940
def isMatchingApply(tp1: Type): Boolean = tp1 match {
862941
case AppliedType(tycon1, args1) =>
863-
def loop(tycon1: Type, args1: List[Type]): Boolean = tycon1.dealiasKeepRefiningAnnots match {
942+
// We intentionally do not dealias `tycon1` or `tycon2` here.
943+
// `TypeApplications#appliedTo` already takes care of dealiasing type
944+
// constructors when this can be done without affecting type
945+
// inference, doing it here would not only prevent code from compiling
946+
// but could also result in the wrong thing being inferred later, for example
947+
// in `tests/run/hk-alias-unification.scala` we end up checking:
948+
//
949+
// Foo[?F, ?T] <:< Foo[[X] =>> (X, String), Int]
950+
//
951+
// Naturally, we'd like to infer:
952+
//
953+
// ?F := [X] => (X, String)
954+
//
955+
// but if we dealias `Foo` then we'll end up trying to check:
956+
//
957+
// ErasedFoo[?F[?T]] <:< ErasedFoo[(Int, String)]
958+
//
959+
// Because of partial unification, this will succeed, but will produce the constraint:
960+
//
961+
// ?F := [X] =>> (Int, X)
962+
//
963+
// Which is not what we wanted!
964+
def loop(tycon1: Type, args1: List[Type]): Boolean = tycon1 match {
864965
case tycon1: TypeParamRef =>
865966
(tycon1 == tycon2 ||
866967
canConstrain(tycon1) && isSubType(tycon1, tycon2)) &&
867968
isSubArgs(args1, args2, tp1, tparams)
868969
case tycon1: TypeRef =>
869-
tycon2.dealiasKeepRefiningAnnots match {
970+
tycon2 match {
870971
case tycon2: TypeRef =>
871972
val tycon1sym = tycon1.symbol
872973
val tycon2sym = tycon2.symbol
@@ -926,60 +1027,26 @@ class TypeComparer(using val comparerCtx: Context) extends ConstraintHandling wi
9261027

9271028
/** `param2` can be instantiated to a type application prefix of the LHS
9281029
* or to a type application prefix of one of the LHS base class instances
929-
* and the resulting type application is a supertype of `tp1`,
930-
* or fallback to fourthTry.
1030+
* and the resulting type application is a supertype of `tp1`.
9311031
*/
9321032
def canInstantiate(tycon2: TypeParamRef): Boolean = {
933-
934-
/** Let
935-
*
936-
* `tparams_1, ..., tparams_k-1` be the type parameters of the rhs
937-
* `tparams1_1, ..., tparams1_n-1` be the type parameters of the constructor of the lhs
938-
* `args1_1, ..., args1_n-1` be the type arguments of the lhs
939-
* `d = n - k`
940-
*
941-
* Returns `true` iff `d >= 0` and `tycon2` can be instantiated to
942-
*
943-
* [tparams1_d, ... tparams1_n-1] -> tycon1[args_1, ..., args_d-1, tparams_d, ... tparams_n-1]
944-
*
945-
* such that the resulting type application is a supertype of `tp1`.
946-
*/
9471033
def appOK(tp1base: Type) = tp1base match {
9481034
case tp1base: AppliedType =>
949-
var tycon1 = tp1base.tycon
950-
val args1 = tp1base.args
951-
val tparams1all = tycon1.typeParams
952-
val lengthDiff = tparams1all.length - tparams.length
953-
lengthDiff >= 0 && {
954-
val tparams1 = tparams1all.drop(lengthDiff)
955-
variancesConform(tparams1, tparams) && {
956-
if (lengthDiff > 0)
957-
tycon1 = HKTypeLambda(tparams1.map(_.paramName))(
958-
tl => tparams1.map(tparam => tl.integrate(tparams, tparam.paramInfo).bounds),
959-
tl => tp1base.tycon.appliedTo(args1.take(lengthDiff) ++
960-
tparams1.indices.toList.map(tl.paramRefs(_))))
961-
(assumedTrue(tycon2) || isSubType(tycon1.ensureLambdaSub, tycon2)) &&
962-
recur(tp1, tycon1.appliedTo(args2))
963-
}
964-
}
1035+
compareAppliedTypeParamRef(tycon2, args2, tp1base, fromBelow = true)
9651036
case _ => false
9661037
}
9671038

968-
tp1.widen match {
969-
case tp1w: AppliedType => appOK(tp1w)
970-
case tp1w =>
971-
tp1w.typeSymbol.isClass && {
972-
val classBounds = tycon2.classSymbols
973-
def liftToBase(bcs: List[ClassSymbol]): Boolean = bcs match {
974-
case bc :: bcs1 =>
975-
classBounds.exists(bc.derivesFrom) && appOK(nonExprBaseType(tp1, bc))
976-
|| liftToBase(bcs1)
977-
case _ =>
978-
false
979-
}
980-
liftToBase(tp1w.baseClasses)
981-
} ||
982-
fourthTry
1039+
val tp1w = tp1.widen
1040+
appOK(tp1w) || tp1w.typeSymbol.isClass && {
1041+
val classBounds = tycon2.classSymbols
1042+
def liftToBase(bcs: List[ClassSymbol]): Boolean = bcs match {
1043+
case bc :: bcs1 =>
1044+
classBounds.exists(bc.derivesFrom) && appOK(nonExprBaseType(tp1, bc))
1045+
|| liftToBase(bcs1)
1046+
case _ =>
1047+
false
1048+
}
1049+
liftToBase(tp1w.baseClasses)
9831050
}
9841051
}
9851052

@@ -1043,8 +1110,8 @@ class TypeComparer(using val comparerCtx: Context) extends ConstraintHandling wi
10431110
tycon1 match {
10441111
case param1: TypeParamRef =>
10451112
def canInstantiate = tp2 match {
1046-
case AppliedType(tycon2, args2) =>
1047-
isSubType(param1, tycon2.ensureLambdaSub) && isSubArgs(args1, args2, tp1, tycon2.typeParams)
1113+
case tp2base: AppliedType =>
1114+
compareAppliedTypeParamRef(param1, args1, tp2base, fromBelow = false)
10481115
case _ =>
10491116
false
10501117
}

compiler/test/dotty/tools/vulpix/ParallelTesting.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import scala.io.Source
1515
import scala.util.{Random, Try, Failure => TryFailure, Success => TrySuccess, Using}
1616
import scala.util.control.NonFatal
1717
import scala.util.matching.Regex
18+
import scala.collection.mutable.ListBuffer
1819

1920
import dotc.{Compiler, Driver}
2021
import dotc.core.Contexts._
@@ -543,11 +544,18 @@ trait ParallelTesting extends RunnerOrchestration { self =>
543544
}
544545

545546
pool.shutdown()
547+
546548
if (!pool.awaitTermination(20, TimeUnit.MINUTES)) {
549+
val remaining = new ListBuffer[TestSource]
550+
filteredSources.lazyZip(eventualResults).foreach { (src, res) =>
551+
if (!res.isDone)
552+
remaining += src
553+
}
554+
547555
pool.shutdownNow()
548556
System.setOut(realStdout)
549557
System.setErr(realStderr)
550-
throw new TimeoutException("Compiling targets timed out")
558+
throw new TimeoutException(s"Compiling targets timed out, remaining targets: ${remaining.mkString(", ")}")
551559
}
552560

553561
eventualResults.foreach { x =>

tests/neg/i3452.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,22 @@ object Test {
66
implicit def case1[F[_]](implicit t: => TC[F[Any]]): TC[Tuple2K[[_] =>> Any, F, Any]] = ???
77
implicit def case2[A, F[_]](implicit r: TC[F[Any]]): TC[A] = ???
88

9+
// Disabled because it leads to an infinite loop in implicit search
10+
// this is probably the same issue as https://github.com/lampepfl/dotty/issues/9568
11+
// implicitly[TC[Int]] // was: error
12+
}
13+
14+
object Test1 {
15+
case class Tuple2K[H[_], T[_], X](h: H[X], t: T[X])
16+
17+
trait TC[A]
18+
19+
implicit def case1[F[_]](implicit t: TC[F[Any]]): TC[Tuple2K[[_] =>> Any, F, Any]] = ???
20+
implicit def case2[A, F[_]](implicit r: TC[F[Any]]): TC[A] = ???
21+
922
implicitly[TC[Int]] // error
1023
}
24+
1125
object Test2 {
1226
trait TC[A]
1327

tests/neg/t2712-8.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Test extends App {
2+
class L[A]
3+
class Quux0[B, CC[_]]
4+
class Quux[C] extends Quux0[C, L]
5+
6+
def foo[D[_]](x: D[D[Boolean]]) = ???
7+
def bar: Quux[Int] = ???
8+
9+
foo(bar) // error: Found: Test.Quux[Int] Required: D[D[Boolean]]
10+
}

tests/pos/anykind.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ object Test {
5656
object Kinder extends KinderLowerImplicits {
5757
type Aux[MA, M0 <: AnyKind, Args0 <: HList] = Kinder[MA] { type M = M0; type Args = Args0 }
5858

59-
implicit def kinder2[M0[_, _], A0, B0]: Kinder.Aux[M0[A0, B0], M0, A0 :: B0 :: HNil] = new Kinder[M0[A0, B0]] { type M[t, u] = M0[t, u]; type Args = A0 :: B0 :: HNil }
6059
implicit def kinder1[M0[_], A0]: Kinder.Aux[M0[A0], M0, A0 :: HNil] = new Kinder[M0[A0]] { type M[t] = M0[t]; type Args = A0 :: HNil }
6160
}
6261

6362
trait KinderLowerImplicits {
63+
implicit def kinder2[M0[_, _], A0, B0]: Kinder.Aux[M0[A0, B0], M0, A0 :: B0 :: HNil] = new Kinder[M0[A0, B0]] { type M[t, u] = M0[t, u]; type Args = A0 :: B0 :: HNil }
6464
implicit def kinder0[A]: Kinder.Aux[A, A, HNil] = new Kinder[A] { type M = A; type Args = HNil }
6565
}
6666

tests/pos/i6565.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ extension [O, U](o: Lifted[O]) def flatMap(f: O => Lifted[U]): Lifted[U] = ???
88

99
val error: Err = Err()
1010

11-
lazy val ok: Lifted[String] = { // ok despite map returning a union
12-
point("a").map(_ => if true then "foo" else error) // ok
11+
lazy val ok: Lifted[String] = {
12+
point("a").flatMap(_ => if true then "foo" else error)
1313
}
1414

1515
lazy val nowAlsoOK: Lifted[String] = {
16-
point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error))
16+
point("a").flatMap(_ => point("b").flatMap(_ => if true then "foo" else error))
1717
}

tests/pos/i9478.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class Foo[T[_, _], F[_], A, B](val fa: T[F[A], F[B]])
2+
3+
object Test {
4+
def x[T[_, _]](tmab: T[Either[Int, String], Either[Int, Int]]) =
5+
new Foo(tmab)
6+
}

tests/pos/t2712-2b.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package test
2+
3+
class X1
4+
class X2
5+
class X3
6+
7+
trait One[A]
8+
trait Two[A, B]
9+
10+
class Foo extends Two[X1, X2] with One[X3]
11+
object Test {
12+
def test1[M[_], A](x: M[A]): M[A] = x
13+
14+
val foo = new Foo
15+
16+
test1(foo): One[X3] // fails in Scala 2 with partial unification enabled, works in Dotty
17+
test1(foo): Two[X1, X2]
18+
}

tests/pos/t2712-8.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class Two[A, B]
2+
class One[A] extends Two[A, A]
3+
4+
object Test {
5+
def foo[F[_, _]](x: F[Int, Int]) = x
6+
7+
val t: One[Int] = ???
8+
foo(t)
9+
}

tests/run/hk-alias-unification.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
trait Bla[T]
2+
object Bla {
3+
implicit def blaInt: Bla[Int] = new Bla[Int] {}
4+
implicit def blaString: Bla[String] = new Bla[String] {
5+
assert(false, "I should not be summoned!")
6+
}
7+
}
8+
9+
trait ErasedFoo[FT]
10+
object Test {
11+
type Foo[F[_], T] = ErasedFoo[F[T]]
12+
type Foo2[F[_], T] = Foo[F, T]
13+
14+
def mkFoo[F[_], T](implicit gen: Bla[T]): Foo[F, T] = new Foo[F, T] {}
15+
def mkFoo2[F[_], T](implicit gen: Bla[T]): Foo2[F, T] = new Foo2[F, T] {}
16+
17+
def main(args: Array[String]): Unit = {
18+
val a: Foo[[X] =>> (X, String), Int] = mkFoo
19+
val b: Foo2[[X] =>> (X, String), Int] = mkFoo
20+
val c: Foo[[X] =>> (X, String), Int] = mkFoo2
21+
}
22+
}
23+

0 commit comments

Comments
 (0)