Skip to content

Attempt to Simplify UnsafeNulls #11375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import core._
import util.Spans._, Types._, Contexts._, Constants._, Names._, Flags._, NameOps._
import Symbols._, StdNames._, Annotations._, Trees._, Symbols._
import Decorators._, DenotTransformers._
import Phases._
import collection.{immutable, mutable}
import util.{Property, SourceFile, NoSource}
import NameKinds.{TempResultName, OuterSelectName}
Expand Down Expand Up @@ -469,7 +470,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

/** The wrapped array method name for an array of type elemtp */
def wrapArrayMethodName(elemtp: Type)(using Context): TermName = {
val elemCls = elemtp.classSymbol
val elemCls = atPhase(erasurePhase.next) { elemtp.classSymbol }
if (elemCls.isPrimitiveValueClass) nme.wrapXArray(elemCls.name)
else if (elemCls.derivesFrom(defn.ObjectClass) && !elemCls.isNotRuntimeClass) nme.wrapRefArray
else nme.genericWrapArray
Expand Down
35 changes: 7 additions & 28 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -453,27 +453,11 @@ class Definitions {
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType))
def NothingType: TypeRef = NothingClass.typeRef
@tu lazy val NullClass: ClassSymbol = {
val parent = if (ctx.explicitNulls) AnyType else ObjectType
val parent = if ctx.explicitNulls then AnyType else ObjectType
enterCompleteClassSymbol(ScalaPackageClass, tpnme.Null, AbstractFinal, parent :: Nil)
}
def NullType: TypeRef = NullClass.typeRef

/** An alias for null values that originate in Java code.
* This type gets special treatment in the Typer. Specifically, `UncheckedNull` can be selected through:
* e.g.
* ```
* // x: String|Null
* x.length // error: `Null` has no `length` field
* // x2: String|UncheckedNull
* x2.length // allowed by the Typer, but unsound (might throw NPE)
* ```
*/
lazy val UncheckedNullAlias: TypeSymbol = {
assert(ctx.explicitNulls)
enterAliasType(tpnme.UncheckedNull, NullType)
}
def UncheckedNullAliasType: TypeRef = UncheckedNullAlias.typeRef

@tu lazy val ImplicitScrutineeTypeSym =
newPermanentSymbol(ScalaPackageClass, tpnme.IMPLICITkw, EmptyFlags, TypeBounds.empty).entered
def ImplicitScrutineeTypeRef: TypeRef = ImplicitScrutineeTypeSym.typeRef
Expand Down Expand Up @@ -634,7 +618,7 @@ class Definitions {
@tu lazy val StringModule: Symbol = StringClass.linkedClass
@tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final)
@tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match {
case List(pt) => pt.isAny || pt.isAnyRef
case List(pt) => pt.isAny || pt.stripNull.isAnyRef
case _ => false
}).symbol

