Skip to content

Don't normalize in AppliedType#superType #15453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
if tycon1sym == tycon2sym && tycon1sym.isAliasType then
val preConstraint = constraint
isSubArgs(args1, args2, tp1, tparams)
&& tryAlso(preConstraint, recur(tp1.superType, tp2.superType))
&& tryAlso(preConstraint, recur(tp1.superTypeNormalized, tp2.superTypeNormalized))
else
isSubArgs(args1, args2, tp1, tparams)
}
Expand Down Expand Up @@ -1177,7 +1177,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
*/
def compareLower(tycon2bounds: TypeBounds, tyconIsTypeRef: Boolean): Boolean =
if ((tycon2bounds.lo `eq` tycon2bounds.hi) && !tycon2bounds.isInstanceOf[MatchAlias])
if (tyconIsTypeRef) recur(tp1, tp2.superType)
if (tyconIsTypeRef) recur(tp1, tp2.superTypeNormalized)
else isSubApproxHi(tp1, tycon2bounds.lo.applyIfParameterized(args2))
else
fallback(tycon2bounds.lo)
Expand Down Expand Up @@ -1249,11 +1249,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

!sym.isClass && {
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
recur(tp1.superType, tp2) ||
recur(tp1.superTypeNormalized, tp2) ||
tryLiftedToThis1
}|| byGadtBounds
case tycon1: TypeProxy =>
recur(tp1.superType, tp2)
recur(tp1.superTypeNormalized, tp2)
case _ =>
false
}
Expand Down Expand Up @@ -2645,9 +2645,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
!(tp2 <:< tp1)
&& (provablyDisjoint(tp1, tp2.tp2) || provablyDisjoint(tp1, tp2.tp1))
case (tp1: NamedType, _) if gadtBounds(tp1.symbol) != null =>
provablyDisjoint(gadtBounds(tp1.symbol).uncheckedNN.hi, tp2) || provablyDisjoint(tp1.superType, tp2)
provablyDisjoint(gadtBounds(tp1.symbol).uncheckedNN.hi, tp2)
|| provablyDisjoint(tp1.superTypeNormalized, tp2)
case (_, tp2: NamedType) if gadtBounds(tp2.symbol) != null =>
provablyDisjoint(tp1, gadtBounds(tp2.symbol).uncheckedNN.hi) || provablyDisjoint(tp1, tp2.superType)
provablyDisjoint(tp1, gadtBounds(tp2.symbol).uncheckedNN.hi)
|| provablyDisjoint(tp1, tp2.superTypeNormalized)
case (tp1: TermRef, tp2: TermRef) if isEnumValueOrModule(tp1) && isEnumValueOrModule(tp2) =>
tp1.termSymbol != tp2.termSymbol
case (tp1: TermRef, tp2: TypeRef) if isEnumValue(tp1) =>
Expand All @@ -2663,11 +2665,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case (tp1: Type, tp2: Type) if defn.isTupleNType(tp2) =>
provablyDisjoint(tp1, tp2.toNestedPairs)
case (tp1: TypeProxy, tp2: TypeProxy) =>
provablyDisjoint(tp1.superType, tp2) || provablyDisjoint(tp1, tp2.superType)
provablyDisjoint(tp1.superTypeNormalized, tp2) || provablyDisjoint(tp1, tp2.superTypeNormalized)
case (tp1: TypeProxy, _) =>
provablyDisjoint(tp1.superType, tp2)
provablyDisjoint(tp1.superTypeNormalized, tp2)
case (_, tp2: TypeProxy) =>
provablyDisjoint(tp1, tp2.superType)
provablyDisjoint(tp1, tp2.superTypeNormalized)
case _ =>
false
}
Expand Down
240 changes: 240 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeEval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package dotty.tools
package dotc
package core

import Types.*, Contexts.*, Symbols.*, Constants.*, Decorators.*
import config.Printers.typr
import reporting.trace
import StdNames.tpnme

object TypeEval:

def tryCompiletimeConstantFold(tp: AppliedType)(using Context): Type = tp.tycon match
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
extension (tp: Type) def fixForEvaluation: Type =
tp.normalized.dealias match
// enable operations for constant singleton terms. E.g.:
// ```
// final val one = 1
// type Two = one.type + one.type
// ```
case tp: TypeProxy if tp.underlying.isStable => tp.underlying.fixForEvaluation
case tp => tp

def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match
case ConstantType(Constant(n)) => Some(n)
case _ => None

def boolValue(tp: Type): Option[Boolean] = tp.fixForEvaluation match
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None

def intValue(tp: Type): Option[Int] = tp.fixForEvaluation match
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None

