Skip to content

Commit 345b2da

Browse files
authored
Minimal support for dependent case classes (#21698)
This lets us write: ```scala trait A: type B case class CC(a: A, b: a.B) ``` Pattern matching works but isn't dependent yet: ```scala x match case CC(a, b) => val a1: A = a // Dependent pattern matching is not currently supported // val b1: a1.B = b val b1 = b // Type is CC#a.B ``` (for my usecase this isn't a problem, I'm working on a type constraint API which lets me write things like `case class CC(a: Int, b: Int GreaterThan[a.type])`) Because case class pattern matching relies on the product selectors `_N`, making it dependent is a bit tricky, currently we generate: ```scala case class CC(a: A, b: a.B): def _1: A = a def _2: a.B = b ``` So the type of `_2` is not obviously related to the type of `_1`, we probably need to change what we generate into: ```scala case class CC(a: A, b: a.B): @uncheckedStable def _1: a.type = a def _2: _1.B = b ``` But this can be done in a separate PR. Fixes #8073.
2 parents 5929c87 + 246793a commit 345b2da

12 files changed

+290
-122
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ class Definitions {
618618
@tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType))
619619
@tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType))
620620
@tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType))
621+
@tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType))
621622
@tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long)
622623
def LongClass(using Context): ClassSymbol = LongType.symbol.asClass
623624
@tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType))

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ object StdNames {
425425
val array_length : N = "array_length"
426426
val array_update : N = "array_update"
427427
val arraycopy: N = "arraycopy"
428+
val arity: N = "arity"
428429
val as: N = "as"
429430
val asTerm: N = "asTerm"
430431
val asModule: N = "asModule"

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
6262
case Select(qual, name) if !name.is(OuterSelectName) && tree.symbol.exists =>
6363
val qualTpe = qual.tpe
6464
assert(
65-
qualTpe.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) ||
65+
qualTpe.widenDealias.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) ||
6666
tree.symbol.is(JavaStatic) && qualTpe.derivesFrom(tree.symbol.enclosingClass),
6767
i"non member selection of ${tree.symbol.showLocated} from ${qualTpe} in $tree")
6868
case _: TypeTree =>

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

