Skip to content

Commit 8617765

Browse files
committed
wip: constrain refinements to type parameters
1 parent b636633 commit 8617765

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+19-3
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,25 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
370370
(mirroredType, elems)
371371

372372
val mirrorType =
373-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
374-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
375-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
373+
def constrainElemLabels(tpe: Type)(using Context): Type =
374+
def strip(tp: Type): Type = tp match
375+
case RefinedType(parent, name, info) if name == tpnme.MirroredElemLabels => info
376+
case RefinedType(parent, _, _) => strip(parent)
377+
case tp => NoType
378+
val labels = TypeOps.nestedPairs(elemLabels)
379+
val formalLabels = strip(formal)
380+
formalLabels match
381+
case TypeBounds(lo: TypeVar, hi: TypeVar) =>
382+
if labels <:< lo then // TODO: not sandboxed
383+
lo.instantiateWith(labels)
384+
if lo ne hi then
385+
if labels <:< hi then // TODO: not sandboxed
386+
hi.instantiateWith(labels)
387+
case _ =>
388+
tpe.refinedWith(tpnme.MirroredElemLabels, TypeAlias(labels))
389+
constrainElemLabels(
390+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
391+
).refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
376392
val mirrorRef =
377393
if useCompanion then companionPath(mirroredType, span)
378394
else anonymousMirror(monoType, ExtendsSumMirror, span)

tests/run/i14150.scala

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import scala.deriving.Mirror
2+
import scala.util.NotGiven
3+
import scala.compiletime.constValue
4+
5+
trait GetConstValue[T] {
6+
type Out
7+
def get : Out
8+
}
9+
10+
object GetConstValue {
11+
type Aux[T, O] = GetConstValue[T] { type Out = O }
12+
13+
inline given value[T <: Singleton](
14+
using
15+
ev : NotGiven[T <:< Tuple],
16+
) : GetConstValue.Aux[T, T] = {
17+
val out = constValue[T]
18+
19+
new GetConstValue[T] {
20+
type Out = T
21+
def get : Out = out
22+
}
23+
}
24+
25+
given empty : GetConstValue[EmptyTuple] with {
26+
type Out = EmptyTuple
27+
def get : Out = EmptyTuple
28+
}
29+
30+
given nonEmpty[H, HRes, Tail <: Tuple, TRes <: Tuple](
31+
using
32+
head : GetConstValue.Aux[H, HRes],
33+
tail : GetConstValue.Aux[Tail, TRes],
34+
) : GetConstValue[H *: Tail] with {
35+
type Out = HRes *: TRes
36+
37+
def get : Out = head.get *: tail.get
38+
}
39+
}
40+
41+
trait MirrorNamesDeriver[T] {
42+
type Derived <: Tuple
43+
def derive : Derived
44+
}
45+
46+
object MirrorNamesDeriver {
47+
given mirDeriver[T, ElemLabels <: NonEmptyTuple](
48+
using
49+
mir: Mirror.SumOf[T] { type MirroredElemLabels = ElemLabels },
50+
ev : GetConstValue.Aux[ElemLabels, ElemLabels],
51+
): MirrorNamesDeriver[T] with {
52+
type Derived = ElemLabels
53+
54+
def derive: ElemLabels = ev.get
55+
}
56+
57+
def derive[T](using d : MirrorNamesDeriver[T]) : d.Derived = d.derive
58+
}
59+
60+
sealed trait SuperT
61+
final case class SubT1(int: Int) extends SuperT
62+
final case class SubT2(str: String, dbl : Double, bool : Boolean) extends SuperT
63+
64+
@main def Test =
65+
66+
// Works when type parameters are set explicitly
67+
val successfulLabels = MirrorNamesDeriver.mirDeriver[SuperT, ("SubT1", "SubT2")].derive
68+
assert(successfulLabels == ("SubT1", "SubT2"))
69+
70+
// Fails when type parameters are inferred
71+
val failedLabels = MirrorNamesDeriver.derive[SuperT]
72+
assert(failedLabels == ("SubT1", "SubT2"))

0 commit comments

Comments
 (0)