Skip to content

Commit ebf2a64

Browse files
committed
Merge branch 'tpch-benchmark' of github.com:octaviansima/opaque into tpch-benchmark
2 parents 17f82fa + 26003b0 commit ebf2a64

File tree

4 files changed

+81
-55
lines changed

4 files changed

+81
-55
lines changed

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/Benchmark.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ object Benchmark {
4343
val spark = SparkSession.builder()
4444
.appName("Benchmark")
4545
.getOrCreate()
46-
var numPartitions = 2 * spark.sparkContext
47-
.getConf
48-
.getInt("spark.executor.instances", 2)
46+
var numPartitions = spark.sparkContext.defaultParallelism
4947
var size = "sf_small"
5048

5149
def dataDir: String = {

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package edu.berkeley.cs.rise.opaque.benchmark
1919

20+
import java.io.File
2021
import scala.io.Source
2122

2223
import org.apache.spark.sql.DataFrame
@@ -162,7 +163,7 @@ object TPCH {
162163
.option("delimiter", "|")
163164
.load(s"${Benchmark.dataDir}/tpch/$size/customer.tbl")
164165

165-
def generateMap(
166+
def generateDFs(
166167
sqlContext: SQLContext, size: String)
167168
: Map[String, DataFrame] = {
168169
Map("part" -> part(sqlContext, size),
@@ -175,38 +176,70 @@ object TPCH {
175176
"customer" -> customer(sqlContext, size)
176177
),
177178
}
178-
179-
def apply(sqlContext: SQLContext, size: String) : TPCH = {
180-
val tpch = new TPCH(sqlContext, size)
181-
tpch.tableNames = tableNames
182-
tpch.nameToDF = generateMap(sqlContext, size)
183-
tpch
184-
}
185179
}
186180

187181
class TPCH(val sqlContext: SQLContext, val size: String) {
188182

189-
var tableNames : Seq[String] = Seq()
190-
var nameToDF : Map[String, DataFrame] = Map()
183+
val tableNames = TPCH.tableNames
184+
val nameToDF = TPCH.generateDFs(sqlContext, size)
191185

192-
def setupViews(securityLevel: SecurityLevel, numPartitions: Int) = {
193-
for ((name, df) <- nameToDF) {
194-
Utils.ensureCached(securityLevel.applyTo(df.repartition(numPartitions))).createOrReplaceTempView(name)
195-
}
196-
}
186+
private var numPartitions: Int = -1
187+
private var nameToPath = Map[String, File]()
188+
private var nameToEncryptedPath = Map[String, File]()
197189

198190
def getQuery(queryNumber: Int) : String = {
199191
val queryLocation = sys.env.getOrElse("OPAQUE_HOME", ".") + "/src/test/resources/tpch/"
200192
Source.fromFile(queryLocation + s"q$queryNumber.sql").getLines().mkString("\n")
201193
}
202194

203-
def performQuery(sqlContext: SQLContext, sqlStr: String) : DataFrame = {
204-
sqlContext.sparkSession.sql(sqlStr);
195+
def generateFiles(numPartitions: Int) = {
196+
if (numPartitions != this.numPartitions) {
197+
this.numPartitions = numPartitions
198+
for ((name, df) <- nameToDF) {
199+
nameToPath.get(name).foreach{ path => Utils.deleteRecursively(path) }
200+
201+
nameToPath += (name -> createPath(df, Insecure, numPartitions))
202+
nameToEncryptedPath += (name -> createPath(df, Encrypted, numPartitions))
203+
}
204+
}
205+
}
206+
207+
private def createPath(df: DataFrame, securityLevel: SecurityLevel, numPartitions: Int): File = {
208+
val partitionedDF = securityLevel.applyTo(df.repartition(numPartitions))
209+
val path = Utils.createTempDir()
210+
path.delete()
211+
securityLevel match {
212+
case Insecure => {
213+
partitionedDF.write.format("com.databricks.spark.csv").save(path.toString)
214+
}
215+
case Encrypted => {
216+
partitionedDF.write.format("edu.berkeley.cs.rise.opaque.EncryptedSource").save(path.toString)
217+
}
218+
}
219+
path
220+
}
221+
222+
private def loadViews(securityLevel: SecurityLevel) = {
223+
val (map, formatStr) = if (securityLevel == Insecure)
224+
(nameToPath, "com.databricks.spark.csv") else
225+
(nameToEncryptedPath, "edu.berkeley.cs.rise.opaque.EncryptedSource")
226+
for ((name, path) <- map) {
227+
val df = sqlContext.sparkSession.read
228+
.format(formatStr)
229+
.schema(nameToDF.get(name).get.schema)
230+
.load(path.toString)
231+
df.createOrReplaceTempView(name)
232+
}
233+
}
234+
235+
def performQuery(sqlStr: String, securityLevel: SecurityLevel): DataFrame = {
236+
loadViews(securityLevel)
237+
sqlContext.sparkSession.sql(sqlStr)
205238
}
206239

207-
def query(queryNumber: Int, securityLevel: SecurityLevel, sqlContext: SQLContext, numPartitions: Int) : DataFrame = {
208-
setupViews(securityLevel, numPartitions)
240+
def query(queryNumber: Int, securityLevel: SecurityLevel, numPartitions: Int): DataFrame = {
209241
val sqlStr = getQuery(queryNumber)
210-
performQuery(sqlContext, sqlStr)
242+
generateFiles(numPartitions)
243+
performQuery(sqlStr, securityLevel)
211244
}
212245
}

src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,27 @@ import org.apache.spark.sql.SQLContext
2424
object TPCHBenchmark {
2525
def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = {
2626
val sqlStr = tpch.getQuery(queryNumber)
27+
tpch.generateFiles(numPartitions)
2728

28-
tpch.setupViews(Encrypted, numPartitions)
2929
Utils.timeBenchmark(
3030
"distributed" -> (numPartitions > 1),
3131
"query" -> s"TPC-H $queryNumber",
3232
"system" -> Encrypted.name) {
3333

34-
val df = tpch.performQuery(sqlContext, sqlStr)
35-
Utils.force(df)
36-
df
34+
tpch.performQuery(sqlStr, Insecure).collect
3735
}
3836

39-
tpch.setupViews(Insecure, numPartitions)
4037
Utils.timeBenchmark(
4138
"distributed" -> (numPartitions > 1),
4239
"query" -> s"TPC-H $queryNumber",
4340
"system" -> Insecure.name) {
4441

45-
val df = tpch.performQuery(sqlContext, sqlStr)
46-
Utils.force(df)
47-
df
42+
tpch.performQuery(sqlStr, Encrypted).collect
4843
}
4944
}
5045

5146
def run(sqlContext: SQLContext, numPartitions: Int, size: String) = {
52-
val tpch = TPCH(sqlContext, size)
47+
val tpch = new TPCH(sqlContext, size)
5348

5449
val supportedQueries = Seq(1, 3, 5, 6, 7, 8, 9, 10, 14, 17)
5550

src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,94 +26,94 @@ import edu.berkeley.cs.rise.opaque.benchmark.TPCH
2626
trait TPCHTests extends OpaqueTestsBase { self =>
2727

2828
def size = "sf_small"
29-
def tpch = TPCH(spark.sqlContext, size)
29+
def tpch = new TPCH(spark.sqlContext, size)
3030

3131
testAgainstSpark("TPC-H 1") { securityLevel =>
32-
tpch.query(1, securityLevel, spark.sqlContext, numPartitions).collect
32+
tpch.query(1, securityLevel, numPartitions).collect
3333
}
3434

3535
testAgainstSpark("TPC-H 2", ignore) { securityLevel =>
36-
tpch.query(2, securityLevel, spark.sqlContext, numPartitions).collect
36+
tpch.query(2, securityLevel, numPartitions).collect
3737
}
3838

3939
testAgainstSpark("TPC-H 3") { securityLevel =>
40-
tpch.query(3, securityLevel, spark.sqlContext, numPartitions).collect
40+
tpch.query(3, securityLevel, numPartitions).collect
4141
}
4242

4343
testAgainstSpark("TPC-H 4", ignore) { securityLevel =>
44-
tpch.query(4, securityLevel, spark.sqlContext, numPartitions).collect
44+
tpch.query(4, securityLevel, numPartitions).collect
4545
}
4646

4747
testAgainstSpark("TPC-H 5") { securityLevel =>
48-
tpch.query(5, securityLevel, spark.sqlContext, numPartitions).collect
48+
tpch.query(5, securityLevel, numPartitions).collect
4949
}
5050

5151
testAgainstSpark("TPC-H 6") { securityLevel =>
52-
tpch.query(6, securityLevel, spark.sqlContext, numPartitions).collect.toSet
52+
tpch.query(6, securityLevel, numPartitions).collect.toSet
5353
}
5454

5555
testAgainstSpark("TPC-H 7") { securityLevel =>
56-
tpch.query(7, securityLevel, spark.sqlContext, numPartitions).collect
56+
tpch.query(7, securityLevel, numPartitions).collect
5757
}
5858

5959
testAgainstSpark("TPC-H 8") { securityLevel =>
60-
tpch.query(8, securityLevel, spark.sqlContext, numPartitions).collect
60+
tpch.query(8, securityLevel, numPartitions).collect
6161
}
6262

6363
testAgainstSpark("TPC-H 9") { securityLevel =>
64-
tpch.query(9, securityLevel, spark.sqlContext, numPartitions).collect
64+
tpch.query(9, securityLevel, numPartitions).collect
6565
}
6666

6767
testAgainstSpark("TPC-H 10") { securityLevel =>
68-
tpch.query(10, securityLevel, spark.sqlContext, numPartitions).collect
68+
tpch.query(10, securityLevel, numPartitions).collect
6969
}
7070

7171
testAgainstSpark("TPC-H 11", ignore) { securityLevel =>
72-
tpch.query(11, securityLevel, spark.sqlContext, numPartitions).collect
72+
tpch.query(11, securityLevel, numPartitions).collect
7373
}
7474

7575
testAgainstSpark("TPC-H 12") { securityLevel =>
76-
tpch.query(12, securityLevel, spark.sqlContext, numPartitions).collect
76+
tpch.query(12, securityLevel, numPartitions).collect
7777
}
7878

7979
testAgainstSpark("TPC-H 13", ignore) { securityLevel =>
80-
tpch.query(13, securityLevel, spark.sqlContext, numPartitions).collect
80+
tpch.query(13, securityLevel, numPartitions).collect
8181
}
8282

8383
testAgainstSpark("TPC-H 14") { securityLevel =>
84-
tpch.query(14, securityLevel, spark.sqlContext, numPartitions).collect.toSet
84+
tpch.query(14, securityLevel, numPartitions).collect.toSet
8585
}
8686

8787
testAgainstSpark("TPC-H 15", ignore) { securityLevel =>
88-
tpch.query(15, securityLevel, spark.sqlContext, numPartitions).collect
88+
tpch.query(15, securityLevel, numPartitions).collect
8989
}
9090

9191
testAgainstSpark("TPC-H 16", ignore) { securityLevel =>
92-
tpch.query(16, securityLevel, spark.sqlContext, numPartitions).collect
92+
tpch.query(16, securityLevel, numPartitions).collect
9393
}
9494

9595
testAgainstSpark("TPC-H 17") { securityLevel =>
96-
tpch.query(17, securityLevel, spark.sqlContext, numPartitions).collect.toSet
96+
tpch.query(17, securityLevel, numPartitions).collect.toSet
9797
}
9898

9999
testAgainstSpark("TPC-H 18", ignore) { securityLevel =>
100-
tpch.query(18, securityLevel, spark.sqlContext, numPartitions).collect
100+
tpch.query(18, securityLevel, numPartitions).collect
101101
}
102102

103103
testAgainstSpark("TPC-H 19") { securityLevel =>
104-
tpch.query(19, securityLevel, spark.sqlContext, numPartitions).collect.toSet
104+
tpch.query(19, securityLevel, numPartitions).collect.toSet
105105
}
106106

107107
testAgainstSpark("TPC-H 20") { securityLevel =>
108-
tpch.query(20, securityLevel, spark.sqlContext, numPartitions).collect.toSet
108+
tpch.query(20, securityLevel, numPartitions).collect.toSet
109109
}
110110

111111
testAgainstSpark("TPC-H 21", ignore) { securityLevel =>
112-
tpch.query(21, securityLevel, spark.sqlContext, numPartitions).collect
112+
tpch.query(21, securityLevel, numPartitions).collect
113113
}
114114

115115
testAgainstSpark("TPC-H 22", ignore) { securityLevel =>
116-
tpch.query(22, securityLevel, spark.sqlContext, numPartitions).collect
116+
tpch.query(22, securityLevel, numPartitions).collect
117117
}
118118
}
119119

0 commit comments

Comments
 (0)