diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index 9a0bdc6bcfd1..90e3bdcd082c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunctio import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.NonFateSharingCache /** * Wraps the [[InternalRow]] with the corresponding [[DataType]] to make it comparable with @@ -34,9 +35,10 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * @param dataTypes the data types for the row */ class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) { + import InternalRowComparableWrapper._ - private val structType = StructType(dataTypes.map(t => StructField("f", t))) - private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + private val structType = structTypeCache.get(dataTypes) + private val ordering = orderingCache.get(dataTypes) override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt @@ -53,6 +55,21 @@ class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[Data } object InternalRowComparableWrapper { + private final val MAX_CACHE_ENTRIES = 1024 + + private val orderingCache = { + val loadFunc = (dataTypes: Seq[DataType]) => { + RowOrdering.createNaturalAscendingOrdering(dataTypes) + } + NonFateSharingCache(loadFunc, MAX_CACHE_ENTRIES) + } + + private val structTypeCache = { + val loadFunc = (dataTypes: Seq[DataType]) => { + StructType(dataTypes.map(t => StructField("f", t))) + } + NonFateSharingCache(loadFunc, MAX_CACHE_ENTRIES) + } def apply( partition: InputPartition with HasPartitionKey, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala new file mode 100644 index 000000000000..cc28e8552516 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning +import org.apache.spark.sql.connector.catalog.PartitionInternalRow +import org.apache.spark.sql.types.IntegerType + +/** + * Benchmark for [[InternalRowComparableWrapper]]. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "catalyst/Test/runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/Test/runMain " + * Results will be written to "benchmarks/InternalRowComparableWrapperBenchmark-results.txt". + * }}} + */ +object InternalRowComparableWrapperBenchmark extends BenchmarkBase { + + private def constructAndRunBenchmark(): Unit = { + val partitionNum = 200_000 + val bucketNum = 4096 + val day = 20240401 + val partitions = (0 until partitionNum).map { i => + val bucketId = i % bucketNum + PartitionInternalRow.apply(Array(day, bucketId)); + } + val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output) + + benchmark.addCase("toSet") { _ => + val distinct = partitions + .map(new InternalRowComparableWrapper(_, Seq(IntegerType, IntegerType))) + .toSet + assert(distinct.size == bucketNum) + } + + benchmark.addCase("mergePartitions") { _ => + // just to mock the data types + val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType))) + + val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) + val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) + val merged = InternalRowComparableWrapper.mergePartitions( + leftPartitioning, rightPartitioning, expressions) + assert(merged.size == bucketNum) + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + constructAndRunBenchmark() + } +}