Skip to content

Encode variances in parameter names #2103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ object desugar {
* class C { type v C$T; type v T = C$T }
*/
def typeDef(tdef: TypeDef)(implicit ctx: Context): Tree = {
val name =
if (tdef.name.hasVariance && tdef.mods.is(Param)) {
ctx.error(em"type parameter name may not start with `+' or `-'", tdef.pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this even allowed by the parser?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ++, -- are legal parameter names.

("$" + tdef.name).toTypeName
}
else tdef.name
if (tdef.mods is PrivateLocalParam) {
val tparam = cpy.TypeDef(tdef)(name = tdef.name.expandedName(ctx.owner))
.withMods(tdef.mods &~ PrivateLocal | ExpandedName)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Definitions {
resultTypeFn: PolyType => Type, flags: FlagSet = EmptyFlags) = {
val tparamNames = tpnme.syntheticTypeParamNames(typeParamCount)
val tparamBounds = tparamNames map (_ => TypeBounds.empty)
val ptype = PolyType(tparamNames, tparamNames.map(alwaysZero))(_ => tparamBounds, resultTypeFn)
val ptype = PolyType(tparamNames)(_ => tparamBounds, resultTypeFn)
enterMethod(cls, name, ptype, flags)
}

Expand Down
26 changes: 25 additions & 1 deletion compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ object NameOps {
if (name.isModuleClassName) name.stripModuleClassSuffix.freshened.moduleClassName
else likeTyped(ctx.freshName(name ++ NameTransformer.NAME_JOIN_STRING)))

/** Name with variance prefix: `+` for covariant, `-` for contravariant */
def withVariance(v: Int): N =
if (hasVariance) dropVariance.withVariance(v)
else v match {
case -1 => likeTyped('-' +: name)
case 1 => likeTyped('+' +: name)
case 0 => name
}

/** Does name have a `+`/`-` variance prefix? */
def hasVariance: Boolean =
name.nonEmpty && name.head == '+' || name.head == '-'

/** Drop variance prefix if name has one */
def dropVariance: N = if (hasVariance) likeTyped(name.tail) else name

/** The variance as implied by the variance prefix, or 0 if there is
* no variance prefix.
*/
def variance = name.head match {
case '-' => -1
case '+' => 1
case _ => 0
}

/** Translate a name into a list of simple TypeNames and TermNames.
* In all segments before the last, type/term is determined by whether
* the following separator char is '.' or '#'. The last segment
Expand Down Expand Up @@ -271,7 +296,6 @@ object NameOps {
else -1
}


/** The number of hops specified in an outer-select name */
def outerSelectHops: Int = {
require(isOuterSelect)
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -846,5 +846,4 @@ object StdNames {
val tpnme = new ScalaTypeNames
val jnme = new JavaTermNames
val jtpnme = new JavaTypeNames

}
9 changes: 6 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,10 @@ class TypeApplications(val self: Type) extends AnyVal {
* TODO: Handle parameterized lower bounds
*/
def LambdaAbstract(tparams: List[TypeParamInfo])(implicit ctx: Context): Type = {
def nameWithVariance(tparam: TypeParamInfo) =
tparam.paramName.withVariance(tparam.paramVariance)
def expand(tp: Type) =
PolyType(
tparams.map(_.paramName), tparams.map(_.paramVariance))(
PolyType(tparams.map(nameWithVariance))(
tl => tparams.map(tparam => tl.lifted(tparams, tparam.paramBounds).bounds),
tl => tl.lifted(tparams, tp))
if (tparams.isEmpty) self
Expand Down Expand Up @@ -364,7 +365,9 @@ class TypeApplications(val self: Type) extends AnyVal {
case arg @ PolyType(tparams, body) if
!tparams.corresponds(hkParams)(_.paramVariance == _.paramVariance) &&
tparams.corresponds(hkParams)(varianceConforms) =>
PolyType(tparams.map(_.paramName), hkParams.map(_.paramVariance))(
PolyType(
(tparams, hkParams).zipped.map((tparam, hkparam) =>
tparam.paramName.withVariance(hkparam.paramVariance)))(
tl => arg.paramBounds.map(_.subst(arg, tl).bounds),
tl => arg.resultType.subst(arg, tl)
)
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
val tparams1 = tparams1a.drop(lengthDiff)
variancesConform(tparams1, tparams) && {
if (lengthDiff > 0)
tycon1b = PolyType(tparams1.map(_.paramName), tparams1.map(_.paramVariance))(
tycon1b = PolyType(tparams1.map(_.paramName))(
tl => tparams1.map(tparam => tl.lifted(tparams, tparam.paramBounds).bounds),
tl => tycon1a.appliedTo(args1.take(lengthDiff) ++
tparams1.indices.toList.map(PolyParam(tl, _))))
Expand Down Expand Up @@ -1279,9 +1279,9 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
original(tp1.appliedTo(tp1.typeParams.map(_.paramBoundsAsSeenFrom(tp1))), tp2)
else
PolyType(
paramNames = tpnme.syntheticTypeParamNames(tparams1.length),
variances = (tparams1, tparams2).zipped.map((tparam1, tparam2) =>
(tparam1.paramVariance + tparam2.paramVariance) / 2))(
paramNames = (tpnme.syntheticTypeParamNames(tparams1.length), tparams1, tparams2)
.zipped.map((pname, tparam1, tparam2) =>
pname.withVariance((tparam1.paramVariance + tparam2.paramVariance) / 2)))(
paramBoundsExp = tl => (tparams1, tparams2).zipped.map((tparam1, tparam2) =>
tl.lifted(tparams1, tparam1.paramBoundsAsSeenFrom(tp1)).bounds &
tl.lifted(tparams2, tparam2.paramBoundsAsSeenFrom(tp2)).bounds),
Expand Down
35 changes: 23 additions & 12 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2567,8 +2567,20 @@ object Types {
}
}

/** A type lambda of the form `[v_0 X_0, ..., v_n X_n] => T` */
class PolyType(val paramNames: List[TypeName], val variances: List[Int])(
/** A type lambda of the form `[X_0 B_0, ..., X_n B_n] => T`
* This is used both as a type of a polymorphic method and as a type of
* a higher-kidned type parameter. Variances are encoded in parameter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: kidned -> kinded

* names. A name starting with `+` designates a covariant parameter,
* a name starting with `-` designates a contravariant parameter,
* and every other name designates a non-variant parameter.
*
* @param paramNames The names `X_0`, ..., `X_n`
* @param paramBoundsExp A function that, given the polytype itself, returns the
* parameter bounds `B_1`, ..., `B_n`
* @param resultTypeExp A function that, given the polytype itself, returns the
* result type `T`.
*/
class PolyType(val paramNames: List[TypeName])(
paramBoundsExp: PolyType => List[TypeBounds], resultTypeExp: PolyType => Type)
extends CachedProxyType with BindingType with MethodOrPoly {

Expand Down Expand Up @@ -2606,7 +2618,7 @@ object Types {
paramBounds.mapConserve(_.substParams(this, argTypes).bounds)

def newLikeThis(paramNames: List[TypeName], paramBounds: List[TypeBounds], resType: Type)(implicit ctx: Context): PolyType =
PolyType.apply(paramNames, variances)(
PolyType.apply(paramNames)(
x => paramBounds mapConserve (_.subst(this, x).bounds),
x => resType.subst(this, x))

Expand Down Expand Up @@ -2639,7 +2651,7 @@ object Types {
case t => mapOver(t)
}
}
PolyType(paramNames ++ that.paramNames, variances ++ that.variances)(
PolyType(paramNames ++ that.paramNames)(
x => this.paramBounds.mapConserve(_.subst(this, x).bounds) ++
that.paramBounds.mapConserve(shift(_).subst(that, x).bounds),
x => shift(that.resultType).subst(that, x).subst(this, x))
Expand All @@ -2659,28 +2671,27 @@ object Types {
case other: PolyType =>
other.paramNames == this.paramNames &&
other.paramBounds == this.paramBounds &&
other.resType == this.resType &&
other.variances == this.variances
other.resType == this.resType
case _ => false
}

override def toString = s"PolyType($variances, $paramNames, $paramBounds, $resType)"
override def toString = s"PolyType($paramNames, $paramBounds, $resType)"

override def computeHash = doHash(variances ::: paramNames, resType, paramBounds)
override def computeHash = doHash(paramNames, resType, paramBounds)
}

object PolyType {
def apply(paramNames: List[TypeName], variances: List[Int])(
def apply(paramNames: List[TypeName])(
paramBoundsExp: PolyType => List[TypeBounds],
resultTypeExp: PolyType => Type)(implicit ctx: Context): PolyType = {
unique(new PolyType(paramNames, variances)(paramBoundsExp, resultTypeExp))
unique(new PolyType(paramNames)(paramBoundsExp, resultTypeExp))
}

def unapply(tl: PolyType): Some[(List[LambdaParam], Type)] =
Some((tl.typeParams, tl.resType))

def any(n: Int)(implicit ctx: Context) =
apply(tpnme.syntheticTypeParamNames(n), List.fill(n)(0))(
apply(tpnme.syntheticTypeParamNames(n))(
pt => List.fill(n)(TypeBounds.empty), pt => defn.AnyType)
}

Expand All @@ -2693,7 +2704,7 @@ object Types {
def paramBounds(implicit ctx: Context): TypeBounds = tl.paramBounds(n)
def paramBoundsAsSeenFrom(pre: Type)(implicit ctx: Context): TypeBounds = paramBounds
def paramBoundsOrCompleter(implicit ctx: Context): Type = paramBounds
def paramVariance(implicit ctx: Context): Int = tl.variances(n)
def paramVariance(implicit ctx: Context): Int = tl.paramNames(n).variance
def toArg: Type = PolyParam(tl, n)
def paramRef(implicit ctx: Context): Type = PolyParam(tl, n)
}
Expand Down
6 changes: 1 addition & 5 deletions compiler/src/dotty/tools/dotc/core/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Standard-Section: "ASTs" TopLevelStat*
BIND Length boundName_NameRef bounds_Type
// for type-variables defined in a type pattern
BYNAMEtype underlying_Type
POLYtype Length result_Type NamesTypes // variance encoded in front of name: +/-/=
POLYtype Length result_Type NamesTypes // variance encoded in front of name: +/-/(nothing)
METHODtype Length result_Type NamesTypes // needed for refinements
PARAMtype Length binder_ASTref paramNum_Nat // needed for refinements
SHARED type_ASTRef
Expand Down Expand Up @@ -546,8 +546,4 @@ object TastyFormat {
case POLYtype | METHODtype => -1
case _ => 0
}

/** Map between variances and name prefixes */
val varianceToPrefix = Map(-1 -> '-', 0 -> '=', 1 -> '+')
val prefixToVariance = Map('-' -> -1, '=' -> 0, '+' -> 1)
}
4 changes: 1 addition & 3 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ class TreePickler(pickler: TastyPickler) {
pickleType(tpe.underlying)
case tpe: PolyType =>
writeByte(POLYtype)
val paramNames = tpe.typeParams.map(tparam =>
varianceToPrefix(tparam.paramVariance) +: tparam.paramName)
pickleMethodic(tpe.resultType, paramNames, tpe.paramBounds)
pickleMethodic(tpe.resultType, tpe.paramNames, tpe.paramBounds)
case tpe: MethodType if richTypes =>
writeByte(METHODtype)
pickleMethodic(tpe.resultType, tpe.paramNames, tpe.paramTypes)
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,8 @@ class TreeUnpickler(reader: TastyReader, tastyName: TastyName.Table, posUnpickle
registerSym(start, sym)
TypeRef.withFixedSym(NoPrefix, sym.name, sym)
case POLYtype =>
val (rawNames, paramReader) = readNamesSkipParams
val (variances, paramNames) = rawNames
.map(name => (prefixToVariance(name.head), name.tail.toTypeName)).unzip
val result = PolyType(paramNames, variances)(
val (paramNames, paramReader) = readNamesSkipParams
val result = PolyType(paramNames.map(_.toTypeName))(
pt => registeringType(pt, paramReader.readParamTypes[TypeBounds](end)),
pt => readType())
goto(end)
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,9 @@ class PlainPrinter(_ctx: Context) extends Printer {
case tp: ExprType =>
changePrec(GlobalPrec) { "=> " ~ toText(tp.resultType) }
case tp: PolyType =>
def paramText(variance: Int, name: Name, bounds: TypeBounds): Text =
varianceString(variance) ~ name.toString ~ toText(bounds)
def paramText(name: Name, bounds: TypeBounds): Text = name.toString ~ toText(bounds)
changePrec(GlobalPrec) {
"[" ~ Text((tp.variances, tp.paramNames, tp.paramBounds).zipped.map(paramText), ", ") ~
"[" ~ Text((tp.paramNames, tp.paramBounds).zipped.map(paramText), ", ") ~
"]" ~ (" => " provided !tp.resultType.isInstanceOf[MethodType]) ~
toTextGlobal(tp.resultType)
}
Expand Down Expand Up @@ -209,7 +208,8 @@ class PlainPrinter(_ctx: Context) extends Printer {

protected def polyParamNameString(name: TypeName): String = name.toString

protected def polyParamNameString(param: PolyParam): String = polyParamNameString(param.binder.paramNames(param.paramNum))
protected def polyParamNameString(param: PolyParam): String =
polyParamNameString(param.binder.paramNames(param.paramNum))

/** The name of the symbol without a unique id. Under refined printing,
* the decoded original name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ trait FullParameterization {

info match {
case info: PolyType =>
PolyType(info.paramNames ++ ctnames, info.variances ++ ctvariances)(
PolyType(info.paramNames ++ ctnames)(
pt =>
(info.paramBounds.map(mapClassParams(_, pt).bounds) ++
mappedClassBounds(pt)).mapConserve(_.subst(info, pt).bounds),
pt => resultType(mapClassParams(_, pt)).subst(info, pt))
case _ =>
if (ctparams.isEmpty) resultType(identity)
else PolyType(ctnames, ctvariances)(mappedClassBounds, pt => resultType(mapClassParams(_, pt)))
else PolyType(ctnames)(mappedClassBounds, pt => resultType(mapClassParams(_, pt)))
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ object ProtoTypes {

/** Create a new polyparam that represents a dependent method parameter singleton */
def newDepPolyParam(tp: Type)(implicit ctx: Context): PolyParam = {
val poly = PolyType(ctx.freshName(nme.DEP_PARAM_PREFIX).toTypeName :: Nil, 0 :: Nil)(
val poly = PolyType(ctx.freshName(nme.DEP_PARAM_PREFIX).toTypeName :: Nil)(
pt => TypeBounds.upper(AndType(tp, defn.SingletonType)) :: Nil,
pt => defn.AnyType)
ctx.typeComparer.addToConstraint(poly, Nil)
Expand Down