diff --git a/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs b/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs index 776e54ba2..6150f448b 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs @@ -138,6 +138,17 @@ public void TestCallbackHandlers() Assert.Empty(callbackHandler.Inputs); } } + + [Fact] + public void TestJvmCallbackClientProperty() + { + var server = new CallbackServer(_mockJvm.Object, run: false); + Assert.Throws(() => server.JvmCallbackClient); + + using ISocketWrapper callbackSocket = SocketFactory.CreateSocket(); + server.Run(callbackSocket); + Assert.NotNull(server.JvmCallbackClient); + } private void TestCallbackConnection( ConcurrentDictionary callbackHandlersDict, diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs index ef6c0407a..d86fd7305 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs @@ -64,9 +64,25 @@ internal sealed class CallbackServer private bool _isRunning = false; private ISocketWrapper _listener; + + private JvmObjectReference _jvmCallbackClient; internal int CurrentNumConnections => _connections.Count; + internal JvmObjectReference JvmCallbackClient + { + get + { + if (_jvmCallbackClient is null) + { + throw new InvalidOperationException( + "Please make sure that CallbackServer was started before accessing JvmCallbackClient."); + } + + return _jvmCallbackClient; + } + } + internal CallbackServer(IJvmBridge jvm, bool run = true) { AppDomain.CurrentDomain.ProcessExit += (s, e) => Shutdown(); @@ -113,7 +129,7 @@ internal void Run(ISocketWrapper listener) // Communicate with the JVM the callback server's address and port. var localEndPoint = (IPEndPoint)_listener.LocalEndPoint; - _jvm.CallStaticJavaMethod( + _jvmCallbackClient = (JvmObjectReference)_jvm.CallStaticJavaMethod( "DotnetHandler", "connectCallback", localEndPoint.Address.ToString(), diff --git a/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs b/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs index f371b0665..07cf658be 100644 --- a/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs +++ b/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs @@ -228,6 +228,7 @@ public DataStreamWriter ForeachBatch(Action func) _jvmObject.Jvm.CallStaticJavaMethod( "org.apache.spark.sql.api.dotnet.DotnetForeachBatchHelper", "callForeachBatch", + SparkEnvironment.CallbackServer.JvmCallbackClient, this, callbackId); return this; diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala index 90ad92439..aea355dfa 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackClient(address: String, port: Int) extends Logging { +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() private[this] var isShutdown: Boolean = false - final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit = + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = getOrCreateConnection() match { case Some(connection) => try { @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging { return Some(connectionPool.dequeue()) } - Some(new CallbackConnection(address, port)) + Some(new CallbackConnection(serDe, address, port)) } private def addConnection(connection: CallbackConnection): Unit = synchronized { diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala index 36726181e..604cf029b 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackConnection(address: String, port: Int) extends Logging { +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val socket: Socket = new Socket(address, port) private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) def send( callbackId: Int, - writeBody: DataOutputStream => Unit): Unit = { + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { logInfo(s"Calling callback [callback id = $callbackId] ...") try { - SerDe.writeInt(outputStream, CallbackFlags.CALLBACK) - SerDe.writeInt(outputStream, callbackId) + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream)) - SerDe.writeInt(outputStream, byteArrayOutputStream.size) + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) byteArrayOutputStream.writeTo(outputStream); } catch { case e: Exception => { @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Signaling END_OF_STREAM.") try { - SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) outputStream.flush() val endOfStreamResponse = readFlag(inputStream) @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { def close(): Unit = { try { - SerDe.writeInt(outputStream, CallbackFlags.CLOSE) + serDe.writeInt(outputStream, CallbackFlags.CLOSE) outputStream.flush() } catch { case e: Exception => logInfo("Unable to send close to .NET callback server.", e) @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging { } private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = SerDe.readInt(inputStream) + val callbackFlag = serDe.readInt(inputStream) if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = SerDe.readString(inputStream) + val exceptionMessage = serDe.readString(inputStream) throw new DotnetException(exceptionMessage) } callbackFlag @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging { val DOTNET_EXCEPTION_THROWN: Int = -3 val END_OF_STREAM: Int = -4 } -} \ No newline at end of file +} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..1d8215d44 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -30,6 +30,10 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None def init(portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -55,7 +59,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self)) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) } }) @@ -64,6 +68,23 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + def run(): Unit = { channelFuture.channel.closeFuture().syncUninterruptibly() } @@ -82,30 +103,12 @@ class DotnetBackend extends Logging { } bootstrap = null + objectTracker.clear() + // Send close to .NET callback server. - DotnetBackend.shutdownCallbackClient() + shutdownCallbackClient() // Shutdown the thread pool whose executors could still be running. ThreadPool.shutdown() } } - -object DotnetBackend extends Logging { - @volatile private[spark] var callbackClient: CallbackClient = null - - private[spark] def setCallbackClient(address: String, port: Int) = synchronized { - if (DotnetBackend.callbackClient == null) { - logInfo(s"Connecting to a callback server at $address:$port") - DotnetBackend.callbackClient = new CallbackClient(address, port) - } else { - throw new Exception("Callback client already set.") - } - } - - private[spark] def shutdownCallbackClient(): Unit = synchronized { - if (callbackClient != null) { - callbackClient.shutdown() - callbackClient = null - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index e632589e4..4d32e43fb 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -7,12 +7,9 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -20,10 +17,12 @@ import org.apache.spark.util.Utils * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. */ -class DotnetBackendHandler(server: DotnetBackend) +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + private[this] val serDe = new SerDe(objectsTracker) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { val reply = handleBackendRequest(msg) ctx.write(reply) @@ -41,56 +40,56 @@ class DotnetBackendHandler(server: DotnetBackend) val dos = new DataOutputStream(bos) // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) + val isStatic = serDe.readBoolean(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) if (objId == "DotnetHandler") { methodName match { case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") server.close() case "rm" => try { - val t = readObjectType(dis) + val t = serDe.readObjectType(dis) assert(t == 'c') - val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) - writeInt(dos, 0) - writeObject(dos, null) + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) } catch { case e: Exception => logError(s"Removing $objId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "rmThread" => try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, result.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port) - writeInt(dos, 0) - writeType(dos, "void") + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case _ => dos.writeInt(-1) } @@ -131,7 +130,7 @@ class DotnetBackendHandler(server: DotnetBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { + objectsTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) case Some(o) => obj = o @@ -159,8 +158,8 @@ class DotnetBackendHandler(server: DotnetBackend) val ret = selectedMethods(index.get).invoke(obj, args: _*) // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor val ctor = cls.getConstructors.filter { x => @@ -169,15 +168,15 @@ class DotnetBackendHandler(server: DotnetBackend) val obj = ctor.newInstance(args: _*) - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) } else { throw new IllegalArgumentException( "invalid method " + methodName + " for object " + objId) } } catch { case e: Throwable => - val jvmObj = JVMObjectTracker.get(objId) + val jvmObj = objectsTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" @@ -199,15 +198,15 @@ class DotnetBackendHandler(server: DotnetBackend) methods.foreach(m => logDebug(m.toString)) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { arg => - readObject(dis) + serDe.readObject(dis) }.toArray } @@ -325,41 +324,3 @@ class DotnetBackendHandler(server: DotnetBackend) def logError(id: String, e: Exception): Unit = {} } - -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private object JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..aceb58c01 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 427df61b6..44cad97c1 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -15,10 +15,10 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ /** - * Functions to serialize and deserialize between CLR & JVM. + * Class responsible for serialization and deserialization between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -object SerDe { +class SerDe(val tracker: JVMObjectTracker) { def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -28,7 +28,7 @@ object SerDe { readTypedObject(dis, dataType) } - def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) @@ -41,14 +41,14 @@ object SerDe { case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => tracker.getObject(readString(dis)) case 'R' => readRowArr(dis) case 'O' => readObjectArr(dis) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } - def readBytes(in: DataInputStream): Array[Byte] = { + private def readBytes(in: DataInputStream): Array[Byte] = { val len = readInt(in) val out = new Array[Byte](len) in.readFully(out) @@ -59,15 +59,15 @@ object SerDe { in.readInt() } - def readLong(in: DataInputStream): Long = { + private def readLong(in: DataInputStream): Long = { in.readLong() } - def readDouble(in: DataInputStream): Double = { + private def readDouble(in: DataInputStream): Double = { in.readDouble() } - def readStringBytes(in: DataInputStream, len: Int): String = { + private def readStringBytes(in: DataInputStream, len: Int): String = { val bytes = new Array[Byte](len) in.readFully(bytes) val str = new String(bytes, "UTF-8") @@ -83,11 +83,11 @@ object SerDe { in.readBoolean() } - def readDate(in: DataInputStream): Date = { + private def readDate(in: DataInputStream): Date = { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Timestamp = { + private def readTime(in: DataInputStream): Timestamp = { val seconds = in.readDouble() val sec = Math.floor(seconds).toLong val t = new Timestamp(sec * 1000L) @@ -95,57 +95,57 @@ object SerDe { t } - def readRow(in: DataInputStream): Row = { + private def readRow(in: DataInputStream): Row = { val len = readInt(in) Row.fromSeq((0 until len).map(_ => readObject(in))) } - def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray } - def readIntArr(in: DataInputStream): Array[Int] = { + private def readIntArr(in: DataInputStream): Array[Int] = { val len = readInt(in) (0 until len).map(_ => readInt(in)).toArray } - def readLongArr(in: DataInputStream): Array[Long] = { + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray } - def readDoubleArr(in: DataInputStream): Array[Double] = { + private def readDoubleArr(in: DataInputStream): Array[Double] = { val len = readInt(in) (0 until len).map(_ => readDouble(in)).toArray } - def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { val len = readInt(in) (0 until len).map(_ => readDoubleArr(in)).toArray } - def readBooleanArr(in: DataInputStream): Array[Boolean] = { + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { val len = readInt(in) (0 until len).map(_ => readBoolean(in)).toArray } - def readStringArr(in: DataInputStream): Array[String] = { + private def readStringArr(in: DataInputStream): Array[String] = { val len = readInt(in) (0 until len).map(_ => readString(in)).toArray } - def readRowArr(in: DataInputStream): java.util.List[Row] = { + private def readRowArr(in: DataInputStream): java.util.List[Row] = { val len = readInt(in) (0 until len).map(_ => readRow(in)).toList.asJava } - def readObjectArr(in: DataInputStream): Seq[Any] = { + private def readObjectArr(in: DataInputStream): Seq[Any] = { val len = readInt(in) (0 until len).map(_ => readObject(in)) } - def readList(dis: DataInputStream): Array[_] = { + private def readList(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -154,13 +154,13 @@ object SerDe { case 'd' => readDoubleArr(dis) case 'A' => readDoubleArrArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { val keysType = readObjectType(in) @@ -299,23 +299,23 @@ object SerDe { out.writeLong(value) } - def writeDouble(out: DataOutputStream, value: Double): Unit = { + private def writeDouble(out: DataOutputStream, value: Double): Unit = { out.writeDouble(value) } - def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { out.writeBoolean(value) } - def writeDate(out: DataOutputStream, value: Date): Unit = { + private def writeDate(out: DataOutputStream, value: Date): Unit = { writeString(out, value.toString) } - def writeTime(out: DataOutputStream, value: Time): Unit = { + private def writeTime(out: DataOutputStream, value: Time): Unit = { out.writeDouble(value.getTime.toDouble / 1000.0) } - def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } @@ -326,53 +326,53 @@ object SerDe { out.write(utf8, 0, len) } - def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { out.writeInt(value.length) out.write(value) } def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) + val objId = tracker.put(value) writeString(out, objId) } - def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { writeType(out, "integer") out.writeInt(value.length) value.foreach(v => out.writeInt(v)) } - def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { writeType(out, "long") out.writeInt(value.length) value.foreach(v => out.writeLong(v)) } - def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { writeType(out, "double") out.writeInt(value.length) value.foreach(v => out.writeDouble(v)) } - def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { writeType(out, "doublearray") out.writeInt(value.length) value.foreach(v => writeDoubleArr(out, v)) } - def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { writeType(out, "logical") out.writeInt(value.length) value.foreach(v => writeBoolean(out, v)) } - def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { writeType(out, "character") out.writeInt(value.length) value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { writeType(out, "raw") out.writeInt(value.length) value.foreach(v => writeBytes(out, v)) diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..990887276 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var handler: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + handler = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + m.writeInt(1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = handler.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('n', reply.readByte()) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..1abf10e20 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var backend: DotnetBackend = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(backend.callbackClient.isDefined) + assertThrows( + classOf[Exception], + new ThrowingRunnable { + override def run(): Unit = { + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + } + }) + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..8c6e51608 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..78ca905bb --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConverter} + +@Test +class SerDeTest { + private var serDe: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + serDe = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, serDe.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + serDe.readObject(input) + } + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, serDe.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, serDe.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, serDe.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, serDe.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, serDe.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + Map( + 11 -> true, + 22 -> 42.42, + 33 -> null).asJava, + serDe.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(Map().asJava, serDe.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + serDe.readObject(input) + } + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, serDe.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows( + classOf[NoSuchElementException], + new ThrowingRunnable { + override def run(): Unit = { + serDe.readObject(input) + } + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99))).asJava, + serDe.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + serDe.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + serDe.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + serDe.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + serDe.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream(in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala index 90ad92439..aea355dfa 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackClient(address: String, port: Int) extends Logging { +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() private[this] var isShutdown: Boolean = false - final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit = + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = getOrCreateConnection() match { case Some(connection) => try { @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging { return Some(connectionPool.dequeue()) } - Some(new CallbackConnection(address, port)) + Some(new CallbackConnection(serDe, address, port)) } private def addConnection(connection: CallbackConnection): Unit = synchronized { diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala index 36726181e..258f02aeb 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackConnection(address: String, port: Int) extends Logging { +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val socket: Socket = new Socket(address, port) private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) def send( callbackId: Int, - writeBody: DataOutputStream => Unit): Unit = { + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { logInfo(s"Calling callback [callback id = $callbackId] ...") try { - SerDe.writeInt(outputStream, CallbackFlags.CALLBACK) - SerDe.writeInt(outputStream, callbackId) + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream)) - SerDe.writeInt(outputStream, byteArrayOutputStream.size) + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) byteArrayOutputStream.writeTo(outputStream); } catch { case e: Exception => { @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Signaling END_OF_STREAM.") try { - SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) outputStream.flush() val endOfStreamResponse = readFlag(inputStream) @@ -53,7 +53,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.") case _ => { throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " + - s"Received: $endOfStreamResponse") + s"Received: $endOfStreamResponse") } } } catch { @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { def close(): Unit = { try { - SerDe.writeInt(outputStream, CallbackFlags.CLOSE) + serDe.writeInt(outputStream, CallbackFlags.CLOSE) outputStream.flush() } catch { case e: Exception => logInfo("Unable to send close to .NET callback server.", e) @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging { } private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = SerDe.readInt(inputStream) + val callbackFlag = serDe.readInt(inputStream) if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = SerDe.readString(inputStream) + val exceptionMessage = serDe.readString(inputStream) throw new DotnetException(exceptionMessage) } callbackFlag @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging { val DOTNET_EXCEPTION_THROWN: Int = -3 val END_OF_STREAM: Int = -4 } -} \ No newline at end of file +} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..1d8215d44 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -30,6 +30,10 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None def init(portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -55,7 +59,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self)) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) } }) @@ -64,6 +68,23 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + def run(): Unit = { channelFuture.channel.closeFuture().syncUninterruptibly() } @@ -82,30 +103,12 @@ class DotnetBackend extends Logging { } bootstrap = null + objectTracker.clear() + // Send close to .NET callback server. - DotnetBackend.shutdownCallbackClient() + shutdownCallbackClient() // Shutdown the thread pool whose executors could still be running. ThreadPool.shutdown() } } - -object DotnetBackend extends Logging { - @volatile private[spark] var callbackClient: CallbackClient = null - - private[spark] def setCallbackClient(address: String, port: Int) = synchronized { - if (DotnetBackend.callbackClient == null) { - logInfo(s"Connecting to a callback server at $address:$port") - DotnetBackend.callbackClient = new CallbackClient(address, port) - } else { - throw new Exception("Callback client already set.") - } - } - - private[spark] def shutdownCallbackClient(): Unit = synchronized { - if (callbackClient != null) { - callbackClient.shutdown() - callbackClient = null - } - } -} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index e632589e4..d95b18313 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -7,12 +7,9 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -20,10 +17,12 @@ import org.apache.spark.util.Utils * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. */ -class DotnetBackendHandler(server: DotnetBackend) +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + private[this] val serDe = new SerDe(objectsTracker) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { val reply = handleBackendRequest(msg) ctx.write(reply) @@ -41,56 +40,60 @@ class DotnetBackendHandler(server: DotnetBackend) val dos = new DataOutputStream(bos) // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) + val isStatic = serDe.readBoolean(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) if (objId == "DotnetHandler") { methodName match { case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") server.close() case "rm" => try { - val t = readObjectType(dis) + val t = serDe.readObjectType(dis) assert(t == 'c') - val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) - writeInt(dos, 0) - writeObject(dos, null) + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) } catch { case e: Exception => logError(s"Removing $objId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "rmThread" => try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, result.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port) - writeInt(dos, 0) - writeType(dos, "void") + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions. + serDe.writeObject(dos, server.callbackClient) case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case _ => dos.writeInt(-1) } @@ -131,7 +134,7 @@ class DotnetBackendHandler(server: DotnetBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { + objectsTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) case Some(o) => obj = o @@ -159,8 +162,8 @@ class DotnetBackendHandler(server: DotnetBackend) val ret = selectedMethods(index.get).invoke(obj, args: _*) // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor val ctor = cls.getConstructors.filter { x => @@ -169,15 +172,15 @@ class DotnetBackendHandler(server: DotnetBackend) val obj = ctor.newInstance(args: _*) - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) } else { throw new IllegalArgumentException( "invalid method " + methodName + " for object " + objId) } } catch { case e: Throwable => - val jvmObj = JVMObjectTracker.get(objId) + val jvmObj = objectsTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" @@ -199,15 +202,15 @@ class DotnetBackendHandler(server: DotnetBackend) methods.foreach(m => logDebug(m.toString)) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { arg => - readObject(dis) + serDe.readObject(dis) }.toArray } @@ -326,40 +329,4 @@ class DotnetBackendHandler(server: DotnetBackend) def logError(id: String, e: Exception): Unit = {} } -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private object JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } -} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..aceb58c01 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 427df61b6..44cad97c1 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -15,10 +15,10 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ /** - * Functions to serialize and deserialize between CLR & JVM. + * Class responsible for serialization and deserialization between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -object SerDe { +class SerDe(val tracker: JVMObjectTracker) { def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -28,7 +28,7 @@ object SerDe { readTypedObject(dis, dataType) } - def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) @@ -41,14 +41,14 @@ object SerDe { case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => tracker.getObject(readString(dis)) case 'R' => readRowArr(dis) case 'O' => readObjectArr(dis) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } - def readBytes(in: DataInputStream): Array[Byte] = { + private def readBytes(in: DataInputStream): Array[Byte] = { val len = readInt(in) val out = new Array[Byte](len) in.readFully(out) @@ -59,15 +59,15 @@ object SerDe { in.readInt() } - def readLong(in: DataInputStream): Long = { + private def readLong(in: DataInputStream): Long = { in.readLong() } - def readDouble(in: DataInputStream): Double = { + private def readDouble(in: DataInputStream): Double = { in.readDouble() } - def readStringBytes(in: DataInputStream, len: Int): String = { + private def readStringBytes(in: DataInputStream, len: Int): String = { val bytes = new Array[Byte](len) in.readFully(bytes) val str = new String(bytes, "UTF-8") @@ -83,11 +83,11 @@ object SerDe { in.readBoolean() } - def readDate(in: DataInputStream): Date = { + private def readDate(in: DataInputStream): Date = { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Timestamp = { + private def readTime(in: DataInputStream): Timestamp = { val seconds = in.readDouble() val sec = Math.floor(seconds).toLong val t = new Timestamp(sec * 1000L) @@ -95,57 +95,57 @@ object SerDe { t } - def readRow(in: DataInputStream): Row = { + private def readRow(in: DataInputStream): Row = { val len = readInt(in) Row.fromSeq((0 until len).map(_ => readObject(in))) } - def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray } - def readIntArr(in: DataInputStream): Array[Int] = { + private def readIntArr(in: DataInputStream): Array[Int] = { val len = readInt(in) (0 until len).map(_ => readInt(in)).toArray } - def readLongArr(in: DataInputStream): Array[Long] = { + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray } - def readDoubleArr(in: DataInputStream): Array[Double] = { + private def readDoubleArr(in: DataInputStream): Array[Double] = { val len = readInt(in) (0 until len).map(_ => readDouble(in)).toArray } - def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { val len = readInt(in) (0 until len).map(_ => readDoubleArr(in)).toArray } - def readBooleanArr(in: DataInputStream): Array[Boolean] = { + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { val len = readInt(in) (0 until len).map(_ => readBoolean(in)).toArray } - def readStringArr(in: DataInputStream): Array[String] = { + private def readStringArr(in: DataInputStream): Array[String] = { val len = readInt(in) (0 until len).map(_ => readString(in)).toArray } - def readRowArr(in: DataInputStream): java.util.List[Row] = { + private def readRowArr(in: DataInputStream): java.util.List[Row] = { val len = readInt(in) (0 until len).map(_ => readRow(in)).toList.asJava } - def readObjectArr(in: DataInputStream): Seq[Any] = { + private def readObjectArr(in: DataInputStream): Seq[Any] = { val len = readInt(in) (0 until len).map(_ => readObject(in)) } - def readList(dis: DataInputStream): Array[_] = { + private def readList(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -154,13 +154,13 @@ object SerDe { case 'd' => readDoubleArr(dis) case 'A' => readDoubleArrArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { val keysType = readObjectType(in) @@ -299,23 +299,23 @@ object SerDe { out.writeLong(value) } - def writeDouble(out: DataOutputStream, value: Double): Unit = { + private def writeDouble(out: DataOutputStream, value: Double): Unit = { out.writeDouble(value) } - def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { out.writeBoolean(value) } - def writeDate(out: DataOutputStream, value: Date): Unit = { + private def writeDate(out: DataOutputStream, value: Date): Unit = { writeString(out, value.toString) } - def writeTime(out: DataOutputStream, value: Time): Unit = { + private def writeTime(out: DataOutputStream, value: Time): Unit = { out.writeDouble(value.getTime.toDouble / 1000.0) } - def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } @@ -326,53 +326,53 @@ object SerDe { out.write(utf8, 0, len) } - def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { out.writeInt(value.length) out.write(value) } def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) + val objId = tracker.put(value) writeString(out, objId) } - def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { writeType(out, "integer") out.writeInt(value.length) value.foreach(v => out.writeInt(v)) } - def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { writeType(out, "long") out.writeInt(value.length) value.foreach(v => out.writeLong(v)) } - def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { writeType(out, "double") out.writeInt(value.length) value.foreach(v => out.writeDouble(v)) } - def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { writeType(out, "doublearray") out.writeInt(value.length) value.foreach(v => writeDoubleArr(out, v)) } - def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { writeType(out, "logical") out.writeInt(value.length) value.foreach(v => writeBoolean(out, v)) } - def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { writeType(out, "character") out.writeInt(value.length) value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { writeType(out, "raw") out.writeInt(value.length) value.foreach(v => writeBytes(out, v)) diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala index c0de9c7bc..afde1931e 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -15,20 +15,19 @@ class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int def call(batchDF: DataFrame, batchId: Long): Unit = callbackClient.send( callbackId, - dos => { - SerDe.writeJObj(dos, batchDF) - SerDe.writeLong(dos, batchId) + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) }) } object DotnetForeachBatchHelper { - def callForeachBatch(dsw: DataStreamWriter[Row], callbackId: Int): Unit = { - val callbackClient = DotnetBackend.callbackClient - if (callbackClient == null) { - throw new Exception("DotnetBackend.callbackClient is null.") + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") } - val dotnetForeachFunc = new DotnetForeachBatchFunction(callbackClient, callbackId) dsw.foreachBatch(dotnetForeachFunc.call _) } } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..79c32d6dc --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions.DataInputStreamExt +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var handler: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + handler = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + m.writeInt(1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = handler.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readNBytes(1), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..1abf10e20 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var backend: DotnetBackend = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(backend.callbackClient.isDefined) + assertThrows( + classOf[Exception], + new ThrowingRunnable { + override def run(): Unit = { + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + } + }) + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..8c6e51608 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..78ca905bb --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConverter} + +@Test +class SerDeTest { + private var serDe: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + serDe = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, serDe.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + serDe.readObject(input) + } + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, serDe.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, serDe.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, serDe.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, serDe.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, serDe.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + Map( + 11 -> true, + 22 -> 42.42, + 33 -> null).asJava, + serDe.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(Map().asJava, serDe.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + serDe.readObject(input) + } + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, serDe.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows( + classOf[NoSuchElementException], + new ThrowingRunnable { + override def run(): Unit = { + serDe.readObject(input) + } + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99))).asJava, + serDe.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + serDe.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + serDe.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + serDe.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + serDe.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream(in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala index 90ad92439..aea355dfa 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackClient(address: String, port: Int) extends Logging { +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() private[this] var isShutdown: Boolean = false - final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit = + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = getOrCreateConnection() match { case Some(connection) => try { @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging { return Some(connectionPool.dequeue()) } - Some(new CallbackConnection(address, port)) + Some(new CallbackConnection(serDe, address, port)) } private def addConnection(connection: CallbackConnection): Unit = synchronized { diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala index 36726181e..604cf029b 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackConnection(address: String, port: Int) extends Logging { +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val socket: Socket = new Socket(address, port) private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) def send( callbackId: Int, - writeBody: DataOutputStream => Unit): Unit = { + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { logInfo(s"Calling callback [callback id = $callbackId] ...") try { - SerDe.writeInt(outputStream, CallbackFlags.CALLBACK) - SerDe.writeInt(outputStream, callbackId) + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream)) - SerDe.writeInt(outputStream, byteArrayOutputStream.size) + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) byteArrayOutputStream.writeTo(outputStream); } catch { case e: Exception => { @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Signaling END_OF_STREAM.") try { - SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) outputStream.flush() val endOfStreamResponse = readFlag(inputStream) @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { def close(): Unit = { try { - SerDe.writeInt(outputStream, CallbackFlags.CLOSE) + serDe.writeInt(outputStream, CallbackFlags.CLOSE) outputStream.flush() } catch { case e: Exception => logInfo("Unable to send close to .NET callback server.", e) @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging { } private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = SerDe.readInt(inputStream) + val callbackFlag = serDe.readInt(inputStream) if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = SerDe.readString(inputStream) + val exceptionMessage = serDe.readString(inputStream) throw new DotnetException(exceptionMessage) } callbackFlag @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging { val DOTNET_EXCEPTION_THROWN: Int = -3 val END_OF_STREAM: Int = -4 } -} \ No newline at end of file +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..c6f528aee 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -8,7 +8,6 @@ package org.apache.spark.api.dotnet import java.net.InetSocketAddress import java.util.concurrent.TimeUnit - import io.netty.bootstrap.ServerBootstrap import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel @@ -30,6 +29,10 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None def init(portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -55,7 +58,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self)) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) } }) @@ -64,6 +67,23 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + def run(): Unit = { channelFuture.channel.closeFuture().syncUninterruptibly() } @@ -82,30 +102,12 @@ class DotnetBackend extends Logging { } bootstrap = null + objectTracker.clear() + // Send close to .NET callback server. - DotnetBackend.shutdownCallbackClient() + shutdownCallbackClient() // Shutdown the thread pool whose executors could still be running. ThreadPool.shutdown() } } - -object DotnetBackend extends Logging { - @volatile private[spark] var callbackClient: CallbackClient = null - - private[spark] def setCallbackClient(address: String, port: Int) = synchronized { - if (DotnetBackend.callbackClient == null) { - logInfo(s"Connecting to a callback server at $address:$port") - DotnetBackend.callbackClient = new CallbackClient(address, port) - } else { - throw new Exception("Callback client already set.") - } - } - - private[spark] def shutdownCallbackClient(): Unit = synchronized { - if (callbackClient != null) { - callbackClient.shutdown() - callbackClient = null - } - } -} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1446e5ff6..29729657b 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -6,13 +6,11 @@ package org.apache.spark.api.dotnet -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.HashMap import scala.language.existentials @@ -20,10 +18,12 @@ import scala.language.existentials * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. */ -class DotnetBackendHandler(server: DotnetBackend) +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + private[this] val serDe = new SerDe(objectsTracker) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { val reply = handleBackendRequest(msg) ctx.write(reply) @@ -41,57 +41,60 @@ class DotnetBackendHandler(server: DotnetBackend) val dos = new DataOutputStream(bos) // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) + val isStatic = serDe.readBoolean(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) if (objId == "DotnetHandler") { methodName match { case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") server.close() case "rm" => try { - val t = readObjectType(dis) + val t = serDe.readObjectType(dis) assert(t == 'c') - val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) - writeInt(dos, 0) - writeObject(dos, null) + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) } catch { case e: Exception => logError(s"Removing $objId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "rmThread" => try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, result.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port) - writeInt(dos, 0) - writeType(dos, "void") + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions. + serDe.writeObject(dos, server.callbackClient) case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") - + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case _ => dos.writeInt(-1) } } else { @@ -131,7 +134,7 @@ class DotnetBackendHandler(server: DotnetBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { + objectsTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) case Some(o) => obj = o @@ -159,8 +162,8 @@ class DotnetBackendHandler(server: DotnetBackend) val ret = selectedMethods(index.get).invoke(obj, args: _*) // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor val ctor = cls.getConstructors.filter { x => @@ -169,15 +172,15 @@ class DotnetBackendHandler(server: DotnetBackend) val obj = ctor.newInstance(args: _*) - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) } else { throw new IllegalArgumentException( "invalid method " + methodName + " for object " + objId) } } catch { case e: Throwable => - val jvmObj = JVMObjectTracker.get(objId) + val jvmObj = objectsTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" @@ -199,15 +202,15 @@ class DotnetBackendHandler(server: DotnetBackend) methods.foreach(m => logDebug(m.toString)) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { arg => - readObject(dis) + serDe.readObject(dis) }.toArray } @@ -326,40 +329,4 @@ class DotnetBackendHandler(server: DotnetBackend) def logError(id: String, e: Exception): Unit = {} } -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private object JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } -} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..81cfaf88b --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 427df61b6..a3df3788a 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -15,10 +15,11 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ /** - * Functions to serialize and deserialize between CLR & JVM. + * Class responsible for serialization and deserialization between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -object SerDe { +class SerDe(val tracker: JVMObjectTracker) { + def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -28,7 +29,7 @@ object SerDe { readTypedObject(dis, dataType) } - def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) @@ -41,14 +42,14 @@ object SerDe { case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => tracker.getObject(readString(dis)) case 'R' => readRowArr(dis) case 'O' => readObjectArr(dis) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } - def readBytes(in: DataInputStream): Array[Byte] = { + private def readBytes(in: DataInputStream): Array[Byte] = { val len = readInt(in) val out = new Array[Byte](len) in.readFully(out) @@ -59,15 +60,15 @@ object SerDe { in.readInt() } - def readLong(in: DataInputStream): Long = { + private def readLong(in: DataInputStream): Long = { in.readLong() } - def readDouble(in: DataInputStream): Double = { + private def readDouble(in: DataInputStream): Double = { in.readDouble() } - def readStringBytes(in: DataInputStream, len: Int): String = { + private def readStringBytes(in: DataInputStream, len: Int): String = { val bytes = new Array[Byte](len) in.readFully(bytes) val str = new String(bytes, "UTF-8") @@ -83,11 +84,11 @@ object SerDe { in.readBoolean() } - def readDate(in: DataInputStream): Date = { + private def readDate(in: DataInputStream): Date = { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Timestamp = { + private def readTime(in: DataInputStream): Timestamp = { val seconds = in.readDouble() val sec = Math.floor(seconds).toLong val t = new Timestamp(sec * 1000L) @@ -95,57 +96,57 @@ object SerDe { t } - def readRow(in: DataInputStream): Row = { + private def readRow(in: DataInputStream): Row = { val len = readInt(in) Row.fromSeq((0 until len).map(_ => readObject(in))) } - def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray } - def readIntArr(in: DataInputStream): Array[Int] = { + private def readIntArr(in: DataInputStream): Array[Int] = { val len = readInt(in) (0 until len).map(_ => readInt(in)).toArray } - def readLongArr(in: DataInputStream): Array[Long] = { + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray } - def readDoubleArr(in: DataInputStream): Array[Double] = { + private def readDoubleArr(in: DataInputStream): Array[Double] = { val len = readInt(in) (0 until len).map(_ => readDouble(in)).toArray } - def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { val len = readInt(in) (0 until len).map(_ => readDoubleArr(in)).toArray } - def readBooleanArr(in: DataInputStream): Array[Boolean] = { + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { val len = readInt(in) (0 until len).map(_ => readBoolean(in)).toArray } - def readStringArr(in: DataInputStream): Array[String] = { + private def readStringArr(in: DataInputStream): Array[String] = { val len = readInt(in) (0 until len).map(_ => readString(in)).toArray } - def readRowArr(in: DataInputStream): java.util.List[Row] = { + private def readRowArr(in: DataInputStream): java.util.List[Row] = { val len = readInt(in) (0 until len).map(_ => readRow(in)).toList.asJava } - def readObjectArr(in: DataInputStream): Seq[Any] = { + private def readObjectArr(in: DataInputStream): Seq[Any] = { val len = readInt(in) (0 until len).map(_ => readObject(in)) } - def readList(dis: DataInputStream): Array[_] = { + private def readList(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -154,13 +155,13 @@ object SerDe { case 'd' => readDoubleArr(dis) case 'A' => readDoubleArrArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { val keysType = readObjectType(in) @@ -299,23 +300,23 @@ object SerDe { out.writeLong(value) } - def writeDouble(out: DataOutputStream, value: Double): Unit = { + private def writeDouble(out: DataOutputStream, value: Double): Unit = { out.writeDouble(value) } - def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { out.writeBoolean(value) } - def writeDate(out: DataOutputStream, value: Date): Unit = { + private def writeDate(out: DataOutputStream, value: Date): Unit = { writeString(out, value.toString) } - def writeTime(out: DataOutputStream, value: Time): Unit = { + private def writeTime(out: DataOutputStream, value: Time): Unit = { out.writeDouble(value.getTime.toDouble / 1000.0) } - def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } @@ -326,53 +327,53 @@ object SerDe { out.write(utf8, 0, len) } - def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { out.writeInt(value.length) out.write(value) } def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) + val objId = tracker.put(value) writeString(out, objId) } - def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { writeType(out, "integer") out.writeInt(value.length) value.foreach(v => out.writeInt(v)) } - def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { writeType(out, "long") out.writeInt(value.length) value.foreach(v => out.writeLong(v)) } - def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { writeType(out, "double") out.writeInt(value.length) value.foreach(v => out.writeDouble(v)) } - def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { writeType(out, "doublearray") out.writeInt(value.length) value.foreach(v => writeDoubleArr(out, v)) } - def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { writeType(out, "logical") out.writeInt(value.length) value.foreach(v => writeBoolean(out, v)) } - def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { writeType(out, "character") out.writeInt(value.length) value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { writeType(out, "raw") out.writeInt(value.length) value.foreach(v => writeBytes(out, v)) diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala index c0de9c7bc..5d06d4304 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -6,7 +6,7 @@ package org.apache.spark.sql.api.dotnet -import org.apache.spark.api.dotnet.{CallbackClient, DotnetBackend, SerDe} +import org.apache.spark.api.dotnet.CallbackClient import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.streaming.DataStreamWriter @@ -15,20 +15,19 @@ class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int def call(batchDF: DataFrame, batchId: Long): Unit = callbackClient.send( callbackId, - dos => { - SerDe.writeJObj(dos, batchDF) - SerDe.writeLong(dos, batchId) + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) }) } object DotnetForeachBatchHelper { - def callForeachBatch(dsw: DataStreamWriter[Row], callbackId: Int): Unit = { - val callbackClient = DotnetBackend.callbackClient - if (callbackClient == null) { - throw new Exception("DotnetBackend.callbackClient is null.") + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") } - val dotnetForeachFunc = new DotnetForeachBatchFunction(callbackClient, callbackId) dsw.foreachBatch(dotnetForeachFunc.call _) } } diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..672455349 --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import Extensions._ +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var handler: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + handler = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + m.writeInt(1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = handler.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readNBytes(1), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..445486bbd --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var backend: DotnetBackend = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(backend.callbackClient.isDefined) + assertThrows(classOf[Exception], () => { + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + }) + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..c6904403b --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..41401d680 --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,373 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters._ + +@Test +class SerDeTest { + private var serDe: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + serDe = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, serDe.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, serDe.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, serDe.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, serDe.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, serDe.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, serDe.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + mapAsJavaMap(Map( + 11 -> true, + 22 -> 42.42, + 33 -> null)), + serDe.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(mapAsJavaMap(Map()), serDe.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, serDe.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows(classOf[NoSuchElementException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + seqAsJavaList(Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99)))), + serDe.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + serDe.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + serDe.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + serDe.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + serDe.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream (in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +}