diff --git a/test/src/main/scala/scalarules/test/extra_protobuf_generator/ExtraProtobufGenerator.scala b/test/src/main/scala/scalarules/test/extra_protobuf_generator/ExtraProtobufGenerator.scala index 4cd4f932e..bc83e37af 100644 --- a/test/src/main/scala/scalarules/test/extra_protobuf_generator/ExtraProtobufGenerator.scala +++ b/test/src/main/scala/scalarules/test/extra_protobuf_generator/ExtraProtobufGenerator.scala @@ -41,14 +41,30 @@ class CustomProtobufGenerator( object ExtraProtobufGenerator extends ProtocCodeGenerator { override def run(req: Array[Byte]): Array[Byte] = { - val registry = ExtensionRegistry.newInstance() - Scalapb.registerAllExtensions(registry) - val request = CodeGeneratorRequest.parseFrom(req) - handleCodeGeneratorRequest(request).toByteArray + val b = CodeGeneratorResponse.newBuilder + + try { + val registry = ExtensionRegistry.newInstance() + Scalapb.registerAllExtensions(registry) + val request = CodeGeneratorRequest.parseFrom(req) + handleCodeGeneratorRequest(request, b) + + } catch { + case e: Throwable => + // Yes, we want to catch _all_ errors and send them back to the + // requestor. Otherwise uncaught errors will cause the generator to + // die and the worker invoking it to hang. + val stackStream = new java.io.ByteArrayOutputStream + e.printStackTrace(new java.io.PrintStream(stackStream)) + b.setError(stackStream.toString()) + } + b.build.toByteArray } - def handleCodeGeneratorRequest(request: CodeGeneratorRequest): CodeGeneratorResponse = { - val b = CodeGeneratorResponse.newBuilder + def handleCodeGeneratorRequest( + request: CodeGeneratorRequest, + b: CodeGeneratorResponse.Builder + ) = { ProtobufGenerator.parseParameters(request.getParameter) match { case Right(params) => try { @@ -73,18 +89,9 @@ object ExtraProtobufGenerator extends ProtocCodeGenerator { } catch { case e: GeneratorException => b.setError(e.message) - case e: Throwable => - // Yes, we want to catch _all_ errors and send them back to the - // requestor. Otherwise uncaught errors will cause the generator to - // die and the worker invoking it to hang. - val stackStream = new java.io.ByteArrayOutputStream - e.printStackTrace(new java.io.PrintStream(stackStream)) - b.setError(stackStream.toString()) } case Left(error) => b.setError(error) } - b.build } - }