Skip to content

Commit 336d7f8

Browse files
committed
wip: constrain refinements to type parameters
1 parent b636633 commit 336d7f8

File tree

2 files changed

+81
-3
lines changed

2 files changed

+81
-3
lines changed

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+9-3
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,15 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
370370
(mirroredType, elems)
371371

372372
val mirrorType =
373-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
374-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
375-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
373+
def constrainElemLabels(tpe: Type)(using Context): Type =
374+
def strip(tp: Type): Type = tp match
375+
case RefinedType(parent, name, info) if name == tpnme.MirroredElemLabels => parent
376+
case tp @ RefinedType(parent, name, info) => tp.derivedRefinedType(strip(parent), name, info)
377+
case tp => tp
378+
val labels = TypeOps.nestedPairs(elemLabels)
379+
strip(formal).refinedWith(tpnme.MirroredElemLabels, TypeAlias(labels))
380+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, constrainElemLabels(formal))
381+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
376382
val mirrorRef =
377383
if useCompanion then companionPath(mirroredType, span)
378384
else anonymousMirror(monoType, ExtendsSumMirror, span)

tests/run/i14150.scala

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import scala.deriving.Mirror
2+
import scala.util.NotGiven
3+
import scala.compiletime.constValue
4+
5+
trait GetConstValue[T] {
6+
type Out
7+
def get : Out
8+
}
9+
10+
object GetConstValue {
11+
type Aux[T, O] = GetConstValue[T] { type Out = O }
12+
13+
inline given value[T <: Singleton](
14+
using
15+
ev : NotGiven[T <:< Tuple],
16+
) : GetConstValue.Aux[T, T] = {
17+
val out = constValue[T]
18+
19+
new GetConstValue[T] {
20+
type Out = T
21+
def get : Out = out
22+
}
23+
}
24+
25+
given empty : GetConstValue[EmptyTuple] with {
26+
type Out = EmptyTuple
27+
def get : Out = EmptyTuple
28+
}
29+
30+
given nonEmpty[H, HRes, Tail <: Tuple, TRes <: Tuple](
31+
using
32+
head : GetConstValue.Aux[H, HRes],
33+
tail : GetConstValue.Aux[Tail, TRes],
34+
) : GetConstValue[H *: Tail] with {
35+
type Out = HRes *: TRes
36+
37+
def get : Out = head.get *: tail.get
38+
}
39+
}
40+
41+
trait MirrorNamesDeriver[T] {
42+
type Derived <: Tuple
43+
def derive : Derived
44+
}
45+
46+
object MirrorNamesDeriver {
47+
given mirDeriver[T, ElemLabels <: NonEmptyTuple](
48+
using
49+
mir: Mirror.SumOf[T] { type MirroredElemLabels = ElemLabels },
50+
ev : GetConstValue.Aux[ElemLabels, ElemLabels],
51+
): MirrorNamesDeriver[T] with {
52+
type Derived = ElemLabels
53+
54+
def derive: ElemLabels = ev.get
55+
}
56+
57+
def derive[T](using d : MirrorNamesDeriver[T]) : d.Derived = d.derive
58+
}
59+
60+
sealed trait SuperT
61+
final case class SubT1(int: Int) extends SuperT
62+
final case class SubT2(str: String, dbl : Double, bool : Boolean) extends SuperT
63+
64+
@main def Test =
65+
66+
// Works when type parameters are set explicitly
67+
val successfulLabels = MirrorNamesDeriver.mirDeriver[SuperT, ("SubT1", "SubT2")].derive
68+
assert(successfulLabels == ("SubT1", "SubT2"))
69+
70+
// Fails when type parameters are inferred
71+
val failedLabels = MirrorNamesDeriver.derive[SuperT]
72+
assert(failedLabels == ("SubT1", "SubT2"))

0 commit comments

Comments
 (0)