def longValue(tp: Type): Option[Long] = tp.fixForEvaluation match
case ConstantType(Constant(n: Long)) => Some(n)
case _ => None

def floatValue(tp: Type): Option[Float] = tp.fixForEvaluation match
case ConstantType(Constant(n: Float)) => Some(n)
case _ => None

def doubleValue(tp: Type): Option[Double] = tp.fixForEvaluation match
case ConstantType(Constant(n: Double)) => Some(n)
case _ => None

def stringValue(tp: Type): Option[String] = tp.fixForEvaluation match
case ConstantType(Constant(n: String)) => Some(n)
case _ => None

// Returns Some(true) if the type is a constant.
// Returns Some(false) if the type is not a constant.
// Returns None if there is not enough information to determine if the type is a constant.
// The type is a constant if it is a constant type or a type operation composition of constant types.
// If we get a type reference for an argument, then the result is not yet known.
def isConst(tp: Type): Option[Boolean] = tp.dealias match
// known to be constant
case ConstantType(_) => Some(true)
// currently not a concrete known type
case TypeRef(NoPrefix,_) => None
// currently not a concrete known type
case _: TypeParamRef => None
// constant if the term is constant
case t: TermRef => isConst(t.underlying)
// an operation type => recursively check all argument compositions
case applied: AppliedType if defn.isCompiletimeAppliedType(applied.typeSymbol) =>
val argsConst = applied.args.map(isConst)
if (argsConst.exists(_.isEmpty)) None
else Some(argsConst.forall(_.get))
// all other types are considered not to be constant
case _ => Some(false)

def expectArgsNum(expectedNum: Int): Unit =
// We can use assert instead of a compiler type error because this error should not
// occur since the type signature of the operation enforces the proper number of args.
assert(tp.args.length == expectedNum, s"Type operation expects $expectedNum arguments but found ${tp.args.length}")

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

// Runs the op and returns the result as a constant type.
// If the op throws an exception, then this exception is converted into a type error.
def runConstantOp(op: => Any): Type =
val result =
try op
catch case e: Throwable =>
throw new TypeError(e.getMessage.nn)
ConstantType(Constant(result))

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
expectArgsNum(1)
extractor(tp.args.head).map(a => runConstantOp(op(a)))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
constantFold2AB(extractor, extractor, op)

def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
expectArgsNum(2)
for
a <- extractorA(tp.args(0))
b <- extractorB(tp.args(1))
yield runConstantOp(op(a, b))

def constantFold3[TA, TB, TC](
extractorA: Type => Option[TA],
extractorB: Type => Option[TB],
extractorC: Type => Option[TC],
op: (TA, TB, TC) => Any
): Option[Type] =
expectArgsNum(3)
for
a <- extractorA(tp.args(0))
b <- extractorB(tp.args(1))
c <- extractorC(tp.args(2))
yield runConstantOp(op(a, b, c))

