@@ -17,6 +17,8 @@ import transform.SyntheticMembers._
1717import util .Property
1818import annotation .{tailrec , constructorOnly }
1919
20+ import scala .collection .mutable
21+
2022/** Synthesize terms for special classes */
2123class Synthesizer (typer : Typer )(using @ constructorOnly c : Context ):
2224 import ast .tpd ._
@@ -337,7 +339,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
337339 if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
338340 val elemLabels = cls.children.map(c => ConstantType (Constant (c.name.toString)))
339341
340- def solve (sym : Symbol ): Type = sym match
342+ def solve (target : Type )( sym : Symbol ): Type = sym match
341343 case childClass : ClassSymbol =>
342344 assert(childClass.isOneOf(Case | Sealed ))
343345 if childClass.is(Module ) then
@@ -348,36 +350,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
348350 // Compute the the full child type by solving the subtype constraint
349351 // `C[X1, ..., Xn] <: P`, where
350352 //
351- // - P is the current `mirroredType `
353+ // - P is the current `targetPart `
352354 // - C is the child class, with type parameters X1, ..., Xn
353355 //
354356 // Contravariant type parameters are minimized, all other type parameters are maximized.
355- def instantiate (using Context ) =
356- val poly = constrained(info, untpd. EmptyTree )._1
357+ def instantiate (targetPart : Type )( using Context ) =
358+ val poly = constrained(info)
357359 val resType = poly.finalResultType
358- val target = mirroredType match
359- case tp : HKTypeLambda => tp.resultType
360- case tp => tp
361- resType <:< target
360+ resType <:< targetPart // record constraints
362361 val tparams = poly.paramRefs
363362 val variances = childClass.typeParams.map(_.paramVarianceSign)
364363 val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
365364 TypeComparer .instanceType(tparam, fromBelow = variance < 0 ))
366365 resType.substParams(poly, instanceTypes)
367- instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
366+
367+ def instantiateAll (using Context ): Type =
368+
369+ // instantiate for each part of a union type, compute lub of the results
370+ def loop (explore : List [Type ], acc : mutable.ListBuffer [Type ]): Type = explore match
371+ case OrType (tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
372+ case tp :: rest => loop(rest , acc += instantiate(tp))
373+ case _ => TypeComparer .lub(acc.toList)
374+
375+ def instantiateLub (tp1 : Type , tp2 : Type ): Type =
376+ loop(tp1 :: tp2 :: Nil , new mutable.ListBuffer [Type ])
377+
378+ target match
379+ case OrType (tp1, tp2) => instantiateLub(tp1, tp2)
380+ case _ => instantiate(target)
381+
382+ instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
368383 case _ =>
369384 childClass.typeRef
370385 case child => child.termRef
371386 end solve
372387
373388 val (monoType, elemsType) = mirroredType match
374389 case mirroredType : HKTypeLambda =>
390+ val target = mirroredType.resultType
375391 val elems = mirroredType.derivedLambdaType(
376- resType = TypeOps .nestedPairs(cls.children.map(solve))
392+ resType = TypeOps .nestedPairs(cls.children.map(solve(target) ))
377393 )
378394 (mkMirroredMonoType(mirroredType), elems)
379- case _ =>
380- val elems = TypeOps .nestedPairs(cls.children.map(solve))
395+ case target =>
396+ val elems = TypeOps .nestedPairs(cls.children.map(solve(target) ))
381397 (mirroredType, elems)
382398
383399 val mirrorType =
0 commit comments