From 17a37588a81af244d2c91bf37d190280f50f5cc3 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Wed, 18 Oct 2023 15:39:08 +0200 Subject: [PATCH 1/7] SPARK-45592: AQE correctness issue --- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 7 +++---- .../sql/execution/adaptive/QueryStageExec.scala | 5 ++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 36c2160f52282..ae6ec0ed9f93b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -542,13 +542,12 @@ case class AdaptiveSparkPlanExec( } case i: InMemoryTableScanExec => - // There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we - // hit it the first time, we should always create a new query stage. val newStage = newQueryStage(i) + val isMaterialized = newStage.isMaterialized CreateStageResult( newPlan = newStage, - allChildStagesMaterialized = false, - newStages = Seq(newStage)) + allChildStagesMaterialized = isMaterialized, + newStages = if (isMaterialized) Seq.empty else Seq(newStage)) case q: QueryStageExec => CreateStageResult(newPlan = q, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index b941feb12fc05..96e40f03a8cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -86,7 +86,7 @@ abstract class QueryStageExec extends LeafExecNode { protected var _resultOption = new AtomicReference[Option[Any]](None) private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption - final def isMaterialized: Boolean = resultOption.get().isDefined + def isMaterialized: Boolean = resultOption.get().isDefined override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning @@ -294,5 +294,8 @@ case class TableCacheQueryStageExec( override protected def doMaterialize(): Future[Any] = future + override def isMaterialized: Boolean = + super.isMaterialized || inMemoryTableScan.isMaterialized + override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6b00799cabd16..e853a32678197 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -71,6 +71,18 @@ class DatasetSuite extends QueryTest private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b) + test("SPARK-45592: correctness issue") { + val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF() + .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + ee.count() + + val minNbrs1 = ee + .groupBy("_1").agg(min(col("_2")).as("min_number")) + .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + val join = ee.join(minNbrs1, "_1") + assert(join.count() == 1000000) + } + test("checkAnswer should compare map correctly") { val data = Seq((1, "2", Map(1 -> 2, 2 -> 1))) checkAnswer( From f05366d681463a028fe0638c7b6ccbb3975aa996 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Thu, 19 Oct 2023 10:32:39 +0200 Subject: [PATCH 2/7] Revert sloppy fix --- .../sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 7 ++++--- .../spark/sql/execution/adaptive/QueryStageExec.scala | 5 +---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index ae6ec0ed9f93b..36c2160f52282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -542,12 +542,13 @@ case class AdaptiveSparkPlanExec( } case i: InMemoryTableScanExec => + // There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we + // hit it the first time, we should always create a new query stage. val newStage = newQueryStage(i) - val isMaterialized = newStage.isMaterialized CreateStageResult( newPlan = newStage, - allChildStagesMaterialized = isMaterialized, - newStages = if (isMaterialized) Seq.empty else Seq(newStage)) + allChildStagesMaterialized = false, + newStages = Seq(newStage)) case q: QueryStageExec => CreateStageResult(newPlan = q, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 96e40f03a8cc9..b941feb12fc05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -86,7 +86,7 @@ abstract class QueryStageExec extends LeafExecNode { protected var _resultOption = new AtomicReference[Option[Any]](None) private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption - def isMaterialized: Boolean = resultOption.get().isDefined + final def isMaterialized: Boolean = resultOption.get().isDefined override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning @@ -294,8 +294,5 @@ case class TableCacheQueryStageExec( override protected def doMaterialize(): Future[Any] = future - override def isMaterialized: Boolean = - super.isMaterialized || inMemoryTableScan.isMaterialized - override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() } From cebba4caf64136568e5d2492fdfd63670a13858c Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 20 Oct 2023 08:28:02 +0200 Subject: [PATCH 3/7] Implement CoalescedHashPartitioning --- .../plans/physical/partitioning.scala | 81 +++- .../sql/catalyst/DistributionSuite.scala | 124 +++--- .../spark/sql/catalyst/ShuffleSpecSuite.scala | 401 ++++++++++-------- .../adaptive/AQEShuffleReadExec.scala | 10 +- .../exchange/ShuffleExchangeExec.scala | 6 +- 5 files changed, 365 insertions(+), 257 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index a61bd3b7324be..e8e8ff7bf1a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -258,19 +258,8 @@ case object SinglePartition extends Partitioning { SinglePartitionShuffleSpec } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. - * - * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires - * stateful operators to retain the same physical partitioning during the lifetime of the query - * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged - * across Spark versions. Violation of this requirement may bring silent correctness issue. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { - +trait HashPartitioningBase extends Expression with Partitioning with Unevaluable { + def expressions: Seq[Expression] override def children: Seq[Expression] = expressions override def nullable: Boolean = false override def dataType: DataType = IntegerType @@ -295,19 +284,53 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } } - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - HashShuffleSpec(this, distribution) - /** * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less * than numPartitions) based on hashing expressions. */ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) +} + +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends HashPartitioningBase { + + override def createShuffleSpec(distribution: ClusteredDistribution): HashShuffleSpec = + HashShuffleSpec(this, distribution) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } +case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) + +/** + * Represents a partitioning where partitions have been coalesced from a HashPartitioning into a + * fewer number of partitions. + */ +case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[CoalescedBoundary]) + extends HashPartitioningBase { + def expressions: Seq[Expression] = from.expressions + + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = + CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning = + copy(from = from.copy(expressions = newChildren)) + + override val numPartitions: Int = partitions.length +} + /** * Represents a partitioning where rows are split across partitions based on transforms defined * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in @@ -700,7 +723,7 @@ case class HashShuffleSpec( } } - override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + override def createPartitioning(clustering: Seq[Expression]): HashPartitioning = { val exprs = hashKeyPositions.map(v => clustering(v.head)) HashPartitioning(exprs, partitioning.numPartitions) } @@ -708,6 +731,30 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +case class CoalescedHashShuffleSpec( + from: HashShuffleSpec, + partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { + + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + case SinglePartitionShuffleSpec => + numPartitions == 1 + case CoalescedHashShuffleSpec(otherParent, otherPartitions) => + partitions == otherPartitions && + from.isCompatibleWith(otherParent) + case ShuffleSpecCollection(specs) => + specs.exists(isCompatibleWith) + case _ => + false + } + + override def canCreatePartitioning: Boolean = from.canCreatePartitioning + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = + CoalescedHashPartitioning(from.createPartitioning(clustering), partitions) + + override def numPartitions: Int = partitions.length +} + /** * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index a924a9ed02e5d..93ffa1401adcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Murmur3Hash, Pmod} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Murmur3Hash, Pmod} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType @@ -146,63 +146,75 @@ class DistributionSuite extends SparkFunSuite { false) } - test("HashPartitioning is the output partitioning") { - // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of - // the required clustering expressions. - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c")), - true) - - checkSatisfied( - HashPartitioning(Seq($"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c")), - true) - - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"b", $"c")), - false) - - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"d", $"e")), - false) - - // When ClusteredDistribution.requireAllClusterKeys is set to true, - // HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are - // exactly same as the required clustering expressions. - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), - true) - - checkSatisfied( - HashPartitioning(Seq($"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), - false) - - checkSatisfied( - HashPartitioning(Seq($"b", $"a", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), - false) - - // HashPartitioning cannot satisfy OrderedDistribution - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), - false) + private def testHashPartitioningLike( + partitioningName: String, + create: (Seq[Expression], Int) => Partitioning): Unit = { + + test(s"$partitioningName is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c")), + true) + + checkSatisfied( + create(Seq($"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c")), + true) + + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"b", $"c")), + false) + + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"d", $"e")), + false) + + // When ClusteredDistribution.requireAllClusterKeys is set to true, + // HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are + // exactly same as the required clustering expressions. + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + true) + + checkSatisfied( + create(Seq($"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + checkSatisfied( + create(Seq($"b", $"a", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + // HashPartitioning cannot satisfy OrderedDistribution + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), + false) + + checkSatisfied( + create(Seq($"a", $"b", $"c"), 1), + OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), + false) // TODO: this can be relaxed. + + checkSatisfied( + create(Seq($"b", $"c"), 10), + OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), + false) + } + } - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 1), - OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), - false) // TODO: this can be relaxed. + testHashPartitioningLike("HashPartitioning", + (expressions, numPartitions) => HashPartitioning(expressions, numPartitions)) - checkSatisfied( - HashPartitioning(Seq($"b", $"c"), 10), - OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), - false) - } + testHashPartitioningLike("CoalescedHashPartitioning", (expressions, numPartitions) => + CoalescedHashPartitioning( + HashPartitioning(expressions, numPartitions), Seq(CoalescedBoundary(0, numPartitions)))) test("RangePartitioning is the output partitioning") { // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 51e7688732265..82d2a6342f05c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -62,211 +62,254 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { } } - test("compatibility: HashShuffleSpec on both sides") { - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = true - ) - - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), - expected = true - ) + private def testHashShuffleSpecLike( + shuffleSpecName: String, + create: (HashPartitioning, ClusteredDistribution) => ShuffleSpec): Unit = { - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), ClusteredDistribution(Seq($"c", $"d"))), - expected = true - ) + test(s"compatibility: $shuffleSpecName on both sides") { + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"c", $"c", $"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"d"), 10), - ClusteredDistribution(Seq($"a", $"c", $"d"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"d"), 10), ClusteredDistribution(Seq($"c", $"d"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10), - ClusteredDistribution(Seq($"a", $"c", $"c"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"c", $"c", $"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10), - ClusteredDistribution(Seq($"a", $"c", $"d"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"d"), 10), + ClusteredDistribution(Seq($"a", $"c", $"d"))), + expected = true + ) - // negative cases - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"c"), 5), - ClusteredDistribution(Seq($"c", $"d"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"c", $"a"), 10), + ClusteredDistribution(Seq($"a", $"c", $"c"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"c", $"a"), 10), + ClusteredDistribution(Seq($"a", $"c", $"d"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + // negative cases + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"c"), 5), + ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - expected = false - ) - } + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) - test("compatibility: Only one side is HashShuffleSpec") { - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - SinglePartitionShuffleSpec, - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1), - ClusteredDistribution(Seq($"a", $"b"))), - SinglePartitionShuffleSpec, - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + expected = false + ) + } - checkCompatible( - SinglePartitionShuffleSpec, - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1), - ClusteredDistribution(Seq($"a", $"b"))), - expected = true - ) + test(s"compatibility: Only one side is $shuffleSpecName") { + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + SinglePartitionShuffleSpec, + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 1), + ClusteredDistribution(Seq($"a", $"b"))), + SinglePartitionShuffleSpec, + expected = true + ) - checkCompatible( - RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + SinglePartitionShuffleSpec, + create(HashPartitioning(Seq($"a", $"b"), 1), + ClusteredDistribution(Seq($"a", $"b"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), + checkCompatible( + RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - expected = true - ) + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"))))), - expected = false - ) + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + expected = true + ) - checkCompatible( - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))))), - expected = true - ) + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + expected = true + ) - checkCompatible( - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"))), - HashShuffleSpec(HashPartitioning(Seq($"c"), 10), - ClusteredDistribution(Seq($"c", $"d"))))), - expected = false - ) + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"))))), + expected = false + ) + + checkCompatible( + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"))), + create(HashPartitioning(Seq($"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))))), + expected = true + ) + + checkCompatible( + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"))), + create(HashPartitioning(Seq($"c"), 10), + ClusteredDistribution(Seq($"c", $"d"))))), + expected = false + ) + } + } + + testHashShuffleSpecLike("HashShuffleSpec", + (partitioning, distribution) => HashShuffleSpec(partitioning, distribution)) + testHashShuffleSpecLike("CoalescedHashShuffleSpec", + (partitioning, distribution) => { + val partitions = if (partitioning.numPartitions == 1) { + Seq(CoalescedBoundary(0, 1)) + } else { + Seq(CoalescedBoundary(0, 1), CoalescedBoundary(0, partitioning.numPartitions)) + } + CoalescedHashShuffleSpec(HashShuffleSpec(partitioning, distribution), partitions) + }) + + test("compatibility: CoalescedHashShuffleSpec other specs") { + val hashShuffleSpec = HashShuffleSpec( + HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))) + checkCompatible( + hashShuffleSpec, + CoalescedHashShuffleSpec(hashShuffleSpec, Seq(CoalescedBoundary(0, 10))), + expected = false + ) + + checkCompatible( + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))), + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))), + expected = true + ) + + checkCompatible( + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 4), CoalescedBoundary(4, 10))), + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))), + expected = false + ) } test("compatibility: other specs") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala index 46ec91dcc0ab2..b5e9ace156db3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} @@ -75,7 +75,13 @@ case class AQEShuffleReadExec private( // partitions is changed. child.outputPartitioning match { case h: HashPartitioning => - CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length)) + val partitions = partitionSpecs.map { + case CoalescedPartitionSpec(start, end, _) => CoalescedBoundary(start, end) + // Can not happend due to isCoalescedRead + case unexpected => + throw new RuntimeException(s"Unexpected ShufflePartitionSpec: $unexpected") + } + CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, partitions)) case r: RangePartitioning => CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length)) // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 509f1e6a1e4f3..65cde848daf91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -275,10 +275,10 @@ object ShuffleExchangeExec { : ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(_, n) => + case h: HashPartitioningBase => // For HashPartitioning, the partitioning key is already a valid partition ID, as we use // `HashPartitioning.partitionIdExpression` to produce partitioning key. - new PartitionIdPassthrough(n) + new PartitionIdPassthrough(h.numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -325,7 +325,7 @@ object ShuffleExchangeExec { position += 1 position } - case h: HashPartitioning => + case h: HashPartitioningBase => val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(sortingExpressions, _) => From 3ea57d825d8898dddd94f2955c96ca219c69ffe0 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 23 Oct 2023 09:02:34 +0200 Subject: [PATCH 4/7] Clean up CoalescedHPSpec should not create partitioning --- .../plans/physical/partitioning.scala | 57 ++++++++++--------- .../exchange/ShuffleExchangeExec.scala | 6 +- .../org/apache/spark/sql/DatasetSuite.scala | 1 + 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e8e8ff7bf1a27..518bca8d39dbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -258,8 +258,19 @@ case object SinglePartition extends Partitioning { SinglePartitionShuffleSpec } -trait HashPartitioningBase extends Expression with Partitioning with Unevaluable { - def expressions: Seq[Expression] +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression with Partitioning with Unevaluable { + override def children: Seq[Expression] = expressions override def nullable: Boolean = false override def dataType: DataType = IntegerType @@ -284,31 +295,18 @@ trait HashPartitioningBase extends Expression with Partitioning with Unevaluable } } + override def createShuffleSpec(distribution: ClusteredDistribution): HashShuffleSpec = + HashShuffleSpec(this, distribution) + /** * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less * than numPartitions) based on hashing expressions. */ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) -} - -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. - * - * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires - * stateful operators to retain the same physical partitioning during the lifetime of the query - * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged - * across Spark versions. Violation of this requirement may bring silent correctness issue. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends HashPartitioningBase { - - override def createShuffleSpec(distribution: ClusteredDistribution): HashShuffleSpec = - HashShuffleSpec(this, distribution) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) + } case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) @@ -318,8 +316,13 @@ case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) * fewer number of partitions. */ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[CoalescedBoundary]) - extends HashPartitioningBase { - def expressions: Seq[Expression] = from.expressions + extends Expression with Partitioning with Unevaluable { + + override def children: Seq[Expression] = from.expressions + override def nullable: Boolean = from.nullable + override def dataType: DataType = from.dataType + + override def satisfies0(required: Distribution): Boolean = from.satisfies0(required) override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions) @@ -329,6 +332,8 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa copy(from = from.copy(expressions = newChildren)) override val numPartitions: Int = partitions.length + + override def stringArgs: Iterator[Any] = Iterator(from) } /** @@ -723,7 +728,7 @@ case class HashShuffleSpec( } } - override def createPartitioning(clustering: Seq[Expression]): HashPartitioning = { + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { val exprs = hashKeyPositions.map(v => clustering(v.head)) HashPartitioning(exprs, partitioning.numPartitions) } @@ -739,18 +744,14 @@ case class CoalescedHashShuffleSpec( case SinglePartitionShuffleSpec => numPartitions == 1 case CoalescedHashShuffleSpec(otherParent, otherPartitions) => - partitions == otherPartitions && - from.isCompatibleWith(otherParent) + partitions == otherPartitions && from.isCompatibleWith(otherParent) case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => false } - override def canCreatePartitioning: Boolean = from.canCreatePartitioning - - override def createPartitioning(clustering: Seq[Expression]): Partitioning = - CoalescedHashPartitioning(from.createPartitioning(clustering), partitions) + override def canCreatePartitioning: Boolean = false override def numPartitions: Int = partitions.length } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 65cde848daf91..509f1e6a1e4f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -275,10 +275,10 @@ object ShuffleExchangeExec { : ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case h: HashPartitioningBase => + case HashPartitioning(_, n) => // For HashPartitioning, the partitioning key is already a valid partition ID, as we use // `HashPartitioning.partitionIdExpression` to produce partitioning key. - new PartitionIdPassthrough(h.numPartitions) + new PartitionIdPassthrough(n) case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -325,7 +325,7 @@ object ShuffleExchangeExec { position += 1 position } - case h: HashPartitioningBase => + case h: HashPartitioning => val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(sortingExpressions, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e853a32678197..ec42085a5c1a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -79,6 +79,7 @@ class DatasetSuite extends QueryTest val minNbrs1 = ee .groupBy("_1").agg(min(col("_2")).as("min_number")) .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + val join = ee.join(minNbrs1, "_1") assert(join.count() == 1000000) } From 408cb9231826527a5111edf5457ff502885e9cfc Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 23 Oct 2023 09:48:36 +0200 Subject: [PATCH 5/7] Fix WriteDistributionAndOrderingSuit --- .../WriteDistributionAndOrderingSuite.scala | 53 +++++++++---------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 6cab0e0239dc4..40938eb642478 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} @@ -264,11 +264,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data"), attr("id")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -377,11 +374,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -875,11 +869,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -963,11 +954,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -1154,11 +1142,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) val writePartitioningExprs = Seq(truncateExpr) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -1422,6 +1407,9 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase case p: physical.HashPartitioning => val resolvedExprs = p.expressions.map(resolveAttrs(_, plan)) p.copy(expressions = resolvedExprs) + case c: physical.CoalescedHashPartitioning => + val resolvedExprs = c.from.expressions.map(resolveAttrs(_, plan)) + c.copy(from = c.from.copy(expressions = resolvedExprs)) case _: UnknownPartitioning => // don't check partitioning if no particular one is expected actualPartitioning @@ -1480,9 +1468,16 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase private def clusteredWritePartitioning( writePartitioningExprs: Seq[catalyst.expressions.Expression], - targetNumPartitions: Option[Int]): physical.Partitioning = { - HashPartitioning(writePartitioningExprs, - targetNumPartitions.getOrElse(conf.numShufflePartitions)) + targetNumPartitions: Option[Int], + coalesce: Boolean): physical.Partitioning = { + val partitioning = HashPartitioning(writePartitioningExprs, + targetNumPartitions.getOrElse(conf.numShufflePartitions)) + if (coalesce) { + CoalescedHashPartitioning( + partitioning, Seq(CoalescedBoundary(0, partitioning.numPartitions))) + } else { + partitioning + } } private def partitionSizes(dataSkew: Boolean, coalesce: Boolean): Seq[Option[Long]] = { From 4f6bd1d5dd9397c4f2a403d030153b8faf7deec0 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 23 Oct 2023 13:25:50 +0200 Subject: [PATCH 6/7] More PR comments --- .../plans/physical/partitioning.scala | 6 ++--- .../org/apache/spark/sql/DatasetSuite.scala | 27 ++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 518bca8d39dbc..e36c55b3b4f97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -295,7 +295,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } } - override def createShuffleSpec(distribution: ClusteredDistribution): HashShuffleSpec = + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = HashShuffleSpec(this, distribution) /** @@ -737,8 +737,8 @@ case class HashShuffleSpec( } case class CoalescedHashShuffleSpec( - from: HashShuffleSpec, - partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { + from: ShuffleSpec, + @transient partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { case SinglePartitionShuffleSpec => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ec42085a5c1a9..bf78e6e11fe99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -71,19 +71,6 @@ class DatasetSuite extends QueryTest private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b) - test("SPARK-45592: correctness issue") { - val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF() - .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) - ee.count() - - val minNbrs1 = ee - .groupBy("_1").agg(min(col("_2")).as("min_number")) - .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) - - val join = ee.join(minNbrs1, "_1") - assert(join.count() == 1000000) - } - test("checkAnswer should compare map correctly") { val data = Seq((1, "2", Map(1 -> 2, 2 -> 1))) checkAnswer( @@ -2658,6 +2645,20 @@ class DatasetSuite extends QueryTest val ds = Seq(1, 2).toDS().persist(StorageLevel.NONE) assert(ds.count() == 2) } + + test("SPARK-45592: Coaleasced shuffle read is not compatible with hash partitioning") { + val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF() + .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + ee.count() + + val minNbrs1 = ee + .groupBy("_1").agg(min(col("_2")).as("min_number")) + .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + + val join = ee.join(minNbrs1, "_1") + assert(join.count() == 1000000) + } + } class DatasetLargeResultCollectingSuite extends QueryTest From 7402821ddd7c8dbfb30e82123b736fd51f623be4 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 30 Oct 2023 11:06:43 +0100 Subject: [PATCH 7/7] More review comments --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 5 +++-- .../org/apache/spark/sql/catalyst/DistributionSuite.scala | 4 ++-- .../org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala | 4 ++-- .../spark/sql/execution/adaptive/AQEShuffleReadExec.scala | 3 ++- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e36c55b3b4f97..0ae2857161c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -333,7 +333,8 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa override val numPartitions: Int = partitions.length - override def stringArgs: Iterator[Any] = Iterator(from) + override def toString: String = from.toString + override def sql: String = from.sql } /** @@ -738,7 +739,7 @@ case class HashShuffleSpec( case class CoalescedHashShuffleSpec( from: ShuffleSpec, - @transient partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { + partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { case SinglePartitionShuffleSpec => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 93ffa1401adcd..7cb4d5f123253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -147,8 +147,8 @@ class DistributionSuite extends SparkFunSuite { } private def testHashPartitioningLike( - partitioningName: String, - create: (Seq[Expression], Int) => Partitioning): Unit = { + partitioningName: String, + create: (Seq[Expression], Int) => Partitioning): Unit = { test(s"$partitioningName is the output partitioning") { // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 82d2a6342f05c..6b069d1c97363 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -63,8 +63,8 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { } private def testHashShuffleSpecLike( - shuffleSpecName: String, - create: (HashPartitioning, ClusteredDistribution) => ShuffleSpec): Unit = { + shuffleSpecName: String, + create: (HashPartitioning, ClusteredDistribution) => ShuffleSpec): Unit = { test(s"compatibility: $shuffleSpecName on both sides") { checkCompatible( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala index b5e9ace156db3..6b39ac70a62ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} @@ -79,7 +80,7 @@ case class AQEShuffleReadExec private( case CoalescedPartitionSpec(start, end, _) => CoalescedBoundary(start, end) // Can not happend due to isCoalescedRead case unexpected => - throw new RuntimeException(s"Unexpected ShufflePartitionSpec: $unexpected") + throw SparkException.internalError(s"Unexpected ShufflePartitionSpec: $unexpected") } CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, partitions)) case r: RangePartitioning =>