Skip to content

Commit 9c36a76

Browse files
committed
[sammy] eta-expansion, overloading (SI-8310)
Playing with Java 8 Streams from the repl showed we weren't eta-expanding, nor resolving overloading for SAMs. Also, the way Java uses wildcards to represent use-site variance stresses type inference past its bendiness point (due to excessive existentials). I introduce `wildcardExtrapolation` to simplify the resulting types (without losing precision): `wildcardExtrapolation(tp) =:= tp`. For example, the `MethodType` given by `def bla(x: (_ >: String)): (_ <: Int)` is both a subtype and a supertype of `def bla(x: String): Int`. Translating http://winterbe.com/posts/2014/07/31/java8-stream-tutorial-examples/ into Scala shows most of this works, though we have some more work to do (see near the end). ``` scala> import java.util.Arrays scala> import java.util.stream.Stream scala> import java.util.stream.IntStream scala> val myList = Arrays.asList("a1", "a2", "b1", "c2", "c1") myList: java.util.List[String] = [a1, a2, b1, c2, c1] scala> myList.stream.filter(_.startsWith("c")).map(_.toUpperCase).sorted.forEach(println) C1 C2 scala> myList.stream.filter(_.startsWith("c")).map(_.toUpperCase).sorted res8: java.util.stream.Stream[?0] = java.util.stream.SortedOps$OfRef@133e7789 scala> Arrays.asList("a1", "a2", "a3").stream.findFirst.ifPresent(println) a1 scala> Stream.of("a1", "a2", "a3").findFirst.ifPresent(println) a1 scala> IntStream.range(1, 4).forEach(println) <console>:37: error: object creation impossible, since method accept in trait IntConsumer of type (x$1: Int)Unit is not defined (Note that Int does not match Any: class Int in package scala is a subclass of class Any in package scala, but method parameter types must match exactly.) IntStream.range(1, 4).forEach(println) ^ scala> IntStream.range(1, 4).forEach(println(_: Int)) // TODO: can we avoid this annotation? 1 2 3 scala> Arrays.stream(Array(1, 2, 3)).map(n => 2 * n + 1).average.ifPresent(println(_: Double)) 5.0 scala> Stream.of("a1", "a2", "a3").map(_.substring(1)).mapToInt(_.parseInt).max.ifPresent(println(_: Int)) // whoops! ReplGlobal.abort: Unknown type: <error>, <error> [class scala.reflect.internal.Types$ErrorType$, class scala.reflect.internal.Types$ErrorType$] TypeRef? false error: Unknown type: <error>, <error> [class scala.reflect.internal.Types$ErrorType$, class scala.reflect.internal.Types$ErrorType$] TypeRef? false scala.reflect.internal.FatalError: Unknown type: <error>, <error> [class scala.reflect.internal.Types$ErrorType$, class scala.reflect.internal.Types$ErrorType$] TypeRef? false at scala.reflect.internal.Reporting$class.abort(Reporting.scala:59) scala> IntStream.range(1, 4).mapToObj(i => "a" + i).forEach(println) a1 a2 a3 ```
1 parent ae70f02 commit 9c36a76

File tree

11 files changed

+108
-11
lines changed

11 files changed

+108
-11
lines changed

src/compiler/scala/tools/nsc/typechecker/Infer.scala

+6
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,17 @@ trait Infer extends Checkable {
295295
&& !isByNameParamType(tp)
296296
&& isCompatible(tp, dropByName(pt))
297297
)
298+
def isCompatibleSam(tp: Type, pt: Type): Boolean = {
299+
val samFun = typer.samToFunctionType(pt)
300+
(samFun ne NoType) && isCompatible(tp, samFun)
301+
}
302+
298303
val tp1 = normalize(tp)
299304

300305
( (tp1 weak_<:< pt)
301306
|| isCoercible(tp1, pt)
302307
|| isCompatibleByName(tp, pt)
308+
|| isCompatibleSam(tp, pt)
303309
)
304310
}
305311
def isCompatibleArgs(tps: List[Type], pts: List[Type]) = (tps corresponds pts)(isCompatible)

src/compiler/scala/tools/nsc/typechecker/Typers.scala

+26-10
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,26 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
741741
case _ =>
742742
}
743743