trace(i"compiletime constant fold $tp", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
val constantType =
if defn.isCompiletime_S(tycon.symbol) then
constantFold1(natValue, _ + 1)
else if owner == defn.CompiletimeOpsAnyModuleClass then name match
case tpnme.Equals => constantFold2(constValue, _ == _)
case tpnme.NotEquals => constantFold2(constValue, _ != _)
case tpnme.ToString => constantFold1(constValue, _.toString)
case tpnme.IsConst => isConst(tp.args.head).map(b => ConstantType(Constant(b)))
case _ => None
else if owner == defn.CompiletimeOpsIntModuleClass then name match
case tpnme.Abs => constantFold1(intValue, _.abs)
case tpnme.Negate => constantFold1(intValue, x => -x)
// ToString is deprecated for ops.int, and moved to ops.any
case tpnme.ToString => constantFold1(intValue, _.toString)
case tpnme.Plus => constantFold2(intValue, _ + _)
case tpnme.Minus => constantFold2(intValue, _ - _)
case tpnme.Times => constantFold2(intValue, _ * _)
case tpnme.Div => constantFold2(intValue, _ / _)
case tpnme.Mod => constantFold2(intValue, _ % _)
case tpnme.Lt => constantFold2(intValue, _ < _)
case tpnme.Gt => constantFold2(intValue, _ > _)
case tpnme.Ge => constantFold2(intValue, _ >= _)
case tpnme.Le => constantFold2(intValue, _ <= _)
case tpnme.Xor => constantFold2(intValue, _ ^ _)
case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
case tpnme.ASR => constantFold2(intValue, _ >> _)
case tpnme.LSL => constantFold2(intValue, _ << _)
case tpnme.LSR => constantFold2(intValue, _ >>> _)
case tpnme.Min => constantFold2(intValue, _ min _)
case tpnme.Max => constantFold2(intValue, _ max _)
case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
case tpnme.ToLong => constantFold1(intValue, _.toLong)
case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
case _ => None
else if owner == defn.CompiletimeOpsLongModuleClass then name match
case tpnme.Abs => constantFold1(longValue, _.abs)
case tpnme.Negate => constantFold1(longValue, x => -x)
case tpnme.Plus => constantFold2(longValue, _ + _)
case tpnme.Minus => constantFold2(longValue, _ - _)
case tpnme.Times => constantFold2(longValue, _ * _)
case tpnme.Div => constantFold2(longValue, _ / _)
case tpnme.Mod => constantFold2(longValue, _ % _)
case tpnme.Lt => constantFold2(longValue, _ < _)
case tpnme.Gt => constantFold2(longValue, _ > _)
case tpnme.Ge => constantFold2(longValue, _ >= _)
case tpnme.Le => constantFold2(longValue, _ <= _)
case tpnme.Xor => constantFold2(longValue, _ ^ _)
case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
case tpnme.ASR => constantFold2(longValue, _ >> _)
case tpnme.LSL => constantFold2(longValue, _ << _)
case tpnme.LSR => constantFold2(longValue, _ >>> _)
case tpnme.Min => constantFold2(longValue, _ min _)
case tpnme.Max => constantFold2(longValue, _ max _)
case tpnme.NumberOfLeadingZeros =>
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
case tpnme.ToInt => constantFold1(longValue, _.toInt)
case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
case _ => None
else if owner == defn.CompiletimeOpsFloatModuleClass then name match
case tpnme.Abs => constantFold1(floatValue, _.abs)
case tpnme.Negate => constantFold1(floatValue, x => -x)
case tpnme.Plus => constantFold2(floatValue, _ + _)
case tpnme.Minus => constantFold2(floatValue, _ - _)
case tpnme.Times => constantFold2(floatValue, _ * _)
case tpnme.Div => constantFold2(floatValue, _ / _)
case tpnme.Mod => constantFold2(floatValue, _ % _)
case tpnme.Lt => constantFold2(floatValue, _ < _)
case tpnme.Gt => constantFold2(floatValue, _ > _)
case tpnme.Ge => constantFold2(floatValue, _ >= _)
case tpnme.Le => constantFold2(floatValue, _ <= _)
case tpnme.Min => constantFold2(floatValue, _ min _)
case tpnme.Max => constantFold2(floatValue, _ max _)
case tpnme.ToInt => constantFold1(floatValue, _.toInt)
case tpnme.ToLong => constantFold1(floatValue, _.toLong)
case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
case _ => None
else if owner == defn.CompiletimeOpsDoubleModuleClass then name match
case tpnme.Abs => constantFold1(doubleValue, _.abs)
case tpnme.Negate => constantFold1(doubleValue, x => -x)
case tpnme.Plus => constantFold2(doubleValue, _ + _)
case tpnme.Minus => constantFold2(doubleValue, _ - _)
case tpnme.Times => constantFold2(doubleValue, _ * _)
case tpnme.Div => constantFold2(doubleValue, _ / _)
case tpnme.Mod => constantFold2(doubleValue, _ % _)
case tpnme.Lt => constantFold2(doubleValue, _ < _)
case tpnme.Gt => constantFold2(doubleValue, _ > _)
case tpnme.Ge => constantFold2(doubleValue, _ >= _)
case tpnme.Le => constantFold2(doubleValue, _ <= _)
case tpnme.Min => constantFold2(doubleValue, _ min _)
case tpnme.Max => constantFold2(doubleValue, _ max _)
case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
case _ => None
else if owner == defn.CompiletimeOpsStringModuleClass then name match
case tpnme.Plus => constantFold2(stringValue, _ + _)
case tpnme.Length => constantFold1(stringValue, _.length)
case tpnme.Matches => constantFold2(stringValue, _ matches _)
case tpnme.Substring =>
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
case tpnme.CharAt =>
constantFold2AB(stringValue, intValue, _ charAt _)
case _ => None
else if owner == defn.CompiletimeOpsBooleanModuleClass then name match
case tpnme.Not => constantFold1(boolValue, x => !x)
case tpnme.And => constantFold2(boolValue, _ && _)
case tpnme.Or => constantFold2(boolValue, _ || _)
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
case _ => None
else None

constantType.getOrElse(NoType)
}

case _ => NoType
end tryCompiletimeConstantFold
end TypeEval
Loading