Skip to content

Commit 5587767

Browse files
authored
Merge pull request #14702 from griggt/fix-14473
2 parents 17e46ad + 0b0f626 commit 5587767

File tree

4 files changed

+156
-21
lines changed

4 files changed

+156
-21
lines changed

compiler/src/dotty/tools/repl/Rendering.scala

+13-7
Original file line numberDiff line numberDiff line change
@@ -129,25 +129,31 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
129129
infoDiagnostic(d.symbol.showUser, d)
130130

131131
/** Render value definition result */
132-
def renderVal(d: Denotation)(using Context): Option[Diagnostic] =
132+
def renderVal(d: Denotation)(using Context): Either[InvocationTargetException, Option[Diagnostic]] =
133133
val dcl = d.symbol.showUser
134134
def msg(s: String) = infoDiagnostic(s, d)
135135
try
136-
if (d.symbol.is(Flags.Lazy)) Some(msg(dcl))
137-
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
138-
catch case e: InvocationTargetException => Some(msg(renderError(e, d)))
136+
Right(
137+
if d.symbol.is(Flags.Lazy) then Some(msg(dcl))
138+
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
139+
)
140+
catch case e: InvocationTargetException => Left(e)
139141
end renderVal
140142

141143
/** Force module initialization in the absence of members. */
142144
def forceModule(sym: Symbol)(using Context): Seq[Diagnostic] =
145+
import scala.util.control.NonFatal
143146
def load() =
144147
val objectName = sym.fullName.encode.toString
145148
Class.forName(objectName, true, classLoader())
146149
Nil
147-
try load() catch case e: ExceptionInInitializerError => List(infoDiagnostic(renderError(e, sym.denot), sym.denot))
150+
try load()
151+
catch
152+
case e: ExceptionInInitializerError => List(renderError(e, sym.denot))
153+
case NonFatal(e) => List(renderError(InvocationTargetException(e), sym.denot))
148154

149155
/** Render the stack trace of the underlying exception. */
150-
private def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): String =
156+
def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): Diagnostic =
151157
import dotty.tools.dotc.util.StackTraceOps._
152158
val cause = ite.getCause match
153159
case e: ExceptionInInitializerError => e.getCause
@@ -159,7 +165,7 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
159165
ste.getClassName.startsWith(REPL_WRAPPER_NAME_PREFIX) // d.symbol.owner.name.show is simple name
160166
&& (ste.getMethodName == nme.STATIC_CONSTRUCTOR.show || ste.getMethodName == nme.CONSTRUCTOR.show)
161167

162-
cause.formatStackTracePrefix(!isWrapperInitialization(_))
168+
infoDiagnostic(cause.formatStackTracePrefix(!isWrapperInitialization(_)), d)
163169
end renderError
164170

165171
private def infoDiagnostic(msg: String, d: Denotation)(using Context): Diagnostic =

compiler/src/dotty/tools/repl/ReplCompiler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ReplCompiler extends Compiler {
6161
val rootCtx = super.rootContext.fresh
6262
.setOwner(defn.EmptyPackageClass)
6363
.withRootImports
64-
(1 to state.objectIndex).foldLeft(rootCtx)((ctx, id) =>
64+
(state.validObjectIndexes).foldLeft(rootCtx)((ctx, id) =>
6565
importPreviousRun(id)(using ctx))
6666
}
6767
}

compiler/src/dotty/tools/repl/ReplDriver.scala

+38-13
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import dotty.tools.runner.ScalaClassLoader.*
3535
import org.jline.reader._
3636

3737
import scala.annotation.tailrec
38+
import scala.collection.mutable
3839
import scala.collection.JavaConverters._
3940
import scala.util.Using
4041

@@ -55,12 +56,15 @@ import scala.util.Using
5556
* @param objectIndex the index of the next wrapper
5657
* @param valIndex the index of next value binding for free expressions
5758
* @param imports a map from object index to the list of user defined imports
59+
* @param invalidObjectIndexes the set of object indexes that failed to initialize
5860
* @param context the latest compiler context
5961
*/
6062
case class State(objectIndex: Int,
6163
valIndex: Int,
6264
imports: Map[Int, List[tpd.Import]],
63-
context: Context)
65+
invalidObjectIndexes: Set[Int],
66+
context: Context):
67+
def validObjectIndexes = (1 to objectIndex).filterNot(invalidObjectIndexes.contains(_))
6468

