diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala index e3227fadbe..e0bb4d4caf 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCH.scala @@ -17,16 +17,21 @@ package edu.berkeley.cs.rise.opaque.benchmark +import scala.io.Source + import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.SQLContext +import edu.berkeley.cs.rise.opaque.Utils + object TPCH { + + val tableNames = Seq("part", "supplier", "lineitem", "partsupp", "orders", "nation", "region", "customer") + def part( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) + sqlContext: SQLContext, size: String) : DataFrame = - securityLevel.applyTo( sqlContext.read.schema( StructType(Seq( StructField("p_partkey", IntegerType), @@ -41,12 +46,10 @@ object TPCH { .format("csv") .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/part.tbl") - .repartition(numPartitions)) def supplier( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) + sqlContext: SQLContext, size: String) : DataFrame = - securityLevel.applyTo( sqlContext.read.schema( StructType(Seq( StructField("s_suppkey", IntegerType), @@ -59,12 +62,10 @@ object TPCH { .format("csv") .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/supplier.tbl") - .repartition(numPartitions)) def lineitem( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) + sqlContext: SQLContext, size: String) : DataFrame = - securityLevel.applyTo( sqlContext.read.schema( StructType(Seq( StructField("l_orderkey", IntegerType), @@ -86,12 +87,10 @@ object TPCH { .format("csv") .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/lineitem.tbl") - .repartition(numPartitions)) def partsupp( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) + sqlContext: SQLContext, size: String) : DataFrame = - securityLevel.applyTo( sqlContext.read.schema( StructType(Seq( StructField("ps_partkey", IntegerType), @@ -102,12 +101,10 @@ object TPCH { .format("csv") .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/partsupp.tbl") - .repartition(numPartitions)) def orders( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) + sqlContext: SQLContext, size: String) : DataFrame = - securityLevel.applyTo( sqlContext.read.schema( StructType(Seq( StructField("o_orderkey", IntegerType), @@ -122,12 +119,10 @@ object TPCH { .format("csv") .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/orders.tbl") - .repartition(numPartitions)) def nation( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) + sqlContext: SQLContext, size: String) : DataFrame = - securityLevel.applyTo( sqlContext.read.schema( StructType(Seq( StructField("n_nationkey", IntegerType), @@ -137,60 +132,85 @@ object TPCH { .format("csv") .option("delimiter", "|") .load(s"${Benchmark.dataDir}/tpch/$size/nation.tbl") - .repartition(numPartitions)) - - - private def tpch9EncryptedDFs( - sqlContext: SQLContext, securityLevel: SecurityLevel, size: String, numPartitions: Int) - : (DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame) = { - val partDF = part(sqlContext, securityLevel, size, numPartitions) - val supplierDF = supplier(sqlContext, securityLevel, size, numPartitions) - val lineitemDF = lineitem(sqlContext, securityLevel, size, numPartitions) - val partsuppDF = partsupp(sqlContext, securityLevel, size, numPartitions) - val ordersDF = orders(sqlContext, securityLevel, size, numPartitions) - val nationDF = nation(sqlContext, securityLevel, size, numPartitions) - (partDF, supplierDF, lineitemDF, partsuppDF, ordersDF, nationDF) + + def region( + sqlContext: SQLContext, size: String) + : DataFrame = + sqlContext.read.schema( + StructType(Seq( + StructField("r_regionkey", IntegerType), + StructField("r_name", StringType), + StructField("r_comment", StringType)))) + .format("csv") + .option("delimiter", "|") + .load(s"${Benchmark.dataDir}/tpch/$size/region.tbl") + + def customer( + sqlContext: SQLContext, size: String) + : DataFrame = + sqlContext.read.schema( + StructType(Seq( + StructField("c_custkey", IntegerType), + StructField("c_name", StringType), + StructField("c_address", StringType), + StructField("c_nationkey", IntegerType), + StructField("c_phone", StringType), + StructField("c_acctbal", FloatType), + StructField("c_mktsegment", StringType), + StructField("c_comment", StringType)))) + .format("csv") + .option("delimiter", "|") + .load(s"${Benchmark.dataDir}/tpch/$size/customer.tbl") + + def generateMap( + sqlContext: SQLContext, size: String) + : Map[String, DataFrame] = { + Map("part" -> part(sqlContext, size), + "supplier" -> supplier(sqlContext, size), + "lineitem" -> lineitem(sqlContext, size), + "partsupp" -> partsupp(sqlContext, size), + "orders" -> orders(sqlContext, size), + "nation" -> nation(sqlContext, size), + "region" -> region(sqlContext, size), + "customer" -> customer(sqlContext, size) + ), + } + + def apply(sqlContext: SQLContext, size: String) : TPCH = { + val tpch = new TPCH(sqlContext, size) + tpch.tableNames = tableNames + tpch.nameToDF = generateMap(sqlContext, size) + tpch.ensureCached() + tpch + } +} + +class TPCH(val sqlContext: SQLContext, val size: String) { + + var tableNames : Seq[String] = Seq() + var nameToDF : Map[String, DataFrame] = Map() + + def ensureCached() = { + for (name <- tableNames) { + nameToDF.get(name).foreach(df => { + Utils.ensureCached(df) + Utils.ensureCached(Encrypted.applyTo(df)) + }) + } + } + + def setupViews(securityLevel: SecurityLevel, numPartitions: Int) = { + for ((name, df) <- nameToDF) { + securityLevel.applyTo(df.repartition(numPartitions)).createOrReplaceTempView(name) + } } - /** TPC-H query 9 - Product Type Profit Measure Query */ - def tpch9( - sqlContext: SQLContext, - securityLevel: SecurityLevel, - size: String, - numPartitions: Int, - quantityThreshold: Option[Int] = None) : DataFrame = { - import sqlContext.implicits._ - val (partDF, supplierDF, lineitemDF, partsuppDF, ordersDF, nationDF) = - tpch9EncryptedDFs(sqlContext, securityLevel, size, numPartitions) - - val df = - ordersDF.select($"o_orderkey", year($"o_orderdate").as("o_year")) // 6. orders - .join( - (nationDF// 4. nation - .join( - supplierDF // 3. supplier - .join( - partDF // 1. part - .filter($"p_name".contains("maroon")) - .select($"p_partkey") - .join(partsuppDF, $"p_partkey" === $"ps_partkey"), // 2. partsupp - $"ps_suppkey" === $"s_suppkey"), - $"s_nationkey" === $"n_nationkey")) - .join( - // 5. lineitem - quantityThreshold match { - case Some(q) => lineitemDF.filter($"l_quantity" > lit(q)) - case None => lineitemDF - }, - $"s_suppkey" === $"l_suppkey" && $"p_partkey" === $"l_partkey"), - $"l_orderkey" === $"o_orderkey") - .select( - $"n_name", - $"o_year", - ($"l_extendedprice" * (lit(1) - $"l_discount") - $"ps_supplycost" * $"l_quantity") - .as("amount")) - .groupBy("n_name", "o_year").agg(sum($"amount").as("sum_profit")) - - df - } + def query(queryNumber: Int, securityLevel: SecurityLevel, sqlContext: SQLContext, numPartitions: Int) : DataFrame = { + setupViews(securityLevel, numPartitions) + + val queryLocation = sys.env.getOrElse("OPAQUE_HOME", ".") + "/src/test/resources/tpch/" + val sqlStr = Source.fromFile(queryLocation + s"q$queryNumber.sql").getLines().mkString("\n") + + sqlContext.sparkSession.sql(sqlStr) + } } diff --git a/src/test/resources/tpch/q1.sql b/src/test/resources/tpch/q1.sql new file mode 100644 index 0000000000..73eb8d8417 --- /dev/null +++ b/src/test/resources/tpch/q1.sql @@ -0,0 +1,23 @@ +-- using default substitutions + +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus diff --git a/src/test/resources/tpch/q10.sql b/src/test/resources/tpch/q10.sql new file mode 100644 index 0000000000..3b2ae588de --- /dev/null +++ b/src/test/resources/tpch/q10.sql @@ -0,0 +1,34 @@ +-- using default substitutions + +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20 diff --git a/src/test/resources/tpch/q11.sql b/src/test/resources/tpch/q11.sql new file mode 100644 index 0000000000..531e78c21b --- /dev/null +++ b/src/test/resources/tpch/q11.sql @@ -0,0 +1,29 @@ +-- using default substitutions + +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc diff --git a/src/test/resources/tpch/q12.sql b/src/test/resources/tpch/q12.sql new file mode 100644 index 0000000000..d3e70eb481 --- /dev/null +++ b/src/test/resources/tpch/q12.sql @@ -0,0 +1,30 @@ +-- using default substitutions + +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode diff --git a/src/test/resources/tpch/q13.sql b/src/test/resources/tpch/q13.sql new file mode 100644 index 0000000000..3375002c5f --- /dev/null +++ b/src/test/resources/tpch/q13.sql @@ -0,0 +1,22 @@ +-- using default substitutions + +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc diff --git a/src/test/resources/tpch/q14.sql b/src/test/resources/tpch/q14.sql new file mode 100644 index 0000000000..753ea56891 --- /dev/null +++ b/src/test/resources/tpch/q14.sql @@ -0,0 +1,15 @@ +-- using default substitutions + +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month diff --git a/src/test/resources/tpch/q15.sql b/src/test/resources/tpch/q15.sql new file mode 100644 index 0000000000..64d0b48ec0 --- /dev/null +++ b/src/test/resources/tpch/q15.sql @@ -0,0 +1,35 @@ +-- using default substitutions + +with revenue0 as + (select + l_suppkey as supplier_no, + sum(l_extendedprice * (1 - l_discount)) as total_revenue + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) + + +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey + diff --git a/src/test/resources/tpch/q16.sql b/src/test/resources/tpch/q16.sql new file mode 100644 index 0000000000..a6ac68898e --- /dev/null +++ b/src/test/resources/tpch/q16.sql @@ -0,0 +1,32 @@ +-- using default substitutions + +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size diff --git a/src/test/resources/tpch/q17.sql b/src/test/resources/tpch/q17.sql new file mode 100644 index 0000000000..74fb1f653a --- /dev/null +++ b/src/test/resources/tpch/q17.sql @@ -0,0 +1,19 @@ +-- using default substitutions + +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ) diff --git a/src/test/resources/tpch/q18.sql b/src/test/resources/tpch/q18.sql new file mode 100644 index 0000000000..210fba19ec --- /dev/null +++ b/src/test/resources/tpch/q18.sql @@ -0,0 +1,35 @@ +-- using default substitutions + +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100 \ No newline at end of file diff --git a/src/test/resources/tpch/q19.sql b/src/test/resources/tpch/q19.sql new file mode 100644 index 0000000000..c07327da3a --- /dev/null +++ b/src/test/resources/tpch/q19.sql @@ -0,0 +1,37 @@ +-- using default substitutions + +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) diff --git a/src/test/resources/tpch/q2.sql b/src/test/resources/tpch/q2.sql new file mode 100644 index 0000000000..d0e3b7e13e --- /dev/null +++ b/src/test/resources/tpch/q2.sql @@ -0,0 +1,46 @@ +-- using default substitutions + +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100 diff --git a/src/test/resources/tpch/q20.sql b/src/test/resources/tpch/q20.sql new file mode 100644 index 0000000000..e161d340b9 --- /dev/null +++ b/src/test/resources/tpch/q20.sql @@ -0,0 +1,39 @@ +-- using default substitutions + +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name diff --git a/src/test/resources/tpch/q21.sql b/src/test/resources/tpch/q21.sql new file mode 100644 index 0000000000..fdcdfbcf79 --- /dev/null +++ b/src/test/resources/tpch/q21.sql @@ -0,0 +1,42 @@ +-- using default substitutions + +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100 \ No newline at end of file diff --git a/src/test/resources/tpch/q22.sql b/src/test/resources/tpch/q22.sql new file mode 100644 index 0000000000..1d7706e9a0 --- /dev/null +++ b/src/test/resources/tpch/q22.sql @@ -0,0 +1,39 @@ +-- using default substitutions + +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode diff --git a/src/test/resources/tpch/q3.sql b/src/test/resources/tpch/q3.sql new file mode 100644 index 0000000000..948d6bcf12 --- /dev/null +++ b/src/test/resources/tpch/q3.sql @@ -0,0 +1,25 @@ +-- using default substitutions + +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10 diff --git a/src/test/resources/tpch/q4.sql b/src/test/resources/tpch/q4.sql new file mode 100644 index 0000000000..67330e36a0 --- /dev/null +++ b/src/test/resources/tpch/q4.sql @@ -0,0 +1,23 @@ +-- using default substitutions + +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority diff --git a/src/test/resources/tpch/q5.sql b/src/test/resources/tpch/q5.sql new file mode 100644 index 0000000000..b973e9f0a0 --- /dev/null +++ b/src/test/resources/tpch/q5.sql @@ -0,0 +1,26 @@ +-- using default substitutions + +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc diff --git a/src/test/resources/tpch/q6.sql b/src/test/resources/tpch/q6.sql new file mode 100644 index 0000000000..22294579ee --- /dev/null +++ b/src/test/resources/tpch/q6.sql @@ -0,0 +1,11 @@ +-- using default substitutions + +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 diff --git a/src/test/resources/tpch/q7.sql b/src/test/resources/tpch/q7.sql new file mode 100644 index 0000000000..21105c0519 --- /dev/null +++ b/src/test/resources/tpch/q7.sql @@ -0,0 +1,41 @@ +-- using default substitutions + +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + year(l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year diff --git a/src/test/resources/tpch/q8.sql b/src/test/resources/tpch/q8.sql new file mode 100644 index 0000000000..81d81871c4 --- /dev/null +++ b/src/test/resources/tpch/q8.sql @@ -0,0 +1,39 @@ +-- using default substitutions + +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year diff --git a/src/test/resources/tpch/q9.sql b/src/test/resources/tpch/q9.sql new file mode 100644 index 0000000000..a4e8e8382b --- /dev/null +++ b/src/test/resources/tpch/q9.sql @@ -0,0 +1,34 @@ +-- using default substitutions + +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 77235e6aa5..79e1bee374 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -19,11 +19,8 @@ package edu.berkeley.cs.rise.opaque import java.sql.Timestamp -import scala.collection.mutable import scala.util.Random -import org.apache.log4j.Level -import org.apache.log4j.LogManager import org.apache.spark.SparkException import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dataset @@ -35,10 +32,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval -import org.scalactic.Equality -import org.scalactic.TolerantNumerics -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import edu.berkeley.cs.rise.opaque.benchmark._ import edu.berkeley.cs.rise.opaque.execution.EncryptedBlockRDDScanExec @@ -46,83 +39,14 @@ import edu.berkeley.cs.rise.opaque.expressions.DotProduct.dot import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply.vectormultiply import edu.berkeley.cs.rise.opaque.expressions.VectorSum -trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self => - def spark: SparkSession - def numPartitions: Int +trait OpaqueOperatorTests extends OpaqueTestsBase { self => - protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.spark.sqlContext - } - import testImplicits._ - - override def beforeAll(): Unit = { - Utils.initSQLContext(spark.sqlContext) - } - - override def afterAll(): Unit = { - spark.stop() - } - - private def equalityToArrayEquality[A : Equality](): Equality[Array[A]] = { - new Equality[Array[A]] { - def areEqual(a: Array[A], b: Any): Boolean = { - b match { - case b: Array[_] => - (a.length == b.length - && a.zip(b).forall { - case (x, y) => implicitly[Equality[A]].areEqual(x, y) - }) - case _ => false - } - } - override def toString: String = s"TolerantArrayEquality" - } - } - - // Modify the behavior of === for Double and Array[Double] to use a numeric tolerance - implicit val tolerantDoubleEquality = TolerantNumerics.tolerantDoubleEquality(1e-6) - implicit val tolerantDoubleArrayEquality = equalityToArrayEquality[Double] - - def testAgainstSpark[A : Equality](name: String)(f: SecurityLevel => A): Unit = { - test(name + " - encrypted") { - // The === operator uses implicitly[Equality[A]], which compares Double and Array[Double] - // using the numeric tolerance specified above - assert(f(Encrypted) === f(Insecure)) - } - } - - def testOpaqueOnly(name: String)(f: SecurityLevel => Unit): Unit = { - test(name + " - encrypted") { - f(Encrypted) - } - } - - def testSparkOnly(name: String)(f: SecurityLevel => Unit): Unit = { - test(name + " - Spark") { - f(Insecure) - } - } - - def withLoggingOff[A](f: () => A): A = { - val sparkLoggers = Seq( - "org.apache.spark", - "org.apache.spark.executor.Executor", - "org.apache.spark.scheduler.TaskSetManager") - val logLevels = new mutable.HashMap[String, Level] - for (l <- sparkLoggers) { - logLevels(l) = LogManager.getLogger(l).getLevel - LogManager.getLogger(l).setLevel(Level.OFF) + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext } - try { - f() - } finally { - for (l <- sparkLoggers) { - LogManager.getLogger(l).setLevel(logLevels(l)) - } - } - } + import testImplicits._ - /** Modified from https://stackoverflow.com/questions/33193958/change-nullable-property-of-column-in-spark-dataframe + /** Modified from https://stackoverflow.com/questions/33193958/change-nullable-property-of-column-in-spark-dataframe * and https://stackoverflow.com/questions/32585670/what-is-the-best-way-to-define-custom-methods-on-a-dataframe * Set nullable property of column. * @param cn is the column name to change @@ -884,10 +808,6 @@ trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self => PageRank.run(spark, securityLevel, "256", numPartitions).collect.toSet } - testAgainstSpark("TPC-H 9") { securityLevel => - TPCH.tpch9(spark.sqlContext, securityLevel, "sf_small", numPartitions).collect.toSet - } - testAgainstSpark("big data 1") { securityLevel => BigDataBenchmark.q1(spark, securityLevel, "tiny", numPartitions).collect } @@ -911,20 +831,20 @@ trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self => } -class OpaqueSinglePartitionSuite extends OpaqueOperatorTests { +class OpaqueOperatorSinglePartitionSuite extends OpaqueOperatorTests { override val spark = SparkSession.builder() .master("local[1]") - .appName("QEDSuite") + .appName("OpaqueOperatorSinglePartitionSuite") .config("spark.sql.shuffle.partitions", 1) .getOrCreate() override def numPartitions: Int = 1 } -class OpaqueMultiplePartitionSuite extends OpaqueOperatorTests { +class OpaqueOperatorMultiplePartitionSuite extends OpaqueOperatorTests { override val spark = SparkSession.builder() .master("local[1]") - .appName("QEDSuite") + .appName("OpaqueOperatorMultiplePartitionSuite") .config("spark.sql.shuffle.partitions", 3) .getOrCreate() diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala new file mode 100644 index 0000000000..8117fb8de1 --- /dev/null +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueTestsBase.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package edu.berkeley.cs.rise.opaque + + +import scala.collection.mutable + +import org.apache.log4j.Level +import org.apache.log4j.LogManager +import org.apache.spark.sql.SparkSession +import org.scalactic.TolerantNumerics +import org.scalactic.Equality +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterAll +import org.scalatest.Tag + +import edu.berkeley.cs.rise.opaque.benchmark._ + +trait OpaqueTestsBase extends FunSuite with BeforeAndAfterAll { self => + + def spark: SparkSession + def numPartitions: Int + + override def beforeAll(): Unit = { + Utils.initSQLContext(spark.sqlContext) + } + + override def afterAll(): Unit = { + spark.stop() + } + + // Modify the behavior of === for Double and Array[Double] to use a numeric tolerance + implicit val tolerantDoubleEquality = TolerantNumerics.tolerantDoubleEquality(1e-6) + + def equalityToArrayEquality[A : Equality](): Equality[Array[A]] = { + new Equality[Array[A]] { + def areEqual(a: Array[A], b: Any): Boolean = { + b match { + case b: Array[_] => + (a.length == b.length + && a.zip(b).forall { + case (x, y) => implicitly[Equality[A]].areEqual(x, y) + }) + case _ => false + } + } + override def toString: String = s"TolerantArrayEquality" + } + } + + def testAgainstSpark[A : Equality](name: String, testFunc: (String, Tag*) => ((=> Any) => Unit) = test) + (f: SecurityLevel => A): Unit = { + testFunc(name + " - encrypted") { + // The === operator uses implicitly[Equality[A]], which compares Double and Array[Double] + // using the numeric tolerance specified above + assert(f(Encrypted) === f(Insecure)) + } + } + + def testOpaqueOnly(name: String)(f: SecurityLevel => Unit): Unit = { + test(name + " - encrypted") { + f(Encrypted) + } + } + + def testSparkOnly(name: String)(f: SecurityLevel => Unit): Unit = { + test(name + " - Spark") { + f(Insecure) + } + } + + def withLoggingOff[A](f: () => A): A = { + val sparkLoggers = Seq( + "org.apache.spark", + "org.apache.spark.executor.Executor", + "org.apache.spark.scheduler.TaskSetManager") + val logLevels = new mutable.HashMap[String, Level] + for (l <- sparkLoggers) { + logLevels(l) = LogManager.getLogger(l).getLevel + LogManager.getLogger(l).setLevel(Level.OFF) + } + try { + f() + } finally { + for (l <- sparkLoggers) { + LogManager.getLogger(l).setLevel(logLevels(l)) + } + } + } +} \ No newline at end of file diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala new file mode 100644 index 0000000000..d003c835f3 --- /dev/null +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package edu.berkeley.cs.rise.opaque + + +import org.apache.spark.sql.SparkSession + +import edu.berkeley.cs.rise.opaque.benchmark._ +import edu.berkeley.cs.rise.opaque.benchmark.TPCH + +trait TPCHTests extends OpaqueTestsBase { self => + + def size = "sf_small" + def tpch = TPCH(spark.sqlContext, size) + + testAgainstSpark("TPC-H 1") { securityLevel => + tpch.query(1, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 2", ignore) { securityLevel => + tpch.query(2, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 3") { securityLevel => + tpch.query(3, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 4", ignore) { securityLevel => + tpch.query(4, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 5") { securityLevel => + tpch.query(5, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 6") { securityLevel => + tpch.query(6, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 7") { securityLevel => + tpch.query(7, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 8") { securityLevel => + tpch.query(8, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 9") { securityLevel => + tpch.query(9, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 10") { securityLevel => + tpch.query(10, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 11", ignore) { securityLevel => + tpch.query(11, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 12", ignore) { securityLevel => + tpch.query(12, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 13", ignore) { securityLevel => + tpch.query(13, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 14") { securityLevel => + tpch.query(14, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 15", ignore) { securityLevel => + tpch.query(15, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 16", ignore) { securityLevel => + tpch.query(16, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 17") { securityLevel => + tpch.query(17, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 18", ignore) { securityLevel => + tpch.query(18, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 19", ignore) { securityLevel => + tpch.query(19, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 20", ignore) { securityLevel => + tpch.query(20, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 21", ignore) { securityLevel => + tpch.query(21, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } + + testAgainstSpark("TPC-H 22", ignore) { securityLevel => + tpch.query(22, securityLevel, spark.sqlContext, numPartitions).collect.toSet + } +} + +class TPCHSinglePartitionSuite extends TPCHTests { + override def numPartitions: Int = 1 + override val spark = SparkSession.builder() + .master("local[1]") + .appName("TPCHSinglePartitionSuite") + .config("spark.sql.shuffle.partitions", numPartitions) + .getOrCreate() +} + +class TPCHMultiplePartitionSuite extends TPCHTests { + override def numPartitions: Int = 3 + override val spark = SparkSession.builder() + .master("local[1]") + .appName("TPCHMultiplePartitionSuite") + .config("spark.sql.shuffle.partitions", numPartitions) + .getOrCreate() +} \ No newline at end of file