744+
/**
745+
* Convert a SAM type to the corresponding FunctionType,
746+
* extrapolating BoundedWildcardTypes in the process
747+
* (no type precision is lost by the extrapolation,
748+
* but this facilitates dealing with the types arising from Java's use-site variance).
749+
*/
750+
def samToFunctionType(tp: Type, sam: Symbol = NoSymbol): Type = {
751+
val samSym = sam orElse samOf(tp)
752+
753+
def correspondingFunctionSymbol = {
754+
val numVparams = samSym.info.params.length
755+
if (numVparams > definitions.MaxFunctionArity) NoSymbol
756+
else FunctionClass(numVparams)
757+
}
758+
759+
if (samSym.exists && samSym.owner != correspondingFunctionSymbol) // don't treat Functions as SAMs
760+
wildcardExtrapolation(normalize(tp memberInfo samSym))
761+
else NoType
762+
}
763+
744764
/** Perform the following adaptations of expression, pattern or type `tree` wrt to
745765
* given mode `mode` and given prototype `pt`:
746766
* (-1) For expressions with annotated types, let AnnotationCheckers decide what to do
@@ -824,7 +844,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
824844
case Block(_, tree1) => tree1.symbol
825845
case _ => tree.symbol
826846
}
827-
if (!meth.isConstructor && isFunctionType(pt)) { // (4.2)
847+
if (!meth.isConstructor && (isFunctionType(pt) || samOf(pt).exists)) { // (4.2)
828848
debuglog(s"eta-expanding $tree: ${tree.tpe} to $pt")
829849
checkParamsConvertible(tree, tree.tpe)
830850
val tree0 = etaExpand(context.unit, tree, this)
@@ -2838,7 +2858,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
28382858
* as `(a => a): Int => Int` should not (yet) get the sam treatment.
28392859
*/
28402860
val sam =
2841-
if (!settings.Xexperimental || pt.typeSymbol == FunctionSymbol) NoSymbol
2861+
if (pt.typeSymbol == FunctionSymbol) NoSymbol
28422862
else samOf(pt)
28432863