6569
/** Main REPL instance, orchestrating input, compilation and presentation */
6670
class ReplDriver(settings: Array[String],
@@ -94,7 +98,7 @@ class ReplDriver(settings: Array[String],
9498
}
9599

96100
/** the initial, empty state of the REPL session */
97-
final def initialState: State = State(0, 0, Map.empty, rootCtx)
101+
final def initialState: State = State(0, 0, Map.empty, Set.empty, rootCtx)
98102

99103
/** Reset state of repl to the initial state
100104
*
@@ -237,7 +241,7 @@ class ReplDriver(settings: Array[String],
237241
completions.map(_.label).distinct.map(makeCandidate)
238242
}
239243
.getOrElse(Nil)
240-
end completions
244+
end completions
241245

242246
private def interpret(res: ParseResult)(implicit state: State): State = {
243247
res match {
@@ -353,14 +357,33 @@ class ReplDriver(settings: Array[String],
353357
val typeAliases =
354358
info.bounds.hi.typeMembers.filter(_.symbol.info.isTypeAlias)
355359

356-
val formattedMembers =
357-
typeAliases.map(rendering.renderTypeAlias) ++
358-
defs.map(rendering.renderMethod) ++
359-
vals.flatMap(rendering.renderVal)
360-
361-
val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers
362-
363-
(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
360+
// The wrapper object may fail to initialize if the rhs of a ValDef throws.
361+
// In that case, don't attempt to render any subsequent vals, and mark this
362+
// wrapper object index as invalid.
363+
var failedInit = false
364+
val renderedVals =
365+
val buf = mutable.ListBuffer[Diagnostic]()
366+
for d <- vals do if !failedInit then rendering.renderVal(d) match
367+
case Right(Some(v)) =>
368+
buf += v
369+
case Left(e) =>
370+
buf += rendering.renderError(e, d)
371+
failedInit = true
372+
case _ =>
373+
buf.toList
374+
375+
if failedInit then
376+
// We limit the returned diagnostics here to `renderedVals`, which will contain the rendered error
377+
// for the val which failed to initialize. Since any other defs, aliases, imports, etc. from this
378+
// input line will be inaccessible, we avoid rendering those so as not to confuse the user.
379+
(state.copy(invalidObjectIndexes = state.invalidObjectIndexes + state.objectIndex), renderedVals)
380+
else
381+
val formattedMembers =
382+
typeAliases.map(rendering.renderTypeAlias)
383+
++ defs.map(rendering.renderMethod)
384+
++ renderedVals
385+
val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers
386+
(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
364387
}
365388
else (state, Seq.empty)
366389

@@ -378,8 +401,10 @@ class ReplDriver(settings: Array[String],
378401
tree.symbol.info.memberClasses
379402
.find(_.symbol.name == newestWrapper.moduleClassName)
380403
.map { wrapperModule =>
381-
val formattedTypeDefs = typeDefs(wrapperModule.symbol)
382404
val (newState, formattedMembers) = extractAndFormatMembers(wrapperModule.symbol)
405+
val formattedTypeDefs = // don't render type defs if wrapper initialization failed
406+
if newState.invalidObjectIndexes.contains(state.objectIndex) then Seq.empty
407+
else typeDefs(wrapperModule.symbol)
383408
val highlighted = (formattedTypeDefs ++ formattedMembers)
384409
.map(d => new Diagnostic(d.msg.mapMsg(SyntaxHighlighting.highlight), d.pos, d.level))
385410
(newState, highlighted)
@@ -420,7 +445,7 @@ class ReplDriver(settings: Array[String],
420445

421446
case Imports =>
422447
for {
423-
objectIndex <- 1 to state.objectIndex
448+
objectIndex <- state.validObjectIndexes
424449
imp <- state.imports.getOrElse(objectIndex, Nil)
425450
} out.println(imp.show(using state.context))
426451
state

compiler/test/dotty/tools/repl/ReplCompilerTests.scala

+104
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,110 @@ class ReplCompilerTests extends ReplTest:
243243
assertEquals(List("// defined class C"), lines())
244244
}
245245

246+
def assertNotFoundError(id: String): Unit =
247+
val lines = storedOutput().linesIterator
248+
assert(lines.next().startsWith("-- [E006] Not Found Error:"))
249+
assert(lines.drop(2).next().trim().endsWith(s"Not found: $id"))
250+
251+
@Test def i4416 = initially {
252+
val state = run("val x = 1 / 0")
253+
val all = lines()
254+
assertEquals(2, all.length)
255+
assert(all.head.startsWith("java.lang.ArithmeticException:"))
256+
state
257+
} andThen {
258+
val state = run("def foo = x")
259+
assertNotFoundError("x")
260+
state
261+
} andThen {
262+
run("x")
263+
assertNotFoundError("x")
264+
}
265+
266+
@Test def i4416b = initially {
267+
val state = run("val a = 1234")
268+
val _ = storedOutput() // discard output
269+
state
270+
} andThen {
271+
val state = run("val a = 1; val x = ???; val y = x")
272+
val all = lines()
273+
assertEquals(3, all.length)
274+
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
275+
state
276+
} andThen {
277+
val state = run("x")
278+
assertNotFoundError("x")
279+
state
280+
} andThen {
281+
val state = run("y")
282+
assertNotFoundError("y")
283+
state
284+
} andThen {
285+
run("a") // `a` should retain its original binding
286+
assertEquals("val res0: Int = 1234", storedOutput().trim)
287+
}
288+
289+
@Test def i4416_imports = initially {
290+
run("import scala.collection.mutable")
291+
} andThen {
292+
val state = run("import scala.util.Try; val x = ???")
293+
val _ = storedOutput() // discard output
294+
state
295+
} andThen {
296+
run(":imports") // scala.util.Try should not be imported
297+
assertEquals("import scala.collection.mutable", storedOutput().trim)
298+
}
299+
300+
@Test def i4416_types_defs_aliases = initially {
301+
val state =
302+
run("""|type Foo = String
303+
|trait Bar
304+
|def bar: Bar = ???
305+
|val x = ???
306+
|""".stripMargin)
307+
val all = lines()
308+
assertEquals(3, all.length)
309+
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
310+
assert("type alias in failed wrapper should not be rendered",
311+
!all.exists(_.startsWith("// defined alias type Foo = String")))
312+
assert("type definitions in failed wrapper should not be rendered",
313+
!all.exists(_.startsWith("// defined trait Bar")))
314+
assert("defs in failed wrapper should not be rendered",
315+
!all.exists(_.startsWith("def bar: Bar")))
316+
state
317+
} andThen {
318+
val state = run("def foo: Foo = ???")
319+
assertNotFoundError("type Foo")
320+
state
321+
} andThen {
322+
val state = run("type B = Bar")
323+
assertNotFoundError("type Bar")
324+
state
325+
} andThen {
326+
run("bar")
327+
assertNotFoundError("bar")
328+
}
329+
330+
@Test def i14473 = initially {
331+
run("""val (x,y) = if true then "hi" else (42,17)""")
332+
val all = lines()
333+
assertEquals(2, all.length)
334+
assertEquals("scala.MatchError: hi (of class java.lang.String)", all.head)
335+
}
336+
337+
@Test def i14701 = initially {
338+
val state = run("val _ = ???")
339+
val all = lines()
340+
assertEquals(3, all.length)
341+
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
342+
state
343+
} andThen {
344+
run("val _ = assert(false)")
345+
val all = lines()
346+
assertEquals(3, all.length)
347+
assertEquals("java.lang.AssertionError: assertion failed", all.head)
348+
}
349+
246350
@Test def i14491 =
247351
initially {
248352
run("import language.experimental.fewerBraces")

0 commit comments

Comments
 (0)