From 913ca52396b3e6b5b17f83ad1471a3f559f44e8d Mon Sep 17 00:00:00 2001 From: ap-itransition Date: Fri, 18 Dec 2020 16:27:33 +0300 Subject: [PATCH 01/15] Method clear introduced for JVMObjectTracker & DotnetBackend adjusted to release JVM objects during shutdown. --- .../spark/api/dotnet/DotnetBackend.scala | 3 +++ .../api/dotnet/DotnetBackendHandler.scala | 6 +++++ .../scala/com/microsoft/scala/AppTest.scala | 23 ------------------- .../spark/api/dotnet/DotnetBackendTest.scala | 19 +++++++++++++++ .../api/dotnet/JVMObjectTrackerTest.scala | 20 ++++++++++++++++ .../spark/api/dotnet/DotnetBackend.scala | 3 +++ .../api/dotnet/DotnetBackendHandler.scala | 6 +++++ .../scala/com/microsoft/scala/AppTest.scala | 23 ------------------- .../spark/api/dotnet/DotnetBackendTest.scala | 19 +++++++++++++++ .../api/dotnet/JVMObjectTrackerTest.scala | 20 ++++++++++++++++ .../spark/api/dotnet/DotnetBackend.scala | 3 +++ .../api/dotnet/DotnetBackendHandler.scala | 7 ++++++ .../scala/com/microsoft/scala/AppTest.scala | 23 ------------------- .../spark/api/dotnet/DotnetBackendTest.scala | 19 +++++++++++++++ .../api/dotnet/JVMObjectTrackerTest.scala | 20 ++++++++++++++++ 15 files changed, 145 insertions(+), 69 deletions(-) delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala delete mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala delete mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..c51f502f3 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -82,6 +82,9 @@ class DotnetBackend extends Logging { } bootstrap = null + // Release references to JVM objects to let them collected by GC + JVMObjectTracker.clear() + // Send close to .NET callback server. DotnetBackend.shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index e632589e4..78a831935 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -362,4 +362,10 @@ private object JVMObjectTracker { objMap.remove(id) } } + + def clear(): Unit = { + synchronized { + objMap.clear() + } + } } diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..808827fd1 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,19 @@ +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class DotnetBackendTest { + + @Test + def shouldReleaseJVMReferencesWhenClose(): Unit = { + val backend = new DotnetBackend + val objectId = JVMObjectTracker.put(new Object) + + backend.close() + + assert( + JVMObjectTracker.get(objectId).isEmpty, + "JVMObjectTracker must be cleaned up during backend shutdown.") + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..767fcbaff --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,20 @@ +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val firstId = JVMObjectTracker.put(new Object) + val secondId = JVMObjectTracker.put(new Object) + val thirdId = JVMObjectTracker.put(new Object) + + JVMObjectTracker.clear() + + assert(JVMObjectTracker.get(firstId).isEmpty) + assert(JVMObjectTracker.get(secondId).isEmpty) + assert(JVMObjectTracker.get(thirdId).isEmpty) + } +} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..c51f502f3 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -82,6 +82,9 @@ class DotnetBackend extends Logging { } bootstrap = null + // Release references to JVM objects to let them collected by GC + JVMObjectTracker.clear() + // Send close to .NET callback server. DotnetBackend.shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index e632589e4..78a831935 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -362,4 +362,10 @@ private object JVMObjectTracker { objMap.remove(id) } } + + def clear(): Unit = { + synchronized { + objMap.clear() + } + } } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..808827fd1 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,19 @@ +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class DotnetBackendTest { + + @Test + def shouldReleaseJVMReferencesWhenClose(): Unit = { + val backend = new DotnetBackend + val objectId = JVMObjectTracker.put(new Object) + + backend.close() + + assert( + JVMObjectTracker.get(objectId).isEmpty, + "JVMObjectTracker must be cleaned up during backend shutdown.") + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..767fcbaff --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,20 @@ +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val firstId = JVMObjectTracker.put(new Object) + val secondId = JVMObjectTracker.put(new Object) + val thirdId = JVMObjectTracker.put(new Object) + + JVMObjectTracker.clear() + + assert(JVMObjectTracker.get(firstId).isEmpty) + assert(JVMObjectTracker.get(secondId).isEmpty) + assert(JVMObjectTracker.get(thirdId).isEmpty) + } +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..c51f502f3 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -82,6 +82,9 @@ class DotnetBackend extends Logging { } bootstrap = null + // Release references to JVM objects to let them collected by GC + JVMObjectTracker.clear() + // Send close to .NET callback server. DotnetBackend.shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1446e5ff6..ea50044db 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -362,4 +362,11 @@ private object JVMObjectTracker { objMap.remove(id) } } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } } diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..808827fd1 --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,19 @@ +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class DotnetBackendTest { + + @Test + def shouldReleaseJVMReferencesWhenClose(): Unit = { + val backend = new DotnetBackend + val objectId = JVMObjectTracker.put(new Object) + + backend.close() + + assert( + JVMObjectTracker.get(objectId).isEmpty, + "JVMObjectTracker must be cleaned up during backend shutdown.") + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..767fcbaff --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,20 @@ +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val firstId = JVMObjectTracker.put(new Object) + val secondId = JVMObjectTracker.put(new Object) + val thirdId = JVMObjectTracker.put(new Object) + + JVMObjectTracker.clear() + + assert(JVMObjectTracker.get(firstId).isEmpty) + assert(JVMObjectTracker.get(secondId).isEmpty) + assert(JVMObjectTracker.get(thirdId).isEmpty) + } +} From cacb0043f74b10566e2510c1223be8224baf4120 Mon Sep 17 00:00:00 2001 From: ap-itransition Date: Fri, 18 Dec 2020 16:40:11 +0300 Subject: [PATCH 02/15] Add test for objCounter under JVMObjectTracker --- .../spark/api/dotnet/DotnetBackendHandler.scala | 1 + .../spark/api/dotnet/JVMObjectTrackerTest.scala | 14 ++++++++++++++ .../spark/api/dotnet/DotnetBackendHandler.scala | 1 + .../spark/api/dotnet/JVMObjectTrackerTest.scala | 14 ++++++++++++++ .../spark/api/dotnet/JVMObjectTrackerTest.scala | 14 ++++++++++++++ 5 files changed, 44 insertions(+) diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 78a831935..7c80c3a1d 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -366,6 +366,7 @@ private object JVMObjectTracker { def clear(): Unit = { synchronized { objMap.clear() + objCounter = 1 } } } diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index 767fcbaff..495799b85 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -17,4 +17,18 @@ class JVMObjectTrackerTest { assert(JVMObjectTracker.get(secondId).isEmpty) assert(JVMObjectTracker.get(thirdId).isEmpty) } + + @Test + def shouldResetCounter(): Unit = { + val firstId = JVMObjectTracker.put(new Object) + val secondId = JVMObjectTracker.put(new Object) + + JVMObjectTracker.clear() + + val thirdId = JVMObjectTracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } } diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 78a831935..7c80c3a1d 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -366,6 +366,7 @@ private object JVMObjectTracker { def clear(): Unit = { synchronized { objMap.clear() + objCounter = 1 } } } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index 767fcbaff..495799b85 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -17,4 +17,18 @@ class JVMObjectTrackerTest { assert(JVMObjectTracker.get(secondId).isEmpty) assert(JVMObjectTracker.get(thirdId).isEmpty) } + + @Test + def shouldResetCounter(): Unit = { + val firstId = JVMObjectTracker.put(new Object) + val secondId = JVMObjectTracker.put(new Object) + + JVMObjectTracker.clear() + + val thirdId = JVMObjectTracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } } diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index 767fcbaff..495799b85 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -17,4 +17,18 @@ class JVMObjectTrackerTest { assert(JVMObjectTracker.get(secondId).isEmpty) assert(JVMObjectTracker.get(thirdId).isEmpty) } + + @Test + def shouldResetCounter(): Unit = { + val firstId = JVMObjectTracker.put(new Object) + val secondId = JVMObjectTracker.put(new Object) + + JVMObjectTracker.clear() + + val thirdId = JVMObjectTracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } } From f32fe7b84fb1369c26980811002fabe7edc1a5df Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 19 Dec 2020 03:25:04 +0300 Subject: [PATCH 03/15] Licence header added for new files --- .../org/apache/spark/api/dotnet/DotnetBackendTest.scala | 6 ++++++ .../org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala | 6 ++++++ .../org/apache/spark/api/dotnet/DotnetBackendTest.scala | 6 ++++++ .../org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala | 6 ++++++ .../org/apache/spark/api/dotnet/DotnetBackendTest.scala | 6 ++++++ .../org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala | 6 ++++++ 6 files changed, 36 insertions(+) diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 808827fd1..9dee936d7 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import org.junit.Test diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index 495799b85..dd1b05c79 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import org.junit.Test diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 808827fd1..9dee936d7 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import org.junit.Test diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index 495799b85..dd1b05c79 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import org.junit.Test diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 808827fd1..9dee936d7 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import org.junit.Test diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index 495799b85..dd1b05c79 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import org.junit.Test From a63f6a042816d612a375183af31599e7b528bde2 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Thu, 31 Dec 2020 02:48:34 +0300 Subject: [PATCH 04/15] Make JVMObjectTracker non static --- .../Interop/Ipc/CallbackServer.cs | 18 +- .../Sql/Streaming/DataStreamWriter.cs | 1 + .../spark/api/dotnet/CallbackClient.scala | 6 +- .../spark/api/dotnet/CallbackConnection.scala | 22 +- .../spark/api/dotnet/DotnetBackend.scala | 48 ++- .../api/dotnet/DotnetBackendHandler.scala | 130 +++----- .../spark/api/dotnet/JVMObjectTracker.scala | 48 +++ .../org/apache/spark/api/dotnet/SerDe.scala | 73 ++--- .../sql/api/dotnet/DotnetForeachBatch.scala | 17 +- .../spark/api/dotnet/DotnetBackendTest.scala | 11 +- .../api/dotnet/JVMObjectTrackerTest.scala | 24 +- .../apache/spark/api/dotnet/SerDeTest.scala | 278 ++++++++++++++++++ 12 files changed, 490 insertions(+), 186 deletions(-) create mode 100644 src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs index ef6c0407a..055493b42 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs @@ -64,9 +64,25 @@ internal sealed class CallbackServer private bool _isRunning = false; private ISocketWrapper _listener; + + private JvmObjectReference _jvmCallbackClient; internal int CurrentNumConnections => _connections.Count; + internal JvmObjectReference JvmCallbackClient + { + get + { + if (_jvmCallbackClient is null) + { + throw new InvalidOperationException( + "Please start CallbackServer before accessing JvmCallbackClient."); + } + + return _jvmCallbackClient; + } + } + internal CallbackServer(IJvmBridge jvm, bool run = true) { AppDomain.CurrentDomain.ProcessExit += (s, e) => Shutdown(); @@ -113,7 +129,7 @@ internal void Run(ISocketWrapper listener) // Communicate with the JVM the callback server's address and port. var localEndPoint = (IPEndPoint)_listener.LocalEndPoint; - _jvm.CallStaticJavaMethod( + _jvmCallbackClient = (JvmObjectReference)_jvm.CallStaticJavaMethod( "DotnetHandler", "connectCallback", localEndPoint.Address.ToString(), diff --git a/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs b/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs index f371b0665..07cf658be 100644 --- a/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs +++ b/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs @@ -228,6 +228,7 @@ public DataStreamWriter ForeachBatch(Action func) _jvmObject.Jvm.CallStaticJavaMethod( "org.apache.spark.sql.api.dotnet.DotnetForeachBatchHelper", "callForeachBatch", + SparkEnvironment.CallbackServer.JvmCallbackClient, this, callbackId); return this; diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala index 90ad92439..aea355dfa 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackClient(address: String, port: Int) extends Logging { +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() private[this] var isShutdown: Boolean = false - final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit = + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = getOrCreateConnection() match { case Some(connection) => try { @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging { return Some(connectionPool.dequeue()) } - Some(new CallbackConnection(address, port)) + Some(new CallbackConnection(serDe, address, port)) } private def addConnection(connection: CallbackConnection): Unit = synchronized { diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala index 36726181e..604cf029b 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackConnection(address: String, port: Int) extends Logging { +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val socket: Socket = new Socket(address, port) private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) def send( callbackId: Int, - writeBody: DataOutputStream => Unit): Unit = { + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { logInfo(s"Calling callback [callback id = $callbackId] ...") try { - SerDe.writeInt(outputStream, CallbackFlags.CALLBACK) - SerDe.writeInt(outputStream, callbackId) + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream)) - SerDe.writeInt(outputStream, byteArrayOutputStream.size) + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) byteArrayOutputStream.writeTo(outputStream); } catch { case e: Exception => { @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Signaling END_OF_STREAM.") try { - SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) outputStream.flush() val endOfStreamResponse = readFlag(inputStream) @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { def close(): Unit = { try { - SerDe.writeInt(outputStream, CallbackFlags.CLOSE) + serDe.writeInt(outputStream, CallbackFlags.CLOSE) outputStream.flush() } catch { case e: Exception => logInfo("Unable to send close to .NET callback server.", e) @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging { } private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = SerDe.readInt(inputStream) + val callbackFlag = serDe.readInt(inputStream) if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = SerDe.readString(inputStream) + val exceptionMessage = serDe.readString(inputStream) throw new DotnetException(exceptionMessage) } callbackFlag @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging { val DOTNET_EXCEPTION_THROWN: Int = -3 val END_OF_STREAM: Int = -4 } -} \ No newline at end of file +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index c51f502f3..0d254afac 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -8,7 +8,6 @@ package org.apache.spark.api.dotnet import java.net.InetSocketAddress import java.util.concurrent.TimeUnit - import io.netty.bootstrap.ServerBootstrap import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel @@ -30,6 +29,10 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectsTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None def init(portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -55,7 +58,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self)) + .addLast("handler", new DotnetBackendHandler(self, objectsTracker)) } }) @@ -64,6 +67,22 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectsTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + } + def run(): Unit = { channelFuture.channel.closeFuture().syncUninterruptibly() } @@ -82,33 +101,12 @@ class DotnetBackend extends Logging { } bootstrap = null - // Release references to JVM objects to let them collected by GC - JVMObjectTracker.clear() + objectsTracker.clear() // Send close to .NET callback server. - DotnetBackend.shutdownCallbackClient() + shutdownCallbackClient() // Shutdown the thread pool whose executors could still be running. ThreadPool.shutdown() } } - -object DotnetBackend extends Logging { - @volatile private[spark] var callbackClient: CallbackClient = null - - private[spark] def setCallbackClient(address: String, port: Int) = synchronized { - if (DotnetBackend.callbackClient == null) { - logInfo(s"Connecting to a callback server at $address:$port") - DotnetBackend.callbackClient = new CallbackClient(address, port) - } else { - throw new Exception("Callback client already set.") - } - } - - private[spark] def shutdownCallbackClient(): Unit = synchronized { - if (callbackClient != null) { - callbackClient.shutdown() - callbackClient = null - } - } -} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index ea50044db..30c002471 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -6,13 +6,11 @@ package org.apache.spark.api.dotnet -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.HashMap import scala.language.existentials @@ -20,10 +18,12 @@ import scala.language.existentials * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. */ -class DotnetBackendHandler(server: DotnetBackend) +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + private[this] val serDe = new SerDe(objectsTracker) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { val reply = handleBackendRequest(msg) ctx.write(reply) @@ -41,57 +41,60 @@ class DotnetBackendHandler(server: DotnetBackend) val dos = new DataOutputStream(bos) // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) + val isStatic = serDe.readBoolean(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) if (objId == "DotnetHandler") { methodName match { case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") server.close() case "rm" => try { - val t = readObjectType(dis) + val t = serDe.readObjectType(dis) assert(t == 'c') - val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) - writeInt(dos, 0) - writeObject(dos, null) + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) } catch { case e: Exception => logError(s"Removing $objId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "rmThread" => try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, result.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port) - writeInt(dos, 0) - writeType(dos, "void") + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions + serDe.writeObject(dos, server.callbackClient) case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") - + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case _ => dos.writeInt(-1) } } else { @@ -131,7 +134,7 @@ class DotnetBackendHandler(server: DotnetBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { + objectsTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) case Some(o) => obj = o @@ -159,8 +162,8 @@ class DotnetBackendHandler(server: DotnetBackend) val ret = selectedMethods(index.get).invoke(obj, args: _*) // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor val ctor = cls.getConstructors.filter { x => @@ -169,15 +172,15 @@ class DotnetBackendHandler(server: DotnetBackend) val obj = ctor.newInstance(args: _*) - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) } else { throw new IllegalArgumentException( "invalid method " + methodName + " for object " + objId) } } catch { case e: Throwable => - val jvmObj = JVMObjectTracker.get(objId) + val jvmObj = objectsTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" @@ -199,15 +202,15 @@ class DotnetBackendHandler(server: DotnetBackend) methods.foreach(m => logDebug(m.toString)) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { arg => - readObject(dis) + serDe.readObject(dis) }.toArray } @@ -326,47 +329,4 @@ class DotnetBackendHandler(server: DotnetBackend) def logError(id: String, e: Exception): Unit = {} } -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private object JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } - - def clear(): Unit = { - synchronized { - objMap.clear() - objCounter = 1 - } - } -} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..8cd2d53be --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,48 @@ +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 427df61b6..838f7fccf 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -18,7 +18,8 @@ import scala.collection.JavaConverters._ * Functions to serialize and deserialize between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -object SerDe { +class SerDe(val tracker: JVMObjectTracker) { + def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -28,7 +29,7 @@ object SerDe { readTypedObject(dis, dataType) } - def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) @@ -41,14 +42,14 @@ object SerDe { case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => tracker.getObject(readString(dis)) case 'R' => readRowArr(dis) case 'O' => readObjectArr(dis) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } - def readBytes(in: DataInputStream): Array[Byte] = { + private def readBytes(in: DataInputStream): Array[Byte] = { val len = readInt(in) val out = new Array[Byte](len) in.readFully(out) @@ -59,15 +60,15 @@ object SerDe { in.readInt() } - def readLong(in: DataInputStream): Long = { + private def readLong(in: DataInputStream): Long = { in.readLong() } - def readDouble(in: DataInputStream): Double = { + private def readDouble(in: DataInputStream): Double = { in.readDouble() } - def readStringBytes(in: DataInputStream, len: Int): String = { + private def readStringBytes(in: DataInputStream, len: Int): String = { val bytes = new Array[Byte](len) in.readFully(bytes) val str = new String(bytes, "UTF-8") @@ -83,11 +84,11 @@ object SerDe { in.readBoolean() } - def readDate(in: DataInputStream): Date = { + private def readDate(in: DataInputStream): Date = { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Timestamp = { + private def readTime(in: DataInputStream): Timestamp = { val seconds = in.readDouble() val sec = Math.floor(seconds).toLong val t = new Timestamp(sec * 1000L) @@ -95,57 +96,57 @@ object SerDe { t } - def readRow(in: DataInputStream): Row = { + private def readRow(in: DataInputStream): Row = { val len = readInt(in) Row.fromSeq((0 until len).map(_ => readObject(in))) } - def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray } - def readIntArr(in: DataInputStream): Array[Int] = { + private def readIntArr(in: DataInputStream): Array[Int] = { val len = readInt(in) (0 until len).map(_ => readInt(in)).toArray } - def readLongArr(in: DataInputStream): Array[Long] = { + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray } - def readDoubleArr(in: DataInputStream): Array[Double] = { + private def readDoubleArr(in: DataInputStream): Array[Double] = { val len = readInt(in) (0 until len).map(_ => readDouble(in)).toArray } - def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { val len = readInt(in) (0 until len).map(_ => readDoubleArr(in)).toArray } - def readBooleanArr(in: DataInputStream): Array[Boolean] = { + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { val len = readInt(in) (0 until len).map(_ => readBoolean(in)).toArray } - def readStringArr(in: DataInputStream): Array[String] = { + private def readStringArr(in: DataInputStream): Array[String] = { val len = readInt(in) (0 until len).map(_ => readString(in)).toArray } - def readRowArr(in: DataInputStream): java.util.List[Row] = { + private def readRowArr(in: DataInputStream): java.util.List[Row] = { val len = readInt(in) (0 until len).map(_ => readRow(in)).toList.asJava } - def readObjectArr(in: DataInputStream): Seq[Any] = { + private def readObjectArr(in: DataInputStream): Seq[Any] = { val len = readInt(in) (0 until len).map(_ => readObject(in)) } - def readList(dis: DataInputStream): Array[_] = { + private def readList(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -154,13 +155,13 @@ object SerDe { case 'd' => readDoubleArr(dis) case 'A' => readDoubleArrArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { val keysType = readObjectType(in) @@ -299,23 +300,23 @@ object SerDe { out.writeLong(value) } - def writeDouble(out: DataOutputStream, value: Double): Unit = { + private def writeDouble(out: DataOutputStream, value: Double): Unit = { out.writeDouble(value) } - def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { out.writeBoolean(value) } - def writeDate(out: DataOutputStream, value: Date): Unit = { + private def writeDate(out: DataOutputStream, value: Date): Unit = { writeString(out, value.toString) } - def writeTime(out: DataOutputStream, value: Time): Unit = { + private def writeTime(out: DataOutputStream, value: Time): Unit = { out.writeDouble(value.getTime.toDouble / 1000.0) } - def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } @@ -326,53 +327,53 @@ object SerDe { out.write(utf8, 0, len) } - def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { out.writeInt(value.length) out.write(value) } def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) + val objId = tracker.put(value) writeString(out, objId) } - def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { writeType(out, "integer") out.writeInt(value.length) value.foreach(v => out.writeInt(v)) } - def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { writeType(out, "long") out.writeInt(value.length) value.foreach(v => out.writeLong(v)) } - def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { writeType(out, "double") out.writeInt(value.length) value.foreach(v => out.writeDouble(v)) } - def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { writeType(out, "doublearray") out.writeInt(value.length) value.foreach(v => writeDoubleArr(out, v)) } - def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { writeType(out, "logical") out.writeInt(value.length) value.foreach(v => writeBoolean(out, v)) } - def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { writeType(out, "character") out.writeInt(value.length) value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { writeType(out, "raw") out.writeInt(value.length) value.foreach(v => writeBytes(out, v)) diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala index c0de9c7bc..5d06d4304 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -6,7 +6,7 @@ package org.apache.spark.sql.api.dotnet -import org.apache.spark.api.dotnet.{CallbackClient, DotnetBackend, SerDe} +import org.apache.spark.api.dotnet.CallbackClient import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.streaming.DataStreamWriter @@ -15,20 +15,19 @@ class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int def call(batchDF: DataFrame, batchId: Long): Unit = callbackClient.send( callbackId, - dos => { - SerDe.writeJObj(dos, batchDF) - SerDe.writeLong(dos, batchId) + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) }) } object DotnetForeachBatchHelper { - def callForeachBatch(dsw: DataStreamWriter[Row], callbackId: Int): Unit = { - val callbackClient = DotnetBackend.callbackClient - if (callbackClient == null) { - throw new Exception("DotnetBackend.callbackClient is null.") + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") } - val dotnetForeachFunc = new DotnetForeachBatchFunction(callbackClient, callbackId) dsw.foreachBatch(dotnetForeachFunc.call _) } } diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 9dee936d7..11caed44b 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -6,20 +6,21 @@ package org.apache.spark.api.dotnet -import org.junit.Test +import org.junit.{Ignore, Test} @Test -class DotnetBackendTest { +class DotnetBackendTest extends { - @Test + @Ignore def shouldReleaseJVMReferencesWhenClose(): Unit = { val backend = new DotnetBackend - val objectId = JVMObjectTracker.put(new Object) + val tracker = new JVMObjectTracker + val objectId = tracker.put(new Object) backend.close() assert( - JVMObjectTracker.get(objectId).isEmpty, + tracker.get(objectId).isEmpty, "JVMObjectTracker must be cleaned up during backend shutdown.") } } diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala index dd1b05c79..43ae79005 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -13,25 +13,27 @@ class JVMObjectTrackerTest { @Test def shouldReleaseAllReferences(): Unit = { - val firstId = JVMObjectTracker.put(new Object) - val secondId = JVMObjectTracker.put(new Object) - val thirdId = JVMObjectTracker.put(new Object) + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) - JVMObjectTracker.clear() + tracker.clear() - assert(JVMObjectTracker.get(firstId).isEmpty) - assert(JVMObjectTracker.get(secondId).isEmpty) - assert(JVMObjectTracker.get(thirdId).isEmpty) + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) } @Test def shouldResetCounter(): Unit = { - val firstId = JVMObjectTracker.put(new Object) - val secondId = JVMObjectTracker.put(new Object) + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) - JVMObjectTracker.clear() + tracker.clear() - val thirdId = JVMObjectTracker.put(new Object) + val thirdId = tracker.put(new Object) assert(firstId.equals("1")) assert(secondId.equals("2")) diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..d39ef89ec --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters._ + +@Test +class SerDeTest { + private var sut: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + sut = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val in = withStreamState(out => { + out.writeByte('n') + }) + + assertEquals(null, sut.readObject(in)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val in = withStreamState(out => { + out.writeByte('_') + }) + + assertThrows(classOf[IllegalArgumentException], () => { + sut.readObject(in) + }) + } + + @Test + def shouldReadInteger(): Unit = { + val in = withStreamState(out => { + out.writeByte('i') + out.writeInt(42) + }) + + assertEquals(42, sut.readObject(in)) + } + + @Test + def shouldReadLong(): Unit = { + val in = withStreamState(state => { + state.writeByte('g') + state.writeLong(42) + }) + + assertEquals(42L, sut.readObject(in)) + } + + @Test + def shouldReadDouble(): Unit = { + val in = withStreamState(state => { + state.writeByte('d') + state.writeDouble(42.42) + }) + + assertEquals(42.42, sut.readObject(in)) + } + + @Test + def shouldReadBoolean(): Unit = { + val in = withStreamState(state => { + state.writeByte('b') + state.writeBoolean(true) + }) + + assertEquals(true, sut.readObject(in)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val in = withStreamState(state => { + state.writeByte('c') + state.writeInt(payload.getBytes("UTF-8").length) + state.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, sut.readObject(in)) + } + + @Test + def shouldReadMap(): Unit = { + val in = withStreamState(state => { + state.writeByte('e') // map type descriptor + state.writeInt(3) // size + state.writeByte('i') // key type + state.writeInt(3) // number of keys + state.writeInt(11) // first key + state.writeInt(22) // second key + state.writeInt(33) // third key + state.writeInt(3) // number of values + state.writeByte('b') // first value type + state.writeBoolean(true) // first value + state.writeByte('d') // second value type + state.writeDouble(42.42) // second value + state.writeByte('n') // third type & value + }) + + assertEquals( + mapAsJavaMap(Map( + 11 -> true, + 22 -> 42.42, + 33 -> null)), + sut.readObject(in)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val in = withStreamState(state => { + state.writeByte('e') // map type descriptor + state.writeInt(0) // size + }) + + assertEquals(mapAsJavaMap(Map()), sut.readObject(in)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val in = withStreamState(state => { + state.writeByte('r') // byte array type descriptor + state.writeInt(3) // length + state.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(in).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val in = withStreamState(state => { + state.writeByte('r') // byte array type descriptor + state.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), sut.readObject(in).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val in = withStreamState(state => { + state.writeByte('l') // type descriptor + state.writeByte('i') // element type + state.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), sut.readObject(in).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val in = withStreamState(state => { + state.writeByte('l') // type descriptor + state.writeByte('b') // element type + state.writeInt(3) // length + state.writeBoolean(true) + state.writeBoolean(false) + state.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), sut.readObject(in).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val in = withStreamState(state => { + state.writeByte('l') // type descriptor + state.writeByte('_') // unsupported element type + }) + + assertThrows(classOf[IllegalArgumentException], () => { + sut.readObject(in) + }) + } + + @Test + def shouldReadDate(): Unit = { + val in = withStreamState(state => { + val date = "2020-12-31" + state.writeByte('D') // type descriptor + state.writeInt(date.getBytes("UTF-8").length) // date string size + state.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), sut.readObject(in)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val in = withStreamState(state => { + val objectIndex = "1" + state.writeByte('j') // type descriptor + state.writeInt(objectIndex.getBytes("UTF-8").length) // size + state.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, sut.readObject(in)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val in = withStreamState(state => { + val objectIndex = "42" + state.writeByte('j') // type descriptor + state.writeInt(objectIndex.getBytes("UTF-8").length) // size + state.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows(classOf[NoSuchElementException], () => { + sut.readObject(in) + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val in = withStreamState(state => { + state.writeByte('R') // type descriptor + state.writeInt(2) // number of rows + state.writeInt(1) // number of elements in 1st row + state.writeByte('i') // type of 1st element in 1st row + state.writeInt(11) + state.writeInt(3) // number of elements in 2st row + state.writeByte('b') // type of 1st element in 2nd row + state.writeBoolean(true) + state.writeByte('d') // type of 2nd element in 2nd row + state.writeDouble(42.24) + state.writeByte('g') // type of 3nd element in 2nd row + state.writeLong(99) + }) + + assertEquals( + seqAsJavaList(Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99)))), + sut.readObject(in)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val in = withStreamState(state => { + state.writeByte('O') // type descriptor + state.writeInt(2) // number of elements + state.writeByte('i') // type of 1st element + state.writeInt(42) + state.writeByte('b') // type of 2nd element + state.writeBoolean(true) + }) + + assertEquals(Seq[Any](42, true), sut.readObject(in).asInstanceOf[Seq[Any]]) + } + + private def withStreamState(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream(); + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } +} From b389da0ee9730b54f5aee354755f7d6444117267 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 2 Jan 2021 01:36:02 +0300 Subject: [PATCH 05/15] Tests for DotnetBackendHandler added --- .../api/dotnet/DotnetBackendHandlerTest.scala | 44 +++ .../spark/api/dotnet/DotnetBackendTest.scala | 26 -- .../apache/spark/api/dotnet/SerDeTest.scala | 319 ++++++++++++------ 3 files changed, 251 insertions(+), 138 deletions(-) create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala delete mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..a43fbf845 --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,44 @@ +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.Test +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val tracker = new JVMObjectTracker + val sut = new DotnetBackendHandler(new DotnetBackend, tracker) + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + m.writeInt(1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = sut.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readAllBytes(), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala deleted file mode 100644 index 11caed44b..000000000 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.{Ignore, Test} - -@Test -class DotnetBackendTest extends { - - @Ignore - def shouldReleaseJVMReferencesWhenClose(): Unit = { - val backend = new DotnetBackend - val tracker = new JVMObjectTracker - val objectId = tracker.put(new Object) - - backend.close() - - assert( - tracker.get(objectId).isEmpty, - "JVMObjectTracker must be cleaned up during backend shutdown.") - } -} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index d39ef89ec..a8607e011 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -10,7 +10,7 @@ import org.apache.spark.sql.Row import org.junit.Assert._ import org.junit.{Before, Test} -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, EOFException} import java.sql.Date import scala.collection.JavaConverters._ @@ -27,92 +27,92 @@ class SerDeTest { @Test def shouldReadNull(): Unit = { - val in = withStreamState(out => { - out.writeByte('n') + val input = givenInput(in => { + in.writeByte('n') }) - assertEquals(null, sut.readObject(in)) + assertEquals(null, sut.readObject(input)) } @Test def shouldThrowForUnsupportedTypes(): Unit = { - val in = withStreamState(out => { - out.writeByte('_') + val input = givenInput(in => { + in.writeByte('_') }) assertThrows(classOf[IllegalArgumentException], () => { - sut.readObject(in) + sut.readObject(input) }) } @Test def shouldReadInteger(): Unit = { - val in = withStreamState(out => { - out.writeByte('i') - out.writeInt(42) + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) }) - assertEquals(42, sut.readObject(in)) + assertEquals(42, sut.readObject(input)) } @Test def shouldReadLong(): Unit = { - val in = withStreamState(state => { - state.writeByte('g') - state.writeLong(42) + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) }) - assertEquals(42L, sut.readObject(in)) + assertEquals(42L, sut.readObject(input)) } @Test def shouldReadDouble(): Unit = { - val in = withStreamState(state => { - state.writeByte('d') - state.writeDouble(42.42) + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) }) - assertEquals(42.42, sut.readObject(in)) + assertEquals(42.42, sut.readObject(input)) } @Test def shouldReadBoolean(): Unit = { - val in = withStreamState(state => { - state.writeByte('b') - state.writeBoolean(true) + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) }) - assertEquals(true, sut.readObject(in)) + assertEquals(true, sut.readObject(input)) } @Test def shouldReadString(): Unit = { val payload = "Spark Dotnet" - val in = withStreamState(state => { - state.writeByte('c') - state.writeInt(payload.getBytes("UTF-8").length) - state.write(payload.getBytes("UTF-8")) + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) }) - assertEquals(payload, sut.readObject(in)) + assertEquals(payload, sut.readObject(input)) } @Test def shouldReadMap(): Unit = { - val in = withStreamState(state => { - state.writeByte('e') // map type descriptor - state.writeInt(3) // size - state.writeByte('i') // key type - state.writeInt(3) // number of keys - state.writeInt(11) // first key - state.writeInt(22) // second key - state.writeInt(33) // third key - state.writeInt(3) // number of values - state.writeByte('b') // first value type - state.writeBoolean(true) // first value - state.writeByte('d') // second value type - state.writeDouble(42.42) // second value - state.writeByte('n') // third type & value + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value }) assertEquals( @@ -120,159 +120,254 @@ class SerDeTest { 11 -> true, 22 -> 42.42, 33 -> null)), - sut.readObject(in)) + sut.readObject(input)) } @Test def shouldReadEmptyMap(): Unit = { - val in = withStreamState(state => { - state.writeByte('e') // map type descriptor - state.writeInt(0) // size + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size }) - assertEquals(mapAsJavaMap(Map()), sut.readObject(in)) + assertEquals(mapAsJavaMap(Map()), sut.readObject(input)) } @Test def shouldReadBytesArray(): Unit = { - val in = withStreamState(state => { - state.writeByte('r') // byte array type descriptor - state.writeInt(3) // length - state.write(Array[Byte](1, 2, 3)) // payload + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload }) - assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(in).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(input).asInstanceOf[Array[Byte]]) } @Test def shouldReadEmptyBytesArray(): Unit = { - val in = withStreamState(state => { - state.writeByte('r') // byte array type descriptor - state.writeInt(0) // length + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length }) - assertArrayEquals(Array[Byte](), sut.readObject(in).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](), sut.readObject(input).asInstanceOf[Array[Byte]]) } @Test def shouldReadEmptyList(): Unit = { - val in = withStreamState(state => { - state.writeByte('l') // type descriptor - state.writeByte('i') // element type - state.writeInt(0) // length + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length }) - assertArrayEquals(Array[Int](), sut.readObject(in).asInstanceOf[Array[Int]]) + assertArrayEquals(Array[Int](), sut.readObject(input).asInstanceOf[Array[Int]]) } @Test def shouldReadList(): Unit = { - val in = withStreamState(state => { - state.writeByte('l') // type descriptor - state.writeByte('b') // element type - state.writeInt(3) // length - state.writeBoolean(true) - state.writeBoolean(false) - state.writeBoolean(true) + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) }) - assertArrayEquals(Array(true, false, true), sut.readObject(in).asInstanceOf[Array[Boolean]]) + assertArrayEquals(Array(true, false, true), sut.readObject(input).asInstanceOf[Array[Boolean]]) } @Test def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { - val in = withStreamState(state => { - state.writeByte('l') // type descriptor - state.writeByte('_') // unsupported element type + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type }) assertThrows(classOf[IllegalArgumentException], () => { - sut.readObject(in) + sut.readObject(input) }) } @Test def shouldReadDate(): Unit = { - val in = withStreamState(state => { + val input = givenInput(in => { val date = "2020-12-31" - state.writeByte('D') // type descriptor - state.writeInt(date.getBytes("UTF-8").length) // date string size - state.write(date.getBytes("UTF-8")) + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) }) - assertEquals(Date.valueOf("2020-12-31"), sut.readObject(in)) + assertEquals(Date.valueOf("2020-12-31"), sut.readObject(input)) } @Test def shouldReadObject(): Unit = { val trackingObject = new Object tracker.put(trackingObject) - val in = withStreamState(state => { + val input = givenInput(in => { val objectIndex = "1" - state.writeByte('j') // type descriptor - state.writeInt(objectIndex.getBytes("UTF-8").length) // size - state.write(objectIndex.getBytes("UTF-8")) + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) }) - assertSame(trackingObject, sut.readObject(in)) + assertSame(trackingObject, sut.readObject(input)) } @Test def shouldThrowWhenReadingNonTrackingObject(): Unit = { - val in = withStreamState(state => { + val input = givenInput(in => { val objectIndex = "42" - state.writeByte('j') // type descriptor - state.writeInt(objectIndex.getBytes("UTF-8").length) // size - state.write(objectIndex.getBytes("UTF-8")) + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) }) assertThrows(classOf[NoSuchElementException], () => { - sut.readObject(in) + sut.readObject(input) }) } @Test def shouldReadSparkRows(): Unit = { - val in = withStreamState(state => { - state.writeByte('R') // type descriptor - state.writeInt(2) // number of rows - state.writeInt(1) // number of elements in 1st row - state.writeByte('i') // type of 1st element in 1st row - state.writeInt(11) - state.writeInt(3) // number of elements in 2st row - state.writeByte('b') // type of 1st element in 2nd row - state.writeBoolean(true) - state.writeByte('d') // type of 2nd element in 2nd row - state.writeDouble(42.24) - state.writeByte('g') // type of 3nd element in 2nd row - state.writeLong(99) + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) }) assertEquals( seqAsJavaList(Seq( Row.fromSeq(Seq(11)), Row.fromSeq(Seq(true, 42.24, 99)))), - sut.readObject(in)) + sut.readObject(input)) } @Test def shouldReadArrayOfObjects(): Unit = { - val in = withStreamState(state => { - state.writeByte('O') // type descriptor - state.writeInt(2) // number of elements - state.writeByte('i') // type of 1st element - state.writeInt(42) - state.writeByte('b') // type of 2nd element - state.writeBoolean(true) + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) }) - assertEquals(Seq[Any](42, true), sut.readObject(in).asInstanceOf[Seq[Any]]) + assertEquals(Seq(42, true), sut.readObject(input).asInstanceOf[Seq[Any]]) } - private def withStreamState(func: DataOutputStream => Unit): DataInputStream = { - val buffer = new ByteArrayOutputStream(); + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + sut.writeObject(out, null) + sut.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + sut.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readAllBytes(), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + sut.writeObject(out, 42.24f.asInstanceOf[Object]) + sut.writeObject(out, 42L.asInstanceOf[Object]) + sut.writeObject(out, 42.asInstanceOf[Object]) + sut.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + sut.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + + assertArrayEquals(in.readAllBytes(), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + sut.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readAllBytes(), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + sut.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifier + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() val out = new DataOutputStream(buffer) func(out) new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) } + + private def whenOutput = givenInput _ + + private def assertEndOfStream (in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } } From 7378726b4078f079ebbb36faa2e5e17e2694c6a9 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 2 Jan 2021 02:31:40 +0300 Subject: [PATCH 06/15] Test for CallbackServer added. --- .../Microsoft.Spark.UnitTest/CallbackTests.cs | 11 +++++++ .../spark/api/dotnet/DotnetBackend.scala | 4 +-- .../api/dotnet/DotnetBackendHandlerTest.scala | 21 ++++++++++-- .../spark/api/dotnet/DotnetBackendTest.scala | 32 +++++++++++++++++++ 4 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala diff --git a/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs b/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs index 776e54ba2..6150f448b 100644 --- a/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs +++ b/src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs @@ -138,6 +138,17 @@ public void TestCallbackHandlers() Assert.Empty(callbackHandler.Inputs); } } + + [Fact] + public void TestJvmCallbackClientProperty() + { + var server = new CallbackServer(_mockJvm.Object, run: false); + Assert.Throws(() => server.JvmCallbackClient); + + using ISocketWrapper callbackSocket = SocketFactory.CreateSocket(); + server.Run(callbackSocket); + Assert.NotNull(server.JvmCallbackClient); + } private void TestCallbackConnection( ConcurrentDictionary callbackHandlersDict, diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index 0d254afac..0073d28bf 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -67,7 +67,7 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } - private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = { callbackClient = callbackClient match { case Some(_) => throw new Exception("Callback client already set.") case None => @@ -76,7 +76,7 @@ class DotnetBackend extends Logging { } } - private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + private[dotnet] def shutdownCallbackClient(): Unit = { callbackClient match { case Some(client) => client.shutdown() case None => logInfo("Callback server has already been shutdown.") diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index a43fbf845..8c573abcc 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -1,15 +1,30 @@ package org.apache.spark.api.dotnet import org.junit.Assert._ -import org.junit.Test +import org.junit.{After, Before, Test} + import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} @Test class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var sut: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + sut = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + @Test def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { - val tracker = new JVMObjectTracker - val sut = new DotnetBackendHandler(new DotnetBackend, tracker) val message = givenMessage(m => { val serDe = new SerDe(null) m.writeBoolean(true) // static method diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..6d858cbe4 --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,32 @@ +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var sut: DotnetBackend = _ + + @Before + def before(): Unit = { + sut = new DotnetBackend + } + + @After + def after(): Unit = { + sut.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // specifying port = 0 to select port dynamically + sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(sut.callbackClient.isDefined) + assertThrows(classOf[Exception], () => { + sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + }) + } +} From 683931c23cbff3063de5c0529e98acca983321e1 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 2 Jan 2021 02:37:53 +0300 Subject: [PATCH 07/15] Revert changes for 2.3 and 2.4 versions --- .../spark/api/dotnet/DotnetBackend.scala | 3 -- .../api/dotnet/DotnetBackendHandler.scala | 7 ---- .../scala/com/microsoft/scala/AppTest.scala | 23 +++++++++++ .../spark/api/dotnet/DotnetBackendTest.scala | 25 ------------ .../api/dotnet/JVMObjectTrackerTest.scala | 40 ------------------- .../spark/api/dotnet/DotnetBackend.scala | 3 -- .../api/dotnet/DotnetBackendHandler.scala | 7 ---- .../scala/com/microsoft/scala/AppTest.scala | 23 +++++++++++ .../spark/api/dotnet/DotnetBackendTest.scala | 25 ------------ .../api/dotnet/JVMObjectTrackerTest.scala | 40 ------------------- .../scala/com/microsoft/scala/AppTest.scala | 23 +++++++++++ 11 files changed, 69 insertions(+), 150 deletions(-) create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala delete mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala delete mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index c51f502f3..f7ee92f0f 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -82,9 +82,6 @@ class DotnetBackend extends Logging { } bootstrap = null - // Release references to JVM objects to let them collected by GC - JVMObjectTracker.clear() - // Send close to .NET callback server. DotnetBackend.shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 7c80c3a1d..e632589e4 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -362,11 +362,4 @@ private object JVMObjectTracker { objMap.remove(id) } } - - def clear(): Unit = { - synchronized { - objMap.clear() - objCounter = 1 - } - } } diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala new file mode 100644 index 000000000..230042b8a --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package com.microsoft.scala + +import org.junit._ +import Assert._ + +@Test +class AppTest { + + @Test + def testOK() = assertTrue(true) + +// @Test +// def testKO() = assertTrue(false) + +} + + diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala deleted file mode 100644 index 9dee936d7..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Test - -@Test -class DotnetBackendTest { - - @Test - def shouldReleaseJVMReferencesWhenClose(): Unit = { - val backend = new DotnetBackend - val objectId = JVMObjectTracker.put(new Object) - - backend.close() - - assert( - JVMObjectTracker.get(objectId).isEmpty, - "JVMObjectTracker must be cleaned up during backend shutdown.") - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala deleted file mode 100644 index dd1b05c79..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Test - -@Test -class JVMObjectTrackerTest { - - @Test - def shouldReleaseAllReferences(): Unit = { - val firstId = JVMObjectTracker.put(new Object) - val secondId = JVMObjectTracker.put(new Object) - val thirdId = JVMObjectTracker.put(new Object) - - JVMObjectTracker.clear() - - assert(JVMObjectTracker.get(firstId).isEmpty) - assert(JVMObjectTracker.get(secondId).isEmpty) - assert(JVMObjectTracker.get(thirdId).isEmpty) - } - - @Test - def shouldResetCounter(): Unit = { - val firstId = JVMObjectTracker.put(new Object) - val secondId = JVMObjectTracker.put(new Object) - - JVMObjectTracker.clear() - - val thirdId = JVMObjectTracker.put(new Object) - - assert(firstId.equals("1")) - assert(secondId.equals("2")) - assert(thirdId.equals("1")) - } -} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index c51f502f3..f7ee92f0f 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -82,9 +82,6 @@ class DotnetBackend extends Logging { } bootstrap = null - // Release references to JVM objects to let them collected by GC - JVMObjectTracker.clear() - // Send close to .NET callback server. DotnetBackend.shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 7c80c3a1d..e632589e4 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -362,11 +362,4 @@ private object JVMObjectTracker { objMap.remove(id) } } - - def clear(): Unit = { - synchronized { - objMap.clear() - objCounter = 1 - } - } } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala new file mode 100644 index 000000000..230042b8a --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package com.microsoft.scala + +import org.junit._ +import Assert._ + +@Test +class AppTest { + + @Test + def testOK() = assertTrue(true) + +// @Test +// def testKO() = assertTrue(false) + +} + + diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala deleted file mode 100644 index 9dee936d7..000000000 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Test - -@Test -class DotnetBackendTest { - - @Test - def shouldReleaseJVMReferencesWhenClose(): Unit = { - val backend = new DotnetBackend - val objectId = JVMObjectTracker.put(new Object) - - backend.close() - - assert( - JVMObjectTracker.get(objectId).isEmpty, - "JVMObjectTracker must be cleaned up during backend shutdown.") - } -} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala deleted file mode 100644 index dd1b05c79..000000000 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Test - -@Test -class JVMObjectTrackerTest { - - @Test - def shouldReleaseAllReferences(): Unit = { - val firstId = JVMObjectTracker.put(new Object) - val secondId = JVMObjectTracker.put(new Object) - val thirdId = JVMObjectTracker.put(new Object) - - JVMObjectTracker.clear() - - assert(JVMObjectTracker.get(firstId).isEmpty) - assert(JVMObjectTracker.get(secondId).isEmpty) - assert(JVMObjectTracker.get(thirdId).isEmpty) - } - - @Test - def shouldResetCounter(): Unit = { - val firstId = JVMObjectTracker.put(new Object) - val secondId = JVMObjectTracker.put(new Object) - - JVMObjectTracker.clear() - - val thirdId = JVMObjectTracker.put(new Object) - - assert(firstId.equals("1")) - assert(secondId.equals("2")) - assert(thirdId.equals("1")) - } -} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala new file mode 100644 index 000000000..230042b8a --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package com.microsoft.scala + +import org.junit._ +import Assert._ + +@Test +class AppTest { + + @Test + def testOK() = assertTrue(true) + +// @Test +// def testKO() = assertTrue(false) + +} + + From 325afcba10fae40080154d39306947be5c2aecd4 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 2 Jan 2021 02:38:53 +0300 Subject: [PATCH 08/15] AppTest for 3.0 removed --- .../scala/com/microsoft/scala/AppTest.scala | 23 ------------------- 1 file changed, 23 deletions(-) delete mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-3-0/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - From 2dfe7792c001110dfb32312e533d57bf8e50fd5f Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 2 Jan 2021 03:20:37 +0300 Subject: [PATCH 09/15] Add readNBytes extension for InputDataStream, as Java 8 standard libraru doesn't contain such method. --- .../api/dotnet/DotnetBackendHandlerTest.scala | 3 ++- .../org/apache/spark/api/dotnet/Extensions.scala | 13 +++++++++++++ .../org/apache/spark/api/dotnet/SerDeTest.scala | 14 +++++++------- 3 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index 8c573abcc..95872bbfb 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -1,5 +1,6 @@ package org.apache.spark.api.dotnet +import Extensions._ import org.junit.Assert._ import org.junit.{After, Before, Test} @@ -45,7 +46,7 @@ class DotnetBackendHandlerTest { "status code must be successful.", 0, reply.readInt()) assertEquals('j', reply.readByte()) assertEquals(1, reply.readInt()) - val trackingId = new String(reply.readAllBytes(), "UTF-8") + val trackingId = new String(reply.readN(1), "UTF-8") assertEquals("1", trackingId) val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull assertEquals(classOf[CallbackClient], client.getClass) diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..3ca57dfea --- /dev/null +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,13 @@ +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readN(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index a8607e011..84f094d0a 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -13,6 +13,7 @@ import org.junit.{Before, Test} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, EOFException} import java.sql.Date import scala.collection.JavaConverters._ +import Extensions._ @Test class SerDeTest { @@ -290,7 +291,7 @@ class SerDeTest { assertEquals(in.readByte(), 'c') // object type assertEquals(in.readInt(), sparkDotnet.length) // length - assertArrayEquals(in.readAllBytes(), sparkDotnet.getBytes("UTF-8")) + assertArrayEquals(in.readN(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) assertEndOfStream(in) } @@ -323,8 +324,7 @@ class SerDeTest { assertEquals(in.readByte(), 'D') // type assertEquals(in.readInt(), 10) // size - - assertArrayEquals(in.readAllBytes(), date.getBytes("UTF-8")) // content + assertArrayEquals(in.readN(10), date.getBytes("UTF-8")) // content } @Test @@ -336,7 +336,7 @@ class SerDeTest { assertEquals(in.readByte(), 'j') assertEquals(in.readInt(), 1) - assertArrayEquals(in.readAllBytes(), "1".getBytes("UTF-8")) + assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) assertSame(tracker.get("1").get, customObject) } @@ -350,10 +350,10 @@ class SerDeTest { assertEquals(in.readByte(), 'l') // array type assertEquals(in.readByte(), 'j') // type of element in array assertEquals(in.readInt(), 2) // array length - assertEquals(in.readInt(), 1) // size of 1st element's identifier - assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) // identifier of 1st element assertEquals(in.readInt(), 1) // size of 2nd element's identifier - assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertArrayEquals(in.readN(1), "2".getBytes("UTF-8")) // identifier of 2nd element assertSame(tracker.get("1").get, payload(0)) assertSame(tracker.get("2").get, payload(1)) } From ff532ec2ca3f5dd79c5c16ad2f7fa618c9eae656 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sat, 2 Jan 2021 03:22:08 +0300 Subject: [PATCH 10/15] Licence headers added --- .../org/apache/spark/api/dotnet/JVMObjectTracker.scala | 7 +++++++ .../apache/spark/api/dotnet/DotnetBackendHandlerTest.scala | 7 +++++++ .../org/apache/spark/api/dotnet/DotnetBackendTest.scala | 7 +++++++ .../scala/org/apache/spark/api/dotnet/Extensions.scala | 7 +++++++ 4 files changed, 28 insertions(+) diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala index 8cd2d53be..81cfaf88b 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -1,3 +1,10 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + package org.apache.spark.api.dotnet import scala.collection.mutable.HashMap diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index 95872bbfb..231c2fecf 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -1,3 +1,10 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + package org.apache.spark.api.dotnet import Extensions._ diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 6d858cbe4..bf4001c3b 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -1,3 +1,10 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + package org.apache.spark.api.dotnet import org.junit.Assert._ diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala index 3ca57dfea..704c248b9 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -1,3 +1,10 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + package org.apache.spark.api.dotnet import java.io.DataInputStream From 5fa52181544893212d84f768ae0fd783347718ce Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Sun, 3 Jan 2021 19:57:16 +0300 Subject: [PATCH 11/15] Apply 3.0 changes for 2.3 --- .../Interop/Ipc/CallbackServer.cs | 2 +- .../spark/api/dotnet/CallbackClient.scala | 6 +- .../spark/api/dotnet/CallbackConnection.scala | 22 +- .../spark/api/dotnet/DotnetBackend.scala | 47 ++- .../api/dotnet/DotnetBackendHandler.scala | 117 ++---- .../spark/api/dotnet/JVMObjectTracker.scala | 54 +++ .../org/apache/spark/api/dotnet/SerDe.scala | 72 ++-- .../scala/com/microsoft/scala/AppTest.scala | 23 -- .../api/dotnet/DotnetBackendHandlerTest.scala | 61 +++ .../spark/api/dotnet/DotnetBackendTest.scala | 43 ++ .../apache/spark/api/dotnet/Extensions.scala | 19 + .../api/dotnet/JVMObjectTrackerTest.scala | 42 ++ .../apache/spark/api/dotnet/SerDeTest.scala | 386 ++++++++++++++++++ .../spark/api/dotnet/DotnetBackend.scala | 5 +- .../apache/spark/api/dotnet/SerDeTest.scala | 4 +- 15 files changed, 726 insertions(+), 177 deletions(-) create mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs index 055493b42..d86fd7305 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs @@ -76,7 +76,7 @@ internal JvmObjectReference JvmCallbackClient if (_jvmCallbackClient is null) { throw new InvalidOperationException( - "Please start CallbackServer before accessing JvmCallbackClient."); + "Please make sure that CallbackServer was started before accessing JvmCallbackClient."); } return _jvmCallbackClient; diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala index 90ad92439..aea355dfa 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackClient(address: String, port: Int) extends Logging { +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() private[this] var isShutdown: Boolean = false - final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit = + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = getOrCreateConnection() match { case Some(connection) => try { @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging { return Some(connectionPool.dequeue()) } - Some(new CallbackConnection(address, port)) + Some(new CallbackConnection(serDe, address, port)) } private def addConnection(connection: CallbackConnection): Unit = synchronized { diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala index 36726181e..604cf029b 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackConnection(address: String, port: Int) extends Logging { +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val socket: Socket = new Socket(address, port) private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) def send( callbackId: Int, - writeBody: DataOutputStream => Unit): Unit = { + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { logInfo(s"Calling callback [callback id = $callbackId] ...") try { - SerDe.writeInt(outputStream, CallbackFlags.CALLBACK) - SerDe.writeInt(outputStream, callbackId) + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream)) - SerDe.writeInt(outputStream, byteArrayOutputStream.size) + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) byteArrayOutputStream.writeTo(outputStream); } catch { case e: Exception => { @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Signaling END_OF_STREAM.") try { - SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) outputStream.flush() val endOfStreamResponse = readFlag(inputStream) @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { def close(): Unit = { try { - SerDe.writeInt(outputStream, CallbackFlags.CLOSE) + serDe.writeInt(outputStream, CallbackFlags.CLOSE) outputStream.flush() } catch { case e: Exception => logInfo("Unable to send close to .NET callback server.", e) @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging { } private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = SerDe.readInt(inputStream) + val callbackFlag = serDe.readInt(inputStream) if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = SerDe.readString(inputStream) + val exceptionMessage = serDe.readString(inputStream) throw new DotnetException(exceptionMessage) } callbackFlag @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging { val DOTNET_EXCEPTION_THROWN: Int = -3 val END_OF_STREAM: Int = -4 } -} \ No newline at end of file +} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..0bd08745a 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -30,6 +30,10 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectsTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None def init(portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -55,7 +59,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self)) + .addLast("handler", new DotnetBackendHandler(self, objectsTracker)) } }) @@ -64,6 +68,23 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectsTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + def run(): Unit = { channelFuture.channel.closeFuture().syncUninterruptibly() } @@ -82,30 +103,12 @@ class DotnetBackend extends Logging { } bootstrap = null + objectsTracker.clear() + // Send close to .NET callback server. - DotnetBackend.shutdownCallbackClient() + shutdownCallbackClient() // Shutdown the thread pool whose executors could still be running. ThreadPool.shutdown() } } - -object DotnetBackend extends Logging { - @volatile private[spark] var callbackClient: CallbackClient = null - - private[spark] def setCallbackClient(address: String, port: Int) = synchronized { - if (DotnetBackend.callbackClient == null) { - logInfo(s"Connecting to a callback server at $address:$port") - DotnetBackend.callbackClient = new CallbackClient(address, port) - } else { - throw new Exception("Callback client already set.") - } - } - - private[spark] def shutdownCallbackClient(): Unit = synchronized { - if (callbackClient != null) { - callbackClient.shutdown() - callbackClient = null - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index e632589e4..400d95ed4 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -7,12 +7,9 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -20,10 +17,12 @@ import org.apache.spark.util.Utils * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. */ -class DotnetBackendHandler(server: DotnetBackend) +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + private[this] val serDe = new SerDe(objectsTracker) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { val reply = handleBackendRequest(msg) ctx.write(reply) @@ -41,56 +40,56 @@ class DotnetBackendHandler(server: DotnetBackend) val dos = new DataOutputStream(bos) // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) + val isStatic = serDe.readBoolean(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) if (objId == "DotnetHandler") { methodName match { case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") server.close() case "rm" => try { - val t = readObjectType(dis) + val t = serDe.readObjectType(dis) assert(t == 'c') - val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) - writeInt(dos, 0) - writeObject(dos, null) + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) } catch { case e: Exception => logError(s"Removing $objId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "rmThread" => try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, result.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port) - writeInt(dos, 0) - writeType(dos, "void") + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case _ => dos.writeInt(-1) } @@ -131,7 +130,7 @@ class DotnetBackendHandler(server: DotnetBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { + objectsTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) case Some(o) => obj = o @@ -159,8 +158,8 @@ class DotnetBackendHandler(server: DotnetBackend) val ret = selectedMethods(index.get).invoke(obj, args: _*) // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor val ctor = cls.getConstructors.filter { x => @@ -169,15 +168,15 @@ class DotnetBackendHandler(server: DotnetBackend) val obj = ctor.newInstance(args: _*) - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) } else { throw new IllegalArgumentException( "invalid method " + methodName + " for object " + objId) } } catch { case e: Throwable => - val jvmObj = JVMObjectTracker.get(objId) + val jvmObj = objectsTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" @@ -199,15 +198,15 @@ class DotnetBackendHandler(server: DotnetBackend) methods.foreach(m => logDebug(m.toString)) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { arg => - readObject(dis) + serDe.readObject(dis) }.toArray } @@ -326,40 +325,4 @@ class DotnetBackendHandler(server: DotnetBackend) def logError(id: String, e: Exception): Unit = {} } -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private object JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..aceb58c01 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 427df61b6..6cf1ddc73 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -18,7 +18,7 @@ import scala.collection.JavaConverters._ * Functions to serialize and deserialize between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -object SerDe { +class SerDe(var tracker: JVMObjectTracker) { def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -28,7 +28,7 @@ object SerDe { readTypedObject(dis, dataType) } - def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) @@ -41,14 +41,14 @@ object SerDe { case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => tracker.getObject(readString(dis)) case 'R' => readRowArr(dis) case 'O' => readObjectArr(dis) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } - def readBytes(in: DataInputStream): Array[Byte] = { + private def readBytes(in: DataInputStream): Array[Byte] = { val len = readInt(in) val out = new Array[Byte](len) in.readFully(out) @@ -59,15 +59,15 @@ object SerDe { in.readInt() } - def readLong(in: DataInputStream): Long = { + private def readLong(in: DataInputStream): Long = { in.readLong() } - def readDouble(in: DataInputStream): Double = { + private def readDouble(in: DataInputStream): Double = { in.readDouble() } - def readStringBytes(in: DataInputStream, len: Int): String = { + private def readStringBytes(in: DataInputStream, len: Int): String = { val bytes = new Array[Byte](len) in.readFully(bytes) val str = new String(bytes, "UTF-8") @@ -83,11 +83,11 @@ object SerDe { in.readBoolean() } - def readDate(in: DataInputStream): Date = { + private def readDate(in: DataInputStream): Date = { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Timestamp = { + private def readTime(in: DataInputStream): Timestamp = { val seconds = in.readDouble() val sec = Math.floor(seconds).toLong val t = new Timestamp(sec * 1000L) @@ -95,57 +95,57 @@ object SerDe { t } - def readRow(in: DataInputStream): Row = { + private def readRow(in: DataInputStream): Row = { val len = readInt(in) Row.fromSeq((0 until len).map(_ => readObject(in))) } - def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray } - def readIntArr(in: DataInputStream): Array[Int] = { + private def readIntArr(in: DataInputStream): Array[Int] = { val len = readInt(in) (0 until len).map(_ => readInt(in)).toArray } - def readLongArr(in: DataInputStream): Array[Long] = { + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray } - def readDoubleArr(in: DataInputStream): Array[Double] = { + private def readDoubleArr(in: DataInputStream): Array[Double] = { val len = readInt(in) (0 until len).map(_ => readDouble(in)).toArray } - def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { val len = readInt(in) (0 until len).map(_ => readDoubleArr(in)).toArray } - def readBooleanArr(in: DataInputStream): Array[Boolean] = { + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { val len = readInt(in) (0 until len).map(_ => readBoolean(in)).toArray } - def readStringArr(in: DataInputStream): Array[String] = { + private def readStringArr(in: DataInputStream): Array[String] = { val len = readInt(in) (0 until len).map(_ => readString(in)).toArray } - def readRowArr(in: DataInputStream): java.util.List[Row] = { + private def readRowArr(in: DataInputStream): java.util.List[Row] = { val len = readInt(in) (0 until len).map(_ => readRow(in)).toList.asJava } - def readObjectArr(in: DataInputStream): Seq[Any] = { + private def readObjectArr(in: DataInputStream): Seq[Any] = { val len = readInt(in) (0 until len).map(_ => readObject(in)) } - def readList(dis: DataInputStream): Array[_] = { + private def readList(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -154,13 +154,13 @@ object SerDe { case 'd' => readDoubleArr(dis) case 'A' => readDoubleArrArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { val keysType = readObjectType(in) @@ -299,23 +299,23 @@ object SerDe { out.writeLong(value) } - def writeDouble(out: DataOutputStream, value: Double): Unit = { + private def writeDouble(out: DataOutputStream, value: Double): Unit = { out.writeDouble(value) } - def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { out.writeBoolean(value) } - def writeDate(out: DataOutputStream, value: Date): Unit = { + private def writeDate(out: DataOutputStream, value: Date): Unit = { writeString(out, value.toString) } - def writeTime(out: DataOutputStream, value: Time): Unit = { + private def writeTime(out: DataOutputStream, value: Time): Unit = { out.writeDouble(value.getTime.toDouble / 1000.0) } - def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } @@ -326,53 +326,53 @@ object SerDe { out.write(utf8, 0, len) } - def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { out.writeInt(value.length) out.write(value) } def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) + val objId = tracker.put(value) writeString(out, objId) } - def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { writeType(out, "integer") out.writeInt(value.length) value.foreach(v => out.writeInt(v)) } - def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { writeType(out, "long") out.writeInt(value.length) value.foreach(v => out.writeLong(v)) } - def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { writeType(out, "double") out.writeInt(value.length) value.foreach(v => out.writeDouble(v)) } - def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { writeType(out, "doublearray") out.writeInt(value.length) value.foreach(v => writeDoubleArr(out, v)) } - def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { writeType(out, "logical") out.writeInt(value.length) value.foreach(v => writeBoolean(out, v)) } - def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { writeType(out, "character") out.writeInt(value.length) value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { writeType(out, "raw") out.writeInt(value.length) value.foreach(v => writeBytes(out, v)) diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..848b0f3ca --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var sut: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + sut = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + m.writeInt(1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = sut.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('n', reply.readByte()) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..3acdbf274 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var sut: DotnetBackend = _ + + @Before + def before(): Unit = { + sut = new DotnetBackend + } + + @After + def after(): Unit = { + sut.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // specifying port = 0 to select port dynamically + sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(sut.callbackClient.isDefined) + assertThrows( + classOf[Exception], + new ThrowingRunnable { + override def run(): Unit = { + sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + } + }) + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..53b2705b0 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readN(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..1199114a6 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConverter} + +@Test +class SerDeTest { + private var sut: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + sut = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, sut.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + sut.readObject(input) + } + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, sut.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, sut.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, sut.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, sut.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, sut.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + Map( + 11 -> true, + 22 -> 42.42, + 33 -> null).asJava, + sut.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(Map().asJava, sut.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), sut.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), sut.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), sut.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + sut.readObject(input) + } + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), sut.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, sut.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows( + classOf[NoSuchElementException], + new ThrowingRunnable { + override def run(): Unit = { + sut.readObject(input) + } + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99))).asJava, + sut.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), sut.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + sut.writeObject(out, null) + sut.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + sut.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readN(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + sut.writeObject(out, 42.24f.asInstanceOf[Object]) + sut.writeObject(out, 42L.asInstanceOf[Object]) + sut.writeObject(out, 42.asInstanceOf[Object]) + sut.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + sut.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readN(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + sut.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + sut.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readN(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream(in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index 0073d28bf..d058653cb 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -67,7 +67,7 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } - private[dotnet] def setCallbackClient(address: String, port: Int): Unit = { + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { callbackClient = callbackClient match { case Some(_) => throw new Exception("Callback client already set.") case None => @@ -76,11 +76,12 @@ class DotnetBackend extends Logging { } } - private[dotnet] def shutdownCallbackClient(): Unit = { + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { callbackClient match { case Some(client) => client.shutdown() case None => logInfo("Callback server has already been shutdown.") } + callbackClient = None } def run(): Unit = { diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index 84f094d0a..4343d8e39 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -6,14 +6,14 @@ package org.apache.spark.api.dotnet +import org.apache.spark.api.dotnet.Extensions._ import org.apache.spark.sql.Row import org.junit.Assert._ import org.junit.{Before, Test} -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.sql.Date import scala.collection.JavaConverters._ -import Extensions._ @Test class SerDeTest { From 0421621be5b075464af0a84d2bc276145130db24 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Mon, 4 Jan 2021 17:53:15 +0300 Subject: [PATCH 12/15] Apply 3.0 changes for 2.4 --- .../api/dotnet/DotnetBackendHandler.scala | 2 - .../api/dotnet/DotnetBackendHandlerTest.scala | 1 - .../spark/api/dotnet/CallbackClient.scala | 6 +- .../spark/api/dotnet/CallbackConnection.scala | 24 +- .../spark/api/dotnet/DotnetBackend.scala | 47 ++- .../api/dotnet/DotnetBackendHandler.scala | 121 ++---- .../spark/api/dotnet/JVMObjectTracker.scala | 54 +++ .../org/apache/spark/api/dotnet/SerDe.scala | 72 ++-- .../sql/api/dotnet/DotnetForeachBatch.scala | 15 +- .../scala/com/microsoft/scala/AppTest.scala | 23 -- .../api/dotnet/DotnetBackendHandlerTest.scala | 66 +++ .../spark/api/dotnet/DotnetBackendTest.scala | 43 ++ .../apache/spark/api/dotnet/Extensions.scala | 19 + .../api/dotnet/JVMObjectTrackerTest.scala | 42 ++ .../apache/spark/api/dotnet/SerDeTest.scala | 386 ++++++++++++++++++ 15 files changed, 737 insertions(+), 184 deletions(-) create mode 100644 src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala delete mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala create mode 100644 src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 400d95ed4..4d32e43fb 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -324,5 +324,3 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack def logError(id: String, e: Exception): Unit = {} } - - diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index 848b0f3ca..6c9b6d2f2 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -4,7 +4,6 @@ * See the LICENSE file in the project root for more information. */ - package org.apache.spark.api.dotnet import org.junit.Assert._ diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala index 90ad92439..aea355dfa 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -20,12 +20,12 @@ import scala.collection.mutable.Queue * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackClient(address: String, port: Int) extends Logging { +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() private[this] var isShutdown: Boolean = false - final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit = + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = getOrCreateConnection() match { case Some(connection) => try { @@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging { return Some(connectionPool.dequeue()) } - Some(new CallbackConnection(address, port)) + Some(new CallbackConnection(serDe, address, port)) } private def addConnection(connection: CallbackConnection): Unit = synchronized { diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala index 36726181e..258f02aeb 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging * @param address The address of the Dotnet CallbackServer * @param port The port of the Dotnet CallbackServer */ -class CallbackConnection(address: String, port: Int) extends Logging { +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { private[this] val socket: Socket = new Socket(address, port) private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) def send( callbackId: Int, - writeBody: DataOutputStream => Unit): Unit = { + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { logInfo(s"Calling callback [callback id = $callbackId] ...") try { - SerDe.writeInt(outputStream, CallbackFlags.CALLBACK) - SerDe.writeInt(outputStream, callbackId) + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream)) - SerDe.writeInt(outputStream, byteArrayOutputStream.size) + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) byteArrayOutputStream.writeTo(outputStream); } catch { case e: Exception => { @@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Signaling END_OF_STREAM.") try { - SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) outputStream.flush() val endOfStreamResponse = readFlag(inputStream) @@ -53,7 +53,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.") case _ => { throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " + - s"Received: $endOfStreamResponse") + s"Received: $endOfStreamResponse") } } } catch { @@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging { def close(): Unit = { try { - SerDe.writeInt(outputStream, CallbackFlags.CLOSE) + serDe.writeInt(outputStream, CallbackFlags.CLOSE) outputStream.flush() } catch { case e: Exception => logInfo("Unable to send close to .NET callback server.", e) @@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging { } private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = SerDe.readInt(inputStream) + val callbackFlag = serDe.readInt(inputStream) if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = SerDe.readString(inputStream) + val exceptionMessage = serDe.readString(inputStream) throw new DotnetException(exceptionMessage) } callbackFlag @@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging { val DOTNET_EXCEPTION_THROWN: Int = -3 val END_OF_STREAM: Int = -4 } -} \ No newline at end of file +} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index f7ee92f0f..0bd08745a 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -30,6 +30,10 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectsTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None def init(portNumber: Int): Int = { val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -55,7 +59,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self)) + .addLast("handler", new DotnetBackendHandler(self, objectsTracker)) } }) @@ -64,6 +68,23 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectsTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + def run(): Unit = { channelFuture.channel.closeFuture().syncUninterruptibly() } @@ -82,30 +103,12 @@ class DotnetBackend extends Logging { } bootstrap = null + objectsTracker.clear() + // Send close to .NET callback server. - DotnetBackend.shutdownCallbackClient() + shutdownCallbackClient() // Shutdown the thread pool whose executors could still be running. ThreadPool.shutdown() } } - -object DotnetBackend extends Logging { - @volatile private[spark] var callbackClient: CallbackClient = null - - private[spark] def setCallbackClient(address: String, port: Int) = synchronized { - if (DotnetBackend.callbackClient == null) { - logInfo(s"Connecting to a callback server at $address:$port") - DotnetBackend.callbackClient = new CallbackClient(address, port) - } else { - throw new Exception("Callback client already set.") - } - } - - private[spark] def shutdownCallbackClient(): Unit = synchronized { - if (callbackClient != null) { - callbackClient.shutdown() - callbackClient = null - } - } -} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index e632589e4..bdac858a6 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -7,12 +7,9 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -20,10 +17,12 @@ import org.apache.spark.util.Utils * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. */ -class DotnetBackendHandler(server: DotnetBackend) +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + private[this] val serDe = new SerDe(objectsTracker) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { val reply = handleBackendRequest(msg) ctx.write(reply) @@ -41,56 +40,60 @@ class DotnetBackendHandler(server: DotnetBackend) val dos = new DataOutputStream(bos) // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) + val isStatic = serDe.readBoolean(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) if (objId == "DotnetHandler") { methodName match { case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") server.close() case "rm" => try { - val t = readObjectType(dis) + val t = serDe.readObjectType(dis) assert(t == 'c') - val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) - writeInt(dos, 0) - writeObject(dos, null) + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) } catch { case e: Exception => logError(s"Removing $objId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "rmThread" => try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, result.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) + serDe.writeInt(dos, -1) } case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port) - writeInt(dos, 0) - writeType(dos, "void") + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions + serDe.writeObject(dos, server.callbackClient) case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") case _ => dos.writeInt(-1) } @@ -131,7 +134,7 @@ class DotnetBackendHandler(server: DotnetBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { + objectsTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) case Some(o) => obj = o @@ -159,8 +162,8 @@ class DotnetBackendHandler(server: DotnetBackend) val ret = selectedMethods(index.get).invoke(obj, args: _*) // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor val ctor = cls.getConstructors.filter { x => @@ -169,15 +172,15 @@ class DotnetBackendHandler(server: DotnetBackend) val obj = ctor.newInstance(args: _*) - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) } else { throw new IllegalArgumentException( "invalid method " + methodName + " for object " + objId) } } catch { case e: Throwable => - val jvmObj = JVMObjectTracker.get(objId) + val jvmObj = objectsTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" @@ -199,15 +202,15 @@ class DotnetBackendHandler(server: DotnetBackend) methods.foreach(m => logDebug(m.toString)) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { arg => - readObject(dis) + serDe.readObject(dis) }.toArray } @@ -326,40 +329,4 @@ class DotnetBackendHandler(server: DotnetBackend) def logError(id: String, e: Exception): Unit = {} } -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private object JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } -} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..aceb58c01 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 427df61b6..6cf1ddc73 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -18,7 +18,7 @@ import scala.collection.JavaConverters._ * Functions to serialize and deserialize between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -object SerDe { +class SerDe(var tracker: JVMObjectTracker) { def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -28,7 +28,7 @@ object SerDe { readTypedObject(dis, dataType) } - def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) @@ -41,14 +41,14 @@ object SerDe { case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => tracker.getObject(readString(dis)) case 'R' => readRowArr(dis) case 'O' => readObjectArr(dis) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } - def readBytes(in: DataInputStream): Array[Byte] = { + private def readBytes(in: DataInputStream): Array[Byte] = { val len = readInt(in) val out = new Array[Byte](len) in.readFully(out) @@ -59,15 +59,15 @@ object SerDe { in.readInt() } - def readLong(in: DataInputStream): Long = { + private def readLong(in: DataInputStream): Long = { in.readLong() } - def readDouble(in: DataInputStream): Double = { + private def readDouble(in: DataInputStream): Double = { in.readDouble() } - def readStringBytes(in: DataInputStream, len: Int): String = { + private def readStringBytes(in: DataInputStream, len: Int): String = { val bytes = new Array[Byte](len) in.readFully(bytes) val str = new String(bytes, "UTF-8") @@ -83,11 +83,11 @@ object SerDe { in.readBoolean() } - def readDate(in: DataInputStream): Date = { + private def readDate(in: DataInputStream): Date = { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Timestamp = { + private def readTime(in: DataInputStream): Timestamp = { val seconds = in.readDouble() val sec = Math.floor(seconds).toLong val t = new Timestamp(sec * 1000L) @@ -95,57 +95,57 @@ object SerDe { t } - def readRow(in: DataInputStream): Row = { + private def readRow(in: DataInputStream): Row = { val len = readInt(in) Row.fromSeq((0 until len).map(_ => readObject(in))) } - def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray } - def readIntArr(in: DataInputStream): Array[Int] = { + private def readIntArr(in: DataInputStream): Array[Int] = { val len = readInt(in) (0 until len).map(_ => readInt(in)).toArray } - def readLongArr(in: DataInputStream): Array[Long] = { + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray } - def readDoubleArr(in: DataInputStream): Array[Double] = { + private def readDoubleArr(in: DataInputStream): Array[Double] = { val len = readInt(in) (0 until len).map(_ => readDouble(in)).toArray } - def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { val len = readInt(in) (0 until len).map(_ => readDoubleArr(in)).toArray } - def readBooleanArr(in: DataInputStream): Array[Boolean] = { + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { val len = readInt(in) (0 until len).map(_ => readBoolean(in)).toArray } - def readStringArr(in: DataInputStream): Array[String] = { + private def readStringArr(in: DataInputStream): Array[String] = { val len = readInt(in) (0 until len).map(_ => readString(in)).toArray } - def readRowArr(in: DataInputStream): java.util.List[Row] = { + private def readRowArr(in: DataInputStream): java.util.List[Row] = { val len = readInt(in) (0 until len).map(_ => readRow(in)).toList.asJava } - def readObjectArr(in: DataInputStream): Seq[Any] = { + private def readObjectArr(in: DataInputStream): Seq[Any] = { val len = readInt(in) (0 until len).map(_ => readObject(in)) } - def readList(dis: DataInputStream): Array[_] = { + private def readList(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -154,13 +154,13 @@ object SerDe { case 'd' => readDoubleArr(dis) case 'A' => readDoubleArrArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { val keysType = readObjectType(in) @@ -299,23 +299,23 @@ object SerDe { out.writeLong(value) } - def writeDouble(out: DataOutputStream, value: Double): Unit = { + private def writeDouble(out: DataOutputStream, value: Double): Unit = { out.writeDouble(value) } - def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { out.writeBoolean(value) } - def writeDate(out: DataOutputStream, value: Date): Unit = { + private def writeDate(out: DataOutputStream, value: Date): Unit = { writeString(out, value.toString) } - def writeTime(out: DataOutputStream, value: Time): Unit = { + private def writeTime(out: DataOutputStream, value: Time): Unit = { out.writeDouble(value.getTime.toDouble / 1000.0) } - def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } @@ -326,53 +326,53 @@ object SerDe { out.write(utf8, 0, len) } - def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { out.writeInt(value.length) out.write(value) } def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) + val objId = tracker.put(value) writeString(out, objId) } - def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { writeType(out, "integer") out.writeInt(value.length) value.foreach(v => out.writeInt(v)) } - def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { writeType(out, "long") out.writeInt(value.length) value.foreach(v => out.writeLong(v)) } - def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { writeType(out, "double") out.writeInt(value.length) value.foreach(v => out.writeDouble(v)) } - def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { writeType(out, "doublearray") out.writeInt(value.length) value.foreach(v => writeDoubleArr(out, v)) } - def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { writeType(out, "logical") out.writeInt(value.length) value.foreach(v => writeBoolean(out, v)) } - def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { writeType(out, "character") out.writeInt(value.length) value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { writeType(out, "raw") out.writeInt(value.length) value.foreach(v => writeBytes(out, v)) diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala index c0de9c7bc..afde1931e 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -15,20 +15,19 @@ class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int def call(batchDF: DataFrame, batchId: Long): Unit = callbackClient.send( callbackId, - dos => { - SerDe.writeJObj(dos, batchDF) - SerDe.writeLong(dos, batchId) + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) }) } object DotnetForeachBatchHelper { - def callForeachBatch(dsw: DataStreamWriter[Row], callbackId: Int): Unit = { - val callbackClient = DotnetBackend.callbackClient - if (callbackClient == null) { - throw new Exception("DotnetBackend.callbackClient is null.") + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") } - val dotnetForeachFunc = new DotnetForeachBatchFunction(callbackClient, callbackId) dsw.foreachBatch(dotnetForeachFunc.call _) } } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala deleted file mode 100644 index 230042b8a..000000000 --- a/src/scala/microsoft-spark-2-4/src/test/scala/com/microsoft/scala/AppTest.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package com.microsoft.scala - -import org.junit._ -import Assert._ - -@Test -class AppTest { - - @Test - def testOK() = assertTrue(true) - -// @Test -// def testKO() = assertTrue(false) - -} - - diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..7d721965a --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions.DataInputStreamExt +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var sut: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + sut = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + m.writeInt(1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = sut.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readN(1), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..3acdbf274 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var sut: DotnetBackend = _ + + @Before + def before(): Unit = { + sut = new DotnetBackend + } + + @After + def after(): Unit = { + sut.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // specifying port = 0 to select port dynamically + sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(sut.callbackClient.isDefined) + assertThrows( + classOf[Exception], + new ThrowingRunnable { + override def run(): Unit = { + sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + } + }) + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..53b2705b0 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readN(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..1199114a6 --- /dev/null +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.function.ThrowingRunnable +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConverter} + +@Test +class SerDeTest { + private var sut: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + sut = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, sut.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + sut.readObject(input) + } + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, sut.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, sut.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, sut.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, sut.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, sut.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + Map( + 11 -> true, + 22 -> 42.42, + 33 -> null).asJava, + sut.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(Map().asJava, sut.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), sut.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), sut.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), sut.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows( + classOf[IllegalArgumentException], + new ThrowingRunnable { + override def run(): Unit = { + sut.readObject(input) + } + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), sut.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, sut.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows( + classOf[NoSuchElementException], + new ThrowingRunnable { + override def run(): Unit = { + sut.readObject(input) + } + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99))).asJava, + sut.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), sut.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + sut.writeObject(out, null) + sut.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + sut.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readN(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + sut.writeObject(out, 42.24f.asInstanceOf[Object]) + sut.writeObject(out, 42L.asInstanceOf[Object]) + sut.writeObject(out, 42.asInstanceOf[Object]) + sut.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + sut.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readN(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + sut.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + sut.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readN(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream(in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} From 04a2fafcccb5f829144aae562288ba6c404a3387 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Tue, 12 Jan 2021 23:43:32 +0300 Subject: [PATCH 13/15] Address code review comments --- .../spark/api/dotnet/DotnetBackend.scala | 12 +-- .../org/apache/spark/api/dotnet/SerDe.scala | 4 +- .../api/dotnet/DotnetBackendHandlerTest.scala | 6 +- .../spark/api/dotnet/DotnetBackendTest.scala | 14 ++-- .../apache/spark/api/dotnet/Extensions.scala | 19 ----- .../apache/spark/api/dotnet/SerDeTest.scala | 79 ++++++++++--------- .../spark/api/dotnet/DotnetBackend.scala | 12 +-- .../org/apache/spark/api/dotnet/SerDe.scala | 4 +- .../api/dotnet/DotnetBackendHandlerTest.scala | 8 +- .../spark/api/dotnet/DotnetBackendTest.scala | 14 ++-- .../apache/spark/api/dotnet/Extensions.scala | 4 +- .../apache/spark/api/dotnet/SerDeTest.scala | 72 ++++++++--------- .../spark/api/dotnet/DotnetBackend.scala | 8 +- .../org/apache/spark/api/dotnet/SerDe.scala | 2 +- .../api/dotnet/DotnetBackendHandlerTest.scala | 8 +- .../spark/api/dotnet/DotnetBackendTest.scala | 14 ++-- .../apache/spark/api/dotnet/Extensions.scala | 4 +- .../apache/spark/api/dotnet/SerDeTest.scala | 72 ++++++++--------- 18 files changed, 171 insertions(+), 185 deletions(-) delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index 0bd08745a..1d8215d44 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -30,7 +30,7 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ - private[this] val objectsTracker = new JVMObjectTracker + private[this] val objectTracker = new JVMObjectTracker @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None @@ -59,7 +59,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self, objectsTracker)) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) } }) @@ -68,16 +68,16 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } - private[dotnet] def setCallbackClient(address: String, port: Int): Unit = { + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { callbackClient = callbackClient match { case Some(_) => throw new Exception("Callback client already set.") case None => logInfo(s"Connecting to a callback server at $address:$port") - Some(new CallbackClient(new SerDe(objectsTracker), address, port)) + Some(new CallbackClient(new SerDe(objectTracker), address, port)) } } - private[dotnet] def shutdownCallbackClient(): Unit = { + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { callbackClient match { case Some(client) => client.shutdown() case None => logInfo("Callback server has already been shutdown.") @@ -103,7 +103,7 @@ class DotnetBackend extends Logging { } bootstrap = null - objectsTracker.clear() + objectTracker.clear() // Send close to .NET callback server. shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 6cf1ddc73..44cad97c1 100644 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -15,10 +15,10 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ /** - * Functions to serialize and deserialize between CLR & JVM. + * Class responsible for serialization and deserialization between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -class SerDe(var tracker: JVMObjectTracker) { +class SerDe(val tracker: JVMObjectTracker) { def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index 6c9b6d2f2..990887276 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -15,13 +15,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da class DotnetBackendHandlerTest { private var backend: DotnetBackend = _ private var tracker: JVMObjectTracker = _ - private var sut: DotnetBackendHandler = _ + private var handler: DotnetBackendHandler = _ @Before def before(): Unit = { backend = new DotnetBackend tracker = new JVMObjectTracker - sut = new DotnetBackendHandler(backend, tracker) + handler = new DotnetBackendHandler(backend, tracker) } @After @@ -44,7 +44,7 @@ class DotnetBackendHandlerTest { m.writeInt(0) // 2nd argument value (port) }) - val payload = sut.handleBackendRequest(message) + val payload = handler.handleBackendRequest(message) val reply = new DataInputStream(new ByteArrayInputStream(payload)) assertEquals( diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 3acdbf274..1abf10e20 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -14,29 +14,29 @@ import java.net.InetAddress @Test class DotnetBackendTest { - private var sut: DotnetBackend = _ + private var backend: DotnetBackend = _ @Before def before(): Unit = { - sut = new DotnetBackend + backend = new DotnetBackend } @After def after(): Unit = { - sut.close() + backend.close() } @Test def shouldNotResetCallbackClient(): Unit = { - // specifying port = 0 to select port dynamically - sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) - assertTrue(sut.callbackClient.isDefined) + assertTrue(backend.callbackClient.isDefined) assertThrows( classOf[Exception], new ThrowingRunnable { override def run(): Unit = { - sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) } }) } diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala deleted file mode 100644 index 53b2705b0..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.io.DataInputStream - -object Extensions { - implicit class DataInputStreamExt(stream: DataInputStream) { - def readN(n: Int): Array[Byte] = { - val buf = new Array[Byte](n) - stream.readFully(buf) - buf - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index 1199114a6..f5a392775 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -6,7 +6,6 @@ package org.apache.spark.api.dotnet -import org.apache.spark.api.dotnet.Extensions._ import org.apache.spark.sql.Row import org.junit.Assert._ import org.junit.function.ThrowingRunnable @@ -18,13 +17,13 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConv @Test class SerDeTest { - private var sut: SerDe = _ + private var serDe: SerDe = _ private var tracker: JVMObjectTracker = _ @Before def before(): Unit = { tracker = new JVMObjectTracker - sut = new SerDe(tracker) + serDe = new SerDe(tracker) } @Test @@ -33,7 +32,7 @@ class SerDeTest { in.writeByte('n') }) - assertEquals(null, sut.readObject(input)) + assertEquals(null, serDe.readObject(input)) } @Test @@ -46,7 +45,7 @@ class SerDeTest { classOf[IllegalArgumentException], new ThrowingRunnable { override def run(): Unit = { - sut.readObject(input) + serDe.readObject(input) } }) } @@ -58,7 +57,7 @@ class SerDeTest { in.writeInt(42) }) - assertEquals(42, sut.readObject(input)) + assertEquals(42, serDe.readObject(input)) } @Test @@ -68,7 +67,7 @@ class SerDeTest { in.writeLong(42) }) - assertEquals(42L, sut.readObject(input)) + assertEquals(42L, serDe.readObject(input)) } @Test @@ -78,7 +77,7 @@ class SerDeTest { in.writeDouble(42.42) }) - assertEquals(42.42, sut.readObject(input)) + assertEquals(42.42, serDe.readObject(input)) } @Test @@ -88,7 +87,7 @@ class SerDeTest { in.writeBoolean(true) }) - assertEquals(true, sut.readObject(input)) + assertEquals(true, serDe.readObject(input)) } @Test @@ -100,7 +99,7 @@ class SerDeTest { in.write(payload.getBytes("UTF-8")) }) - assertEquals(payload, sut.readObject(input)) + assertEquals(payload, serDe.readObject(input)) } @Test @@ -126,7 +125,7 @@ class SerDeTest { 11 -> true, 22 -> 42.42, 33 -> null).asJava, - sut.readObject(input)) + serDe.readObject(input)) } @Test @@ -136,7 +135,7 @@ class SerDeTest { in.writeInt(0) // size }) - assertEquals(Map().asJava, sut.readObject(input)) + assertEquals(Map().asJava, serDe.readObject(input)) } @Test @@ -147,7 +146,7 @@ class SerDeTest { in.write(Array[Byte](1, 2, 3)) // payload }) - assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(input).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) } @Test @@ -157,7 +156,7 @@ class SerDeTest { in.writeInt(0) // length }) - assertArrayEquals(Array[Byte](), sut.readObject(input).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) } @Test @@ -168,7 +167,7 @@ class SerDeTest { in.writeInt(0) // length }) - assertArrayEquals(Array[Int](), sut.readObject(input).asInstanceOf[Array[Int]]) + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) } @Test @@ -182,7 +181,7 @@ class SerDeTest { in.writeBoolean(true) }) - assertArrayEquals(Array(true, false, true), sut.readObject(input).asInstanceOf[Array[Boolean]]) + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) } @Test @@ -196,7 +195,7 @@ class SerDeTest { classOf[IllegalArgumentException], new ThrowingRunnable { override def run(): Unit = { - sut.readObject(input) + serDe.readObject(input) } }) } @@ -210,7 +209,7 @@ class SerDeTest { in.write(date.getBytes("UTF-8")) }) - assertEquals(Date.valueOf("2020-12-31"), sut.readObject(input)) + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) } @Test @@ -224,7 +223,7 @@ class SerDeTest { in.write(objectIndex.getBytes("UTF-8")) }) - assertSame(trackingObject, sut.readObject(input)) + assertSame(trackingObject, serDe.readObject(input)) } @Test @@ -240,7 +239,7 @@ class SerDeTest { classOf[NoSuchElementException], new ThrowingRunnable { override def run(): Unit = { - sut.readObject(input) + serDe.readObject(input) } }) } @@ -266,7 +265,7 @@ class SerDeTest { Seq( Row.fromSeq(Seq(11)), Row.fromSeq(Seq(true, 42.24, 99))).asJava, - sut.readObject(input)) + serDe.readObject(input)) } @Test @@ -280,14 +279,14 @@ class SerDeTest { in.writeBoolean(true) }) - assertEquals(Seq(42, true), sut.readObject(input).asInstanceOf[Seq[Any]]) + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) } @Test def shouldWriteNull(): Unit = { val in = whenOutput(out => { - sut.writeObject(out, null) - sut.writeObject(out, Unit) + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) }) assertEquals(in.readByte(), 'n') @@ -299,22 +298,22 @@ class SerDeTest { def shouldWriteString(): Unit = { val sparkDotnet = "Spark Dotnet" val in = whenOutput(out => { - sut.writeObject(out, sparkDotnet) + serDe.writeObject(out, sparkDotnet) }) assertEquals(in.readByte(), 'c') // object type assertEquals(in.readInt(), sparkDotnet.length) // length - assertArrayEquals(in.readN(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertArrayEquals(readNBytes(in, sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) assertEndOfStream(in) } @Test def shouldWritePrimitiveTypes(): Unit = { val in = whenOutput(out => { - sut.writeObject(out, 42.24f.asInstanceOf[Object]) - sut.writeObject(out, 42L.asInstanceOf[Object]) - sut.writeObject(out, 42.asInstanceOf[Object]) - sut.writeObject(out, true.asInstanceOf[Object]) + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) }) assertEquals(in.readByte(), 'd') @@ -332,24 +331,24 @@ class SerDeTest { def shouldWriteDate(): Unit = { val date = "2020-12-31" val in = whenOutput(out => { - sut.writeObject(out, Date.valueOf(date)) + serDe.writeObject(out, Date.valueOf(date)) }) assertEquals(in.readByte(), 'D') // type assertEquals(in.readInt(), 10) // size - assertArrayEquals(in.readN(10), date.getBytes("UTF-8")) // content + assertArrayEquals(readNBytes(in, 10), date.getBytes("UTF-8")) // content } @Test def shouldWriteCustomObjects(): Unit = { val customObject = new Object val in = whenOutput(out => { - sut.writeObject(out, customObject) + serDe.writeObject(out, customObject) }) assertEquals(in.readByte(), 'j') assertEquals(in.readInt(), 1) - assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) + assertArrayEquals(readNBytes(in, 1), "1".getBytes("UTF-8")) assertSame(tracker.get("1").get, customObject) } @@ -357,16 +356,16 @@ class SerDeTest { def shouldWriteArrayOfCustomObjects(): Unit = { val payload = Array(new Object, new Object) val in = whenOutput(out => { - sut.writeObject(out, payload) + serDe.writeObject(out, payload) }) assertEquals(in.readByte(), 'l') // array type assertEquals(in.readByte(), 'j') // type of element in array assertEquals(in.readInt(), 2) // array length assertEquals(in.readInt(), 1) // size of 1st element's identifiers - assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertArrayEquals(readNBytes(in, 1), "1".getBytes("UTF-8")) // identifier of 1st element assertEquals(in.readInt(), 1) // size of 2nd element's identifier - assertArrayEquals(in.readN(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertArrayEquals(readNBytes(in, 1), "2".getBytes("UTF-8")) // identifier of 2nd element assertSame(tracker.get("1").get, payload(0)) assertSame(tracker.get("2").get, payload(1)) } @@ -383,4 +382,10 @@ class SerDeTest { private def assertEndOfStream(in: DataInputStream): Unit = { assertEquals(-1, in.read()) } + + def readNBytes(in: DataInputStream, n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + in.readFully(buf) + buf + } } diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index 0bd08745a..1d8215d44 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -30,7 +30,7 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ - private[this] val objectsTracker = new JVMObjectTracker + private[this] val objectTracker = new JVMObjectTracker @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None @@ -59,7 +59,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self, objectsTracker)) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) } }) @@ -68,16 +68,16 @@ class DotnetBackend extends Logging { channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort } - private[dotnet] def setCallbackClient(address: String, port: Int): Unit = { + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { callbackClient = callbackClient match { case Some(_) => throw new Exception("Callback client already set.") case None => logInfo(s"Connecting to a callback server at $address:$port") - Some(new CallbackClient(new SerDe(objectsTracker), address, port)) + Some(new CallbackClient(new SerDe(objectTracker), address, port)) } } - private[dotnet] def shutdownCallbackClient(): Unit = { + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { callbackClient match { case Some(client) => client.shutdown() case None => logInfo("Callback server has already been shutdown.") @@ -103,7 +103,7 @@ class DotnetBackend extends Logging { } bootstrap = null - objectsTracker.clear() + objectTracker.clear() // Send close to .NET callback server. shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 6cf1ddc73..44cad97c1 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -15,10 +15,10 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ /** - * Functions to serialize and deserialize between CLR & JVM. + * Class responsible for serialization and deserialization between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ -class SerDe(var tracker: JVMObjectTracker) { +class SerDe(val tracker: JVMObjectTracker) { def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index 7d721965a..79c32d6dc 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -16,13 +16,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da class DotnetBackendHandlerTest { private var backend: DotnetBackend = _ private var tracker: JVMObjectTracker = _ - private var sut: DotnetBackendHandler = _ + private var handler: DotnetBackendHandler = _ @Before def before(): Unit = { backend = new DotnetBackend tracker = new JVMObjectTracker - sut = new DotnetBackendHandler(backend, tracker) + handler = new DotnetBackendHandler(backend, tracker) } @After @@ -45,14 +45,14 @@ class DotnetBackendHandlerTest { m.writeInt(0) // 2nd argument value (port) }) - val payload = sut.handleBackendRequest(message) + val payload = handler.handleBackendRequest(message) val reply = new DataInputStream(new ByteArrayInputStream(payload)) assertEquals( "status code must be successful.", 0, reply.readInt()) assertEquals('j', reply.readByte()) assertEquals(1, reply.readInt()) - val trackingId = new String(reply.readN(1), "UTF-8") + val trackingId = new String(reply.readNBytes(1), "UTF-8") assertEquals("1", trackingId) val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull assertEquals(classOf[CallbackClient], client.getClass) diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index 3acdbf274..1abf10e20 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -14,29 +14,29 @@ import java.net.InetAddress @Test class DotnetBackendTest { - private var sut: DotnetBackend = _ + private var backend: DotnetBackend = _ @Before def before(): Unit = { - sut = new DotnetBackend + backend = new DotnetBackend } @After def after(): Unit = { - sut.close() + backend.close() } @Test def shouldNotResetCallbackClient(): Unit = { - // specifying port = 0 to select port dynamically - sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) - assertTrue(sut.callbackClient.isDefined) + assertTrue(backend.callbackClient.isDefined) assertThrows( classOf[Exception], new ThrowingRunnable { override def run(): Unit = { - sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) } }) } diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala index 53b2705b0..8c6e51608 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -8,9 +8,9 @@ package org.apache.spark.api.dotnet import java.io.DataInputStream -object Extensions { +private[dotnet] object Extensions { implicit class DataInputStreamExt(stream: DataInputStream) { - def readN(n: Int): Array[Byte] = { + def readNBytes(n: Int): Array[Byte] = { val buf = new Array[Byte](n) stream.readFully(buf) buf diff --git a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index 1199114a6..78ca905bb 100644 --- a/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-2-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -18,13 +18,13 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConv @Test class SerDeTest { - private var sut: SerDe = _ + private var serDe: SerDe = _ private var tracker: JVMObjectTracker = _ @Before def before(): Unit = { tracker = new JVMObjectTracker - sut = new SerDe(tracker) + serDe = new SerDe(tracker) } @Test @@ -33,7 +33,7 @@ class SerDeTest { in.writeByte('n') }) - assertEquals(null, sut.readObject(input)) + assertEquals(null, serDe.readObject(input)) } @Test @@ -46,7 +46,7 @@ class SerDeTest { classOf[IllegalArgumentException], new ThrowingRunnable { override def run(): Unit = { - sut.readObject(input) + serDe.readObject(input) } }) } @@ -58,7 +58,7 @@ class SerDeTest { in.writeInt(42) }) - assertEquals(42, sut.readObject(input)) + assertEquals(42, serDe.readObject(input)) } @Test @@ -68,7 +68,7 @@ class SerDeTest { in.writeLong(42) }) - assertEquals(42L, sut.readObject(input)) + assertEquals(42L, serDe.readObject(input)) } @Test @@ -78,7 +78,7 @@ class SerDeTest { in.writeDouble(42.42) }) - assertEquals(42.42, sut.readObject(input)) + assertEquals(42.42, serDe.readObject(input)) } @Test @@ -88,7 +88,7 @@ class SerDeTest { in.writeBoolean(true) }) - assertEquals(true, sut.readObject(input)) + assertEquals(true, serDe.readObject(input)) } @Test @@ -100,7 +100,7 @@ class SerDeTest { in.write(payload.getBytes("UTF-8")) }) - assertEquals(payload, sut.readObject(input)) + assertEquals(payload, serDe.readObject(input)) } @Test @@ -126,7 +126,7 @@ class SerDeTest { 11 -> true, 22 -> 42.42, 33 -> null).asJava, - sut.readObject(input)) + serDe.readObject(input)) } @Test @@ -136,7 +136,7 @@ class SerDeTest { in.writeInt(0) // size }) - assertEquals(Map().asJava, sut.readObject(input)) + assertEquals(Map().asJava, serDe.readObject(input)) } @Test @@ -147,7 +147,7 @@ class SerDeTest { in.write(Array[Byte](1, 2, 3)) // payload }) - assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(input).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) } @Test @@ -157,7 +157,7 @@ class SerDeTest { in.writeInt(0) // length }) - assertArrayEquals(Array[Byte](), sut.readObject(input).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) } @Test @@ -168,7 +168,7 @@ class SerDeTest { in.writeInt(0) // length }) - assertArrayEquals(Array[Int](), sut.readObject(input).asInstanceOf[Array[Int]]) + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) } @Test @@ -182,7 +182,7 @@ class SerDeTest { in.writeBoolean(true) }) - assertArrayEquals(Array(true, false, true), sut.readObject(input).asInstanceOf[Array[Boolean]]) + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) } @Test @@ -196,7 +196,7 @@ class SerDeTest { classOf[IllegalArgumentException], new ThrowingRunnable { override def run(): Unit = { - sut.readObject(input) + serDe.readObject(input) } }) } @@ -210,7 +210,7 @@ class SerDeTest { in.write(date.getBytes("UTF-8")) }) - assertEquals(Date.valueOf("2020-12-31"), sut.readObject(input)) + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) } @Test @@ -224,7 +224,7 @@ class SerDeTest { in.write(objectIndex.getBytes("UTF-8")) }) - assertSame(trackingObject, sut.readObject(input)) + assertSame(trackingObject, serDe.readObject(input)) } @Test @@ -240,7 +240,7 @@ class SerDeTest { classOf[NoSuchElementException], new ThrowingRunnable { override def run(): Unit = { - sut.readObject(input) + serDe.readObject(input) } }) } @@ -266,7 +266,7 @@ class SerDeTest { Seq( Row.fromSeq(Seq(11)), Row.fromSeq(Seq(true, 42.24, 99))).asJava, - sut.readObject(input)) + serDe.readObject(input)) } @Test @@ -280,14 +280,14 @@ class SerDeTest { in.writeBoolean(true) }) - assertEquals(Seq(42, true), sut.readObject(input).asInstanceOf[Seq[Any]]) + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) } @Test def shouldWriteNull(): Unit = { val in = whenOutput(out => { - sut.writeObject(out, null) - sut.writeObject(out, Unit) + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) }) assertEquals(in.readByte(), 'n') @@ -299,22 +299,22 @@ class SerDeTest { def shouldWriteString(): Unit = { val sparkDotnet = "Spark Dotnet" val in = whenOutput(out => { - sut.writeObject(out, sparkDotnet) + serDe.writeObject(out, sparkDotnet) }) assertEquals(in.readByte(), 'c') // object type assertEquals(in.readInt(), sparkDotnet.length) // length - assertArrayEquals(in.readN(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) assertEndOfStream(in) } @Test def shouldWritePrimitiveTypes(): Unit = { val in = whenOutput(out => { - sut.writeObject(out, 42.24f.asInstanceOf[Object]) - sut.writeObject(out, 42L.asInstanceOf[Object]) - sut.writeObject(out, 42.asInstanceOf[Object]) - sut.writeObject(out, true.asInstanceOf[Object]) + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) }) assertEquals(in.readByte(), 'd') @@ -332,24 +332,24 @@ class SerDeTest { def shouldWriteDate(): Unit = { val date = "2020-12-31" val in = whenOutput(out => { - sut.writeObject(out, Date.valueOf(date)) + serDe.writeObject(out, Date.valueOf(date)) }) assertEquals(in.readByte(), 'D') // type assertEquals(in.readInt(), 10) // size - assertArrayEquals(in.readN(10), date.getBytes("UTF-8")) // content + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content } @Test def shouldWriteCustomObjects(): Unit = { val customObject = new Object val in = whenOutput(out => { - sut.writeObject(out, customObject) + serDe.writeObject(out, customObject) }) assertEquals(in.readByte(), 'j') assertEquals(in.readInt(), 1) - assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) assertSame(tracker.get("1").get, customObject) } @@ -357,16 +357,16 @@ class SerDeTest { def shouldWriteArrayOfCustomObjects(): Unit = { val payload = Array(new Object, new Object) val in = whenOutput(out => { - sut.writeObject(out, payload) + serDe.writeObject(out, payload) }) assertEquals(in.readByte(), 'l') // array type assertEquals(in.readByte(), 'j') // type of element in array assertEquals(in.readInt(), 2) // array length assertEquals(in.readInt(), 1) // size of 1st element's identifiers - assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element assertEquals(in.readInt(), 1) // size of 2nd element's identifier - assertArrayEquals(in.readN(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element assertSame(tracker.get("1").get, payload(0)) assertSame(tracker.get("2").get, payload(1)) } diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala index d058653cb..c6f528aee 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -29,7 +29,7 @@ class DotnetBackend extends Logging { private[this] var channelFuture: ChannelFuture = _ private[this] var bootstrap: ServerBootstrap = _ private[this] var bossGroup: EventLoopGroup = _ - private[this] val objectsTracker = new JVMObjectTracker + private[this] val objectTracker = new JVMObjectTracker @volatile private[dotnet] var callbackClient: Option[CallbackClient] = None @@ -58,7 +58,7 @@ class DotnetBackend extends Logging { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self, objectsTracker)) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) } }) @@ -72,7 +72,7 @@ class DotnetBackend extends Logging { case Some(_) => throw new Exception("Callback client already set.") case None => logInfo(s"Connecting to a callback server at $address:$port") - Some(new CallbackClient(new SerDe(objectsTracker), address, port)) + Some(new CallbackClient(new SerDe(objectTracker), address, port)) } } @@ -102,7 +102,7 @@ class DotnetBackend extends Logging { } bootstrap = null - objectsTracker.clear() + objectTracker.clear() // Send close to .NET callback server. shutdownCallbackClient() diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 838f7fccf..a3df3788a 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.Row import scala.collection.JavaConverters._ /** - * Functions to serialize and deserialize between CLR & JVM. + * Class responsible for serialization and deserialization between CLR & JVM. * This implementation of methods is mostly identical to the SerDe implementation in R. */ class SerDe(val tracker: JVMObjectTracker) { diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala index 231c2fecf..672455349 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -17,13 +17,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da class DotnetBackendHandlerTest { private var backend: DotnetBackend = _ private var tracker: JVMObjectTracker = _ - private var sut: DotnetBackendHandler = _ + private var handler: DotnetBackendHandler = _ @Before def before(): Unit = { backend = new DotnetBackend tracker = new JVMObjectTracker - sut = new DotnetBackendHandler(backend, tracker) + handler = new DotnetBackendHandler(backend, tracker) } @After @@ -46,14 +46,14 @@ class DotnetBackendHandlerTest { m.writeInt(0) // 2nd argument value (port) }) - val payload = sut.handleBackendRequest(message) + val payload = handler.handleBackendRequest(message) val reply = new DataInputStream(new ByteArrayInputStream(payload)) assertEquals( "status code must be successful.", 0, reply.readInt()) assertEquals('j', reply.readByte()) assertEquals(1, reply.readInt()) - val trackingId = new String(reply.readN(1), "UTF-8") + val trackingId = new String(reply.readNBytes(1), "UTF-8") assertEquals("1", trackingId) val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull assertEquals(classOf[CallbackClient], client.getClass) diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala index bf4001c3b..445486bbd 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -14,26 +14,26 @@ import java.net.InetAddress @Test class DotnetBackendTest { - private var sut: DotnetBackend = _ + private var backend: DotnetBackend = _ @Before def before(): Unit = { - sut = new DotnetBackend + backend = new DotnetBackend } @After def after(): Unit = { - sut.close() + backend.close() } @Test def shouldNotResetCallbackClient(): Unit = { - // specifying port = 0 to select port dynamically - sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) - assertTrue(sut.callbackClient.isDefined) + assertTrue(backend.callbackClient.isDefined) assertThrows(classOf[Exception], () => { - sut.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) }) } } diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala index 704c248b9..c6904403b 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -9,9 +9,9 @@ package org.apache.spark.api.dotnet import java.io.DataInputStream -object Extensions { +private[dotnet] object Extensions { implicit class DataInputStreamExt(stream: DataInputStream) { - def readN(n: Int): Array[Byte] = { + def readNBytes(n: Int): Array[Byte] = { val buf = new Array[Byte](n) stream.readFully(buf) buf diff --git a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index 4343d8e39..41401d680 100644 --- a/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-3-0/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -17,13 +17,13 @@ import scala.collection.JavaConverters._ @Test class SerDeTest { - private var sut: SerDe = _ + private var serDe: SerDe = _ private var tracker: JVMObjectTracker = _ @Before def before(): Unit = { tracker = new JVMObjectTracker - sut = new SerDe(tracker) + serDe = new SerDe(tracker) } @Test @@ -32,7 +32,7 @@ class SerDeTest { in.writeByte('n') }) - assertEquals(null, sut.readObject(input)) + assertEquals(null, serDe.readObject(input)) } @Test @@ -42,7 +42,7 @@ class SerDeTest { }) assertThrows(classOf[IllegalArgumentException], () => { - sut.readObject(input) + serDe.readObject(input) }) } @@ -53,7 +53,7 @@ class SerDeTest { in.writeInt(42) }) - assertEquals(42, sut.readObject(input)) + assertEquals(42, serDe.readObject(input)) } @Test @@ -63,7 +63,7 @@ class SerDeTest { in.writeLong(42) }) - assertEquals(42L, sut.readObject(input)) + assertEquals(42L, serDe.readObject(input)) } @Test @@ -73,7 +73,7 @@ class SerDeTest { in.writeDouble(42.42) }) - assertEquals(42.42, sut.readObject(input)) + assertEquals(42.42, serDe.readObject(input)) } @Test @@ -83,7 +83,7 @@ class SerDeTest { in.writeBoolean(true) }) - assertEquals(true, sut.readObject(input)) + assertEquals(true, serDe.readObject(input)) } @Test @@ -95,7 +95,7 @@ class SerDeTest { in.write(payload.getBytes("UTF-8")) }) - assertEquals(payload, sut.readObject(input)) + assertEquals(payload, serDe.readObject(input)) } @Test @@ -121,7 +121,7 @@ class SerDeTest { 11 -> true, 22 -> 42.42, 33 -> null)), - sut.readObject(input)) + serDe.readObject(input)) } @Test @@ -131,7 +131,7 @@ class SerDeTest { in.writeInt(0) // size }) - assertEquals(mapAsJavaMap(Map()), sut.readObject(input)) + assertEquals(mapAsJavaMap(Map()), serDe.readObject(input)) } @Test @@ -142,7 +142,7 @@ class SerDeTest { in.write(Array[Byte](1, 2, 3)) // payload }) - assertArrayEquals(Array[Byte](1, 2, 3), sut.readObject(input).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) } @Test @@ -152,7 +152,7 @@ class SerDeTest { in.writeInt(0) // length }) - assertArrayEquals(Array[Byte](), sut.readObject(input).asInstanceOf[Array[Byte]]) + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) } @Test @@ -163,7 +163,7 @@ class SerDeTest { in.writeInt(0) // length }) - assertArrayEquals(Array[Int](), sut.readObject(input).asInstanceOf[Array[Int]]) + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) } @Test @@ -177,7 +177,7 @@ class SerDeTest { in.writeBoolean(true) }) - assertArrayEquals(Array(true, false, true), sut.readObject(input).asInstanceOf[Array[Boolean]]) + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) } @Test @@ -188,7 +188,7 @@ class SerDeTest { }) assertThrows(classOf[IllegalArgumentException], () => { - sut.readObject(input) + serDe.readObject(input) }) } @@ -201,7 +201,7 @@ class SerDeTest { in.write(date.getBytes("UTF-8")) }) - assertEquals(Date.valueOf("2020-12-31"), sut.readObject(input)) + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) } @Test @@ -215,7 +215,7 @@ class SerDeTest { in.write(objectIndex.getBytes("UTF-8")) }) - assertSame(trackingObject, sut.readObject(input)) + assertSame(trackingObject, serDe.readObject(input)) } @Test @@ -228,7 +228,7 @@ class SerDeTest { }) assertThrows(classOf[NoSuchElementException], () => { - sut.readObject(input) + serDe.readObject(input) }) } @@ -253,7 +253,7 @@ class SerDeTest { seqAsJavaList(Seq( Row.fromSeq(Seq(11)), Row.fromSeq(Seq(true, 42.24, 99)))), - sut.readObject(input)) + serDe.readObject(input)) } @Test @@ -267,14 +267,14 @@ class SerDeTest { in.writeBoolean(true) }) - assertEquals(Seq(42, true), sut.readObject(input).asInstanceOf[Seq[Any]]) + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) } @Test def shouldWriteNull(): Unit = { val in = whenOutput(out => { - sut.writeObject(out, null) - sut.writeObject(out, Unit) + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) }) assertEquals(in.readByte(), 'n') @@ -286,22 +286,22 @@ class SerDeTest { def shouldWriteString(): Unit = { val sparkDotnet = "Spark Dotnet" val in = whenOutput(out => { - sut.writeObject(out, sparkDotnet) + serDe.writeObject(out, sparkDotnet) }) assertEquals(in.readByte(), 'c') // object type assertEquals(in.readInt(), sparkDotnet.length) // length - assertArrayEquals(in.readN(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) assertEndOfStream(in) } @Test def shouldWritePrimitiveTypes(): Unit = { val in = whenOutput(out => { - sut.writeObject(out, 42.24f.asInstanceOf[Object]) - sut.writeObject(out, 42L.asInstanceOf[Object]) - sut.writeObject(out, 42.asInstanceOf[Object]) - sut.writeObject(out, true.asInstanceOf[Object]) + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) }) assertEquals(in.readByte(), 'd') @@ -319,24 +319,24 @@ class SerDeTest { def shouldWriteDate(): Unit = { val date = "2020-12-31" val in = whenOutput(out => { - sut.writeObject(out, Date.valueOf(date)) + serDe.writeObject(out, Date.valueOf(date)) }) assertEquals(in.readByte(), 'D') // type assertEquals(in.readInt(), 10) // size - assertArrayEquals(in.readN(10), date.getBytes("UTF-8")) // content + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content } @Test def shouldWriteCustomObjects(): Unit = { val customObject = new Object val in = whenOutput(out => { - sut.writeObject(out, customObject) + serDe.writeObject(out, customObject) }) assertEquals(in.readByte(), 'j') assertEquals(in.readInt(), 1) - assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) assertSame(tracker.get("1").get, customObject) } @@ -344,16 +344,16 @@ class SerDeTest { def shouldWriteArrayOfCustomObjects(): Unit = { val payload = Array(new Object, new Object) val in = whenOutput(out => { - sut.writeObject(out, payload) + serDe.writeObject(out, payload) }) assertEquals(in.readByte(), 'l') // array type assertEquals(in.readByte(), 'j') // type of element in array assertEquals(in.readInt(), 2) // array length assertEquals(in.readInt(), 1) // size of 1st element's identifiers - assertArrayEquals(in.readN(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element assertEquals(in.readInt(), 1) // size of 2nd element's identifier - assertArrayEquals(in.readN(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element assertSame(tracker.get("1").get, payload(0)) assertSame(tracker.get("2").get, payload(1)) } From 8f1d0bed868e64f2aaf0da73777c70f044cc2dd2 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Wed, 13 Jan 2021 00:10:54 +0300 Subject: [PATCH 14/15] Adding dots for the code comments --- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index bdac858a6..d95b18313 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -87,7 +87,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack // Sends reference of CallbackClient to dotnet side, // so that dotnet process can send the client back to Java side - // when calling any API containing callback functions + // when calling any API containing callback functions. serDe.writeObject(dos, server.callbackClient) case "closeCallback" => logInfo("Requesting to close callback client") diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 30c002471..29729657b 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -88,7 +88,7 @@ class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTrack // Sends reference of CallbackClient to dotnet side, // so that dotnet process can send the client back to Java side - // when calling any API containing callback functions + // when calling any API containing callback functions. serDe.writeObject(dos, server.callbackClient) case "closeCallback" => logInfo("Requesting to close callback client") From d98ac6c46712a757f9905bea06dc00b9ada26047 Mon Sep 17 00:00:00 2001 From: Aleksandr Popitich Date: Wed, 13 Jan 2021 20:03:22 +0300 Subject: [PATCH 15/15] Adding Extensions.scala back for 2.3 version. --- .../apache/spark/api/dotnet/Extensions.scala | 19 +++++++++++++++++++ .../apache/spark/api/dotnet/SerDeTest.scala | 17 ++++++----------- 2 files changed, 25 insertions(+), 11 deletions(-) create mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..8c6e51608 --- /dev/null +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala index f5a392775..78ca905bb 100644 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -6,6 +6,7 @@ package org.apache.spark.api.dotnet +import org.apache.spark.api.dotnet.Extensions._ import org.apache.spark.sql.Row import org.junit.Assert._ import org.junit.function.ThrowingRunnable @@ -303,7 +304,7 @@ class SerDeTest { assertEquals(in.readByte(), 'c') // object type assertEquals(in.readInt(), sparkDotnet.length) // length - assertArrayEquals(readNBytes(in, sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) assertEndOfStream(in) } @@ -336,7 +337,7 @@ class SerDeTest { assertEquals(in.readByte(), 'D') // type assertEquals(in.readInt(), 10) // size - assertArrayEquals(readNBytes(in, 10), date.getBytes("UTF-8")) // content + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content } @Test @@ -348,7 +349,7 @@ class SerDeTest { assertEquals(in.readByte(), 'j') assertEquals(in.readInt(), 1) - assertArrayEquals(readNBytes(in, 1), "1".getBytes("UTF-8")) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) assertSame(tracker.get("1").get, customObject) } @@ -363,9 +364,9 @@ class SerDeTest { assertEquals(in.readByte(), 'j') // type of element in array assertEquals(in.readInt(), 2) // array length assertEquals(in.readInt(), 1) // size of 1st element's identifiers - assertArrayEquals(readNBytes(in, 1), "1".getBytes("UTF-8")) // identifier of 1st element + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element assertEquals(in.readInt(), 1) // size of 2nd element's identifier - assertArrayEquals(readNBytes(in, 1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element assertSame(tracker.get("1").get, payload(0)) assertSame(tracker.get("2").get, payload(1)) } @@ -382,10 +383,4 @@ class SerDeTest { private def assertEndOfStream(in: DataInputStream): Unit = { assertEquals(-1, in.read()) } - - def readNBytes(in: DataInputStream, n: Int): Array[Byte] = { - val buf = new Array[Byte](n) - in.readFully(buf) - buf - } }