Skip to content

fix 14150 - constrain refinements to type parameters #15014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
* <parent> {
* MirroredMonoType = <monoType>
* MirroredType = <mirroredType>
* MirroredLabel = <label> }
* MirroredLabel = <label>
* }
*/
private def mirrorCore(parentClass: ClassSymbol, monoType: Type, mirroredType: Type, label: Name, formal: Type)(using Context) =
formal & parentClass.typeRef
private def mirrorCore(parentClass: ClassSymbol, monoType: Type, mirroredType: Type, label: Name)(using Context) =
parentClass.typeRef
.refinedWith(tpnme.MirroredMonoType, TypeAlias(monoType))
.refinedWith(tpnme.MirroredType, TypeAlias(mirroredType))
.refinedWith(tpnme.MirroredLabel, TypeAlias(ConstantType(Constant(label.toString))))
Expand All @@ -269,6 +269,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
then report.error(
em"$name mismatch, expected: $expected, found: $actual.", ctx.source.atSpan(span))

extension (formal: Type)
/** `tp := op; tp <:< formal; formal & tp` */
private def constrained_&(op: Context ?=> Type)(using Context): Type =
val tp = op
tp <:< formal
formal & tp

private def mkMirroredMonoType(mirroredType: HKTypeLambda)(using Context): Type =
val monoMap = new TypeMap:
def apply(t: Type) = t match
Expand Down Expand Up @@ -366,10 +373,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val elemsLabels = TypeOps.nestedPairs(elemLabels)
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
val mirrorType =
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
val mirrorType = formal.constrained_& {
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
}
val mirrorRef =
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
else anonymousMirror(monoType, ExtendsProductMirror, span)
Expand All @@ -382,12 +390,15 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val singleton = tref.termSymbol // prefer alias name over the orignal name
val singletonPath = pathFor(tref).withSpan(span)
if tref.classSymbol.is(Scala2x) then // could be Scala 3 alias of Scala 2 case object.
val mirrorType =
mirrorCore(defn.Mirror_SingletonProxyClass, mirroredType, mirroredType, singleton.name, formal)
val mirrorType = formal.constrained_& {
mirrorCore(defn.Mirror_SingletonProxyClass, mirroredType, mirroredType, singleton.name)
}
val mirrorRef = New(defn.Mirror_SingletonProxyClass.typeRef, singletonPath :: Nil)
withNoErrors(mirrorRef.cast(mirrorType))
else
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, singleton.name, formal)
val mirrorType = formal.constrained_& {
mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, singleton.name)
}
withNoErrors(singletonPath.cast(mirrorType))
case MirrorSource.ClassSymbol(cls) =>
if cls.isGenericProduct then makeProductMirror(cls)
Expand Down Expand Up @@ -452,9 +463,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
(mirroredType, elems)

val mirrorType =
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
val labels = TypeOps.nestedPairs(elemLabels)
formal.constrained_& {
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(labels))
}
val mirrorRef =
if cls.useCompanionAsSumMirror then companionPath(mirroredType, span)
else anonymousMirror(monoType, ExtendsSumMirror, span)
Expand All @@ -463,7 +477,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
withErrors(i"type `$mirroredType` is not a generic sum because $acceptableMsg")
else if !clsIsGenericSum then
withErrors(i"$cls is not a generic sum because ${cls.whyNotGenericSum}")
else
else
EmptyTreeNoError
end sumMirror

Expand Down
29 changes: 29 additions & 0 deletions tests/neg/7380a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import scala.deriving.Mirror

object Lib {

trait HasFields[T]
trait HasFieldsOrNone[T]

given mirHasFields[T, ElemLabels <: NonEmptyTuple](using
mir: Mirror.ProductOf[T] { type MirroredElemLabels = ElemLabels }
): HasFields[T]()

given mirHasFieldsOrNone[T, ElemLabels <: Tuple](using
mir: Mirror.ProductOf[T] { type MirroredElemLabels = ElemLabels }
): HasFieldsOrNone[T]()

}

object Test {

Lib.mirHasFields[(Int, String), ("_1", "_2", "_3")] // error

summon[Lib.HasFields[(Int, String)]] // ok

case class NoFields()

summon[Lib.HasFields[NoFields]] // error
summon[Lib.HasFieldsOrNone[NoFields]] // ok

}
7 changes: 1 addition & 6 deletions tests/run-macros/i7987.check
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
scala.deriving.Mirror {
type MirroredType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
type MirroredMonoType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
type MirroredElemTypes >: scala.Nothing <: scala.Tuple
} & scala.deriving.Mirror.Product {
scala.deriving.Mirror.Product {
type MirroredMonoType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
type MirroredType >: scala.Some[scala.Int] <: scala.Some[scala.Int]
type MirroredLabel >: "Some" <: "Some"
} {
type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]
type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple]
}
74 changes: 74 additions & 0 deletions tests/run/i14150.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import scala.deriving.Mirror
import scala.util.NotGiven
import scala.compiletime.constValue

trait GetConstValue[T] {
type Out
def get : Out
}

object GetConstValue {
type Aux[T, O] = GetConstValue[T] { type Out = O }

inline given value[T <: Singleton](
using
ev : NotGiven[T <:< Tuple],
) : GetConstValue.Aux[T, T] = {
val out = constValue[T]

new GetConstValue[T] {
type Out = T
def get : Out = out
}
}

given empty : GetConstValue[EmptyTuple] with {
type Out = EmptyTuple
def get : Out = EmptyTuple
}

given nonEmpty[H, HRes, Tail <: Tuple, TRes <: Tuple](
using
head : GetConstValue.Aux[H, HRes],
tail : GetConstValue.Aux[Tail, TRes],
) : GetConstValue[H *: Tail] with {
type Out = HRes *: TRes

def get : Out = head.get *: tail.get
}
}

trait MirrorNamesDeriver[T] {
type Derived <: Tuple
def derive : Derived
}

object MirrorNamesDeriver {
given mirDeriver[T, ElemLabels <: NonEmptyTuple](
using
mir: Mirror.SumOf[T] { type MirroredElemLabels = ElemLabels },
ev : GetConstValue.Aux[ElemLabels, ElemLabels],
): MirrorNamesDeriver[T] with {
type Derived = ElemLabels

def derive: ElemLabels = ev.get
}

def derive[T](using d : MirrorNamesDeriver[T]) : d.Derived = d.derive
}

sealed trait SuperT
final case class SubT1(int: Int) extends SuperT
final case class SubT2(str: String, dbl : Double, bool : Boolean) extends SuperT

@main def Test =

// Works when type parameters are set explicitly
val successfulLabels = MirrorNamesDeriver.mirDeriver[SuperT, ("SubT1", "SubT2")].derive
println(successfulLabels)
assert(successfulLabels == ("SubT1", "SubT2"))

// Fails when type parameters are inferred
val failedLabels = MirrorNamesDeriver.derive[SuperT]
println(successfulLabels)
assert(failedLabels == ("SubT1", "SubT2"))