Skip to content

Commit 624c276

Browse files
committed
reorganise construction of refined mirror types,
record constraints against the formal.
1 parent 001bfc3 commit 624c276

File tree

4 files changed

+129
-16
lines changed

4 files changed

+129
-16
lines changed

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+25-10
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
240240
* <parent> {
241241
* MirroredMonoType = <monoType>
242242
* MirroredType = <mirroredType>
243-
* MirroredLabel = <label> }
243+
* MirroredLabel = <label>
244244
* }
245245
*/
246-
private def mirrorCore(parentClass: ClassSymbol, monoType: Type, mirroredType: Type, label: Name, formal: Type)(using Context) =
247-
formal & parentClass.typeRef
246+
private def mirrorCore(parentClass: ClassSymbol, monoType: Type, mirroredType: Type, label: Name)(using Context) =
247+
parentClass.typeRef
248248
.refinedWith(tpnme.MirroredMonoType, TypeAlias(monoType))
249249
.refinedWith(tpnme.MirroredType, TypeAlias(mirroredType))
250250
.refinedWith(tpnme.MirroredLabel, TypeAlias(ConstantType(Constant(label.toString))))
@@ -269,6 +269,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
269269
then report.error(
270270
em"$name mismatch, expected: $expected, found: $actual.", ctx.source.atSpan(span))
271271

272+
extension (formal: Type)
273+
/** `tp := op; tp <:< formal; formal & tp` */
274+
private def constrained_&(op: Context ?=> Type)(using Context): Type =
275+
val tp = op
276+
tp <:< formal
277+
formal & tp
278+
272279
private def mkMirroredMonoType(mirroredType: HKTypeLambda)(using Context): Type =
273280
val monoMap = new TypeMap:
274281
def apply(t: Type) = t match
@@ -300,10 +307,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
300307
val elemsLabels = TypeOps.nestedPairs(elemLabels)
301308
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
302309
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
303-
val mirrorType =
304-
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
310+
val mirrorType = formal.constrained_& {
311+
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name)
305312
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
306313
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
314+
}
307315
val mirrorRef =
308316
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
309317
else anonymousMirror(monoType, ExtendsProductMirror, span)
@@ -328,11 +336,15 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
328336
else (cls.sourceModule, cls.sourceModule.reachableTermRef)
329337
val singletonPath = pathFor(singletonRef).withSpan(span)
330338
if singleton.info.classSymbol.is(Scala2x) then // could be Scala 3 alias of Scala 2 case object.
331-
val mirrorType = mirrorCore(defn.Mirror_SingletonProxyClass, mirroredType, mirroredType, singleton.name, formal)
339+
val mirrorType = formal.constrained_& {
340+
mirrorCore(defn.Mirror_SingletonProxyClass, mirroredType, mirroredType, singleton.name)
341+
}
332342
val mirrorRef = New(defn.Mirror_SingletonProxyClass.typeRef, singletonPath :: Nil)
333343
withNoErrors(mirrorRef.cast(mirrorType))
334344
else
335-
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, singleton.name, formal)
345+
val mirrorType = formal.constrained_& {
346+
mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, singleton.name)
347+
}
336348
withNoErrors(singletonPath.cast(mirrorType))
337349
else
338350
val acceptableMsg = whyNotAcceptableType(mirroredType, cls)
@@ -408,9 +420,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
408420
(mirroredType, elems)
409421

410422
val mirrorType =
411-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
423+
val labels = TypeOps.nestedPairs(elemLabels)
424+
formal.constrained_& {
425+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name)
412426
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
413-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
427+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(labels))
428+
}
414429
val mirrorRef =
415430
if cls.useCompanionAsSumMirror then companionPath(mirroredType, span)
416431
else anonymousMirror(monoType, ExtendsSumMirror, span)
@@ -419,7 +434,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
419434
withErrors(i"type $mirroredType is not a generic sum because $acceptableMsg")
420435
else if !clsIsGenericSum then
421436
withErrors(i"$cls is not a generic sum because ${cls.whyNotGenericSum}")
422-
else
437+
else
423438
EmptyTreeNoError
424439
end sumMirror
425440

