From 1ece803bdff98163813702944ec48bea2f62f59e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 02:01:39 +0900 Subject: [PATCH 01/19] initial commit --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9816b33ae8df..81888807c949 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2230,6 +2230,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { + case ArrayType(et, false) if cls.isEmpty => + Cast(inputData, inputData.dataType) case ArrayType(et, _) => val expr = MapObjects(func, inputData, et, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => From c042ff27ba4e60e78fc0fd8ec65d80363fec4983 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 16:48:55 +0900 Subject: [PATCH 02/19] addressed review comment --- .../apache/spark/sql/catalyst/CatalystTypeConverters.scala | 2 +- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4ebdb139fe0..e87f20750ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -40,7 +40,7 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - private def isPrimitive(dataType: DataType): Boolean = { + def isPrimitive(dataType: DataType): Boolean = { dataType match { case BooleanType => true case ByteType => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 81888807c949..72785a7d4b51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2230,8 +2230,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, false) if cls.isEmpty => - Cast(inputData, inputData.dataType) + case ArrayType(et, false) if cls.isEmpty && + CatalystTypeConverters.isPrimitive(et) => Cast(inputData, inputData.dataType) case ArrayType(et, _) => val expr = MapObjects(func, inputData, et, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => From 264a18b0be6e562c31b044a90f1dfc5ac3aa1a4d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 11 Apr 2017 03:40:33 +0900 Subject: [PATCH 03/19] add an optimizer to remove MapObjects --- .../sql/catalyst/analysis/Analyzer.scala | 2 - .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/objects.scala | 40 +++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 72785a7d4b51..9816b33ae8df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2230,8 +2230,6 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, false) if cls.isEmpty && - CatalystTypeConverters.isPrimitive(et) => Cast(inputData, inputData.dataType) case ArrayType(et, _) => val expr = MapObjects(func, inputData, et, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d221b0611a89..a9a7685133ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -91,6 +91,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineLimits, CombineUnions, // Constant folding and strength reduction + EliminateMapObjects, NullPropagation(conf), FoldablePropagation, OptimizeIn(conf), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 257dbfac8c3e..75b179a745cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types._ /* * This file defines optimization rules related to object manipulation (for the Dataset API). @@ -96,3 +99,40 @@ object CombineTypedFilters extends Rule[LogicalPlan] { } } } + +/** + * Removes MapObjects when the following conditions are satisfied + * 1. Mapobject(e) where e is lambdavariable + * 2. the function will convert an expression MapObjects(e) to AssertNotNull(e) + * 3. the inputData is of primitive type array and its element is not nullable. + * 4. the outputData is of primitive type array and its element does not have enull + * 5. no custom collection class specified + * representation of data item. For example back to back map operations. + */ +object EliminateMapObjects extends Rule[LogicalPlan] { + private def convertDataTypeToArrayClass(dt: DataType): Class[_] = dt match { + case IntegerType => classOf[Array[Int]] + case LongType => classOf[Array[Long]] + case DoubleType => classOf[Array[Double]] + case FloatType => classOf[Array[Float]] + case ShortType => classOf[Array[Short]] + case ByteType => classOf[Array[Byte]] + case BooleanType => classOf[Array[Boolean]] + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case _ @ DeserializeToObject(_ @ Invoke( + MapObjects(_, _, inputType, args, inputData, customCollectionCls, _), + funcName, returnType @ ObjectType(returnCls), arguments, propagateNull, returnNullable), + outputObjAttr, child) if CatalystTypeConverters.isPrimitive(inputType) && + returnCls.isAssignableFrom(convertDataTypeToArrayClass(inputType)) && + customCollectionCls.isEmpty => + args match { + case _@AssertNotNull(LambdaVariable(_, _, dataType, _), _) if dataType == inputType => + DeserializeToObject(Invoke( + inputData, funcName, returnType, arguments, propagateNull, returnNullable), + outputObjAttr, child) + case _ => plan + } + } +} From 82a1f2bd376b53ac8e4d29f167c2bd12e0bb1961 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 11 Apr 2017 03:41:42 +0900 Subject: [PATCH 04/19] add test suites to check deserializer in DeserializeToObjectExec --- .../spark/sql/DatasetPrimitiveSuite.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 541565344f75..6b22d5c086ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql import scala.collection.immutable.Queue import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.execution.DeserializeToObjectExec import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ case class IntClass(value: Int) @@ -263,4 +266,27 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val dsInt = Seq(Array(1, 2, 3)).toDS.cache.map(e => e) + val planInt = dsInt.queryExecution.executedPlan + val deserializeInt = planInt.find(_.isInstanceOf[DeserializeToObjectExec]) + assert(deserializeInt.isDefined) + assert(deserializeInt.get match { + case _ @ DeserializeToObjectExec(_ @ Invoke(_, _ @ "toIntArray", _, _, _, _), _, _) => + true + case _ => + false + }) + + val dsDouble = Seq(Array(1.1, 2.2, 3.3)).toDS.cache.map(e => e) + val planDouble = dsDouble.queryExecution.executedPlan + val deserializeDouble = planDouble.find(_.isInstanceOf[DeserializeToObjectExec]) + assert(deserializeDouble.isDefined) + assert(deserializeDouble.get match { + case _ @ DeserializeToObjectExec(_ @ Invoke(_, _ @ "toDoubleArray", _, _, _, _), _, _) => + true + case _ => + false + }) + } } From cc3c7bdcccc3c5e768ac26103a498da436f9cbac Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 14 Apr 2017 18:27:05 +0900 Subject: [PATCH 05/19] eliminate no-op AssertNotNull --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../spark/sql/catalyst/expressions/objects/objects.scala | 1 + .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9816b33ae8df..d9f36f7f874d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2230,8 +2230,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, _) => - val expr = MapObjects(func, inputData, et, cls) transformUp { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f446c3e4a75f..bedc7a96c1fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -458,6 +458,7 @@ object MapObjects { function: Expression => Expression, inputData: Expression, elementType: DataType, + elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8445ee06bd89..0aeedde478e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet - import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -368,6 +368,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + case a @ AssertNotNull(c, _) if !c.nullable => c + // For Coalesce, remove null literals. case e @ Coalesce(children) => val newChildren = children.filterNot(isNullLiteral) From 7d43d270d3287d5b50890534456df3ee1c9d844b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 14 Apr 2017 18:29:50 +0900 Subject: [PATCH 06/19] remove unnecessary Mapobjects by checking MapObjects and lambdaVariable --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/objects.scala | 26 +++++++------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a9a7685133ec..0d5450f04187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -91,8 +91,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineLimits, CombineUnions, // Constant folding and strength reduction - EliminateMapObjects, NullPropagation(conf), + EliminateMapObjects, FoldablePropagation, OptimizeIn(conf), ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 75b179a745cf..b914fa961a5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -102,11 +102,9 @@ object CombineTypedFilters extends Rule[LogicalPlan] { /** * Removes MapObjects when the following conditions are satisfied - * 1. Mapobject(e) where e is lambdavariable - * 2. the function will convert an expression MapObjects(e) to AssertNotNull(e) - * 3. the inputData is of primitive type array and its element is not nullable. - * 4. the outputData is of primitive type array and its element does not have enull - * 5. no custom collection class specified + * 1. Mapobject(e) where e is lambdavariable(), which means types for input output + * are primitive types + * 2. no custom collection class specified * representation of data item. For example back to back map operations. */ object EliminateMapObjects extends Rule[LogicalPlan] { @@ -122,17 +120,11 @@ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case _ @ DeserializeToObject(_ @ Invoke( - MapObjects(_, _, inputType, args, inputData, customCollectionCls, _), - funcName, returnType @ ObjectType(returnCls), arguments, propagateNull, returnNullable), - outputObjAttr, child) if CatalystTypeConverters.isPrimitive(inputType) && - returnCls.isAssignableFrom(convertDataTypeToArrayClass(inputType)) && - customCollectionCls.isEmpty => - args match { - case _@AssertNotNull(LambdaVariable(_, _, dataType, _), _) if dataType == inputType => - DeserializeToObject(Invoke( - inputData, funcName, returnType, arguments, propagateNull, returnNullable), - outputObjAttr, child) - case _ => plan - } + MapObjects(_, _, _, LambdaVariable(_, _, _, _), inputData, customCollectionCls, _), + funcName, returnType @ ObjectType(returnCls), arguments, propagateNull, returnNullable), + outputObjAttr, child) if customCollectionCls.isEmpty => + DeserializeToObject(Invoke( + inputData, funcName, returnType, arguments, propagateNull, returnNullable), + outputObjAttr, child) } } From 3bcee3f66ff5447d342656a3de892af1fda0fd87 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 14 Apr 2017 19:46:37 +0900 Subject: [PATCH 07/19] fix scala style error --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 0aeedde478e5..8c7ee6c9361e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet + import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ From 1b24be0ca52234997deb549a73cf227339d1ad37 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 15 Apr 2017 02:43:35 +0900 Subject: [PATCH 08/19] use simpler matchings --- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../spark/sql/catalyst/optimizer/objects.scala | 16 +++------------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8c7ee6c9361e..c2b898c97d6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -369,7 +369,7 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case a @ AssertNotNull(c, _) if !c.nullable => c + case _ @ AssertNotNull(c, _) if !c.nullable => c // For Coalesce, remove null literals. case e @ Coalesce(children) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index b914fa961a5c..32162082c9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -108,21 +108,11 @@ object CombineTypedFilters extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateMapObjects extends Rule[LogicalPlan] { - private def convertDataTypeToArrayClass(dt: DataType): Class[_] = dt match { - case IntegerType => classOf[Array[Int]] - case LongType => classOf[Array[Long]] - case DoubleType => classOf[Array[Double]] - case FloatType => classOf[Array[Float]] - case ShortType => classOf[Array[Short]] - case ByteType => classOf[Array[Byte]] - case BooleanType => classOf[Array[Boolean]] - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case _ @ DeserializeToObject(_ @ Invoke( - MapObjects(_, _, _, LambdaVariable(_, _, _, _), inputData, customCollectionCls, _), - funcName, returnType @ ObjectType(returnCls), arguments, propagateNull, returnNullable), - outputObjAttr, child) if customCollectionCls.isEmpty => + MapObjects(_, _, _, _ : LambdaVariable, inputData, None, _), + funcName, returnType @ ObjectType(_), arguments, propagateNull, returnNullable), + outputObjAttr, child) => DeserializeToObject(Invoke( inputData, funcName, returnType, arguments, propagateNull, returnNullable), outputObjAttr, child) From 5174b5ea055a90e1d57f999ac5570cc8272557db Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 15 Apr 2017 02:45:19 +0900 Subject: [PATCH 09/19] use non-e2e test suite --- .../optimizer/EliminateMapObjectsSuite.scala | 60 +++++++++++++++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 24 -------- 2 files changed, 60 insertions(+), 24 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala new file mode 100644 index 000000000000..71d886dd790d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + EliminateMapObjects) :: Nil + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6b22d5c086ee..73adc9facb23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -265,28 +265,4 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } - - test("SPARK-20254: Remove unnecessary data conversion for primitive array") { - val dsInt = Seq(Array(1, 2, 3)).toDS.cache.map(e => e) - val planInt = dsInt.queryExecution.executedPlan - val deserializeInt = planInt.find(_.isInstanceOf[DeserializeToObjectExec]) - assert(deserializeInt.isDefined) - assert(deserializeInt.get match { - case _ @ DeserializeToObjectExec(_ @ Invoke(_, _ @ "toIntArray", _, _, _, _), _, _) => - true - case _ => - false - }) - - val dsDouble = Seq(Array(1.1, 2.2, 3.3)).toDS.cache.map(e => e) - val planDouble = dsDouble.queryExecution.executedPlan - val deserializeDouble = planDouble.find(_.isInstanceOf[DeserializeToObjectExec]) - assert(deserializeDouble.isDefined) - assert(deserializeDouble.get match { - case _ @ DeserializeToObjectExec(_ @ Invoke(_, _ @ "toDoubleArray", _, _, _, _), _, _) => - true - case _ => - false - }) - } } From e608c45540426d264f9ed0dfcd559e03c72c157a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 15 Apr 2017 16:48:04 +0900 Subject: [PATCH 10/19] rebase with master --- .../org/apache/spark/sql/catalyst/optimizer/objects.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 32162082c9da..7aae52681d56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -110,9 +110,10 @@ object CombineTypedFilters extends Rule[LogicalPlan] { object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case _ @ DeserializeToObject(_ @ Invoke( - MapObjects(_, _, _, _ : LambdaVariable, inputData, None, _), + MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _), + inputData, None, _), funcName, returnType @ ObjectType(_), arguments, propagateNull, returnNullable), - outputObjAttr, child) => + outputObjAttr, child) if dataType == castDataType => DeserializeToObject(Invoke( inputData, funcName, returnType, arguments, propagateNull, returnNullable), outputObjAttr, child) From 2ee4030c5c8e1057fe81cdfc8e92ff6c90fa783f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 16 Apr 2017 02:50:07 +0900 Subject: [PATCH 11/19] fix test failure --- .../optimizer/EliminateMapObjectsSuite.scala | 120 +++++++++--------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index 71d886dd790d..25cbe3693273 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -1,60 +1,60 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types._ - -class EliminateMapObjectsSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("EliminateMapObjects", FixedPoint(50), - NullPropagation(conf), - EliminateMapObjects) :: Nil - } - - implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() - implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() - - test("SPARK-20254: Remove unnecessary data conversion for primitive array") { - val intObjType = ObjectType(classOf[Array[Int]]) - val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) - val intQuery = intInput.deserialize[Array[Int]].analyze - val intOptimized = Optimize.execute(intQuery) - val intExpected = DeserializeToObject( - Invoke(intInput.output(0), "toIntArray", intObjType), - AttributeReference("obj", intObjType, true)(), intInput) - comparePlans(intOptimized, intExpected) - - val doubleObjType = ObjectType(classOf[Array[Double]]) - val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) - val doubleQuery = doubleInput.deserialize[Array[Double]].analyze - val doubleOptimized = Optimize.execute(doubleQuery) - val doubleExpected = DeserializeToObject( - Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType), - AttributeReference("obj", doubleObjType, true)(), doubleInput) - comparePlans(doubleOptimized, doubleExpected) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + EliminateMapObjects) :: Nil + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} From 1d6ab36cb7dd0400ed6b1f3e8565cb2e45aacdc3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 16 Apr 2017 10:12:22 +0900 Subject: [PATCH 12/19] address review comments --- .../apache/spark/sql/catalyst/CatalystTypeConverters.scala | 2 +- .../spark/sql/catalyst/expressions/objects/objects.scala | 2 ++ .../org/apache/spark/sql/catalyst/optimizer/objects.scala | 5 ++--- .../scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index e87f20750ae7..d4ebdb139fe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -40,7 +40,7 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - def isPrimitive(dataType: DataType): Boolean = { + private def isPrimitive(dataType: DataType): Boolean = { dataType match { case BooleanType => true case ByteType => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bedc7a96c1fc..03b24151f534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -451,6 +451,8 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 7aae52681d56..0062b95e59e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction -import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -109,10 +108,10 @@ object CombineTypedFilters extends Rule[LogicalPlan] { */ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case _ @ DeserializeToObject(_ @ Invoke( + case _ @ DeserializeToObject(Invoke( MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _), inputData, None, _), - funcName, returnType @ ObjectType(_), arguments, propagateNull, returnNullable), + funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), outputObjAttr, child) if dataType == castDataType => DeserializeToObject(Invoke( inputData, funcName, returnType, arguments, propagateNull, returnNullable), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 73adc9facb23..42ef298d7d08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.execution.DeserializeToObjectExec import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ case class IntClass(value: Int) @@ -265,4 +264,5 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + } From 0fd8c259ee732377e17667fb8f7336d21d71b425 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 16 Apr 2017 21:40:25 +0900 Subject: [PATCH 13/19] rebase with master --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 03b24151f534..1a202ecf745c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -465,7 +465,7 @@ object MapObjects { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } From 791aad97338a7fac600ed1f185fb43a73e74d772 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 16 Apr 2017 21:41:37 +0900 Subject: [PATCH 14/19] address review comments --- .../sql/catalyst/optimizer/Optimizer.scala | 4 +-- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../sql/catalyst/optimizer/objects.scala | 9 ++++- .../optimizer/EliminateMapObjectsSuite.scala | 34 ++++++++++++------- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d5450f04187..dd768d18e858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -92,7 +92,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineUnions, // Constant folding and strength reduction NullPropagation(conf), - EliminateMapObjects, FoldablePropagation, OptimizeIn(conf), ConstantFolding, @@ -120,7 +119,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: - Batch("Typed Filter Optimization", fixedPoint, + Batch("Object Expressions Optimization", fixedPoint, + EliminateMapObjects, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index c2b898c97d6a..ea2c5d241d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -369,7 +369,7 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case _ @ AssertNotNull(c, _) if !c.nullable => c + case AssertNotNull(c, _) if !c.nullable => c // For Coalesce, remove null literals. case e @ Coalesce(children) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 0062b95e59e8..a6b1c55a5d75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -110,11 +110,18 @@ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case _ @ DeserializeToObject(Invoke( MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _), - inputData, None, _), + inputData, None), funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), outputObjAttr, child) if dataType == castDataType => DeserializeToObject(Invoke( inputData, funcName, returnType, arguments, propagateNull, returnNullable), outputObjAttr, child) + case _ @ DeserializeToObject(Invoke( + MapObjects(_, _, _, LambdaVariable(_, _, dataType, _), inputData, None), + funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), + outputObjAttr, child) => + DeserializeToObject(Invoke( + inputData, funcName, returnType, arguments, propagateNull, returnNullable), + outputObjAttr, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index 25cbe3693273..d274379f2294 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -28,11 +28,17 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ class EliminateMapObjectsSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = + class Optimize(addSimplifyCast: Boolean) extends RuleExecutor[LogicalPlan] { + val batches = if (addSimplifyCast) { Batch("EliminateMapObjects", FixedPoint(50), NullPropagation(conf), + SimplifyCasts, EliminateMapObjects) :: Nil + } else { + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + EliminateMapObjects) :: Nil + } } implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() @@ -42,19 +48,23 @@ class EliminateMapObjectsSuite extends PlanTest { val intObjType = ObjectType(classOf[Array[Int]]) val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) val intQuery = intInput.deserialize[Array[Int]].analyze - val intOptimized = Optimize.execute(intQuery) - val intExpected = DeserializeToObject( - Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), - AttributeReference("obj", intObjType, true)(), intInput) - comparePlans(intOptimized, intExpected) + Seq(true, false).foreach { addSimplifyCast => + val intOptimized = new Optimize(addSimplifyCast).execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + } val doubleObjType = ObjectType(classOf[Array[Double]]) val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) val doubleQuery = doubleInput.deserialize[Array[Double]].analyze - val doubleOptimized = Optimize.execute(doubleQuery) - val doubleExpected = DeserializeToObject( - Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), - AttributeReference("obj", doubleObjType, true)(), doubleInput) - comparePlans(doubleOptimized, doubleExpected) + Seq(true, false).foreach { addSimplifyCast => + val doubleOptimized = new Optimize(addSimplifyCast).execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } } } From f695e50e38bd329db3b75951dd7af52fea3b3dde Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 17 Apr 2017 13:31:37 +0900 Subject: [PATCH 15/19] address review comments --- .../sql/catalyst/optimizer/objects.scala | 14 ++------ .../optimizer/EliminateMapObjectsSuite.scala | 32 +++++++------------ .../spark/sql/DatasetPrimitiveSuite.scala | 2 -- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index a6b1c55a5d75..55288ac654a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -104,20 +104,12 @@ object CombineTypedFilters extends Rule[LogicalPlan] { * 1. Mapobject(e) where e is lambdavariable(), which means types for input output * are primitive types * 2. no custom collection class specified - * representation of data item. For example back to back map operations. + * representation of data item. For example back to back map operations. */ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case _ @ DeserializeToObject(Invoke( - MapObjects(_, _, _, Cast(LambdaVariable(_, _, dataType, _), castDataType, _), - inputData, None), - funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), - outputObjAttr, child) if dataType == castDataType => - DeserializeToObject(Invoke( - inputData, funcName, returnType, arguments, propagateNull, returnNullable), - outputObjAttr, child) - case _ @ DeserializeToObject(Invoke( - MapObjects(_, _, _, LambdaVariable(_, _, dataType, _), inputData, None), + case DeserializeToObject(Invoke( + MapObjects(_, _, _, _ : LambdaVariable, inputData, None), funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), outputObjAttr, child) => DeserializeToObject(Invoke( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index d274379f2294..d4f37e2a5e87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -28,16 +28,12 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ class EliminateMapObjectsSuite extends PlanTest { - class Optimize(addSimplifyCast: Boolean) extends RuleExecutor[LogicalPlan] { - val batches = if (addSimplifyCast) { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { Batch("EliminateMapObjects", FixedPoint(50), NullPropagation(conf), SimplifyCasts, EliminateMapObjects) :: Nil - } else { - Batch("EliminateMapObjects", FixedPoint(50), - NullPropagation(conf), - EliminateMapObjects) :: Nil } } @@ -48,23 +44,19 @@ class EliminateMapObjectsSuite extends PlanTest { val intObjType = ObjectType(classOf[Array[Int]]) val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) val intQuery = intInput.deserialize[Array[Int]].analyze - Seq(true, false).foreach { addSimplifyCast => - val intOptimized = new Optimize(addSimplifyCast).execute(intQuery) - val intExpected = DeserializeToObject( - Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), - AttributeReference("obj", intObjType, true)(), intInput) - comparePlans(intOptimized, intExpected) - } + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) val doubleObjType = ObjectType(classOf[Array[Double]]) val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) val doubleQuery = doubleInput.deserialize[Array[Double]].analyze - Seq(true, false).foreach { addSimplifyCast => - val doubleOptimized = new Optimize(addSimplifyCast).execute(doubleQuery) - val doubleExpected = DeserializeToObject( - Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), - AttributeReference("obj", doubleObjType, true)(), doubleInput) - comparePlans(doubleOptimized, doubleExpected) - } + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 42ef298d7d08..541565344f75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql import scala.collection.immutable.Queue import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.execution.DeserializeToObjectExec import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int) From ce6927dba70b5ec494bb7e5b1d8d5b51a2062db5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 18 Apr 2017 16:54:35 +0900 Subject: [PATCH 16/19] remove MapObject with LambdaVariable if it is for primitive type (non-nullable) --- .../spark/sql/catalyst/optimizer/objects.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 55288ac654a6..190564862013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -107,13 +107,17 @@ object CombineTypedFilters extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateMapObjects extends Rule[LogicalPlan] { +/* def apply(plan: LogicalPlan): LogicalPlan = plan transform { case DeserializeToObject(Invoke( - MapObjects(_, _, _, _ : LambdaVariable, inputData, None), - funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), - outputObjAttr, child) => - DeserializeToObject(Invoke( - inputData, funcName, returnType, arguments, propagateNull, returnNullable), - outputObjAttr, child) + MapObjects(_, _, _, _ : LambdaVariable, inputData, None), + funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), + outputObjAttr, child) => + DeserializeToObject(Invoke( + inputData, funcName, returnType, arguments, propagateNull, returnNullable), + outputObjAttr, child) +*/ + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData } } From a6bedeec986da33c1145b1d09b32e4e61d58e422 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 18 Apr 2017 16:56:37 +0900 Subject: [PATCH 17/19] remove unused import --- .../scala/org/apache/spark/sql/catalyst/optimizer/objects.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 190564862013..a44de67309bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types._ /* * This file defines optimization rules related to object manipulation (for the Dataset API). From c0dca2ba463b954b4af1611bad17a851088607bf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 18 Apr 2017 21:10:39 +0900 Subject: [PATCH 18/19] address review comment --- .../apache/spark/sql/catalyst/optimizer/objects.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index a44de67309bb..6d6f8b0a26ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -106,16 +106,6 @@ object CombineTypedFilters extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateMapObjects extends Rule[LogicalPlan] { -/* - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case DeserializeToObject(Invoke( - MapObjects(_, _, _, _ : LambdaVariable, inputData, None), - funcName, returnType: ObjectType, arguments, propagateNull, returnNullable), - outputObjAttr, child) => - DeserializeToObject(Invoke( - inputData, funcName, returnType, arguments, propagateNull, returnNullable), - outputObjAttr, child) -*/ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData } From 8de69156614601e59f9621493a919965b4e22510 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 18 Apr 2017 22:38:59 +0900 Subject: [PATCH 19/19] address review comment --- .../org/apache/spark/sql/catalyst/optimizer/objects.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 6d6f8b0a26ec..8cdc6425bcad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -100,10 +100,9 @@ object CombineTypedFilters extends Rule[LogicalPlan] { /** * Removes MapObjects when the following conditions are satisfied - * 1. Mapobject(e) where e is lambdavariable(), which means types for input output - * are primitive types - * 2. no custom collection class specified - * representation of data item. For example back to back map operations. + * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output + * are primitive types with non-nullable + * 2. no custom collection class specified representation of data item. */ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {