Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ public void TestCallbackHandlers()
Assert.Empty(callbackHandler.Inputs);
}
}

[Fact]
public void TestJvmCallbackClientProperty()
{
var server = new CallbackServer(_mockJvm.Object, run: false);
Assert.Throws<InvalidOperationException>(() => server.JvmCallbackClient);

using ISocketWrapper callbackSocket = SocketFactory.CreateSocket();
server.Run(callbackSocket);
Assert.NotNull(server.JvmCallbackClient);
}

private void TestCallbackConnection(
ConcurrentDictionary<int, ICallbackHandler> callbackHandlersDict,
Expand Down
18 changes: 17 additions & 1 deletion src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,25 @@ internal sealed class CallbackServer
private bool _isRunning = false;

private ISocketWrapper _listener;

private JvmObjectReference _jvmCallbackClient;

internal int CurrentNumConnections => _connections.Count;

internal JvmObjectReference JvmCallbackClient
{
get
{
if (_jvmCallbackClient is null)
{
throw new InvalidOperationException(
"Please make sure that CallbackServer was started before accessing JvmCallbackClient.");
}

return _jvmCallbackClient;
}
}

internal CallbackServer(IJvmBridge jvm, bool run = true)
{
AppDomain.CurrentDomain.ProcessExit += (s, e) => Shutdown();
Expand Down Expand Up @@ -113,7 +129,7 @@ internal void Run(ISocketWrapper listener)

// Communicate with the JVM the callback server's address and port.
var localEndPoint = (IPEndPoint)_listener.LocalEndPoint;
_jvm.CallStaticJavaMethod(
_jvmCallbackClient = (JvmObjectReference)_jvm.CallStaticJavaMethod(
"DotnetHandler",
"connectCallback",
localEndPoint.Address.ToString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ public DataStreamWriter ForeachBatch(Action<DataFrame, long> func)
_jvmObject.Jvm.CallStaticJavaMethod(
"org.apache.spark.sql.api.dotnet.DotnetForeachBatchHelper",
"callForeachBatch",
SparkEnvironment.CallbackServer.JvmCallbackClient,
this,
callbackId);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue
* @param address The address of the Dotnet CallbackServer
* @param port The port of the Dotnet CallbackServer
*/
class CallbackClient(address: String, port: Int) extends Logging {
class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()

private[this] var isShutdown: Boolean = false

final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit =
final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
getOrCreateConnection() match {
case Some(connection) =>
try {
Expand All @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging {
return Some(connectionPool.dequeue())
}

Some(new CallbackConnection(address, port))
Some(new CallbackConnection(serDe, address, port))
}

private def addConnection(connection: CallbackConnection): Unit = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging
* @param address The address of the Dotnet CallbackServer
* @param port The port of the Dotnet CallbackServer
*/
class CallbackConnection(address: String, port: Int) extends Logging {
class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
private[this] val socket: Socket = new Socket(address, port)
private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)

def send(
callbackId: Int,
writeBody: DataOutputStream => Unit): Unit = {
writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
logInfo(s"Calling callback [callback id = $callbackId] ...")

try {
SerDe.writeInt(outputStream, CallbackFlags.CALLBACK)
SerDe.writeInt(outputStream, callbackId)
serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
serDe.writeInt(outputStream, callbackId)

val byteArrayOutputStream = new ByteArrayOutputStream()
writeBody(new DataOutputStream(byteArrayOutputStream))
SerDe.writeInt(outputStream, byteArrayOutputStream.size)
writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
serDe.writeInt(outputStream, byteArrayOutputStream.size)
byteArrayOutputStream.writeTo(outputStream);
} catch {
case e: Exception => {
Expand All @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging {

logInfo(s"Signaling END_OF_STREAM.")
try {
SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
outputStream.flush()

val endOfStreamResponse = readFlag(inputStream)
Expand All @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging {

def close(): Unit = {
try {
SerDe.writeInt(outputStream, CallbackFlags.CLOSE)
serDe.writeInt(outputStream, CallbackFlags.CLOSE)
outputStream.flush()
} catch {
case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
Expand Down Expand Up @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging {
}

private def readFlag(inputStream: DataInputStream): Int = {
val callbackFlag = SerDe.readInt(inputStream)
val callbackFlag = serDe.readInt(inputStream)
if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
val exceptionMessage = SerDe.readString(inputStream)
val exceptionMessage = serDe.readString(inputStream)
throw new DotnetException(exceptionMessage)
}
callbackFlag
Expand All @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging {
val DOTNET_EXCEPTION_THROWN: Int = -3
val END_OF_STREAM: Int = -4
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class DotnetBackend extends Logging {
private[this] var channelFuture: ChannelFuture = _
private[this] var bootstrap: ServerBootstrap = _
private[this] var bossGroup: EventLoopGroup = _
private[this] val objectTracker = new JVMObjectTracker

@volatile
private[dotnet] var callbackClient: Option[CallbackClient] = None

def init(portNumber: Int): Int = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
Expand All @@ -55,7 +59,7 @@ class DotnetBackend extends Logging {
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("handler", new DotnetBackendHandler(self))
.addLast("handler", new DotnetBackendHandler(self, objectTracker))
}
})

Expand All @@ -64,6 +68,23 @@ class DotnetBackend extends Logging {
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
}

private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
callbackClient = callbackClient match {
case Some(_) => throw new Exception("Callback client already set.")
case None =>
logInfo(s"Connecting to a callback server at $address:$port")
Some(new CallbackClient(new SerDe(objectTracker), address, port))
}
}

private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
callbackClient match {
case Some(client) => client.shutdown()
case None => logInfo("Callback server has already been shutdown.")
}
callbackClient = None
}

def run(): Unit = {
channelFuture.channel.closeFuture().syncUninterruptibly()
}
Expand All @@ -82,30 +103,12 @@ class DotnetBackend extends Logging {
}
bootstrap = null

objectTracker.clear()

// Send close to .NET callback server.
DotnetBackend.shutdownCallbackClient()
shutdownCallbackClient()

// Shutdown the thread pool whose executors could still be running.
ThreadPool.shutdown()
}
}

object DotnetBackend extends Logging {
@volatile private[spark] var callbackClient: CallbackClient = null

private[spark] def setCallbackClient(address: String, port: Int) = synchronized {
if (DotnetBackend.callbackClient == null) {
logInfo(s"Connecting to a callback server at $address:$port")
DotnetBackend.callbackClient = new CallbackClient(address, port)
} else {
throw new Exception("Callback client already set.")
}
}

private[spark] def shutdownCallbackClient(): Unit = synchronized {
if (callbackClient != null) {
callbackClient.shutdown()
callbackClient = null
}
}
}
Loading