tests/neg/7380a.scala

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.deriving.Mirror
2+
3+
object Lib {
4+
5+
trait HasFields[T]
6+
trait HasFieldsOrNone[T]
7+
8+
given mirHasFields[T, ElemLabels <: NonEmptyTuple](using
9+
mir: Mirror.ProductOf[T] { type MirroredElemLabels = ElemLabels }
10+
): HasFields[T]()
11+
12+
given mirHasFieldsOrNone[T, ElemLabels <: Tuple](using
13+
mir: Mirror.ProductOf[T] { type MirroredElemLabels = ElemLabels }
14+
): HasFieldsOrNone[T]()
15+
16+
}
17+
18+
object Test {
19+
20+
Lib.mirHasFields[(Int, String), ("_1", "_2", "_3")] // error
21+
22+
summon[Lib.HasFields[(Int, String)]] // ok
23+
24+
case class NoFields()
25+
26+
summon[Lib.HasFields[NoFields]] // error
27+
summon[Lib.HasFieldsOrNone[NoFields]] // ok
28+
29+
}

tests/run-macros/i7987.check

+1-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
scala.deriving.Mirror {
2-
type MirroredType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
3-
type MirroredMonoType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
4-
type MirroredElemTypes >: scala.Nothing <: scala.Tuple
5-
} & scala.deriving.Mirror.Product {
1+
scala.deriving.Mirror.Product {
62
type MirroredMonoType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
73
type MirroredType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
84
type MirroredLabel >: "Some" <: "Some"
9-
} {
105
type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]
116
type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple]
127
}

tests/run/i14150.scala

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import scala.deriving.Mirror
2+
import scala.util.NotGiven
3+
import scala.compiletime.constValue
4+
5+
trait GetConstValue[T] {
6+
type Out
7+
def get : Out
8+
}
9+
10+
object GetConstValue {
11+
type Aux[T, O] = GetConstValue[T] { type Out = O }
12+
13+
inline given value[T <: Singleton](
14+
using
15+
ev : NotGiven[T <:< Tuple],
16+
) : GetConstValue.Aux[T, T] = {
17+
val out = constValue[T]
18+
19+
new GetConstValue[T] {
20+
type Out = T
21+
def get : Out = out
22+
}
23+
}
24+
25+
given empty : GetConstValue[EmptyTuple] with {
26+
type Out = EmptyTuple
27+
def get : Out = EmptyTuple
28+
}
29+
30+
given nonEmpty[H, HRes, Tail <: Tuple, TRes <: Tuple](
31+
using
32+
head : GetConstValue.Aux[H, HRes],
33+
tail : GetConstValue.Aux[Tail, TRes],
34+
) : GetConstValue[H *: Tail] with {
35+
type Out = HRes *: TRes
36+
37+
def get : Out = head.get *: tail.get
38+
}
39+
}
40+
41+
trait MirrorNamesDeriver[T] {
42+
type Derived <: Tuple
43+
def derive : Derived
44+
}
45+
46+
object MirrorNamesDeriver {
47+
given mirDeriver[T, ElemLabels <: NonEmptyTuple](
48+
using
49+
mir: Mirror.SumOf[T] { type MirroredElemLabels = ElemLabels },
50+
ev : GetConstValue.Aux[ElemLabels, ElemLabels],
51+
): MirrorNamesDeriver[T] with {
52+
type Derived = ElemLabels
53+
54+
def derive: ElemLabels = ev.get
55+
}
56+
57+
def derive[T](using d : MirrorNamesDeriver[T]) : d.Derived = d.derive
58+
}
59+
60+
sealed trait SuperT
61+
final case class SubT1(int: Int) extends SuperT
62+
final case class SubT2(str: String, dbl : Double, bool : Boolean) extends SuperT
63+
64+
@main def Test =
65+
66+
// Works when type parameters are set explicitly
67+
val successfulLabels = MirrorNamesDeriver.mirDeriver[SuperT, ("SubT1", "SubT2")].derive
68+
println(successfulLabels)
69+
assert(successfulLabels == ("SubT1", "SubT2"))
70+
71+
// Fails when type parameters are inferred
72+
val failedLabels = MirrorNamesDeriver.derive[SuperT]
73+
println(successfulLabels)
74+
assert(failedLabels == ("SubT1", "SubT2"))

0 commit comments

Comments
 (0)