28442864
/* The SAM case comes first so that this works:
@@ -2848,15 +2868,11 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
28482868
* Note that the arity of the sam must correspond to the arity of the function.
28492869
*/
28502870
val samViable = sam.exists && sameLength(sam.info.params, fun.vparams)
2871+
val ptNorm = if (samViable) samToFunctionType(pt, sam) else pt
28512872
val (argpts, respt) =
2852-
if (samViable) {
2853-
val samInfo = pt memberInfo sam
2854-
(samInfo.paramTypes, samInfo.resultType)
2855-
} else {
2856-
pt baseType FunctionSymbol match {
2857-
case TypeRef(_, FunctionSymbol, args :+ res) => (args, res)
2858-
case _ => (fun.vparams map (_ => if (pt == ErrorType) ErrorType else NoType), WildcardType)
2859-
}
2873+
ptNorm baseType FunctionSymbol match {
2874+
case TypeRef(_, FunctionSymbol, args :+ res) => (args, res)
2875+
case _ => (fun.vparams map (_ => if (pt == ErrorType) ErrorType else NoType), WildcardType)
28602876
}
28612877

28622878
if (!FunctionSymbol.exists)

src/reflect/scala/reflect/internal/Definitions.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ trait Definitions extends api.StandardDefinitions {
790790
* The class defining the method is a supertype of `tp` that
791791
* has a public no-arg primary constructor.
792792
*/
793-
def samOf(tp: Type): Symbol = {
793+
def samOf(tp: Type): Symbol = if (!settings.Xexperimental) NoSymbol else {
794794
// if tp has a constructor, it must be public and must not take any arguments
795795
// (not even an implicit argument list -- to keep it simple for now)
796796
val tpSym = tp.typeSymbol

src/reflect/scala/reflect/internal/tpe/TypeMaps.scala

+16
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,22 @@ private[internal] trait TypeMaps {
422422
}
423423
}
424424

425+
/**
426+
* Get rid of BoundedWildcardType where variance allows us to do so.
427+
* Invariant: `wildcardExtrapolation(tp) =:= tp`
428+
*
429+
* For example, the MethodType given by `def bla(x: (_ >: String)): (_ <: Int)`
430+
* is both a subtype and a supertype of `def bla(x: String): Int`.
431+
*/
432+
object wildcardExtrapolation extends TypeMap(trackVariance = true) {
433+
def apply(tp: Type): Type =
434+
tp match {
435+
case BoundedWildcardType(TypeBounds(lo, AnyTpe)) if variance.isContravariant =>lo
436+
case BoundedWildcardType(TypeBounds(NothingTpe, hi)) if variance.isCovariant => hi
437+
case tp => mapOver(tp)
438+
}
439+
}
440+
425441
/** Might the given symbol be important when calculating the prefix
426442
* of a type? When tp.asSeenFrom(pre, clazz) is called on `tp`,
427443
* the result will be `tp` unchanged if `pre` is trivial and `clazz`

src/reflect/scala/reflect/runtime/JavaUniverseForce.scala

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse =>
170170
this.dropSingletonType
171171
this.abstractTypesToBounds
172172
this.dropIllegalStarTypes
173+
this.wildcardExtrapolation
173174
this.IsDependentCollector
174175
this.ApproximateDependentMap
175176
this.wildcardToTypeVarMap

test/files/pos/sammy_exist.flags

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-Xexperimental

test/files/pos/sammy_exist.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// scala> typeOf[java.util.stream.Stream[_]].nonPrivateMember(TermName("map")).info
2+
// [R](x$1: java.util.function.Function[_ >: T, _ <: R])java.util.stream.Stream[R]
3+
4+
// java.util.function.Function
5+
trait Fun[A, B] { def apply(x: A): B }
6+
7+
// java.util.stream.Stream
8+
class S[T](x: T) { def map[R](f: Fun[_ >: T, _ <: R]): R = f(x) }
9+
10+
class Bla { def foo: Bla = this }
11+
12+
object T {
13+
val aBlaSAM = (new S(new Bla)).map(_.foo) // : Bla should be inferred, when running under -Xexperimental [TODO]
14+
val fun: Fun[Bla, Bla] = (x: Bla) => x
15+
val aBlaSAMX = (new S(new Bla)).map(fun) // : Bla should be inferred, when running under -Xexperimental [TODO]
16+
}
17+
//
18+
// // or, maybe by variance-cast?
19+
// import annotation.unchecked.{uncheckedVariance => uv}
20+
// type SFun[-A, +B] = Fun[_ >: A, _ <: B @uv]
21+
//
22+
// def jf[T, R](f: T => R): SFun[T, R] = (x: T) => f(x): R
23+
//
24+
// val aBlaSAM = (new S(new Bla)).map(jf(_.foo)) // : Bla should be inferred [type checks, but existential inferred]

test/files/pos/sammy_overload.flags

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-Xexperimental

test/files/pos/sammy_overload.scala

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
trait Consumer[T] {
2+
def consume(x: T): Unit
3+
}
4+
5+
object Test {
6+
def foo(x: String): Unit = ???
7+
def foo(): Unit = ???
8+
val f: Consumer[_ >: String] = foo
9+
}

test/files/pos/t8310.flags

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-Xexperimental

test/files/pos/t8310.scala

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
trait Comparinator[T] { def compare(a: T, b: T): Int }
2+
3+
object TestOkay {
4+
def sort(x: Comparinator[_ >: String]) = ()
5+
sort((a: String, b: String) => a.compareToIgnoreCase(b))
6+
}
7+
8+
object TestOkay2 {
9+
def sort[T](x: Comparinator[_ >: T]) = ()
10+
sort((a: String, b: String) => a.compareToIgnoreCase(b))
11+
}
12+
13+
object TestOkay3 {
14+
def sort[T](xs: Option[T], x: Comparinator[_ >: T]) = ()
15+
sort(Some(""), (a: String, b: String) => a.compareToIgnoreCase(b))
16+
}
17+
18+
object TestKoOverloaded {
19+
def sort[T](xs: Option[T]) = ()
20+
def sort[T](xs: Option[T], x: Comparinator[_ >: T]) = ()
21+
sort(Some(""), (a: String, b: String) => a.compareToIgnoreCase(b))
22+
}

0 commit comments

Comments
 (0)