Skip to content

Commit 4f1fc94

Browse files
committed
Add Precise type class for precise type inference
1 parent 19bbf49 commit 4f1fc94

15 files changed

+257
-51
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+2
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,8 @@ class Definitions {
535535
def ConsType: TypeRef = ConsClass.typeRef
536536
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
537537

538+
@tu lazy val PreciseClass: ClassSymbol = requiredClass("scala.Precise")
539+
538540
@tu lazy val SingletonClass: ClassSymbol =
539541
// needed as a synthetic class because Scala 2.x refers to it in classfiles
540542
// but does not define it as an explicit class.

compiler/src/dotty/tools/dotc/core/Types.scala

+12-2
Original file line numberDiff line numberDiff line change
@@ -4873,7 +4873,7 @@ object Types extends TypeUtils {
48734873
initOrigin: TypeParamRef,
48744874
creatorState: TyperState | Null,
48754875
val initNestingLevel: Int,
4876-
precise: Boolean) extends CachedProxyType with ValueType {
4876+
val precise: Boolean) extends CachedProxyType with ValueType {
48774877
private var currentOrigin = initOrigin
48784878

48794879
def origin: TypeParamRef = currentOrigin
@@ -4968,9 +4968,19 @@ object Types extends TypeUtils {
49684968
else
49694969
instantiateWith(tp)
49704970

4971+
def isPrecise(using Context) =
4972+
precise
4973+
|| {
4974+
val constr = ctx.typerState.constraint
4975+
constr.upper(origin).exists: tparam =>
4976+
constr.typeVarOfParam(tparam) match
4977+
case tvar: TypeVar => tvar.precise
4978+
case _ => false
4979+
}
4980+
49714981
/** Widen unions when instantiating this variable in the current context? */
49724982
def widenPolicy(using Context): Widen =
4973-
if precise then Widen.None
4983+
if isPrecise then Widen.None
49744984
else if ctx.typerState.constraint.isHard(this) then Widen.Singletons
49754985
else Widen.Unions
49764986

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+22-13
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,18 @@ object ProtoTypes {
701701
case FunProto((arg: untpd.TypedSplice) :: Nil, _) => arg.isExtensionReceiver
702702
case _ => false
703703

704-
object SingletonConstrained:
705-
def unapply(tp: Type)(using Context): Option[Type] = tp.dealias match
706-
case RefinedType(parent, tpnme.Self, TypeAlias(tp))
707-
if parent.typeSymbol == defn.SingletonClass => Some(tp)
704+
/** An extractor for Singleton and Precise witness types.
705+
*
706+
* Singleton { type Self = T } returns Some(T, true)
707+
* Precise { type Self = T } returns Some(T, false)
708+
*/
709+
object PreciseConstrained:
710+
def unapply(tp: Type)(using Context): Option[(Type, Boolean)] = tp.dealias match
711+
case RefinedType(parent, tpnme.Self, TypeAlias(tp)) =>
712+
val tsym = parent.typeSymbol
713+
if tsym == defn.SingletonClass then Some((tp, true))
714+
else if tsym == defn.PreciseClass then Some((tp, false))
715+
else None
708716
case _ => None
709717

710718
/** Add all parameters of given type lambda `tl` to the constraint's domain.
@@ -728,30 +736,31 @@ object ProtoTypes {
728736
// hk type lambdas can be added to constraints without typevars during match reduction
729737
val added = state.constraint.ensureFresh(tl)
730738

731-
def singletonConstrainedRefs(tp: Type): Set[TypeParamRef] = tp match
739+
def preciseConstrainedRefs(tp: Type, singletonOnly: Boolean): Set[TypeParamRef] = tp match
732740
case tp: MethodType if tp.isContextualMethod =>
733741
val ownBounds =
734-
for case SingletonConstrained(ref: TypeParamRef) <- tp.paramInfos
742+
for
743+
case PreciseConstrained(ref: TypeParamRef, singleton) <- tp.paramInfos
744+
if !singletonOnly || singleton
735745
yield ref
736-
ownBounds.toSet ++ singletonConstrainedRefs(tp.resType)
746+
ownBounds.toSet ++ preciseConstrainedRefs(tp.resType, singletonOnly)
737747
case tp: LambdaType =>
738-
singletonConstrainedRefs(tp.resType)
748+
preciseConstrainedRefs(tp.resType, singletonOnly)
739749
case _ =>
740750
Set.empty
741751

742-
val singletonRefs = singletonConstrainedRefs(added)
743-
def isSingleton(ref: TypeParamRef) = singletonRefs.contains(ref)
744-
745752
def newTypeVars: List[TypeVar] =
753+
val preciseRefs = preciseConstrainedRefs(added, singletonOnly = false)
746754
for paramRef <- added.paramRefs yield
747-
val tvar = TypeVar(paramRef, state, nestingLevel, precise = isSingleton(paramRef))
755+
val tvar = TypeVar(paramRef, state, nestingLevel, precise = preciseRefs.contains(paramRef))
748756
state.ownedVars += tvar
749757
tvar
750758

751759
val tvars = if addTypeVars then newTypeVars else Nil
752760
TypeComparer.addToConstraint(added, tvars)
761+
val singletonRefs = preciseConstrainedRefs(added, singletonOnly = true)
753762
for paramRef <- added.paramRefs do
754-
if isSingleton(paramRef) then paramRef <:< defn.SingletonType
763+
if singletonRefs.contains(paramRef) then paramRef <:< defn.SingletonType
755764
(added, tvars)
756765
end constrained
757766

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
237237
end synthesizedValueOf
238238

239239
val synthesizedSingleton: SpecialHandler = (formal, span) => formal match
240-
case SingletonConstrained(tp) =>
240+
case PreciseConstrained(tp, true) =>
241241
if tp.isSingletonBounded(frozen = false) then
242242
withNoErrors:
243243
ref(defn.Compiletime_erasedValue).appliedToType(formal).withSpan(span)
@@ -246,6 +246,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
246246
case _ =>
247247
EmptyTreeNoError
248248

249+
val synthesizedPrecise: SpecialHandler = (formal, span) => formal match
250+
case PreciseConstrained(tp, false) =>
251+
withNoErrors:
252+
ref(defn.Compiletime_erasedValue).appliedToType(formal).withSpan(span)
253+
case _ =>
254+
EmptyTreeNoError
255+
249256
/** Create an anonymous class `new Object { type MirroredMonoType = ... }`
250257
* and mark it with given attachment so that it is made into a mirror at PostTyper.
251258
*/
@@ -746,6 +753,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
746753
defn.ManifestClass -> synthesizedManifest,
747754
defn.OptManifestClass -> synthesizedOptManifest,
748755
defn.SingletonClass -> synthesizedSingleton,
756+
defn.PreciseClass -> synthesizedPrecise,
749757
)
750758

751759
def tryAll(formal: Type, span: Span)(using Context): TreeWithErrors =

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -2940,7 +2940,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
29402940
cpy.Select(id)(This(cls), id.name)
29412941
case _ =>
29422942
super.transform(tree)
2943-
ValDef(impl, anchorParams.transform(rhs))
2943+
ValDef(impl, anchorParams.transform(rhs)).withSpan(impl.span.endPos)
29442944
end givenImpl
29452945

29462946
val givenImpls =

compiler/test/dotty/tools/repl/TabcompleteTests.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ class TabcompleteTests extends ReplTest {
122122
}
123123

124124
@Test def moduleCompletion = initially {
125-
assertEquals(List("Predef"), tabComplete("object Foo { type T = Pre"))
125+
assertEquals(List("Predef"), tabComplete("object Foo { type T = Pred"))
126126
}
127127

128128
@Test def i6415 = initially {
129-
assertEquals(List("Predef"), tabComplete("object Foo { opaque type T = Pre"))
129+
assertEquals(List("Predef"), tabComplete("object Foo { opaque type T = Pred"))
130130
}
131131

132132
@Test def i6361 = initially {

docs/_docs/reference/experimental/typeclasses.md

+35-30
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,39 @@ This is less of a disruption than it might appear at first:
444444
- Simplification of the language since a feature is dropped
445445
- Eliminate non-obvious and misleading syntax.
446446

447+
448+
### Bonus: Fixing Singleton
449+
450+
We know the current treatment of `Singleton` as a type bound is broken since
451+
`x.type | y.type <: Singleton` holds by the subtyping rules for union types, even though `x.type | y.type` is clearly not a singleton.
452+
453+
A better approach is to treat `Singleton` as a type class that is interpreted specially by the compiler.
454+
455+
We can do this in a backwards-compatible way by defining `Singleton` like this:
456+
457+
```scala
458+
trait Singleton:
459+
type Self
460+
```
461+
462+
Then, instead of using an unsound upper bound we can use a context bound:
463+
464+
```scala
465+
def f[X: Singleton](x: X) = ...
466+
```
467+
468+
The context bound is treated specially by the compiler so that no using clause is generated at runtime (this is straightforward, using the erased definitions mechanism).
469+
470+
### Bonus: Precise Typing
471+
472+
This approach also presents a solution to the problem how to express precise type variables. We can introduce another special type class `Precise` and use it like this:
473+
474+
```scala
475+
def f[X: Precise](x: X) = ...
476+
```
477+
Like a `Singleton` bound, a `Precise` bound disables automatic widening of singleton types or union types in inferred instances of type variable `X`. But there is no requirement that the type argument _must_ be a singleton.
478+
479+
447480
## Summary of Syntax Changes
448481

449482
Here is the complete context-free syntax for all proposed features.
@@ -692,38 +725,10 @@ Dimi Racordon tried to [port some core elements](https://github.com/kyouko-taiga
692725

693726
With the improvements proposed here, the library can now be expressed quite clearly and straightforwardly. See tests/pos/hylolib in this PR for details.
694727

695-
## Suggested Improvements unrelated to Type Classes
696-
697-
The following two improvements elsewhere would make sense alongside the suggested changes to type classes. But only the first (fixing singleton) forms a part of this proposal and is implemented.
698-
699-
### Fixing Singleton
700-
701-
We know the current treatment of `Singleton` as a type bound is broken since
702-
`x.type | y.type <: Singleton` holds by the subtyping rules for union types, even though `x.type | y.type` is clearly not a singleton.
703-
704-
A better approach is to treat `Singleton` as a type class that is interpreted specially by the compiler.
728+
## Suggested Improvement unrelated to Type Classes
705729

706-
We can do this in a backwards-compatible way by defining `Singleton` like this:
730+
The following improvement would make sense alongside the suggested changes to type classes. But it does not form part of this proposal and is not yet implemented.
707731

708-
```scala
709-
trait Singleton:
710-
type Self
711-
```
712-
713-
Then, instead of using an unsound upper bound we can use a context bound:
714-
715-
```scala
716-
def f[X: Singleton](x: X) = ...
717-
```
718-
719-
The context bound is treated specially by the compiler so that no using clause is generated at runtime (this is straightforward, using the erased definitions mechanism).
720-
721-
_Aside_: This can also lead to a solution how to express precise type variables. We can introduce another special type class `Precise` and use it like this:
722-
723-
```scala
724-
def f[X: Precise](x: X) = ...
725-
```
726-
This would disable automatic widening of singleton types in inferred instances of type variable `X`.
727732

728733
### Using `as` also in Patterns
729734

library/src/scala/Precise.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package scala
2+
import annotation.experimental
3+
import language.experimental.erasedDefinitions
4+
5+
/** A type class-like trait intended as a context bound for type variables.
6+
* If we have `[X: Precise]`, instances of the type variable `X` are inferred
7+
* in precise mode. This means that singleton types and union types are not
8+
* widened.
9+
*/
10+
@experimental erased trait Precise:
11+
type Self

tests/neg/singleton-ctx-bound.check

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
-- [E007] Type Mismatch Error: tests/neg/singleton-ctx-bound.scala:7:5 -------------------------------------------------
2+
7 | f1(someInt) // error
3+
| ^^^^^^^
4+
| Found: Int
5+
| Required: Singleton
6+
|
7+
| longer explanation available when compiling with `-explain`
8+
-- [E007] Type Mismatch Error: tests/neg/singleton-ctx-bound.scala:12:5 ------------------------------------------------
9+
12 | f2(someInt) // error
10+
| ^^^^^^^
11+
| Found: Int
12+
| Required: Singleton
13+
|
14+
| longer explanation available when compiling with `-explain`
15+
-- [E172] Type Error: tests/neg/singleton-ctx-bound.scala:13:26 --------------------------------------------------------
16+
13 | f2(if ??? then 1 else 2) // error
17+
| ^
18+
|No given instance of type (1 : Int) | (2 : Int) is Singleton was found for parameter x$2 of method f2 in object Test. Failed to synthesize an instance of type (1 : Int) | (2 : Int) is Singleton: (1 : Int) | (2 : Int) is not a singleton
19+
-- [E007] Type Mismatch Error: tests/neg/singleton-ctx-bound.scala:17:5 ------------------------------------------------
20+
17 | f3(someInt) // error
21+
| ^^^^^^^
22+
| Found: Int
23+
| Required: Singleton
24+
|
25+
| longer explanation available when compiling with `-explain`
26+
-- [E172] Type Error: tests/neg/singleton-ctx-bound.scala:18:26 --------------------------------------------------------
27+
18 | f3(if ??? then 1 else 2) // error
28+
| ^
29+
|No given instance of type Singleton{type Self = (1 : Int) | (2 : Int)} was found for a context parameter of method f3 in object Test. Failed to synthesize an instance of type Singleton{type Self = (1 : Int) | (2 : Int)}: (1 : Int) | (2 : Int) is not a singleton
30+
-- [E172] Type Error: tests/neg/singleton-ctx-bound.scala:33:6 ---------------------------------------------------------
31+
33 |class D extends A: // error
32+
|^
33+
|No given instance of type Singleton{type Self = D.this.Elem} was found for inferring the implementation of the deferred given instance given_Singleton_Elem in trait A. Failed to synthesize an instance of type Singleton{type Self = D.this.Elem}: D.this.Elem is not a singleton
34+
34 | type Elem = Int

tests/neg/singleton-ctx-bound.scala

+15
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,18 @@ object Test:
1818
f3(if ??? then 1 else 2) // error
1919
f3(3 * 2) // OK
2020
f3(6) // OK
21+
22+
import compiletime.*
23+
24+
trait A:
25+
type Elem: Singleton
26+
27+
class B extends A:
28+
type Elem = 1 // OK
29+
30+
class C[X: Singleton] extends A:
31+
type Elem = X // OK
32+
33+
class D extends A: // error
34+
type Elem = Int
35+
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//> using options -language:experimental.modularity -source future
2+
import compiletime.*
3+
4+
trait A:
5+
type Elem: Singleton
6+
7+
class B extends A:
8+
type Elem = 1
9+
10+
class C[X: Singleton] extends A:
11+
type Elem = X
12+
13+

tests/pos/precise-ctx-bound.scala

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//> using options -language:experimental.modularity -source future
2+
object Test:
3+
4+
class Wrap[T](x: T)
5+
6+
def f0[T](x: T): Wrap[T] = Wrap(x)
7+
val x0 = f0(1)
8+
val _: Wrap[Int] = x0
9+
10+
def f1[T: Precise](x: T): Wrap[T] = Wrap(x)
11+
def l = "hello".length
12+
val x1 = Wrap(l)
13+
val _: Wrap[Int] = x1
14+
15+
def f2[T](x: T)(using Precise { type Self = T}): Wrap[T] = Wrap(x)
16+
val x2 = f2(1)
17+
val _: Wrap[1] = x2
18+
19+
def f3[T: Precise](x: T): Wrap[T] = Wrap(x)
20+
val x3 = f3(identity(1))
21+
val _: Wrap[1] = x3
22+
23+
def f4[T](x: T)(using T is Precise): Wrap[T] = Wrap(x)
24+
val x4 = f4(1)
25+
val _: Wrap[1] = x4
26+
val y4 = f4(if ??? then 1 else 2)
27+
val _: Wrap[1 | 2] = y4
28+
val z4 = f4(if ??? then B() else C())
29+
val _: Wrap[B | C] = z4
30+
trait A
31+
class B extends A
32+
class C extends A
33+
34+
class C0[T](x: T):
35+
def fld: T = x
36+
val y0 = C0("hi")
37+
val _: String = y0.fld
38+
39+
class C2[T](x: T)(using T is Precise):
40+
def fld: T = x
41+
val y2 = C2(identity("hi"))
42+
val _: "hi" = y2.fld
43+
44+
class C3[T: Precise](x: T):
45+
def fld: T = x
46+
val y3 = C3("hi")
47+
val _: "hi" = y3.fld

0 commit comments

Comments
 (0)