+81-31
Original file line numberDiff line numberDiff line change
@@ -504,53 +504,103 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
504504
/** The class
505505
*
506506
* ```
507-
* case class C[T <: U](x: T, y: String*)
507+
* trait U:
508+
* type Elem
509+
*
510+
* case class C[T <: U](a: T, b: a.Elem, c: String*)
508511
* ```
509512
*
510513
* gets the `fromProduct` method:
511514
*
512515
* ```
513516
* def fromProduct(x$0: Product): MirroredMonoType =
514-
* new C[U](
515-
* x$0.productElement(0).asInstanceOf[U],
516-
* x$0.productElement(1).asInstanceOf[Seq[String]]: _*)
517+
* val a$1 = x$0.productElement(0).asInstanceOf[U]
518+
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
519+
* val c$1 = x$0.productElement(2).asInstanceOf[Seq[String]]
520+
* new C[U](a$1, b$1, c$1*)
517521
* ```
518522
* where
519523
* ```
520524
* type MirroredMonoType = C[?]
521525
* ```
526+
*
527+
* However, if the last parameter is annotated `@unroll` then we generate:
528+
*
529+
* def fromProduct(x$0: Product): MirroredMonoType =
530+
* val arity = x$0.productArity
531+
* val a$1 = x$0.productElement(0).asInstanceOf[U]
532+
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
533+
* val c$1 = (
534+
* if arity > 2 then
535+
* x$0.productElement(2)
536+
* else
537+
* <default getter for the third parameter of C>
538+
* ).asInstanceOf[Seq[String]]
539+
* new C[U](a$1, b$1, c$1*)
522540
*/
523-
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
524-
def extractParams(tpe: Type): List[Type] =
525-
tpe.asInstanceOf[MethodType].paramInfos
526-
527-
def computeFromCaseClass: (Type, List[Type]) =
528-
val (baseRef, baseInfo) =
529-
val rawRef = caseClass.typeRef
530-
val rawInfo = caseClass.primaryConstructor.info
531-
optInfo match
532-
case Some(info) =>
533-
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
534-
case _ =>
535-
(rawRef, rawInfo)
536-
baseInfo match
541+
def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
542+
val classRef = optInfo match
543+
case Some(info) => TypeRef(info.pre, caseClass)
544+
case _ => caseClass.typeRef
545+
val (newPrefix, constrMeth, constrSyms) =
546+
val constr = TermRef(classRef, caseClass.primaryConstructor)
547+
val symss = caseClass.primaryConstructor.paramSymss
548+
(constr.info: @unchecked) match
537549
case tl: PolyType =>
538550
val tvars = constrained(tl)
539551
val targs = for tvar <- tvars yield
540552
tvar.instantiate(fromBelow = false)
541-
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
542-
case methTpe =>
543-
(baseRef, extractParams(methTpe))
544-
end computeFromCaseClass
545-
546-
val (classRefApplied, paramInfos) = computeFromCaseClass
547-
val elems =
548-
for ((formal, idx) <- paramInfos.zipWithIndex) yield
549-
val elem =
550-
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
551-
.ensureConforms(formal.translateFromRepeated(toArray = false))
552-
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
553-
New(classRefApplied, elems)
553+
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType], symss(1))
554+
case mt: MethodType =>
555+
(classRef, mt, symss.head)
556+
557+
// Index of the first parameter marked `@unroll` or -1
558+
val unrolledFrom =
559+
constrSyms.indexWhere(_.hasAnnotation(defn.UnrollAnnot))
560+
561+
// `val arity = x$0.productArity`
562+
val arityDef: Option[ValDef] =
563+
if unrolledFrom != -1 then
564+
Some(SyntheticValDef(nme.arity, productParam.select(defn.Product_productArity).withSpan(ctx.owner.span.focus)))
565+
else None
566+
val arityRefTree = arityDef.map(vd => ref(vd.symbol))
567+
568+
// Create symbols for the vals corresponding to each parameter
569+
// If there are dependent parameters, the infos won't be correct yet.
570+
val bindingSyms = constrMeth.paramRefs.map: pref =>
571+
newSymbol(ctx.owner, pref.paramName.freshened, Synthetic,
572+
pref.underlying.translateFromRepeated(toArray = false), coord = ctx.owner.span.focus)
573+
val bindingRefs = bindingSyms.map(TermRef(NoPrefix, _))
574+
// Fix the infos for dependent parameters
575+
if constrMeth.isParamDependent then
576+
bindingSyms.foreach: bindingSym =>
577+
bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs)
578+
579+
def defaultGetterAtIndex(idx: Int): Tree =
580+
val defaultGetterPrefix = caseClass.primaryConstructor.name.toTermName
581+
ref(caseClass.companionModule).select(NameKinds.DefaultGetterName(defaultGetterPrefix, idx))
582+
583+
val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) =>
584+
val selection = productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
585+
val rhs = (
586+
if unrolledFrom != -1 && idx >= unrolledFrom then
587+
If(arityRefTree.get.select(defn.Int_>).appliedTo(Literal(Constant(idx))),
588+
thenp =
589+
selection,
590+
elsep =
591+
defaultGetterAtIndex(idx))
592+
else
593+
selection
594+
).ensureConforms(bindingSym.info)
595+
ValDef(bindingSym, rhs)
596+
597+
val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) =>
598+
val refTree = ref(bindingRef)
599+
if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree
600+
Block(
601+
arityDef.toList ::: bindingDefs,
602+
New(newPrefix, newArgs)
603+
)
554604
end fromProductBody
555605

