Skip to content

Commit 11eac65

Browse files
committed
reorganise construction of refined mirror types,
record constraints against the formal.
1 parent bb29e20 commit 11eac65

File tree

4 files changed

+136
-23
lines changed

4 files changed

+136
-23
lines changed

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+32-17
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2626
/** Handlers to synthesize implicits for special types */
2727
type SpecialHandler = (Type, Span) => Context ?=> TreeWithErrors
2828
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]
29-
29+
3030
val synthesizedClassTag: SpecialHandler = (formal, span) =>
3131
formal.argInfos match
3232
case arg :: Nil =>
@@ -240,11 +240,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
240240
* <parent> {
241241
* MirroredMonoType = <monoType>
242242
* MirroredType = <mirroredType>
243-
* MirroredLabel = <label> }
243+
* MirroredLabel = <label>
244244
* }
245245
*/
246-
private def mirrorCore(parentClass: ClassSymbol, monoType: Type, mirroredType: Type, label: Name, formal: Type)(using Context) =
247-
formal & parentClass.typeRef
246+
private def mirrorCore(parentClass: ClassSymbol, monoType: Type, mirroredType: Type, label: Name)(using Context) =
247+
parentClass.typeRef
248248
.refinedWith(tpnme.MirroredMonoType, TypeAlias(monoType))
249249
.refinedWith(tpnme.MirroredType, TypeAlias(mirroredType))
250250
.refinedWith(tpnme.MirroredLabel, TypeAlias(ConstantType(Constant(label.toString))))
@@ -269,6 +269,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
269269
then report.error(
270270
em"$name mismatch, expected: $expected, found: $actual.", ctx.source.atSpan(span))
271271

272+
extension (formal: Type)
273+
/** `tp := op; tp <:< formal; formal & tp` */
274+
private def constrained_&(op: Context ?=> Type)(using Context): Type =
275+
val tp = op
276+
tp <:< formal
277+
formal & tp
278+
272279
private def mkMirroredMonoType(mirroredType: HKTypeLambda)(using Context): Type =
273280
val monoMap = new TypeMap:
274281
def apply(t: Type) = t match
@@ -313,22 +320,23 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
313320
val elemsLabels = TypeOps.nestedPairs(elemLabels)
314321
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
315322
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
316-
val mirrorType =
317-
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
323+
val mirrorType = formal.constrained_& {
324+
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name)
318325
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
319326
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
327+
}
320328
val mirrorRef =
321329
if (genAnonyousMirror(cls)) anonymousMirror(monoType, ExtendsProductMirror, span)
322330
else companionPath(mirroredType, span)
323331
withNoErrors(mirrorRef.cast(mirrorType))
324332
end makeProductMirror
325333

326-
def getError(cls: Symbol): String =
334+
def getError(cls: Symbol): String =
327335
val reason = if !cls.isGenericProduct then
328336
i"because ${cls.whyNotGenericProduct}"
329337
else if !canAccessCtor(cls) then
330338
i"because the constructor of $cls is innaccessible from the calling scope."
331-
else
339+
else
332340
""
333341
i"$cls is not a generic product $reason"
334342
end getError
@@ -341,11 +349,15 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
341349
val module = mirroredType.termSymbol
342350
val modulePath = pathFor(mirroredType).withSpan(span)
343351
if module.info.classSymbol.is(Scala2x) then
344-
val mirrorType = mirrorCore(defn.Mirror_SingletonProxyClass, mirroredType, mirroredType, module.name, formal)
352+
val mirrorType = formal.constrained_& {
353+
mirrorCore(defn.Mirror_SingletonProxyClass, mirroredType, mirroredType, module.name)
354+
}
345355
val mirrorRef = New(defn.Mirror_SingletonProxyClass.typeRef, modulePath :: Nil)
346356
withNoErrors(mirrorRef.cast(mirrorType))
347357
else
348-
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name, formal)
358+
val mirrorType = formal.constrained_& {
359+
mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name)
360+
}
349361
withNoErrors(modulePath.cast(mirrorType))
350362
else
351363
val cls = mirroredType.classSymbol
@@ -419,16 +431,19 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
419431
(mirroredType, elems)
420432

421433
val mirrorType =
422-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
434+
val labels = TypeOps.nestedPairs(elemLabels)
435+
formal.constrained_& {
436+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name)
423437
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
424-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
438+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(labels))
439+
}
425440
val mirrorRef =
426441
if useCompanion then companionPath(mirroredType, span)
427442
else anonymousMirror(monoType, ExtendsSumMirror, span)
428443
withNoErrors(mirrorRef.cast(mirrorType))
429444
else if !clsIsGenericSum then
430445
(EmptyTree, List(i"$cls is not a generic sum because ${cls.whyNotGenericSum(declScope)}"))
431-
else
446+
else
432447
EmptyTreeNoError
433448
end sumMirror
434449

@@ -595,7 +610,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
595610
tp.baseType(cls)
596611
val base = baseWithRefinements(formal)
597612
val result =
598-
if (base <:< formal.widenExpr)
613+
if (base <:< formal.widenExpr)
599614
// With the subtype test we enforce that the searched type `formal` is of the right form
600615
handler(base, span)
601616
else EmptyTreeNoError
@@ -609,19 +624,19 @@ end Synthesizer
609624

610625
object Synthesizer:
611626

