Skip to content

Improve our ability to override default parameters #11704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -602,13 +602,8 @@ object desugar {
// def _1: T1 = this.p1
// ...
// def _N: TN = this.pN (unless already given as valdef or parameterless defdef)
// def copy(p1: T1 = p1: @uncheckedVariance, ...,
// pN: TN = pN: @uncheckedVariance)(moreParams) =
// def copy(p1: T1 = p1..., pN: TN = pN)(moreParams) =
// new C[...](p1, ..., pN)(moreParams)
//
// Note: copy default parameters need @uncheckedVariance; see
// neg/t1843-variances.scala for a test case. The test would give
// two errors without @uncheckedVariance, one of them spurious.
val (caseClassMeths, enumScaffolding) = {
def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) =
DefDef(name, Nil, tpt, rhs).withMods(synthetic)
Expand Down Expand Up @@ -638,10 +633,8 @@ object desugar {
}
if (mods.is(Abstract) || hasRepeatedParam) Nil // cannot have default arguments for repeated parameters, hence copy method is not issued
else {
def copyDefault(vparam: ValDef) =
makeAnnotated("scala.annotation.unchecked.uncheckedVariance", refOfDef(vparam))
val copyFirstParams = derivedVparamss.head.map(vparam =>
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
cpy.ValDef(vparam)(rhs = refOfDef(vparam)))
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree))
DefDef(
Expand Down
62 changes: 35 additions & 27 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1386,13 +1386,10 @@ class Namer { typer: Typer =>
}
end inherited

/** The proto-type to be used when inferring the result type from
* the right hand side. This is `WildcardType` except if the definition
* is a default getter. In that case, the proto-type is the type of
* the corresponding parameter where bound parameters are replaced by
* Wildcards.
/** If this is a default getter, the type of the corresponding method parameter,
* otherwise NoType.
*/
def rhsProto = sym.asTerm.name collect {
def defaultParamType = sym.name match
case DefaultGetterName(original, idx) =>
val meth: Denotation =
if (original.isConstructorName && (sym.owner.is(ModuleClass)))
Expand All @@ -1401,37 +1398,24 @@ class Namer { typer: Typer =>
ctx.defContext(sym).denotNamed(original)
def paramProto(paramss: List[List[Type]], idx: Int): Type = paramss match {
case params :: paramss1 =>
if (idx < params.length) wildApprox(params(idx))
if (idx < params.length) params(idx)
else paramProto(paramss1, idx - params.length)
case nil =>
WildcardType
NoType
}
val defaultAlts = meth.altsWith(_.hasDefaultParams)
if (defaultAlts.length == 1)
paramProto(defaultAlts.head.info.widen.paramInfoss, idx)
else
WildcardType
} getOrElse WildcardType
NoType
case _ =>
NoType

// println(s"final inherited for $sym: ${inherited.toString}") !!!
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
// TODO Scala 3.1: only check for inline vals (no final ones)
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)

// Widen rhs type and eliminate `|' but keep ConstantTypes if
// definition is inline (i.e. final in Scala2) and keep module singleton types
// instead of widening to the underlying module class types.
// We also drop the @Repeated annotation here to avoid leaking it in method result types
// (see run/inferred-repeated-result).
def widenRhs(tp: Type): Type =
tp.widenTermRefExpr.simplified match
case ctp: ConstantType if isInlineVal => ctp
case tp => TypeComparer.widenInferred(tp, rhsProto)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp

var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType)
if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod)
Expand All @@ -1443,8 +1427,32 @@ class Namer { typer: Typer =>
rhsCtx.setFreshGADTBounds
rhsCtx.gadt.addToConstraint(typeParams)
}
def rhsType = PrepareInlineable.dropInlineIfError(sym,
typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(using rhsCtx)).tpe

def typedAheadRhs(pt: Type) =
PrepareInlineable.dropInlineIfError(sym,
typedAheadExpr(mdef.rhs, pt)(using rhsCtx))

def rhsType =
// For default getters, we use the corresponding parameter type as an
// expected type but we run it through `wildApprox` to allow default
// parameters like in `def mkList[T](value: T = 1): List[T]`.
val defaultTp = defaultParamType
val pt = inherited.orElse(wildApprox(defaultTp)).orElse(WildcardType).widenExpr
val tp = typedAheadRhs(pt).tpe
if (defaultTp eq pt) && (tp frozen_<:< defaultTp) then
// When possible, widen to the default getter parameter type to permit a
// larger choice of overrides (see `default-getter.scala`).
// For justification on the use of `@uncheckedVariance`, see
// `default-getter-variance.scala`.
AnnotatedType(defaultTp, Annotation(defn.UncheckedVarianceAnnot))
else tp.widenTermRefExpr.simplified match
case ctp: ConstantType if isInlineVal => ctp
case tp =>
TypeComparer.widenInferred(tp, pt)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp

// Approximate a type `tp` with a type that does not contain skolem types.
val deskolemize = new ApproximatingTypeMap {
Expand All @@ -1455,7 +1463,7 @@ class Namer { typer: Typer =>
}
}

def cookedRhsType = deskolemize(dealiasIfUnit(widenRhs(rhsType)))
def cookedRhsType = deskolemize(dealiasIfUnit(rhsType))
def lhsType = fullyDefinedType(cookedRhsType, "right-hand side", mdef.span)
//if (sym.name.toString == "y") println(i"rhs = $rhsType, cooked = $cookedRhsType")
if (inherited.exists)
Expand Down
45 changes: 45 additions & 0 deletions tests/pos/default-getter-variance.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
class Foo[+A] {
def count(f: A => Boolean = _ => true): Unit = {}
// The preceding line is valid, even though the generated default getter
// has type `A => Boolean` which wouldn't normally pass variance checks
// because it's equivalent to the following overloads which are valid:
def count2(f: A => Boolean): Unit = {}
def count2(): Unit = count(_ => true)
}

class Bar1[+A] extends Foo[A] {
override def count(f: A => Boolean): Unit = {}
// This reasoning extends to overrides:
override def count2(f: A => Boolean): Unit = {}
}

class Bar2[+A] extends Foo[A] {
override def count(f: A => Boolean = _ => true): Unit = {}
// ... including overrides which also override the default getter:
override def count2(f: A => Boolean): Unit = {}
override def count2(): Unit = count(_ => true)
}

// This can be contrasted with the need for variance checks in
// `protected[this] methods (cf tests/neg/t7093.scala),
// default getters do not have the same problem since they cannot
// appear in arbitrary contexts.


// Crucially, this argument does not apply to situations in which the default
// getter result type is not a subtype of the parameter type, for example (from
// tests/neg/variance.scala):
//
// class Foo[+A: ClassTag](x: A) {
// private[this] val elems: Array[A] = Array(x)
// def f[B](x: Array[B] = elems): Array[B] = x
// }
//
// If we tried to rewrite this with an overload, it would fail
// compilation:
//
// def f[B](): Array[B] = f(elems) // error: Found: Array[A], Expected: Array[B]
//
// So we only disable variance checking for default getters whose
// result type is the method parameter type, this is checked by
// `tests/neg/variance.scala`
10 changes: 10 additions & 0 deletions tests/pos/default-getter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class X
class Y extends X

class A {
def foo(param: X = new Y): X = param
}

class B extends A {
override def foo(param: X = new X): X = param
}
2 changes: 1 addition & 1 deletion tests/neg/i4659b.scala → tests/pos/i4659c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ case class SourcePosition(outer: SourcePosition = NoSourcePosition) {
assert(outer != null) // crash
}

object NoSourcePosition extends SourcePosition() // error
object NoSourcePosition extends SourcePosition()