Expand All @@ -646,15 +630,13 @@ class Definitions {
@tu lazy val ClassCastExceptionClass: ClassSymbol = requiredClass("java.lang.ClassCastException")
@tu lazy val ClassCastExceptionClass_stringConstructor: TermSymbol = ClassCastExceptionClass.info.member(nme.CONSTRUCTOR).suchThat(_.info.firstParamTypes match {
case List(pt) =>
val pt1 = if (ctx.explicitNulls) pt.stripNull() else pt
pt1.isRef(StringClass)
pt.stripNull.isRef(StringClass)
case _ => false
}).symbol.asTerm
@tu lazy val ArithmeticExceptionClass: ClassSymbol = requiredClass("java.lang.ArithmeticException")
@tu lazy val ArithmeticExceptionClass_stringConstructor: TermSymbol = ArithmeticExceptionClass.info.member(nme.CONSTRUCTOR).suchThat(_.info.firstParamTypes match {
case List(pt) =>
val pt1 = if (ctx.explicitNulls) pt.stripNull() else pt
pt1.isRef(StringClass)
pt.stripNull.isRef(StringClass)
case _ => false
}).symbol.asTerm

Expand Down Expand Up @@ -1236,7 +1218,7 @@ class Definitions {
idx == name.length || name(idx).isDigit && digitsOnlyAfter(name, idx + 1)

def isBottomClass(cls: Symbol): Boolean =
if (ctx.explicitNulls && !ctx.phase.erasedTypes) cls == NothingClass
if ctx.explicitNulls && !ctx.phase.erasedTypes then cls == NothingClass
else isBottomClassAfterErasure(cls)

def isBottomClassAfterErasure(cls: Symbol): Boolean = cls == NothingClass || cls == NullClass
Expand Down Expand Up @@ -1700,8 +1682,8 @@ class Definitions {
// ----- Initialization ---------------------------------------------------

/** Lists core classes that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
@tu lazy val syntheticScalaClasses: List[TypeSymbol] = {
val synth = List(
@tu lazy val syntheticScalaClasses: List[TypeSymbol] =
List(
AnyClass,
MatchableClass,
AnyRefAlias,
Expand All @@ -1715,9 +1697,6 @@ class Definitions {
NothingClass,
SingletonClass)

if (ctx.explicitNulls) synth :+ UncheckedNullAlias else synth
}

@tu lazy val syntheticCoreClasses: List[Symbol] = syntheticScalaClasses ++ List(
EmptyPackageVal,
OpsPackageClass)
Expand Down
63 changes: 32 additions & 31 deletions compiler/src/dotty/tools/dotc/core/JavaNullInterop.scala
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
package dotty.tools.dotc.core
package dotty.tools.dotc
package core

import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Flags.JavaDefined
import dotty.tools.dotc.core.StdNames.{jnme, nme}
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Types._
import config.Feature._
import Contexts._
import Flags.JavaDefined
import NullOpsDecorator._
import StdNames.nme
import Symbols._
import Types._

/** This module defines methods to interpret types of Java symbols, which are implicitly nullable in Java,
* as Scala types, which are explicitly nullable.
*
* The transformation is (conceptually) a function `n` that adheres to the following rules:
* (1) n(T) = T|UncheckedNull if T is a reference type
* (1) n(T) = T | Null if T is a reference type
* (2) n(T) = T if T is a value type
* (3) n(C[T]) = C[T]|UncheckedNull if C is Java-defined
* (4) n(C[T]) = C[n(T)]|UncheckedNull if C is Scala-defined
* (5) n(A|B) = n(A)|n(B)|UncheckedNull
* (3) n(C[T]) = C[T] | Null if C is Java-defined
* (4) n(C[T]) = C[n(T)] | Null if C is Scala-defined
* (5) n(A|B) = n(A) | n(B) | Null
* (6) n(A&B) = n(A) & n(B)
* (7) n((A1, ..., Am)R) = (n(A1), ..., n(Am))n(R) for a method with arguments (A1, ..., Am) and return type R
* (8) n(T) = T otherwise
*
* Treatment of generics (rules 3 and 4):
* - if `C` is Java-defined, then `n(C[T]) = C[T]|UncheckedNull`. That is, we don't recurse
* on the type argument, and only add UncheckedNull on the outside. This is because
* - if `C` is Java-defined, then `n(C[T]) = C[T] | Null`. That is, we don't recurse
* on the type argument, and only add Null on the outside. This is because
* `C` itself will be nullified, and in particular so will be usages of `C`'s type argument within C's body.
* e.g. calling `get` on a `java.util.List[String]` already returns `String|Null` and not `String`, so
* we don't need to write `java.util.List[String|Null]`.
* - if `C` is Scala-defined, however, then we want `n(C[T]) = C[n(T)]|UncheckedNull`. This is because
* we don't need to write `java.util.List[String | Null]`.
* - if `C` is Scala-defined, however, then we want `n(C[T]) = C[n(T)] | Null`. This is because
* `C` won't be nullified, so we need to indicate that its type argument is nullable.
*
* Notice that since the transformation is only applied to types attached to Java symbols, it doesn't need
Expand All @@ -43,10 +45,9 @@ object JavaNullInterop {
*
* After calling `nullifyMember`, Scala will see the method as
*
* def foo(arg: String|UncheckedNull): String|UncheckedNull
* def foo(arg: String | Null): String | Null
*
* This nullability function uses `UncheckedNull` instead of vanilla `Null`, for usability.
* This means that we can select on the return of `foo`:
* If unsafeNulls is enabled, we can select on the return of `foo`:
*
* val len = foo("hello").length
*
Expand All @@ -57,10 +58,10 @@ object JavaNullInterop {
assert(sym.is(JavaDefined), "can only nullify java-defined members")

// Some special cases when nullifying the type
if (isEnumValueDef || sym.name == nme.TYPE_)
if isEnumValueDef || sym.name == nme.TYPE_ then
// Don't nullify the `TYPE` field in every class and Java enum instances
tp
else if (sym.name == nme.toString_ || sym.isConstructor || hasNotNullAnnot(sym))
else if sym.name == nme.toString_ || sym.isConstructor || hasNotNullAnnot(sym) then
// Don't nullify the return type of the `toString` method.
// Don't nullify the return type of constructors.
// Don't nullify the return type of methods with a not-null annotation.
Expand All @@ -81,20 +82,20 @@ object JavaNullInterop {
private def nullifyExceptReturnType(tp: Type)(using Context): Type =
new JavaNullMap(true)(tp)

/** Nullifies a Java type by adding `| UncheckedNull` in the relevant places. */
/** Nullifies a Java type by adding `| Null` in the relevant places. */
private def nullifyType(tp: Type)(using Context): Type =
new JavaNullMap(false)(tp)

/** A type map that implements the nullification function on types. Given a Java-sourced type, this adds `| UncheckedNull`
/** A type map that implements the nullification function on types. Given a Java-sourced type, this adds `| Null`
* in the right places to make the nulls explicit in Scala.
*
* @param outermostLevelAlreadyNullable whether this type is already nullable at the outermost level.
* For example, `Array[String]|UncheckedNull` is already nullable at the
* outermost level, but `Array[String|UncheckedNull]` isn't.
* For example, `Array[String] | Null` is already nullable at the
* outermost level, but `Array[String | Null]` isn't.
* If this parameter is set to true, then the types of fields, and the return
* types of methods will not be nullified.
* This is useful for e.g. constructors, and also so that `A & B` is nullified
* to `(A & B) | UncheckedNull`, instead of `(A|UncheckedNull & B|UncheckedNull) | UncheckedNull`.
* to `(A & B) | Null`, instead of `(A | Null & B | Null) | Null`.
*/
private class JavaNullMap(var outermostLevelAlreadyNullable: Boolean)(using Context) extends TypeMap {
/** Should we nullify `tp` at the outermost level? */
Expand All @@ -107,15 +108,15 @@ object JavaNullInterop {
!tp.isRef(defn.AnyClass) &&
// We don't nullify Java varargs at the top level.
// Example: if `setNames` is a Java method with signature `void setNames(String... names)`,
// then its Scala signature will be `def setNames(names: (String|UncheckedNull)*): Unit`.
// then its Scala signature will be `def setNames(names: (String|Null)*): Unit`.
// This is because `setNames(null)` passes as argument a single-element array containing the value `null`,
// and not a `null` array.
!tp.isRef(defn.RepeatedParamClass)
case _ => true
})

override def apply(tp: Type): Type = tp match {
case tp: TypeRef if needsNull(tp) => OrUncheckedNull(tp)
case tp: TypeRef if needsNull(tp) => OrNull(tp)
case appTp @ AppliedType(tycon, targs) =>
val oldOutermostNullable = outermostLevelAlreadyNullable
// We don't make the outmost levels of type arguments nullable if tycon is Java-defined.
Expand All @@ -125,7 +126,7 @@ object JavaNullInterop {
val targs2 = targs map this
outermostLevelAlreadyNullable = oldOutermostNullable
val appTp2 = derivedAppliedType(appTp, tycon, targs2)
if (needsNull(tycon)) OrUncheckedNull(appTp2) else appTp2
if needsNull(tycon) then OrNull(appTp2) else appTp2
case ptp: PolyType =>
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
case mtp: MethodType =>
Expand All @@ -136,11 +137,11 @@ object JavaNullInterop {
derivedLambdaType(mtp)(paramInfos2, this(mtp.resType))
case tp: TypeAlias => mapOver(tp)
case tp: AndType =>
// nullify(A & B) = (nullify(A) & nullify(B)) | UncheckedNull, but take care not to add
// duplicate `UncheckedNull`s at the outermost level inside `A` and `B`.
// nullify(A & B) = (nullify(A) & nullify(B)) | Null, but take care not to add
// duplicate `Null`s at the outermost level inside `A` and `B`.
outermostLevelAlreadyNullable = true
OrUncheckedNull(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
case tp: TypeParamRef if needsNull(tp) => OrUncheckedNull(tp)
OrNull(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
case tp: TypeParamRef if needsNull(tp) => OrNull(tp)
// In all other cases, return the type unchanged.
// In particular, if the type is a ConstantType, then we don't nullify it because it is the
// type of a final non-nullable field.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ object Mode {
/** Are we resolving a TypeTest node? */
val InTypeTest: Mode = newMode(27, "InTypeTest")

/** Are we enforcing null safety */
/** Are we enforcing null safety? */
val SafeNulls = newMode(28, "SafeNulls")

/** We are typing the body of the condition of an `inline if` or the scrutinee of an `inline match`
Expand Down
118 changes: 47 additions & 71 deletions compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala
Original file line number Diff line number Diff line change
@@ -1,86 +1,62 @@
package dotty.tools.dotc.core
package dotty.tools.dotc
package core

import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Symbols.defn
import dotty.tools.dotc.core.Types._
import ast.Trees._
import Contexts._
import Symbols.defn
import Types._

/** Defines operations on nullable types. */
object NullOpsDecorator {

extension (self: Type) {
/** Is this type exactly `UncheckedNull` (no vars, aliases, refinements etc allowed)? */
def isUncheckedNullType(using Context): Boolean = {
assert(ctx.explicitNulls)
// We can't do `self == defn.UncheckedNull` because when trees are unpickled new references
// to `UncheckedNull` could be created that are different from `defn.UncheckedNull`.
// Instead, we compare the symbol.
self.isDirectRef(defn.UncheckedNullAlias)
}
/** Defines operations on nullable types and tree. */
object NullOpsDecorator:

extension (self: Type)
/** Syntactically strips the nullability from this type.
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null` (or `UncheckedNull`),
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
* then return `T1 | ... | Ti-1 | Ti+1 | ... | Tn`.
* If this type isn't (syntactically) nullable, then returns the type unchanged.
*
* @param onlyUncheckedNull whether we only remove `UncheckedNull`, the default value is false
* The type will not be changed if explicit-nulls is not enabled.
*/
def stripNull(onlyUncheckedNull: Boolean = false)(using Context): Type = {
assert(ctx.explicitNulls)

def isNull(tp: Type) =
if (onlyUncheckedNull) tp.isUncheckedNullType
else tp.isNullType

def strip(tp: Type): Type = tp match {
case tp @ OrType(lhs, rhs) =>
val llhs = strip(lhs)
val rrhs = strip(rhs)
if (isNull(rrhs)) llhs
else if (isNull(llhs)) rrhs
else tp.derivedOrType(llhs, rrhs)
case tp @ AndType(tp1, tp2) =>
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
// since `stripNull((A | Null) & B)` would produce the wrong
// result `(A & B) | Null`.
val tp1s = strip(tp1)
val tp2s = strip(tp2)
if((tp1s ne tp1) && (tp2s ne tp2))
tp.derivedAndType(tp1s, tp2s)
else tp
case _ => tp
}

val self1 = self.widenDealias
val stripped = strip(self1)
if (stripped ne self1) stripped else self
}

/** Like `stripNull`, but removes only the `UncheckedNull`s. */
def stripUncheckedNull(using Context): Type = self.stripNull(true)

/** Collapses all `UncheckedNull` unions within this type, and not just the outermost ones (as `stripUncheckedNull` does).
* e.g. (Array[String|UncheckedNull]|UncheckedNull).stripUncheckedNull => Array[String|UncheckedNull]
* (Array[String|UncheckedNull]|UncheckedNull).stripAllUncheckedNull => Array[String]
* If no `UncheckedNull` unions are found within the type, then returns the input type unchanged.
*/
def stripAllUncheckedNull(using Context): Type = {
object RemoveNulls extends TypeMap {
override def apply(tp: Type): Type = mapOver(tp.stripNull(true))
}
val rem = RemoveNulls(self)
if (rem ne self) rem else self
def stripNull(using Context): Type = {
def strip(tp: Type): Type =
val tpWiden = tp.widenDealias
val tpStripped = tpWiden match {
case tp @ OrType(lhs, rhs) =>
val llhs = strip(lhs)
val rrhs = strip(rhs)
if rrhs.isNullType then llhs
else if llhs.isNullType then rrhs
else tp.derivedOrType(llhs, rrhs)
case tp @ AndType(tp1, tp2) =>
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
// since `stripNull((A | Null) & B)` would produce the wrong
// result `(A & B) | Null`.
val tp1s = strip(tp1)
val tp2s = strip(tp2)
if (tp1s ne tp1) && (tp2s ne tp2) then
tp.derivedAndType(tp1s, tp2s)
else tp
case tp @ TypeBounds(lo, hi) =>
tp.derivedTypeBounds(strip(lo), strip(hi))
case tp => tp
}
if tpStripped ne tpWiden then tpStripped else tp

if ctx.explicitNulls then strip(self) else self
}

/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
def isNullableUnion(using Context): Boolean = {
val stripped = self.stripNull()
val stripped = self.stripNull
stripped ne self
}
end extension

/** Is self (after widening and dealiasing) a type of the form `T | UncheckedNull`? */
def isUncheckedNullableUnion(using Context): Boolean = {
val stripped = self.stripNull(true)
stripped ne self
import ast.tpd._

extension (self: Tree)
// cast the type of the tree to a non-nullable type
def castToNonNullable(using Context): Tree = self.typeOpt match {
case OrNull(tp) => self.cast(tp)
case _ => self
}
}
}
end NullOpsDecorator
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ object StdNames {
final val Nothing: N = "Nothing"
final val NotNull: N = "NotNull"
final val Null: N = "Null"
final val UncheckedNull: N = "UncheckedNull"
final val Object: N = "Object"
final val FromJavaObject: N = "<FromJavaObject>"
final val Product: N = "Product"
Expand Down
Loading