612-
/** Tuple used to store the synthesis result with a list of errors. */
627+
/** Tuple used to store the synthesis result with a list of errors. */
613628
type TreeWithErrors = (Tree, List[String])
614629
private def withNoErrors(tree: Tree): TreeWithErrors = (tree, List.empty)
615630

616631
private val EmptyTreeNoError: TreeWithErrors = withNoErrors(EmptyTree)
617632

618633
private def orElse(treeWithErrors1: TreeWithErrors, treeWithErrors2: => TreeWithErrors): TreeWithErrors = treeWithErrors1 match
619-
case (tree, errors) if tree eq genericEmptyTree =>
634+
case (tree, errors) if tree eq genericEmptyTree =>
620635
val (tree2, errors2) = treeWithErrors2
621636
(tree2, errors ::: errors2)
622637
case _ => treeWithErrors1
623638

624-
private def clearErrorsIfNotEmpty(treeWithErrors: TreeWithErrors) = treeWithErrors match
639+
private def clearErrorsIfNotEmpty(treeWithErrors: TreeWithErrors) = treeWithErrors match
625640
case (tree, _) if tree eq genericEmptyTree => treeWithErrors
626641
case (tree, _) => withNoErrors(tree)
627642

tests/neg/7380a.scala

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.deriving.Mirror
2+
3+
object Lib {
4+
5+
trait HasFields[T]
6+
trait HasFieldsOrNone[T]
7+
8+
given mirHasFields[T, ElemLabels <: NonEmptyTuple](using
9+
mir: Mirror.ProductOf[T] { type MirroredElemLabels = ElemLabels }
10+
): HasFields[T]()
11+
12+
given mirHasFieldsOrNone[T, ElemLabels <: Tuple](using
13+
mir: Mirror.ProductOf[T] { type MirroredElemLabels = ElemLabels }
14+
): HasFieldsOrNone[T]()
15+
16+
}
17+
18+
object Test {
19+
20+
Lib.mirHasFields[(Int, String), ("_1", "_2", "_3")] // error
21+
22+
summon[Lib.HasFields[(Int, String)]] // ok
23+
24+
case class NoFields()
25+
26+
summon[Lib.HasFields[NoFields]] // error
27+
summon[Lib.HasFieldsOrNone[NoFields]] // ok
28+
29+
}

tests/run-macros/i7987.check

+1-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
scala.deriving.Mirror {
2-
type MirroredType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
3-
type MirroredMonoType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
4-
type MirroredElemTypes >: scala.Nothing <: scala.Tuple
5-
} & scala.deriving.Mirror.Product {
1+
scala.deriving.Mirror.Product {
62
type MirroredMonoType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
73
type MirroredType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
84
type MirroredLabel >: "Some" <: "Some"
9-
} {
105
type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]
116
type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple]
127
}

tests/run/i14150.scala

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import scala.deriving.Mirror
2+
import scala.util.NotGiven
3+
import scala.compiletime.constValue
4+
5+
trait GetConstValue[T] {
6+
type Out
7+
def get : Out
8+
}
9+
10+
object GetConstValue {
11+
type Aux[T, O] = GetConstValue[T] { type Out = O }
12+
13+
inline given value[T <: Singleton](
14+
using
15+
ev : NotGiven[T <:< Tuple],
16+
) : GetConstValue.Aux[T, T] = {
17+
val out = constValue[T]
18+
19+
new GetConstValue[T] {
20+
type Out = T
21+
def get : Out = out
22+
}
23+
}
24+
25+
given empty : GetConstValue[EmptyTuple] with {
26+
type Out = EmptyTuple
27+
def get : Out = EmptyTuple
28+
}
29+
30+
given nonEmpty[H, HRes, Tail <: Tuple, TRes <: Tuple](
31+
using
32+
head : GetConstValue.Aux[H, HRes],
33+
tail : GetConstValue.Aux[Tail, TRes],
34+
) : GetConstValue[H *: Tail] with {
35+
type Out = HRes *: TRes
36+
37+
def get : Out = head.get *: tail.get
38+
}
39+
}
40+
41+
trait MirrorNamesDeriver[T] {
42+
type Derived <: Tuple
43+
def derive : Derived
44+
}
45+
46+
object MirrorNamesDeriver {
47+
given mirDeriver[T, ElemLabels <: NonEmptyTuple](
48+
using
49+
mir: Mirror.SumOf[T] { type MirroredElemLabels = ElemLabels },
50+
ev : GetConstValue.Aux[ElemLabels, ElemLabels],
51+
): MirrorNamesDeriver[T] with {
52+
type Derived = ElemLabels
53+
54+
def derive: ElemLabels = ev.get
55+
}
56+
57+
def derive[T](using d : MirrorNamesDeriver[T]) : d.Derived = d.derive
58+
}
59+
60+
sealed trait SuperT
61+
final case class SubT1(int: Int) extends SuperT
62+
final case class SubT2(str: String, dbl : Double, bool : Boolean) extends SuperT
63+
64+
@main def Test =
65+
66+
// Works when type parameters are set explicitly
67+
val successfulLabels = MirrorNamesDeriver.mirDeriver[SuperT, ("SubT1", "SubT2")].derive
68+
println(successfulLabels)
69+
assert(successfulLabels == ("SubT1", "SubT2"))
70+
71+
// Fails when type parameters are inferred
72+
val failedLabels = MirrorNamesDeriver.derive[SuperT]
73+
println(successfulLabels)
74+
assert(failedLabels == ("SubT1", "SubT2"))

0 commit comments

Comments
 (0)