556606
/** For an enum T:

compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala

+19-67
Original file line numberDiff line numberDiff line change
@@ -228,46 +228,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
228228
forwarderDef
229229
}
230230

231-
private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
232-
cpy.DefDef(defdef)(
233-
name = defdef.name,
234-
paramss = defdef.paramss,
235-
tpt = defdef.tpt,
236-
rhs = Match(
237-
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
238-
startParamIndices.map { paramIndex =>
239-
val Apply(select, args) = defdef.rhs: @unchecked
240-
CaseDef(
241-
Literal(Constant(paramIndex)),
242-
EmptyTree,
243-
Apply(
244-
select,
245-
args.take(paramIndex) ++
246-
Range(paramIndex, paramCount).map(n =>
247-
ref(defdef.symbol.owner.companionModule)
248-
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
249-
)
250-
)
251-
)
252-
} :+ CaseDef(
253-
Underscore(defn.IntType),
254-
EmptyTree,
255-
defdef.rhs
256-
)
257-
)
258-
).setDefTree
259-
}
260-
261-
private enum Gen:
262-
case Substitute(origin: Symbol, newDef: DefDef)
263-
case Forwarders(origin: Symbol, forwarders: List[DefDef])
231+
case class Forwarders(origin: Symbol, forwarders: List[DefDef])
264232

265-
def origin: Symbol
266-
def extras: List[DefDef] = this match
267-
case Substitute(_, d) => d :: Nil
268-
case Forwarders(_, ds) => ds
269-
270-
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Gen] = tree match {
233+
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Forwarders] = tree match {
271234
case defdef: DefDef if defdef.paramss.nonEmpty =>
272235
import dotty.tools.dotc.core.NameOps.isConstructorName
273236

@@ -277,38 +240,29 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
277240
val isCaseApply =
278241
defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass)
279242

280-
val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass)
281-
282243
val annotated =
283244
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
284245
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
285-
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
286246
else defdef.symbol
287247

288248
compute(annotated) match {
289249
case Nil => None
290250
case (paramClauseIndex, annotationIndices) :: Nil =>
291251
val paramCount = annotated.paramSymss(paramClauseIndex).size
292-
if isCaseFromProduct then
293-
Some(Gen.Substitute(
294-
origin = defdef.symbol,
295-
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
296-
))
297-
else
298-
val generatedDefs =
299-
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
300-
indices.foldLeft(List.empty[DefDef]):
301-
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
302-
generateSingleForwarder(
303-
defdef,
304-
paramIndex,
305-
paramCount,
306-
nextParamIndex,
307-
paramClauseIndex,
308-
isCaseApply
309-
) :: defdefs
310-
case _ => unreachable("sliding with at least 2 elements")
311-
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
252+
val generatedDefs =
253+
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
254+
indices.foldLeft(List.empty[DefDef]):
255+
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
256+
generateSingleForwarder(
257+
defdef,
258+
paramIndex,
259+
paramCount,
260+
nextParamIndex,
261+
paramClauseIndex,
262+
isCaseApply
263+
) :: defdefs
264+
case _ => unreachable("sliding with at least 2 elements")
265+
Some(Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
312266

313267
case multiple =>
314268
report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos)
@@ -323,14 +277,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
323277
val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
324278
val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
325279
val allGenerated = generatedBody ++ generatedConstr0
326-
val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet
327-
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))
328280

329281
if allGenerated.nonEmpty then
330-
val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
282+
val byName = (tmpl.constr :: tmpl.body).groupMap(_.symbol.name.toString)(_.symbol)
331283
for
332284
syntheticDefs <- allGenerated
333-
dcl <- syntheticDefs.extras
285+
dcl <- syntheticDefs.forwarders
334286
do
335287
val replaced = dcl.symbol
336288
byName.get(dcl.name.toString).foreach { syms =>
@@ -348,7 +300,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
348300
tmpl.parents,
349301
tmpl.derived,
350302
tmpl.self,
351-
otherDecls ++ allGenerated.flatMap(_.extras)
303+
tmpl.body ++ allGenerated.flatMap(_.forwarders)
352304
)
353305
}
354306

compiler/src/dotty/tools/dotc/typer/Namer.scala

+1-13
Original file line numberDiff line numberDiff line change
@@ -1957,9 +1957,7 @@ class Namer { typer: Typer =>
19571957
if isConstructor then
19581958
// set result type tree to unit, but take the current class as result type of the symbol
19591959
typedAheadType(ddef.tpt, defn.UnitType)
1960-
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
1961-
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
1962-
mt
1960+
wrapMethType(effectiveResultType(sym, paramSymss))
19631961
else
19641962
val paramFn = if Feature.enabled(Feature.modularity) && sym.isAllOf(Given | Method) then wrapRefinedMethType else wrapMethType
19651963
valOrDefDefSig(ddef, sym, paramSymss, paramFn)
@@ -2001,16 +1999,6 @@ class Namer { typer: Typer =>
20011999
ddef.trailingParamss.foreach(completeParams)
20022000
end completeTrailingParamss
20032001

2004-
/** Checks an implementation restriction on case classes. */
2005-
def checkCaseClassParamDependencies(mt: Type, cls: Symbol)(using Context): Unit =
2006-
mt.stripPoly match
2007-
case mt: MethodType if cls.is(Case) && mt.isParamDependent =>
2008-
// See issue #8073 for background
2009-
report.error(
2010-
em"""Implementation restriction: case classes cannot have dependencies between parameters""",
2011-
cls.srcPos)
2012-
case _ =>
2013-
20142002
private def setParamTrackedWithAccessors(psym: Symbol, ownerTpe: Type)(using Context): Unit =
20152003
for acc <- ownerTpe.decls.lookupAll(psym.name) if acc.is(ParamAccessor) do
20162004
acc.resetFlag(PrivateLocal)

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3198,7 +3198,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
31983198
.withType(dummy.termRef)
31993199
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
32003200
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
3201-
if cls.isEnum || firstParentTpe.classSymbol.isEnum then
3201+
if cls.isEnum || !cls.isRefinementClass && firstParentTpe.classSymbol.isEnum then
32023202
checkEnum(cdef, cls, firstParent)
32033203
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)
32043204

tests/neg/i8069.scala

-8
This file was deleted.

tests/pos/enum-refinement.scala

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
enum Enum:
2+
case EC(val x: Int)
3+
4+
val a: Enum.EC { val x: 1 } = Enum.EC(1).asInstanceOf[Enum.EC { val x: 1 }]
5+
6+
import scala.language.experimental.modularity
7+
8+
enum EnumT:
9+
case EC(tracked val x: Int)
10+
11+
val b: EnumT.EC { val x: 1 } = EnumT.EC(1)
12+

0 commit comments

Comments
 (0)