Skip to content

Commit 08fc93c

Browse files
committed
SAM Overloading and existential bounds
In some cases existential bounds can be simplified without losing precision. For example: trait Blargle[T] { def compare(a: T, b: T): Int } trait Test { def foo(a: Blargle[_ >: String]): Int } can be simplified to: trait Test { def foo(a: Blargle[String]): Int } see: scala/scala#4101 #SCL8956 Fixed
1 parent 360d478 commit 08fc93c

File tree

2 files changed

+133
-14
lines changed

2 files changed

+133
-14
lines changed

src/org/jetbrains/plugins/scala/lang/psi/ScalaPsiUtil.scala

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ import java.lang.ref.WeakReference
77
import com.intellij.codeInsight.PsiEquivalenceUtil
88
import com.intellij.lang.java.JavaLanguage
99
import com.intellij.openapi.diagnostic.Logger
10-
import com.intellij.openapi.module.{JavaModuleType, ModuleUtil, ModuleUtilCore, Module}
10+
import com.intellij.openapi.module.{JavaModuleType, Module, ModuleUtil}
1111
import com.intellij.openapi.progress.ProgressManager
1212
import com.intellij.openapi.project.Project
1313
import com.intellij.openapi.roots.{ProjectFileIndex, ProjectRootManager}
14-
import com.intellij.openapi.updateSettings.impl.pluginsAdvertisement.PluginsAdvertiser.Plugin
1514
import com.intellij.openapi.util.TextRange
1615
import com.intellij.psi._
1716
import com.intellij.psi.codeStyle.CodeStyleSettingsManager
@@ -61,7 +60,7 @@ import org.jetbrains.plugins.scala.lang.resolve.{ResolvableReferenceExpression,
6160
import org.jetbrains.plugins.scala.lang.structureView.ScalaElementPresentation
6261
import org.jetbrains.plugins.scala.project.ScalaLanguageLevel.Scala_2_11
6362
import org.jetbrains.plugins.scala.project.settings.ScalaCompilerConfiguration
64-
import org.jetbrains.plugins.scala.project.{Version, ModuleExt, ProjectPsiElementExt}
63+
import org.jetbrains.plugins.scala.project.{ModuleExt, ProjectPsiElementExt}
6564
import org.jetbrains.plugins.scala.settings.ScalaProjectSettings
6665
import org.jetbrains.plugins.scala.util.ScEquivalenceUtil
6766

@@ -2316,7 +2315,7 @@ object ScalaPsiUtil {
23162315
* @see SCL-6140
23172316
* @see https://github.com/scala/scala/pull/3018/
23182317
*/
2319-
def toSAMType(expected: ScType, scope: GlobalSearchScope): Option[ScType] = {
2318+
def toSAMType(expected: ScType, scalaScope: GlobalSearchScope): Option[ScType] = {
23202319

23212320
def constructorValidForSAM(constructors: Array[PsiMethod]): Boolean = {
23222321
//primary constructor (if any) must be public, no-args, not overloaded
@@ -2346,8 +2345,14 @@ object ScalaPsiUtil {
23462345
!abst.head.hasTypeParameters
23472346

23482347
if (valid) {
2349-
abst.head.getType() match {
2350-
case Success(tp, _) => Some(sub.subst(tp))
2348+
val fun = abst.head
2349+
fun.getType() match {
2350+
case Success(tp, _) =>
2351+
val subbed = sub.subst(tp)
2352+
extrapolateWildcardBounds(subbed, expected, fun.getProject, scalaScope) match {
2353+
case s@Some(_) => s
2354+
case _ => Some(subbed)
2355+
}
23512356
case _ => None
23522357
}
23532358
} else None
@@ -2359,15 +2364,63 @@ object ScalaPsiUtil {
23592364
//need to generate ScType for Java method
23602365
val method = abst.head
23612366
val project = method.getProject
2362-
val returnType: ScType = ScType.create(method.getReturnType, project, scope)
2367+
val returnType: ScType = ScType.create(method.getReturnType, project, scalaScope)
23632368
val params: Array[ScType] = method.getParameterList.getParameters.map {
2364-
param: PsiParameter => ScType.create(param.getTypeElement.getType, project, scope)
2369+
param: PsiParameter => ScType.create(param.getTypeElement.getType, project, scalaScope)
2370+
}
2371+
val fun = ScFunctionType(returnType, params)(project, scalaScope)
2372+
val subbed = sub.subst(fun)
2373+
extrapolateWildcardBounds(subbed, expected, project, scalaScope) match {
2374+
case s@Some(_) => s
2375+
case _ => Some(subbed)
23652376
}
2366-
val result = ScFunctionType(returnType, params)(project, scope)
2367-
Some(sub.subst(result))
23682377
} else None
23692378
}
23702379
case None => None
23712380
}
23722381
}
2382+
2383+
/**
2384+
* In some cases existential bounds can be simplified without losing precision
2385+
*
2386+
* trait Comparinator[T] { def compare(a: T, b: T): Int }
2387+
*
2388+
* trait Test {
2389+
* def foo(a: Comparinator[_ >: String]): Int
2390+
* }
2391+
*
2392+
* can be simplified to:
2393+
*
2394+
* trait Test {
2395+
* def foo(a: Comparinator[String]): Int
2396+
* }
2397+
*
2398+
* @see https://github.com/scala/scala/pull/4101
2399+
* @see SCL-8956
2400+
*/
2401+
private def extrapolateWildcardBounds(tp: ScType, expected: ScType, proj: Project, scope: GlobalSearchScope): Option[ScType] = {
2402+
expected match {
2403+
case ScExistentialType(ScParameterizedType(expectedDesignator, _), wildcards) =>
2404+
tp match {
2405+
case ScFunctionType(retTp, params) =>
2406+
def convertParameter(tpArg: ScType, variance: Int): ScType = {
2407+
wildcards.find(_.name == tpArg.canonicalText) match {
2408+
case Some(wildcard) =>
2409+
(wildcard.lowerBound, wildcard.upperBound) match {
2410+
case (lo, Any) if variance == ScTypeParam.Contravariant => lo
2411+
case (Nothing, hi) if variance == ScTypeParam.Covariant => hi
2412+
case _ => tpArg
2413+
}
2414+
case _ => tpArg
2415+
}
2416+
}
2417+
//parameter clauses are contravariant positions, return types are covariant positions
2418+
val newParams = params.map(convertParameter(_, ScTypeParam.Contravariant))
2419+
val newRetTp = convertParameter(retTp, ScTypeParam.Covariant)
2420+
Some(ScFunctionType(newRetTp, newParams)(proj, scope))
2421+
case _ => None
2422+
}
2423+
case _ => None
2424+
}
2425+
}
23732426
}

test/org/jetbrains/plugins/scala/annotator/SingleAbstractMethodTest.scala

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,83 @@ class SingleAbstractMethodTest extends ScalaLightPlatformCodeInsightTestCaseAdap
293293
checkCodeHasNoErrors(code)
294294
}
295295

296-
def checkCodeHasNoErrors(code: String) {
297-
assertMatches(messages(code)) {
296+
def testExistentialBounds(): Unit = {
297+
val code =
298+
"""
299+
|trait Blargle[T] {
300+
| def foo(a: T): String
301+
|}
302+
|
303+
|def f(b: Blargle[_ >: Int]) = -1
304+
|f(s => s.toString)
305+
|
306+
|def g[T](b: Blargle[_ >: T]) = -1
307+
|g((s: String) => s)
308+
|
309+
|trait Blergh[T] {
310+
| def foo(): T
311+
|}
312+
|
313+
|def h[T](b: Blergh[_ <: T]) = -1
314+
|h(() => "")
315+
|def i(b: Blergh[_ <: String]) = -1
316+
|i(() => "")
317+
|
318+
""".stripMargin
319+
checkCodeHasNoErrors(code)
320+
}
321+
322+
def testOverload(): Unit = {
323+
val code =
324+
"""
325+
|trait SAMOverload[A] {
326+
| def foo(s: A): Int = ???
327+
|}
328+
|
329+
|def f[T](s: T): Unit = ()
330+
|def f[T](s: T, a: SAMOverload[_ >: T]) = ()
331+
|f("", (s: String) => 2)
332+
|
333+
""".stripMargin
334+
checkCodeHasNoErrors(code)
335+
}
336+
337+
def testJavaSAM(): Unit = {
338+
val scalaCode = "new ObservableCopy(1).mapFunc(x => x + 1)"
339+
val javaCode =
340+
"""
341+
|public interface Func1<T, R> {
342+
| R call(T t);
343+
|}
344+
|
345+
|public class ObservableCopy<T> {
346+
| public ObservableCopy(T t) {}
347+
|
348+
| public final <R> ObservableCopy<R> mapFunc(Func1<? super T, ? extends R> func) {
349+
| return null;
350+
| }
351+
|}
352+
|
353+
""".stripMargin
354+
checkCodeHasNoErrors(scalaCode, Some(javaCode))
355+
}
356+
357+
def checkCodeHasNoErrors(scalaCode: String, javaCode: Option[String] = None) {
358+
assertMatches(messages(scalaCode, javaCode)) {
298359
case Nil =>
299360
}
300361
}
301362

302-
def messages(code: String): List[Message] = {
363+
def messages(@Language("Scala") scalaCode: String, javaCode: Option[String] = None): List[Message] = {
364+
javaCode match {
365+
case Some(s) => configureFromFileTextAdapter("dummy.java", s)
366+
case _ =>
367+
}
368+
303369
val annotator = new ScalaAnnotator() {}
304370
val mock = new AnnotatorHolderMock
305371

306-
val parse: ScalaFile = parseText(code)
372+
val parse: ScalaFile = parseText(scalaCode)
307373

308374
parse.depthFirst.foreach(annotator.annotate(_, mock))
309375

0 commit comments

Comments
 (0)