diff --git a/core/pom.xml b/core/pom.xml index 8a872dea1de4..4446dbdb5ed0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -378,7 +378,7 @@ net.razorvine pyrolite - 4.23 + 4.30 net.razorvine diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index bceb26cfd4f8..5114cf70e3f2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -94,6 +94,7 @@ private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, listenerBus: LiveListenerBus, conf: SparkConf, + cleaner: Option[ContextCleaner] = None, clock: Clock = new SystemClock()) extends Logging { @@ -148,7 +149,7 @@ private[spark] class ExecutorAllocationManager( // Listener for Spark events that impact the allocation policy val listener = new ExecutorAllocationListener - val executorMonitor = new ExecutorMonitor(conf, client, clock) + val executorMonitor = new ExecutorMonitor(conf, client, listenerBus, clock) // Executor that handles the scheduling task. private val executor = @@ -194,11 +195,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException( s"s${DYN_ALLOCATION_SUSTAINED_SCHEDULER_BACKLOG_TIMEOUT.key} must be > 0!") } - // Require external shuffle service for dynamic allocation - // Otherwise, we may lose shuffle files when killing executors - if (!conf.get(config.SHUFFLE_SERVICE_ENABLED) && !testing) { - throw new SparkException("Dynamic allocation of executors requires the external " + - "shuffle service. You may enable this through spark.shuffle.service.enabled.") + if (!conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + if (conf.get(config.DYN_ALLOCATION_SHUFFLE_TRACKING)) { + logWarning("Dynamic allocation without a shuffle service is an experimental feature.") + } else if (!testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } } if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { @@ -214,6 +217,7 @@ private[spark] class ExecutorAllocationManager( def start(): Unit = { listenerBus.addToManagementQueue(listener) listenerBus.addToManagementQueue(executorMonitor) + cleaner.foreach(_.attachListener(executorMonitor)) val scheduleTask = new Runnable() { override def run(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a0d7aa743223..75182b0c9008 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -553,14 +553,22 @@ class SparkContext(config: SparkConf) extends Logging { None } - // Optionally scale number of executors dynamically based on workload. Exposed for testing. + _cleaner = + if (_conf.get(CLEANER_REFERENCE_TRACKING)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) _executorAllocationManager = if (dynamicAllocationEnabled) { schedulerBackend match { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( - schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, + cleaner = cleaner)) case _ => None } @@ -569,14 +577,6 @@ class SparkContext(config: SparkConf) extends Logging { } _executorAllocationManager.foreach(_.start()) - _cleaner = - if (_conf.get(CLEANER_REFERENCE_TRACKING)) { - Some(new ContextCleaner(this)) - } else { - None - } - _cleaner.foreach(_.start()) - setupAndStartListenerBus() postEnvironmentUpdate() postApplicationStart() @@ -1791,7 +1791,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { - def addJarFile(file: File): String = { + def addLocalJarFile(file: File): String = { try { if (!file.exists()) { throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") @@ -1808,12 +1808,36 @@ class SparkContext(config: SparkConf) extends Logging { } } + def checkRemoteJarFile(path: String): String = { + val hadoopPath = new Path(path) + val scheme = new URI(path).getScheme + if (!Array("http", "https", "ftp").contains(scheme)) { + try { + val fs = hadoopPath.getFileSystem(hadoopConfiguration) + if (!fs.exists(hadoopPath)) { + throw new FileNotFoundException(s"Jar ${path} not found") + } + if (fs.isDirectory(hadoopPath)) { + throw new IllegalArgumentException( + s"Directory ${path} is not allowed for addJar") + } + path + } catch { + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) + null + } + } else { + path + } + } + if (path == null) { logWarning("null specified as parameter to addJar") } else { val key = if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - addJarFile(new File(path)) + addLocalJarFile(new File(path)) } else { val uri = new URI(path) // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies @@ -1822,12 +1846,12 @@ class SparkContext(config: SparkConf) extends Logging { // A JAR file which exists only on the driver node case null => // SPARK-22585 path without schema is not url encoded - addJarFile(new File(uri.getRawPath)) + addLocalJarFile(new File(uri.getRawPath)) // A JAR file which exists only on the driver node - case "file" => addJarFile(new File(uri.getPath)) + case "file" => addLocalJarFile(new File(uri.getPath)) // A JAR file which exists locally on every worker node case "local" => "file:" + uri.getPath - case _ => path + case _ => checkRemoteJarFile(path) } } if (key != null) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index eee6e4b28ac4..62d60475985b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -81,4 +81,8 @@ private[spark] object PythonUtils { def isEncryptionEnabled(sc: JavaSparkContext): Boolean = { sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED) } + + def getBroadcastThreshold(sc: JavaSparkContext): Long = { + sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 9462dfd950ba..01e64b6972ae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -186,9 +186,6 @@ private[spark] object SerDeUtil extends Logging { val unpickle = new Unpickler iter.flatMap { row => val obj = unpickle.loads(row) - // `Opcodes.MEMOIZE` of Protocol 4 (Python 3.4+) will store objects in internal map - // of `Unpickler`. This map is cleared when calling `Unpickler.close()`. - unpickle.close() if (batched) { obj match { case array: Array[Any] => array.toSeq diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index a70754c6e2c4..f912ed64c80b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.internal.config +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils @@ -135,6 +136,7 @@ private[rest] class StandaloneSubmitRequestServlet( val sparkProperties = request.sparkProperties val driverMemory = sparkProperties.get(config.DRIVER_MEMORY.key) val driverCores = sparkProperties.get(config.DRIVER_CORES.key) + val driverDefaultJavaOptions = sparkProperties.get(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS) val driverExtraJavaOptions = sparkProperties.get(config.DRIVER_JAVA_OPTIONS.key) val driverExtraClassPath = sparkProperties.get(config.DRIVER_CLASS_PATH.key) val driverExtraLibraryPath = sparkProperties.get(config.DRIVER_LIBRARY_PATH.key) @@ -160,9 +162,11 @@ private[rest] class StandaloneSubmitRequestServlet( .set("spark.master", updatedMasters) val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val defaultJavaOpts = driverDefaultJavaOptions.map(Utils.splitCommandString) + .getOrElse(Seq.empty) val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) - val javaOpts = sparkJavaOpts ++ extraJavaOpts + val javaOpts = sparkJavaOpts ++ defaultJavaOpts ++ extraJavaOpts val command = new Command( "org.apache.spark.deploy.worker.DriverWrapper", Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index a33c2874d1a5..759d857d56e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -263,12 +263,14 @@ private[spark] class HadoopDelegationTokenManager( val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") ugi - } else { + } else if (!SparkHadoopUtil.get.isProxyUser(UserGroupInformation.getCurrentUser())) { logInfo(s"Attempting to load user's ticket cache.") val ccache = sparkConf.getenv("KRB5CCNAME") val user = Option(sparkConf.getenv("KRB5PRINCIPAL")).getOrElse( UserGroupInformation.getCurrentUser().getUserName()) UserGroupInformation.getUGIFromTicketCache(ccache, user) + } else { + UserGroupInformation.getCurrentUser() } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index f27aca03773a..68e1994f0f94 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -127,8 +127,9 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] that does not have a default value. */ def createOptional: OptionalConfigEntry[T] = { - val entry = new OptionalConfigEntry[T](parent.key, parent._alternatives, converter, - stringConverter, parent._doc, parent._public) + val entry = new OptionalConfigEntry[T](parent.key, parent._prependedKey, + parent._prependSeparator, parent._alternatives, converter, stringConverter, parent._doc, + parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -141,8 +142,9 @@ private[spark] class TypedConfigBuilder[T]( createWithDefaultString(default.asInstanceOf[String]) } else { val transformedDefault = converter(stringConverter(default)) - val entry = new ConfigEntryWithDefault[T](parent.key, parent._alternatives, - transformedDefault, converter, stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey, + parent._prependSeparator, parent._alternatives, transformedDefault, converter, + stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -150,8 +152,9 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] with a function to determine the default value */ def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { - val entry = new ConfigEntryWithDefaultFunction[T](parent.key, parent._alternatives, defaultFunc, - converter, stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, parent._prependedKey, + parent._prependSeparator, parent._alternatives, defaultFunc, converter, stringConverter, + parent._doc, parent._public) parent._onCreate.foreach(_ (entry)) entry } @@ -161,8 +164,9 @@ private[spark] class TypedConfigBuilder[T]( * [[String]] and must be a valid value for the entry. */ def createWithDefaultString(default: String): ConfigEntry[T] = { - val entry = new ConfigEntryWithDefaultString[T](parent.key, parent._alternatives, default, - converter, stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefaultString[T](parent.key, parent._prependedKey, + parent._prependSeparator, parent._alternatives, default, converter, stringConverter, + parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -178,6 +182,8 @@ private[spark] case class ConfigBuilder(key: String) { import ConfigHelpers._ + private[config] var _prependedKey: Option[String] = None + private[config] var _prependSeparator: String = "" private[config] var _public = true private[config] var _doc = "" private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None @@ -202,24 +208,34 @@ private[spark] case class ConfigBuilder(key: String) { this } + def withPrepended(key: String, separator: String = " "): ConfigBuilder = { + _prependedKey = Option(key) + _prependSeparator = separator + this + } + def withAlternative(key: String): ConfigBuilder = { _alternatives = _alternatives :+ key this } def intConf: TypedConfigBuilder[Int] = { + checkPrependConfig new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int")) } def longConf: TypedConfigBuilder[Long] = { + checkPrependConfig new TypedConfigBuilder(this, toNumber(_, _.toLong, key, "long")) } def doubleConf: TypedConfigBuilder[Double] = { + checkPrependConfig new TypedConfigBuilder(this, toNumber(_, _.toDouble, key, "double")) } def booleanConf: TypedConfigBuilder[Boolean] = { + checkPrependConfig new TypedConfigBuilder(this, toBoolean(_, key)) } @@ -228,20 +244,30 @@ private[spark] case class ConfigBuilder(key: String) { } def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = { + checkPrependConfig new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit)) } def bytesConf(unit: ByteUnit): TypedConfigBuilder[Long] = { + checkPrependConfig new TypedConfigBuilder(this, byteFromString(_, unit), byteToString(_, unit)) } def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { - val entry = new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) + val entry = new FallbackConfigEntry(key, _prependedKey, _prependSeparator, _alternatives, _doc, + _public, fallback) _onCreate.foreach(_(entry)) entry } def regexConf: TypedConfigBuilder[Regex] = { + checkPrependConfig new TypedConfigBuilder(this, regexFromString(_, this.key), _.toString) } + + private def checkPrependConfig = { + if (_prependedKey.isDefined) { + throw new IllegalArgumentException(s"$key type must be string if prepend used") + } + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index ede3ace4f9aa..c5df4c882009 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -28,6 +28,8 @@ package org.apache.spark.internal.config * value declared as a string. * * @param key the key for the configuration + * @param prependedKey the key for the configuration which will be prepended + * @param prependSeparator the separator which is used for prepending * @param valueConverter how to convert a string to the value. It should throw an exception if the * string does not have the required format. * @param stringConverter how to convert a value to a string that the user can use it as a valid @@ -41,6 +43,8 @@ package org.apache.spark.internal.config */ private[spark] abstract class ConfigEntry[T] ( val key: String, + val prependedKey: Option[String], + val prependSeparator: String, val alternatives: List[String], val valueConverter: String => T, val stringConverter: T => String, @@ -54,7 +58,15 @@ private[spark] abstract class ConfigEntry[T] ( def defaultValueString: String protected def readString(reader: ConfigReader): Option[String] = { - alternatives.foldLeft(reader.get(key))((res, nextKey) => res.orElse(reader.get(nextKey))) + val values = Seq( + prependedKey.flatMap(reader.get(_)), + alternatives.foldLeft(reader.get(key))((res, nextKey) => res.orElse(reader.get(nextKey))) + ).flatten + if (values.nonEmpty) { + Some(values.mkString(prependSeparator)) + } else { + None + } } def readFrom(reader: ConfigReader): T @@ -68,13 +80,24 @@ private[spark] abstract class ConfigEntry[T] ( private class ConfigEntryWithDefault[T] ( key: String, + prependedKey: Option[String], + prependSeparator: String, alternatives: List[String], _defaultValue: T, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry( + key, + prependedKey, + prependSeparator, + alternatives, + valueConverter, + stringConverter, + doc, + isPublic + ) { override def defaultValue: Option[T] = Some(_defaultValue) @@ -86,14 +109,25 @@ private class ConfigEntryWithDefault[T] ( } private class ConfigEntryWithDefaultFunction[T] ( - key: String, - alternatives: List[String], - _defaultFunction: () => T, - valueConverter: String => T, - stringConverter: T => String, - doc: String, - isPublic: Boolean) - extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { + key: String, + prependedKey: Option[String], + prependSeparator: String, + alternatives: List[String], + _defaultFunction: () => T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry( + key, + prependedKey, + prependSeparator, + alternatives, + valueConverter, + stringConverter, + doc, + isPublic + ) { override def defaultValue: Option[T] = Some(_defaultFunction()) @@ -106,13 +140,24 @@ private class ConfigEntryWithDefaultFunction[T] ( private class ConfigEntryWithDefaultString[T] ( key: String, + prependedKey: Option[String], + prependSeparator: String, alternatives: List[String], _defaultValue: String, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry( + key, + prependedKey, + prependSeparator, + alternatives, + valueConverter, + stringConverter, + doc, + isPublic + ) { override def defaultValue: Option[T] = Some(valueConverter(_defaultValue)) @@ -130,14 +175,23 @@ private class ConfigEntryWithDefaultString[T] ( */ private[spark] class OptionalConfigEntry[T]( key: String, + prependedKey: Option[String], + prependSeparator: String, alternatives: List[String], val rawValueConverter: String => T, val rawStringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry[Option[T]](key, alternatives, + extends ConfigEntry[Option[T]]( + key, + prependedKey, + prependSeparator, + alternatives, s => Some(rawValueConverter(s)), - v => v.map(rawStringConverter).orNull, doc, isPublic) { + v => v.map(rawStringConverter).orNull, + doc, + isPublic + ) { override def defaultValueString: String = ConfigEntry.UNDEFINED @@ -151,12 +205,22 @@ private[spark] class OptionalConfigEntry[T]( */ private[spark] class FallbackConfigEntry[T] ( key: String, + prependedKey: Option[String], + prependSeparator: String, alternatives: List[String], doc: String, isPublic: Boolean, val fallback: ConfigEntry[T]) - extends ConfigEntry[T](key, alternatives, - fallback.valueConverter, fallback.stringConverter, doc, isPublic) { + extends ConfigEntry[T]( + key, + prependedKey, + prependSeparator, + alternatives, + fallback.valueConverter, + fallback.stringConverter, + doc, + isPublic + ) { override def defaultValueString: String = s"" diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7c332fdb8572..f2b88fe00cdf 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -48,7 +48,10 @@ package object config { ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional private[spark] val DRIVER_JAVA_OPTIONS = - ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.createOptional + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS) + .withPrepended(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS) + .stringConf + .createOptional private[spark] val DRIVER_LIBRARY_PATH = ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.createOptional @@ -174,7 +177,10 @@ package object config { ConfigBuilder("spark.executor.heartbeat.maxFailures").internal().intConf.createWithDefault(60) private[spark] val EXECUTOR_JAVA_OPTIONS = - ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS) + .withPrepended(SparkLauncher.EXECUTOR_DEFAULT_JAVA_OPTIONS) + .stringConf + .createOptional private[spark] val EXECUTOR_LIBRARY_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.createOptional @@ -363,6 +369,17 @@ package object config { .checkValue(_ >= 0L, "Timeout must be >= 0.") .createWithDefault(60) + private[spark] val DYN_ALLOCATION_SHUFFLE_TRACKING = + ConfigBuilder("spark.dynamicAllocation.shuffleTracking.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val DYN_ALLOCATION_SHUFFLE_TIMEOUT = + ConfigBuilder("spark.dynamicAllocation.shuffleTimeout") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(_ >= 0L, "Timeout must be >= 0.") + .createWithDefault(Long.MaxValue) + private[spark] val DYN_ALLOCATION_SCHEDULER_BACKLOG_TIMEOUT = ConfigBuilder("spark.dynamicAllocation.schedulerBacklogTimeout") .timeConf(TimeUnit.SECONDS).createWithDefault(1) @@ -1240,6 +1257,14 @@ package object config { "mechanisms to guarantee data won't be corrupted during broadcast") .booleanConf.createWithDefault(true) + private[spark] val BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD = + ConfigBuilder("spark.broadcast.UDFCompressionThreshold") + .doc("The threshold at which user-defined functions (UDFs) and Python RDD commands " + + "are compressed by broadcast in bytes unless otherwise specified") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v >= 0, "The threshold should be non-negative.") + .createWithDefault(1L * 1024 * 1024) + private[spark] val RDD_COMPRESS = ConfigBuilder("spark.rdd.compress") .doc("Whether to compress serialized RDD partitions " + "(e.g. for StorageLevel.MEMORY_ONLY_SER in Scala " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 33a68f24bd53..e3216151462b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -37,7 +37,8 @@ class StageInfo( val parentIds: Seq[Int], val details: String, val taskMetrics: TaskMetrics = null, - private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { + private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty, + private[spark] val shuffleDepId: Option[Int] = None) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -90,6 +91,10 @@ private[spark] object StageInfo { ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos + val shuffleDepId = stage match { + case sms: ShuffleMapStage => Option(sms.shuffleDep).map(_.shuffleId) + case _ => None + } new StageInfo( stage.id, attemptId, @@ -99,6 +104,7 @@ private[spark] object StageInfo { stage.parents.map(_.id), stage.details, taskMetrics, - taskLocalityPreferences) + taskLocalityPreferences, + shuffleDepId) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala index 9aac4d2281ec..f5beb403555e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala @@ -36,14 +36,19 @@ import org.apache.spark.util.Clock private[spark] class ExecutorMonitor( conf: SparkConf, client: ExecutorAllocationClient, - clock: Clock) extends SparkListener with Logging { + listenerBus: LiveListenerBus, + clock: Clock) extends SparkListener with CleanerListener with Logging { private val idleTimeoutMs = TimeUnit.SECONDS.toMillis( conf.get(DYN_ALLOCATION_EXECUTOR_IDLE_TIMEOUT)) private val storageTimeoutMs = TimeUnit.SECONDS.toMillis( conf.get(DYN_ALLOCATION_CACHED_EXECUTOR_IDLE_TIMEOUT)) + private val shuffleTimeoutMs = conf.get(DYN_ALLOCATION_SHUFFLE_TIMEOUT) + private val fetchFromShuffleSvcEnabled = conf.get(SHUFFLE_SERVICE_ENABLED) && conf.get(SHUFFLE_SERVICE_FETCH_RDD_ENABLED) + private val shuffleTrackingEnabled = !conf.get(SHUFFLE_SERVICE_ENABLED) && + conf.get(DYN_ALLOCATION_SHUFFLE_TRACKING) private val executors = new ConcurrentHashMap[String, Tracker]() @@ -64,6 +69,26 @@ private[spark] class ExecutorMonitor( private val nextTimeout = new AtomicLong(Long.MaxValue) private var timedOutExecs = Seq.empty[String] + // Active job tracking. + // + // The following state is used when an external shuffle service is not in use, and allows Spark + // to scale down based on whether the shuffle data stored in executors is in use. + // + // The algorithm works as following: when jobs start, some state is kept that tracks which stages + // are part of that job, and which shuffle ID is attached to those stages. As tasks finish, the + // executor tracking code is updated to include the list of shuffles for which it's storing + // shuffle data. + // + // If executors hold shuffle data that is related to an active job, then the executor is + // considered to be in "shuffle busy" state; meaning that the executor is not allowed to be + // removed. If the executor has shuffle data but it doesn't relate to any active job, then it + // may be removed when idle, following the shuffle-specific timeout configuration. + // + // The following fields are not thread-safe and should be only used from the event thread. + private val shuffleToActiveJobs = new mutable.HashMap[Int, mutable.ArrayBuffer[Int]]() + private val stageToShuffleID = new mutable.HashMap[Int, Int]() + private val jobToStageIDs = new mutable.HashMap[Int, Seq[Int]]() + def reset(): Unit = { executors.clear() nextTimeout.set(Long.MaxValue) @@ -85,7 +110,7 @@ private[spark] class ExecutorMonitor( var newNextTimeout = Long.MaxValue timedOutExecs = executors.asScala - .filter { case (_, exec) => !exec.pendingRemoval } + .filter { case (_, exec) => !exec.pendingRemoval && !exec.hasActiveShuffle } .filter { case (_, exec) => val deadline = exec.timeoutAt if (deadline > now) { @@ -124,6 +149,109 @@ private[spark] class ExecutorMonitor( def pendingRemovalCount: Int = executors.asScala.count { case (_, exec) => exec.pendingRemoval } + override def onJobStart(event: SparkListenerJobStart): Unit = { + if (!shuffleTrackingEnabled) { + return + } + + val shuffleStages = event.stageInfos.flatMap { s => + s.shuffleDepId.toSeq.map { shuffleId => + s.stageId -> shuffleId + } + } + + var updateExecutors = false + shuffleStages.foreach { case (stageId, shuffle) => + val jobIDs = shuffleToActiveJobs.get(shuffle) match { + case Some(jobs) => + // If a shuffle is being re-used, we need to re-scan the executors and update their + // tracker with the information that the shuffle data they're storing is in use. + logDebug(s"Reusing shuffle $shuffle in job ${event.jobId}.") + updateExecutors = true + jobs + + case _ => + logDebug(s"Registered new shuffle $shuffle (from stage $stageId).") + val jobs = new mutable.ArrayBuffer[Int]() + shuffleToActiveJobs(shuffle) = jobs + jobs + } + jobIDs += event.jobId + } + + if (updateExecutors) { + val activeShuffleIds = shuffleStages.map(_._2).toSeq + var needTimeoutUpdate = false + val activatedExecs = new mutable.ArrayBuffer[String]() + executors.asScala.foreach { case (id, exec) => + if (!exec.hasActiveShuffle) { + exec.updateActiveShuffles(activeShuffleIds) + if (exec.hasActiveShuffle) { + needTimeoutUpdate = true + activatedExecs += id + } + } + } + + logDebug(s"Activated executors ${activatedExecs.mkString(",")} due to shuffle data " + + s"needed by new job ${event.jobId}.") + + if (needTimeoutUpdate) { + nextTimeout.set(Long.MinValue) + } + } + + stageToShuffleID ++= shuffleStages + jobToStageIDs(event.jobId) = shuffleStages.map(_._1).toSeq + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + if (!shuffleTrackingEnabled) { + return + } + + var updateExecutors = false + val activeShuffles = new mutable.ArrayBuffer[Int]() + shuffleToActiveJobs.foreach { case (shuffleId, jobs) => + jobs -= event.jobId + if (jobs.nonEmpty) { + activeShuffles += shuffleId + } else { + // If a shuffle went idle we need to update all executors to make sure they're correctly + // tracking active shuffles. + updateExecutors = true + } + } + + if (updateExecutors) { + if (log.isDebugEnabled()) { + if (activeShuffles.nonEmpty) { + logDebug( + s"Job ${event.jobId} ended, shuffles ${activeShuffles.mkString(",")} still active.") + } else { + logDebug(s"Job ${event.jobId} ended, no active shuffles remain.") + } + } + + val deactivatedExecs = new mutable.ArrayBuffer[String]() + executors.asScala.foreach { case (id, exec) => + if (exec.hasActiveShuffle) { + exec.updateActiveShuffles(activeShuffles) + if (!exec.hasActiveShuffle) { + deactivatedExecs += id + } + } + } + + logDebug(s"Executors ${deactivatedExecs.mkString(",")} do not have active shuffle data " + + s"after job ${event.jobId} finished.") + } + + jobToStageIDs.remove(event.jobId).foreach { stages => + stages.foreach { id => stageToShuffleID -= id } + } + } + override def onTaskStart(event: SparkListenerTaskStart): Unit = { val executorId = event.taskInfo.executorId // Guard against a late arriving task start event (SPARK-26927). @@ -137,6 +265,21 @@ private[spark] class ExecutorMonitor( val executorId = event.taskInfo.executorId val exec = executors.get(executorId) if (exec != null) { + // If the task succeeded and the stage generates shuffle data, record that this executor + // holds data for the shuffle. This code will track all executors that generate shuffle + // for the stage, even if speculative tasks generate duplicate shuffle data and end up + // being ignored by the map output tracker. + // + // This means that an executor may be marked as having shuffle data, and thus prevented + // from being removed, even though the data may not be used. + if (shuffleTrackingEnabled && event.reason == Success) { + stageToShuffleID.get(event.stageId).foreach { shuffleId => + exec.addShuffle(shuffleId) + } + } + + // Update the number of running tasks after checking for shuffle data, so that the shuffle + // information is up-to-date in case the executor is going idle. exec.updateRunningTasks(-1) } } @@ -171,7 +314,6 @@ private[spark] class ExecutorMonitor( // available. So don't count blocks that can be served by the external service. if (storageLevel.isValid && (!fetchFromShuffleSvcEnabled || !storageLevel.useDisk)) { val hadCachedBlocks = exec.cachedBlocks.nonEmpty - val blocks = exec.cachedBlocks.getOrElseUpdate(blockId.rddId, new mutable.BitSet(blockId.splitIndex)) blocks += blockId.splitIndex @@ -201,6 +343,25 @@ private[spark] class ExecutorMonitor( } } + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case ShuffleCleanedEvent(id) => cleanupShuffle(id) + case _ => + } + + override def rddCleaned(rddId: Int): Unit = { } + + override def shuffleCleaned(shuffleId: Int): Unit = { + // Because this is called in a completely separate thread, we post a custom event to the + // listener bus so that the internal state is safely updated. + listenerBus.post(ShuffleCleanedEvent(shuffleId)) + } + + override def broadcastCleaned(broadcastId: Long): Unit = { } + + override def accumCleaned(accId: Long): Unit = { } + + override def checkpointCleaned(rddId: Long): Unit = { } + // Visible for testing. private[dynalloc] def isExecutorIdle(id: String): Boolean = { Option(executors.get(id)).map(_.isIdle).getOrElse(throw new NoSuchElementException(id)) @@ -209,7 +370,7 @@ private[spark] class ExecutorMonitor( // Visible for testing private[dynalloc] def timedOutExecutors(when: Long): Seq[String] = { executors.asScala.flatMap { case (id, tracker) => - if (tracker.timeoutAt <= when) Some(id) else None + if (tracker.isIdle && tracker.timeoutAt <= when) Some(id) else None }.toSeq } @@ -236,6 +397,14 @@ private[spark] class ExecutorMonitor( } } + private def cleanupShuffle(id: Int): Unit = { + logDebug(s"Cleaning up state related to shuffle $id.") + shuffleToActiveJobs -= id + executors.asScala.foreach { case (_, exec) => + exec.removeShuffle(id) + } + } + private class Tracker { @volatile var timeoutAt: Long = Long.MaxValue @@ -244,6 +413,7 @@ private[spark] class ExecutorMonitor( @volatile var timedOut: Boolean = false var pendingRemoval: Boolean = false + var hasActiveShuffle: Boolean = false private var idleStart: Long = -1 private var runningTasks: Int = 0 @@ -252,8 +422,11 @@ private[spark] class ExecutorMonitor( // This should only be used in the event thread. val cachedBlocks = new mutable.HashMap[Int, mutable.BitSet]() - // For testing. - def isIdle: Boolean = idleStart >= 0 + // The set of shuffles for which shuffle data is held by the executor. + // This should only be used in the event thread. + private val shuffleIds = if (shuffleTrackingEnabled) new mutable.HashSet[Int]() else null + + def isIdle: Boolean = idleStart >= 0 && !hasActiveShuffle def updateRunningTasks(delta: Int): Unit = { runningTasks = math.max(0, runningTasks + delta) @@ -264,7 +437,18 @@ private[spark] class ExecutorMonitor( def updateTimeout(): Unit = { val oldDeadline = timeoutAt val newDeadline = if (idleStart >= 0) { - idleStart + (if (cachedBlocks.nonEmpty) storageTimeoutMs else idleTimeoutMs) + val timeout = if (cachedBlocks.nonEmpty || (shuffleIds != null && shuffleIds.nonEmpty)) { + val _cacheTimeout = if (cachedBlocks.nonEmpty) storageTimeoutMs else Long.MaxValue + val _shuffleTimeout = if (shuffleIds != null && shuffleIds.nonEmpty) { + shuffleTimeoutMs + } else { + Long.MaxValue + } + math.min(_cacheTimeout, _shuffleTimeout) + } else { + idleTimeoutMs + } + idleStart + timeout } else { Long.MaxValue } @@ -279,5 +463,32 @@ private[spark] class ExecutorMonitor( updateNextTimeout(newDeadline) } } + + def addShuffle(id: Int): Unit = { + if (shuffleIds.add(id)) { + hasActiveShuffle = true + } + } + + def removeShuffle(id: Int): Unit = { + if (shuffleIds.remove(id) && shuffleIds.isEmpty) { + hasActiveShuffle = false + if (isIdle) { + updateTimeout() + } + } + } + + def updateActiveShuffles(ids: Iterable[Int]): Unit = { + val hadActiveShuffle = hasActiveShuffle + hasActiveShuffle = ids.exists(shuffleIds.contains) + if (hadActiveShuffle && isIdle) { + updateTimeout() + } + } + } + + private case class ShuffleCleanedEvent(id: Int) extends SparkListenerEvent { + override protected[spark] def logEvent: Boolean = false } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 3ba33e358ef0..191b516661e4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -1008,7 +1008,7 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { private def createManager( conf: SparkConf, clock: Clock = new SystemClock()): ExecutorAllocationManager = { - val manager = new ExecutorAllocationManager(client, listenerBus, conf, clock) + val manager = new ExecutorAllocationManager(client, listenerBus, conf, clock = clock) managers += manager manager.start() manager diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 6be1fedc123d..202b85dcf569 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -389,6 +389,19 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst """.stripMargin.trim) } + test("SPARK-28355: Use Spark conf for threshold at which UDFs are compressed by broadcast") { + val conf = new SparkConf() + + // Check the default value + assert(conf.get(BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) === 1L * 1024 * 1024) + + // Set the conf + conf.set(BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD, 1L * 1024) + + // Verify that it has been set properly + assert(conf.get(BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) === 1L * 1024) + } + val defaultIllegalValue = "SomeIllegalValue" val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 628ac60fa767..fed3ae35ee0e 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -170,6 +170,17 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("add FS jar files not exists") { + try { + val jarPath = "hdfs:///no/path/to/TestUDTF.jar" + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(jarPath) + assert(sc.listJars().forall(!_.contains("TestUDTF.jar"))) + } finally { + sc.stop() + } + } + test("SPARK-17650: malformed url's throw exceptions before bricking Executors") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 02514dc7daef..c3bfa7ddee5b 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -70,8 +70,8 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: fallback") { val conf = new SparkConf() - val parentConf = ConfigBuilder(testKey("parent")).intConf.createWithDefault(1) - val confWithFallback = ConfigBuilder(testKey("fallback")).fallbackConf(parentConf) + val parentConf = ConfigBuilder(testKey("parent1")).intConf.createWithDefault(1) + val confWithFallback = ConfigBuilder(testKey("fallback1")).fallbackConf(parentConf) assert(conf.get(confWithFallback) === 1) conf.set(confWithFallback, 2) assert(conf.get(parentConf) === 1) @@ -289,6 +289,92 @@ class ConfigEntrySuite extends SparkFunSuite { assert(conf.get(iConf) === 3) } + test("conf entry: prepend with default separator") { + val conf = new SparkConf() + val prependedKey = testKey("prepended1") + val prependedConf = ConfigBuilder(prependedKey).stringConf.createOptional + val derivedConf = ConfigBuilder(testKey("prepend1")) + .withPrepended(prependedKey) + .stringConf + .createOptional + + conf.set(derivedConf, "1") + assert(conf.get(derivedConf) === Some("1")) + + conf.set(prependedConf, "2") + assert(conf.get(derivedConf) === Some("2 1")) + } + + test("conf entry: prepend with custom separator") { + val conf = new SparkConf() + val prependedKey = testKey("prepended2") + val prependedConf = ConfigBuilder(prependedKey).stringConf.createOptional + val derivedConf = ConfigBuilder(testKey("prepend2")) + .withPrepended(prependedKey, ",") + .stringConf + .createOptional + + conf.set(derivedConf, "1") + assert(conf.get(derivedConf) === Some("1")) + + conf.set(prependedConf, "2") + assert(conf.get(derivedConf) === Some("2,1")) + } + + test("conf entry: prepend with fallback") { + val conf = new SparkConf() + val prependedKey = testKey("prepended3") + val prependedConf = ConfigBuilder(prependedKey).stringConf.createOptional + val derivedConf = ConfigBuilder(testKey("prepend3")) + .withPrepended(prependedKey) + .stringConf + .createOptional + val confWithFallback = ConfigBuilder(testKey("fallback2")).fallbackConf(derivedConf) + + assert(conf.get(confWithFallback) === None) + + conf.set(derivedConf, "1") + assert(conf.get(confWithFallback) === Some("1")) + + conf.set(prependedConf, "2") + assert(conf.get(confWithFallback) === Some("2 1")) + + conf.set(confWithFallback, Some("3")) + assert(conf.get(confWithFallback) === Some("3")) + } + + test("conf entry: prepend should work only with string type") { + var i = 0 + def testPrependFail(createConf: (String, String) => Unit): Unit = { + intercept[IllegalArgumentException] { + createConf(testKey(s"prependedFail$i"), testKey(s"prependFail$i")) + }.getMessage.contains("type must be string if prepend used") + i += 1 + } + + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).intConf + ) + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).longConf + ) + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).doubleConf + ) + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).booleanConf + ) + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).timeConf(TimeUnit.MILLISECONDS) + ) + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).bytesConf(ByteUnit.BYTE) + ) + testPrependFail( (prependedKey, prependKey) => + ConfigBuilder(testKey(prependKey)).withPrepended(prependedKey).regexConf + ) + } + test("onCreate") { var onCreateCalled = false ConfigBuilder(testKey("oc1")).onCreate(_ => onCreateCalled = true).intConf.createWithDefault(1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala index 8d1577e835d2..e11ee97469b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{doAnswer, mock, when} import org.apache.spark._ import org.apache.spark.internal.config._ @@ -34,10 +34,13 @@ class ExecutorMonitorSuite extends SparkFunSuite { private val idleTimeoutMs = TimeUnit.SECONDS.toMillis(60L) private val storageTimeoutMs = TimeUnit.SECONDS.toMillis(120L) + private val shuffleTimeoutMs = TimeUnit.SECONDS.toMillis(240L) private val conf = new SparkConf() .set(DYN_ALLOCATION_EXECUTOR_IDLE_TIMEOUT.key, "60s") .set(DYN_ALLOCATION_CACHED_EXECUTOR_IDLE_TIMEOUT.key, "120s") + .set(DYN_ALLOCATION_SHUFFLE_TIMEOUT.key, "240s") + .set(SHUFFLE_SERVICE_ENABLED, true) private var monitor: ExecutorMonitor = _ private var client: ExecutorAllocationClient = _ @@ -55,7 +58,7 @@ class ExecutorMonitorSuite extends SparkFunSuite { when(client.isExecutorActive(any())).thenAnswer { invocation => knownExecs.contains(invocation.getArguments()(0).asInstanceOf[String]) } - monitor = new ExecutorMonitor(conf, client, clock) + monitor = new ExecutorMonitor(conf, client, null, clock) } test("basic executor timeout") { @@ -205,7 +208,7 @@ class ExecutorMonitorSuite extends SparkFunSuite { assert(monitor.timedOutExecutors(storageDeadline) === Seq("1")) conf.set(SHUFFLE_SERVICE_ENABLED, true).set(SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) - monitor = new ExecutorMonitor(conf, client, clock) + monitor = new ExecutorMonitor(conf, client, null, clock) monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", null)) monitor.onBlockUpdated(rddUpdate(1, 0, "1", level = StorageLevel.MEMORY_ONLY)) @@ -259,8 +262,119 @@ class ExecutorMonitorSuite extends SparkFunSuite { assert(monitor.timedOutExecutors().toSet === Set("2")) } + test("shuffle block tracking") { + val bus = mockListenerBus() + conf.set(DYN_ALLOCATION_SHUFFLE_TRACKING, true).set(SHUFFLE_SERVICE_ENABLED, false) + monitor = new ExecutorMonitor(conf, client, bus, clock) + + // 3 jobs: 2 and 3 share a shuffle, 1 has a separate shuffle. + val stage1 = stageInfo(1, shuffleId = 0) + val stage2 = stageInfo(2) + + val stage3 = stageInfo(3, shuffleId = 1) + val stage4 = stageInfo(4) + + val stage5 = stageInfo(5, shuffleId = 1) + val stage6 = stageInfo(6) + + // Start jobs 1 and 2. Finish a task on each, but don't finish the jobs. This should prevent the + // executor from going idle since there are active shuffles. + monitor.onJobStart(SparkListenerJobStart(1, clock.getTimeMillis(), Seq(stage1, stage2))) + monitor.onJobStart(SparkListenerJobStart(2, clock.getTimeMillis(), Seq(stage3, stage4))) + + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", null)) + assert(monitor.timedOutExecutors(idleDeadline) === Seq("1")) + + // First a failed task, to make sure it does not count. + monitor.onTaskStart(SparkListenerTaskStart(1, 0, taskInfo("1", 1))) + monitor.onTaskEnd(SparkListenerTaskEnd(1, 0, "foo", TaskResultLost, taskInfo("1", 1), null)) + assert(monitor.timedOutExecutors(idleDeadline) === Seq("1")) + + monitor.onTaskStart(SparkListenerTaskStart(1, 0, taskInfo("1", 1))) + monitor.onTaskEnd(SparkListenerTaskEnd(1, 0, "foo", Success, taskInfo("1", 1), null)) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + + monitor.onTaskStart(SparkListenerTaskStart(3, 0, taskInfo("1", 1))) + monitor.onTaskEnd(SparkListenerTaskEnd(3, 0, "foo", Success, taskInfo("1", 1), null)) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + + // Finish the jobs, now the executor should be idle, but with the shuffle timeout, since the + // shuffles are not active. + monitor.onJobEnd(SparkListenerJobEnd(1, clock.getTimeMillis(), JobSucceeded)) + assert(!monitor.isExecutorIdle("1")) + + monitor.onJobEnd(SparkListenerJobEnd(2, clock.getTimeMillis(), JobSucceeded)) + assert(monitor.isExecutorIdle("1")) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + assert(monitor.timedOutExecutors(storageDeadline).isEmpty) + assert(monitor.timedOutExecutors(shuffleDeadline) === Seq("1")) + + // Start job 3. Since it shares a shuffle with job 2, the executor should not be considered + // idle anymore, even if no tasks are run. + monitor.onJobStart(SparkListenerJobStart(3, clock.getTimeMillis(), Seq(stage5, stage6))) + assert(!monitor.isExecutorIdle("1")) + assert(monitor.timedOutExecutors(shuffleDeadline).isEmpty) + + monitor.onJobEnd(SparkListenerJobEnd(3, clock.getTimeMillis(), JobSucceeded)) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + assert(monitor.timedOutExecutors(shuffleDeadline) === Seq("1")) + + // Clean up the shuffles, executor now should now time out at the idle deadline. + monitor.shuffleCleaned(0) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + monitor.shuffleCleaned(1) + assert(monitor.timedOutExecutors(idleDeadline) === Seq("1")) + } + + test("shuffle tracking with multiple executors and concurrent jobs") { + val bus = mockListenerBus() + conf.set(DYN_ALLOCATION_SHUFFLE_TRACKING, true).set(SHUFFLE_SERVICE_ENABLED, false) + monitor = new ExecutorMonitor(conf, client, bus, clock) + + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", null)) + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "2", null)) + + // Two separate jobs with separate shuffles. The first job will only run tasks on + // executor 1, the second on executor 2. Ensures that jobs finishing don't affect + // executors that are active in other jobs. + + val stage1 = stageInfo(1, shuffleId = 0) + val stage2 = stageInfo(2) + monitor.onJobStart(SparkListenerJobStart(1, clock.getTimeMillis(), Seq(stage1, stage2))) + + val stage3 = stageInfo(3, shuffleId = 1) + val stage4 = stageInfo(4) + monitor.onJobStart(SparkListenerJobStart(2, clock.getTimeMillis(), Seq(stage3, stage4))) + + monitor.onTaskStart(SparkListenerTaskStart(1, 0, taskInfo("1", 1))) + monitor.onTaskEnd(SparkListenerTaskEnd(1, 0, "foo", Success, taskInfo("1", 1), null)) + assert(monitor.timedOutExecutors(idleDeadline) === Seq("2")) + + monitor.onTaskStart(SparkListenerTaskStart(3, 0, taskInfo("2", 1))) + monitor.onTaskEnd(SparkListenerTaskEnd(3, 0, "foo", Success, taskInfo("2", 1), null)) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + + monitor.onJobEnd(SparkListenerJobEnd(1, clock.getTimeMillis(), JobSucceeded)) + assert(monitor.isExecutorIdle("1")) + assert(!monitor.isExecutorIdle("2")) + + monitor.onJobEnd(SparkListenerJobEnd(2, clock.getTimeMillis(), JobSucceeded)) + assert(monitor.isExecutorIdle("2")) + assert(monitor.timedOutExecutors(idleDeadline).isEmpty) + + monitor.shuffleCleaned(0) + monitor.shuffleCleaned(1) + assert(monitor.timedOutExecutors(idleDeadline).toSet === Set("1", "2")) + } + private def idleDeadline: Long = clock.getTimeMillis() + idleTimeoutMs + 1 private def storageDeadline: Long = clock.getTimeMillis() + storageTimeoutMs + 1 + private def shuffleDeadline: Long = clock.getTimeMillis() + shuffleTimeoutMs + 1 + + private def stageInfo(id: Int, shuffleId: Int = -1): StageInfo = { + new StageInfo(id, 0, s"stage$id", 1, Nil, Nil, "", + shuffleDepId = if (shuffleId >= 0) Some(shuffleId) else None) + } private def taskInfo( execId: String, @@ -286,4 +400,16 @@ class ExecutorMonitorSuite extends SparkFunSuite { RDDBlockId(rddId, splitIndex), level, 1L, 0L)) } + /** + * Mock the listener bus *only* for the functionality needed by the shuffle tracking code. + * Any other event sent through the mock bus will fail. + */ + private def mockListenerBus(): LiveListenerBus = { + val bus = mock(classOf[LiveListenerBus]) + doAnswer { invocation => + monitor.onOtherEvent(invocation.getArguments()(0).asInstanceOf[SparkListenerEvent]) + }.when(bus).post(any()) + bus + } + } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 242163931f7a..f5f93ece660b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -128,7 +128,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { val files = testRolling(appender, testOutputStream, textToAppend, 0, isCompressed = true) files.foreach { file => logInfo(file.toString + ": " + file.length + " bytes") - assert(file.length < rolloverSize) + assert(file.length <= rolloverSize) } } diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 2f660ccfd92f..79158bb6edfe 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -170,7 +170,7 @@ parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.1.jar protobuf-java-2.5.0.jar py4j-0.10.8.1.jar -pyrolite-4.23.jar +pyrolite-4.30.jar scala-compiler-2.12.8.jar scala-library-2.12.8.jar scala-parser-combinators_2.12-1.1.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2 index e1e114fa08ca..5e03a5951db0 100644 --- a/dev/deps/spark-deps-hadoop-3.2 +++ b/dev/deps/spark-deps-hadoop-3.2 @@ -189,7 +189,7 @@ parquet-hadoop-1.10.1.jar parquet-jackson-1.10.1.jar protobuf-java-2.5.0.jar py4j-0.10.8.1.jar -pyrolite-4.23.jar +pyrolite-4.30.jar re2j-1.1.jar scala-compiler-2.12.8.jar scala-library-2.12.8.jar diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 593e34983aa0..e51e9560b59d 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -353,7 +353,7 @@ def choose_jira_assignee(issue, asf_jira): except: # assume it's a user id, and try to assign (might fail, we just prompt again) assignee = asf_jira.user(raw_assignee) - asf_jira.assign_issue(issue.key, assignee.key) + asf_jira.assign_issue(issue.key, assignee.name) return assignee except KeyboardInterrupt: raise diff --git a/docs/configuration.md b/docs/configuration.md index 211dfbb3f459..108862416f8d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -410,11 +410,31 @@ Apart from these, the following properties are also available, and may be useful your default properties file. + + spark.driver.defaultJavaOptions + (none) + + A string of default JVM options to prepend to spark.driver.extraJavaOptions. + This is intended to be set by administrators. + + For instance, GC settings or other logging. + Note that it is illegal to set maximum heap size (-Xmx) settings with this option. Maximum heap + size settings can be set with spark.driver.memory in the cluster mode and through + the --driver-memory command line option in the client mode. + +
Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-java-options command line option or in + your default properties file. + + spark.driver.extraJavaOptions (none) - A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + A string of extra JVM options to pass to the driver. This is intended to be set by users. + + For instance, GC settings or other logging. Note that it is illegal to set maximum heap size (-Xmx) settings with this option. Maximum heap size settings can be set with spark.driver.memory in the cluster mode and through the --driver-memory command line option in the client mode. @@ -423,6 +443,8 @@ Apart from these, the following properties are also available, and may be useful directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-java-options command line option or in your default properties file. + + spark.driver.defaultJavaOptions will be prepended to this configuration. @@ -457,11 +479,31 @@ Apart from these, the following properties are also available, and may be useful this option. + + spark.executor.defaultJavaOptions + (none) + + A string of default JVM options to prepend to spark.executor.extraJavaOptions. + This is intended to be set by administrators. + + For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this + option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file + used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc + + spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + A string of extra JVM options to pass to executors. This is intended to be set by users. + + For instance, GC settings or other logging. Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. @@ -470,6 +512,8 @@ Apart from these, the following properties are also available, and may be useful application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc + + spark.executor.defaultJavaOptions will be prepended to this configuration. @@ -2070,6 +2114,26 @@ Apart from these, the following properties are also available, and may be useful description. + + spark.dynamicAllocation.shuffleTracking.enabled + false + + Experimental. Enables shuffle file tracking for executors, which allows dynamic allocation + without the need for an external shuffle service. This option will try to keep alive executors + that are storing shuffle data for active jobs. + + + + spark.dynamicAllocation.shuffleTimeout + infinity + + When shuffle tracking is enabled, controls the timeout for executors that are holding shuffle + data. The default value means that Spark will rely on the shuffles being garbage collected to be + able to release executors. If for some reason garbage collection is not cleaning up shuffles + quickly enough, this option can be used to control when to time out executors even when they are + storing shuffle data. + + ### Thread Configurations diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index dc93e9cea5bc..9d9b253a5c8e 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -142,20 +142,20 @@ To use a custom metrics.properties for the application master and executors, upd - spark.yarn.am.resource.{resource-type} + spark.yarn.am.resource.{resource-type}.amount (none) Amount of resource to use for the YARN Application Master in client mode. - In cluster mode, use spark.yarn.driver.resource.<resource-type> instead. + In cluster mode, use spark.yarn.driver.resource.<resource-type>.amount instead. Please note that this feature can be used only with YARN 3.0+ For reference, see YARN Resource Model documentation: https://hadoop.apache.org/docs/r3.0.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html

Example: - To request GPU resources from YARN, use: spark.yarn.am.resource.yarn.io/gpu + To request GPU resources from YARN, use: spark.yarn.am.resource.yarn.io/gpu.amount - spark.yarn.driver.resource.{resource-type} + spark.yarn.driver.resource.{resource-type}.amount (none) Amount of resource to use for the YARN Application Master in cluster mode. @@ -163,11 +163,11 @@ To use a custom metrics.properties for the application master and executors, upd For reference, see YARN Resource Model documentation: https://hadoop.apache.org/docs/r3.0.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html

Example: - To request GPU resources from YARN, use: spark.yarn.driver.resource.yarn.io/gpu + To request GPU resources from YARN, use: spark.yarn.driver.resource.yarn.io/gpu.amount - spark.yarn.executor.resource.{resource-type} + spark.yarn.executor.resource.{resource-type}.amount (none) Amount of resource to use per executor process. @@ -175,7 +175,7 @@ To use a custom metrics.properties for the application master and executors, upd For reference, see YARN Resource Model documentation: https://hadoop.apache.org/docs/r3.0.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html

Example: - To request GPU resources from YARN, use: spark.yarn.executor.resource.yarn.io/gpu + To request GPU resources from YARN, use: spark.yarn.executor.resource.yarn.io/gpu.amount diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index d39bd933427f..f13d298674b2 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -149,6 +149,10 @@ license: | - Since Spark 3.0, if files or subdirectories disappear during recursive directory listing (i.e. they appear in an intermediate listing but then cannot be read or listed during later phases of the recursive directory listing, due to either concurrent file deletions or object store consistency issues) then the listing will fail with an exception unless `spark.sql.files.ignoreMissingFiles` is `true` (default `false`). In previous versions, these missing files or subdirectories would be ignored. Note that this change of behavior only applies during initial table file listing (or during `REFRESH TABLE`), not during query execution: the net change is that `spark.sql.files.ignoreMissingFiles` is now obeyed during table file listing / query planning, not only at query execution time. + - Since Spark 3.0, substitution order of nested WITH clauses is changed and an inner CTE definition takes precedence over an outer. In version 2.4 and earlier, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `1` while in version 3.0 it returns `2`. The previous behaviour can be restored by setting `spark.sql.legacy.ctePrecedence.enabled` to `true`. + + - Since Spark 3.0, the `add_months` function adjusts the resulting date to a last day of month only if it is invalid. For example, `select add_months(DATE'2019-01-31', 1)` results `2019-02-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when it is invalid, or the original date is a last day of months. For example, adding a month to `2019-02-28` resultes in `2019-03-31`. + ## Upgrading from Spark SQL 2.4 to 2.4.1 - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index d5224da2cf3f..fe3c60040d0a 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -818,6 +818,11 @@ Delegation tokens can be obtained from multiple clusters and ${cluster} +#### Kafka Specific Configurations + +Kafka's own configurations can be set with `kafka.` prefix, e.g, `--conf spark.kafka.clusters.${cluster}.kafka.retries=1`. +For possible Kafka parameters, see [Kafka adminclient config docs](http://kafka.apache.org/documentation.html#adminclientconfigs). + #### Caveats - Obtaining delegation token for proxy user is not yet supported ([KAFKA-6945](https://issues.apache.org/jira/browse/KAFKA-6945)). diff --git a/docs/tuning.md b/docs/tuning.md index 222f8720ce35..1faf7cfe0d68 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -248,7 +248,7 @@ Our experience suggests that the effect of GC tuning depends on your application There are [many more tuning options](https://docs.oracle.com/javase/8/docs/technotes/guides/vm/gctuning/index.html) described online, but at a high level, managing how frequently full GC takes place can help in reducing the overhead. -GC tuning flags for executors can be specified by setting `spark.executor.extraJavaOptions` in +GC tuning flags for executors can be specified by setting `spark.executor.defaultJavaOptions` or `spark.executor.extraJavaOptions` in a job's configuration. # Other Considerations diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 40bf3b1530fb..924bf374c737 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -50,7 +50,7 @@ abstract class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUti override protected def beforeAll(): Unit = { super.beforeAll() - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + spark.conf.set(SQLConf.FILES_MAX_PARTITION_BYTES.key, 1024) } def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 82ce16c2b7e5..efd7ca74c796 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -120,24 +120,24 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types.length == 12) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Long")) assert(types(5).equals("class java.lang.Double")) - assert(types(6).equals("class java.lang.Double")) - assert(types(7).equals("class java.lang.Double")) + assert(types(6).equals("class java.lang.Float")) + assert(types(7).equals("class java.lang.Float")) assert(types(8).equals("class java.math.BigDecimal")) assert(types(9).equals("class java.math.BigDecimal")) assert(types(10).equals("class java.math.BigDecimal")) assert(types(11).equals("class java.math.BigDecimal")) assert(row.getBoolean(0) == false) assert(row.getInt(1) == 255) - assert(row.getInt(2) == 32767) + assert(row.getShort(2) == 32767) assert(row.getInt(3) == 2147483647) assert(row.getLong(4) == 9223372036854775807L) assert(row.getDouble(5) == 1.2345678901234512E14) // float = float(53) has 15-digits precision - assert(row.getDouble(6) == 1.23456788103168E14) // float(24) has 7-digits precision - assert(row.getDouble(7) == 1.23456788103168E14) // real = float(24) + assert(row.getFloat(6) == 1.23456788103168E14) // float(24) has 7-digits precision + assert(row.getFloat(7) == 1.23456788103168E14) // real = float(24) assert(row.getAs[BigDecimal](8).equals(new BigDecimal("123.00"))) assert(row.getAs[BigDecimal](9).equals(new BigDecimal("12345.12000"))) assert(row.getAs[BigDecimal](10).equals(new BigDecimal("922337203685477.5800"))) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 64b9837cc5fa..8cdc4a1806b2 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -376,8 +376,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val e = intercept[org.apache.spark.SparkException] { spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", new Properties()).collect() } - assert(e.getMessage.contains( - "requirement failed: Decimal precision 39 exceeds max precision 38")) + assert(e.getCause().isInstanceOf[ArithmeticException]) + assert(e.getMessage.contains("Decimal precision 39 exceeds max precision 38")) // custom schema can read data val props = new Properties() diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 462f88ff14a8..89da9a1de6f7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -206,4 +206,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { """.stripMargin.replaceAll("\n", " ")) assert(sql("select c1, c3 from queryOption").collect.toSet == expectedResult) } + + test("write byte as smallint") { + sqlContext.createDataFrame(Seq((1.toByte, 2.toShort))) + .write.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties) + val df = sqlContext.read.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties) + val schema = df.schema + assert(schema.head.dataType == ShortType) + assert(schema(1).dataType == ShortType) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getShort(0) === 1) + assert(rows(0).getShort(1) === 2) + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 9b3e78c84c34..76c25980fc33 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -21,7 +21,7 @@ import org.apache.kafka.clients.producer.ProducerRecord import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.datasources.v2.ContinuousScanExec -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala index e089e36eba5f..ba8340ea59c1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala @@ -135,7 +135,7 @@ class KafkaDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTe test("failOnDataLoss=false should not return duplicated records: microbatch v1") { withSQLConf( - "spark.sql.streaming.disabledV2MicroBatchReaders" -> + SQLConf.DISABLED_V2_STREAMING_MICROBATCH_READERS.key -> classOf[KafkaSourceProvider].getCanonicalName) { verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => val query = df.writeStream.format("memory").queryName(table).start() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 3d14ebe267c4..bb9b3696fe8f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1066,7 +1066,7 @@ class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() spark.conf.set( - "spark.sql.streaming.disabledV2MicroBatchReaders", + SQLConf.DISABLED_V2_STREAMING_MICROBATCH_READERS.key, classOf[KafkaSourceProvider].getCanonicalName) } diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala index 84d58d8c419a..e1f3c800a51f 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala @@ -23,6 +23,7 @@ import org.apache.kafka.common.security.auth.SecurityProtocol.SASL_SSL import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT private[spark] case class KafkaTokenClusterConf( identifier: String, @@ -35,7 +36,8 @@ private[spark] case class KafkaTokenClusterConf( keyStoreLocation: Option[String], keyStorePassword: Option[String], keyPassword: Option[String], - tokenMechanism: String) { + tokenMechanism: String, + specifiedKafkaParams: Map[String, String]) { override def toString: String = s"KafkaTokenClusterConf{" + s"identifier=$identifier, " + s"authBootstrapServers=$authBootstrapServers, " + @@ -43,11 +45,12 @@ private[spark] case class KafkaTokenClusterConf( s"securityProtocol=$securityProtocol, " + s"kerberosServiceName=$kerberosServiceName, " + s"trustStoreLocation=$trustStoreLocation, " + - s"trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + + s"trustStorePassword=${trustStorePassword.map(_ => REDACTION_REPLACEMENT_TEXT)}, " + s"keyStoreLocation=$keyStoreLocation, " + - s"keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + - s"keyPassword=${keyPassword.map(_ => "xxx")}, " + - s"tokenMechanism=$tokenMechanism}" + s"keyStorePassword=${keyStorePassword.map(_ => REDACTION_REPLACEMENT_TEXT)}, " + + s"keyPassword=${keyPassword.map(_ => REDACTION_REPLACEMENT_TEXT)}, " + + s"tokenMechanism=$tokenMechanism, " + + s"specifiedKafkaParams=${KafkaRedactionUtil.redactParams(specifiedKafkaParams.toSeq)}}" } private [kafka010] object KafkaTokenSparkConf extends Logging { @@ -59,6 +62,8 @@ private [kafka010] object KafkaTokenSparkConf extends Logging { def getClusterConfig(sparkConf: SparkConf, identifier: String): KafkaTokenClusterConf = { val configPrefix = s"$CLUSTERS_CONFIG_PREFIX$identifier." val sparkClusterConf = sparkConf.getAllWithPrefix(configPrefix).toMap + val configKafkaPrefix = s"${configPrefix}kafka." + val sparkClusterKafkaConf = sparkConf.getAllWithPrefix(configKafkaPrefix).toMap val result = KafkaTokenClusterConf( identifier, sparkClusterConf @@ -76,7 +81,8 @@ private [kafka010] object KafkaTokenSparkConf extends Logging { sparkClusterConf.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG), sparkClusterConf.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG), sparkClusterConf.getOrElse("sasl.token.mechanism", - KafkaTokenSparkConf.DEFAULT_SASL_TOKEN_MECHANISM) + KafkaTokenSparkConf.DEFAULT_SASL_TOKEN_MECHANISM), + sparkClusterKafkaConf ) logDebug(s"getClusterConfig($identifier): $result") result diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala index da21d2e2413d..950df867e1e8 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala @@ -134,6 +134,16 @@ private[spark] object KafkaTokenUtil extends Logging { } } + logDebug("AdminClient params before specified params: " + + s"${KafkaRedactionUtil.redactParams(adminClientProperties.asScala.toSeq)}") + + clusterConf.specifiedKafkaParams.foreach { param => + adminClientProperties.setProperty(param._1, param._2) + } + + logDebug("AdminClient params after specified params: " + + s"${KafkaRedactionUtil.redactParams(adminClientProperties.asScala.toSeq)}") + adminClientProperties } diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala index 74f1cdcf7346..eebbf96afa47 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala @@ -107,7 +107,8 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach { protected def createClusterConf( identifier: String, - securityProtocol: String): KafkaTokenClusterConf = { + securityProtocol: String, + specifiedKafkaParams: Map[String, String] = Map.empty): KafkaTokenClusterConf = { KafkaTokenClusterConf( identifier, bootStrapServers, @@ -119,6 +120,7 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach { Some(keyStoreLocation), Some(keyStorePassword), Some(keyPassword), - KafkaTokenSparkConf.DEFAULT_SASL_TOKEN_MECHANISM) + KafkaTokenSparkConf.DEFAULT_SASL_TOKEN_MECHANISM, + specifiedKafkaParams) } } diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenSparkConfSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenSparkConfSuite.scala index 60bb8a2bc6c3..61184a6fac33 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenSparkConfSuite.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenSparkConfSuite.scala @@ -96,6 +96,16 @@ class KafkaTokenSparkConfSuite extends SparkFunSuite with BeforeAndAfterEach { assert(clusterConfig.tokenMechanism === tokenMechanism) } + test("getClusterConfig should return specified kafka params") { + sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", authBootStrapServers) + sparkConf.set(s"spark.kafka.clusters.$identifier1.kafka.customKey", "customValue") + + val clusterConfig = KafkaTokenSparkConf.getClusterConfig(sparkConf, identifier1) + assert(clusterConfig.identifier === identifier1) + assert(clusterConfig.authBootstrapServers === authBootStrapServers) + assert(clusterConfig.specifiedKafkaParams === Map("customKey" -> "customValue")) + } + test("getAllClusterConfigs should return empty list when nothing configured") { assert(KafkaTokenSparkConf.getAllClusterConfigs(sparkConf).isEmpty) } diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala index bcca920eed4e..5496195b4149 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala @@ -155,6 +155,15 @@ class KafkaTokenUtilSuite extends SparkFunSuite with KafkaDelegationTokenTest { assert(saslJaasConfig.contains("useTicketCache=true")) } + test("createAdminClientProperties with specified params should include it") { + val clusterConf = createClusterConf(identifier1, SASL_SSL.name, + Map("customKey" -> "customValue")) + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf, clusterConf) + + assert(adminClientProperties.get("customKey") === "customValue") + } + test("isGlobalJaasConfigurationProvided without global config should return false") { assert(!KafkaTokenUtil.isGlobalJaasConfigurationProvided) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index f86d40015bd2..84940d96b563 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -48,6 +48,8 @@ public class SparkLauncher extends AbstractLauncher { public static final String DRIVER_MEMORY = "spark.driver.memory"; /** Configuration key for the driver class path. */ public static final String DRIVER_EXTRA_CLASSPATH = "spark.driver.extraClassPath"; + /** Configuration key for the default driver VM options. */ + public static final String DRIVER_DEFAULT_JAVA_OPTIONS = "spark.driver.defaultJavaOptions"; /** Configuration key for the driver VM options. */ public static final String DRIVER_EXTRA_JAVA_OPTIONS = "spark.driver.extraJavaOptions"; /** Configuration key for the driver native library path. */ @@ -57,6 +59,8 @@ public class SparkLauncher extends AbstractLauncher { public static final String EXECUTOR_MEMORY = "spark.executor.memory"; /** Configuration key for the executor class path. */ public static final String EXECUTOR_EXTRA_CLASSPATH = "spark.executor.extraClassPath"; + /** Configuration key for the default executor VM options. */ + public static final String EXECUTOR_DEFAULT_JAVA_OPTIONS = "spark.executor.defaultJavaOptions"; /** Configuration key for the executor VM options. */ public static final String EXECUTOR_EXTRA_JAVA_OPTIONS = "spark.executor.extraJavaOptions"; /** Configuration key for the executor native library path. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index e3ee843f6244..3479e0c3422b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -267,13 +267,10 @@ private List buildSparkSubmitCommand(Map env) // We don't want the client to specify Xmx. These have to be set by their corresponding // memory flag --driver-memory or configuration entry spark.driver.memory + String driverDefaultJavaOptions = config.get(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS); + checkJavaOptions(driverDefaultJavaOptions); String driverExtraJavaOptions = config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS); - if (!isEmpty(driverExtraJavaOptions) && driverExtraJavaOptions.contains("Xmx")) { - String msg = String.format("Not allowed to specify max heap(Xmx) memory settings through " + - "java options (was %s). Use the corresponding --driver-memory or " + - "spark.driver.memory configuration instead.", driverExtraJavaOptions); - throw new IllegalArgumentException(msg); - } + checkJavaOptions(driverExtraJavaOptions); if (isClientMode) { // Figuring out where the memory value come from is a little tricky due to precedence. @@ -289,6 +286,7 @@ private List buildSparkSubmitCommand(Map env) String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); cmd.add("-Xmx" + memory); + addOptionString(cmd, driverDefaultJavaOptions); addOptionString(cmd, driverExtraJavaOptions); mergeEnvPathList(env, getLibPathEnvName(), config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); @@ -299,6 +297,15 @@ private List buildSparkSubmitCommand(Map env) return cmd; } + private void checkJavaOptions(String javaOptions) { + if (!isEmpty(javaOptions) && javaOptions.contains("Xmx")) { + String msg = String.format("Not allowed to specify max heap(Xmx) memory settings through " + + "java options (was %s). Use the corresponding --driver-memory or " + + "spark.driver.memory configuration instead.", javaOptions); + throw new IllegalArgumentException(msg); + } + } + private List buildPySparkShellCommand(Map env) throws IOException { // For backwards compatibility, if a script is specified in // the pyspark command line, then run it using spark-submit. diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index e694e9066f12..32a91b178941 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -251,6 +251,8 @@ public void testMissingAppResource() { } private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception { + final String DRIVER_DEFAULT_PARAM = "-Ddriver-default"; + final String DRIVER_EXTRA_PARAM = "-Ddriver-extra"; String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = @@ -270,7 +272,8 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath()); launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver"); + launcher.conf.put(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS, DRIVER_DEFAULT_PARAM); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_PARAM); launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); } else { launcher.childEnv.put("SPARK_CONF_DIR", System.getProperty("spark.test.home") @@ -284,6 +287,9 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th if (isDriver) { assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx1g")); + assertTrue("Driver default options should be configured.", + cmd.contains(DRIVER_DEFAULT_PARAM)); + assertTrue("Driver extra options should be configured.", cmd.contains(DRIVER_EXTRA_PARAM)); } else { boolean found = false; for (String arg : cmd) { @@ -293,6 +299,10 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th } } assertFalse("Memory arguments should not be set.", found); + assertFalse("Driver default options should not be configured.", + cmd.contains(DRIVER_DEFAULT_PARAM)); + assertFalse("Driver extra options should not be configured.", + cmd.contains(DRIVER_EXTRA_PARAM)); } String[] cp = findArgValue(cmd, "-cp").split(Pattern.quote(File.pathSeparator)); diff --git a/launcher/src/test/resources/spark-defaults.conf b/launcher/src/test/resources/spark-defaults.conf index 3a51208c7c24..22c253693dcf 100644 --- a/launcher/src/test/resources/spark-defaults.conf +++ b/launcher/src/test/resources/spark-defaults.conf @@ -17,5 +17,6 @@ spark.driver.memory=1g spark.driver.extraClassPath=/driver -spark.driver.extraJavaOptions=-Ddriver +spark.driver.defaultJavaOptions=-Ddriver-default +spark.driver.extraJavaOptions=-Ddriver-extra spark.driver.extraLibraryPath=/native \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d8f3dfa87443..58815434cbda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -204,8 +204,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, if ($(predictionCol).nonEmpty) { transformImpl(dataset) } else { - this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + - " since no output columns were set.") + this.logWarning(s"$uid: Predictor.transform() does nothing" + + " because no output columns were set.") dataset.toDF } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index e35e6ce7fdad..568cdd11a12a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -204,8 +204,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur } if (numColsOutput == 0) { - logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") + logWarning(s"$uid: ClassificationModel.transform() does nothing" + + " because no output columns were set.") } outputData.toDF } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index e1fceb1fc96a..675315e3bb07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -169,9 +169,9 @@ final class OneVsRestModel private[ml] ( // Check schema transformSchema(dataset.schema, logging = true) - if (getPredictionCol == "" && getRawPredictionCol == "") { - logWarning(s"$uid: OneVsRestModel.transform() was called as NOOP" + - " since no output columns were set.") + if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) { + logWarning(s"$uid: OneVsRestModel.transform() does nothing" + + " because no output columns were set.") return dataset.toDF } @@ -218,7 +218,7 @@ final class OneVsRestModel private[ml] ( var predictionColNames = Seq.empty[String] var predictionColumns = Seq.empty[Column] - if (getRawPredictionCol != "") { + if (getRawPredictionCol.nonEmpty) { val numClass = models.length // output the RawPrediction as vector @@ -228,18 +228,18 @@ final class OneVsRestModel private[ml] ( Vectors.dense(predArray) } - predictionColNames = predictionColNames :+ getRawPredictionCol - predictionColumns = predictionColumns :+ rawPredictionUDF(col(accColName)) + predictionColNames :+= getRawPredictionCol + predictionColumns :+= rawPredictionUDF(col(accColName)) } - if (getPredictionCol != "") { + if (getPredictionCol.nonEmpty) { // output the index of the classifier with highest confidence as prediction val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } - predictionColNames = predictionColNames :+ getPredictionCol - predictionColumns = predictionColumns :+ labelUDF(col(accColName)) + predictionColNames :+= getPredictionCol + predictionColumns :+= labelUDF(col(accColName)) .as(getPredictionCol, labelMetadata) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 730fcab333e1..5046caa568d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -147,8 +147,8 @@ abstract class ProbabilisticClassificationModel[ } if (numColsOutput == 0) { - this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() does nothing" + + " because no output columns were set.") } outputData.toDF } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index fb4698ab5564..9a51d2f18846 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -110,11 +110,29 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predUDF = udf((vector: Vector) => predict(vector)) - val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset - .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) - .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + + if ($(predictionCol).nonEmpty) { + val predUDF = udf((vector: Vector) => predict(vector)) + predictionColNames :+= $(predictionCol) + predictionColumns :+= predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + } + + if ($(probabilityCol).nonEmpty) { + val probUDF = udf((vector: Vector) => predictProbability(vector)) + predictionColNames :+= $(probabilityCol) + predictionColumns :+= probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) + } else { + this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() + } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index aa8103701445..91201e7bd03f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -461,17 +461,10 @@ abstract class LDAModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - if ($(topicDistributionCol).nonEmpty) { - val func = getTopicDistributionMethod - val transformer = udf(func) - - dataset.withColumn($(topicDistributionCol), - transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol))) - } else { - logWarning("LDAModel.transform was called without any output columns. Set an output column" + - " such as topicDistributionCol to produce results.") - dataset.toDF() - } + val func = getTopicDistributionMethod + val transformer = udf(func) + dataset.withColumn($(topicDistributionCol), + transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } /** @@ -490,7 +483,7 @@ abstract class LDAModel private[ml] ( Vectors.zeros(k) } else { val (ids: List[Int], cts: Array[Double]) = vector match { - case v: DenseVector => ((0 until v.size).toList, v.values) + case v: DenseVector => (List.range(0, v.size), v.values) case v: SparseVector => (v.indices.toList, v.values) case other => throw new UnsupportedOperationException( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 17f2c17c9552..81cf2e1a4ff7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -169,7 +169,7 @@ class StandardScalerModel private[ml] ( case d: DenseVector => d.values.clone() case v: Vector => v.toArray } - val newValues = scaler.transfromWithMean(values) + val newValues = scaler.transformWithMean(values) Vectors.dense(newValues) } else if ($(withStd)) { vector: Vector => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 067dfa43433e..1565782dd631 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -37,7 +37,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -355,13 +355,28 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predictUDF = udf { features: Vector => predict(features) } - val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + + if ($(predictionCol).nonEmpty) { + val predictUDF = udf { features: Vector => predict(features) } + predictionColNames :+= $(predictionCol) + predictionColumns :+= predictUDF(col($(featuresCol))) + } + if (hasQuantilesCol) { - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol)))) + val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + predictionColNames :+= $(quantilesCol) + predictionColumns :+= predictQuantilesUDF(col($(featuresCol))) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) } else { - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + this.logWarning(s"$uid: AFTSurvivalRegressionModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index f4f4e56a3578..6348289de516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType @@ -216,16 +216,28 @@ class DecisionTreeRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector) => predict(features) } - val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset.toDF() + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + val predictUDF = udf { (features: Vector) => predict(features) } + predictionColNames :+= $(predictionCol) + predictionColumns :+= predictUDF(col($(featuresCol))) } + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { - output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + predictionColNames :+= $(varianceCol) + predictionColumns :+= predictVarianceUDF(col($(featuresCol))) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) + } else { + this.logWarning(s"$uid: DecisionTreeRegressionModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() } - output } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 885b13bf8dac..b1a8f95c1261 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1041,18 +1041,31 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } - val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) } + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) - var output = dataset + if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset)) + val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } + predictionColNames :+= $(predictionCol) + predictionColumns :+= predictUDF(col($(featuresCol)), offset) } + if (hasLinkPredictionCol) { - output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset)) + val predictLinkUDF = + udf { (features: Vector, offset: Double) => predictLink(features, offset) } + predictionColNames :+= $(linkPredictionCol) + predictionColumns :+= predictLinkUDF(col($(featuresCol)), offset) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) + } else { + this.logWarning(s"$uid: GeneralizedLinearRegressionModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() } - output.toDF() } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 4c478a5477c0..4617073f9dec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1357,10 +1357,6 @@ private[spark] abstract class SerDeBase { val unpickle = new Unpickler iter.flatMap { row => val obj = unpickle.loads(row) - // `Opcodes.MEMOIZE` of Protocol 4 (Python 3.4+) will store objects in internal map - // of `Unpickler`. This map is cleared when calling `Unpickler.close()`. Pyrolite - // doesn't clear it up, so we manually clear it. - unpickle.close() if (batched) { obj match { case list: JArrayList[_] => list.asScala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 578b779cd52d..19e53e7eac84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -141,7 +141,7 @@ class StandardScalerModel @Since("1.3.0") ( case d: DenseVector => d.values.clone() case v: Vector => v.toArray } - val newValues = transfromWithMean(values) + val newValues = transformWithMean(values) Vectors.dense(newValues) } else if (withStd) { vector match { @@ -161,7 +161,7 @@ class StandardScalerModel @Since("1.3.0") ( } } - private[spark] def transfromWithMean(values: Array[Double]): Array[Double] = { + private[spark] def transformWithMean(values: Array[Double]): Array[Double] = { // By default, Scala generates Java methods for member variables. So every time when // the member variables are accessed, `invokespecial` will be called which is expensive. // This can be avoid by having a local reference of `shift`. diff --git a/pom.xml b/pom.xml index 81e6e7956ab7..dae5eb007015 100644 --- a/pom.xml +++ b/pom.xml @@ -832,7 +832,7 @@ org.mockito mockito-core - 2.23.4 + 2.28.2 test diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cb3b80309643..5978f88d6a46 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -372,7 +372,11 @@ object MimaExcludes { // [SPARK-26616][MLlib] Expose document frequency in IDFModel ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf") + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf"), + + // [SPARK-28199][SS] Remove deprecated ProcessingTime + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime$") ) // Exclude rules for 2.4.x diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8bcc67ab1c3e..96fdf5f33b39 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2490,7 +2490,7 @@ def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M + if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc): # Default 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f8be8ee5d4c3..398471234d2b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,8 +22,10 @@ basestring = unicode = str long = int from functools import reduce + from html import escape as html_escape else: from itertools import imap as map + from cgi import escape as html_escape import warnings @@ -375,7 +377,6 @@ def _repr_html_(self): by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are using support eager evaluation with HTML. """ - import cgi if not self._support_repr_html: self._support_repr_html = True if self.sql_ctx._conf.isReplEagerEvalEnabled(): @@ -390,11 +391,11 @@ def _repr_html_(self): html = "\n" # generate table head - html += "\n" % "\n" % "\n" % "
%s
".join(map(lambda x: cgi.escape(x), head)) + html += "
%s
".join(map(lambda x: html_escape(x), head)) # generate table rows for row in row_data: html += "
%s
".join( - map(lambda x: cgi.escape(x), row)) + map(lambda x: html_escape(x), row)) html += "
\n" if has_more_data: html += "only showing top %d %s\n" % ( diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6bb7da6b2edb..3ced81427397 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -36,6 +36,7 @@ from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf +from pyspark.sql.utils import to_str # Note to developers: all of PySpark functions here take string as column names whenever possible. # Namely, if columns are referred as arguments, they can be always both Column or string, @@ -114,6 +115,10 @@ def _(): _.__doc__ = 'Window function: ' + doc return _ + +def _options_to_str(options): + return {key: to_str(value) for (key, value) in options.items()} + _lit_doc = """ Creates a :class:`Column` of literal value. @@ -2343,7 +2348,7 @@ def from_json(col, schema, options={}): schema = schema.json() elif isinstance(schema, Column): schema = _to_java_column(schema) - jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) + jc = sc._jvm.functions.from_json(_to_java_column(col), schema, _options_to_str(options)) return Column(jc) @@ -2384,7 +2389,7 @@ def to_json(col, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.to_json(_to_java_column(col), options) + jc = sc._jvm.functions.to_json(_to_java_column(col), _options_to_str(options)) return Column(jc) @@ -2415,7 +2420,7 @@ def schema_of_json(json, options={}): raise TypeError("schema argument should be a column or string") sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_json(col, options) + jc = sc._jvm.functions.schema_of_json(col, _options_to_str(options)) return Column(jc) @@ -2442,7 +2447,7 @@ def schema_of_csv(csv, options={}): raise TypeError("schema argument should be a column or string") sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_csv(col, options) + jc = sc._jvm.functions.schema_of_csv(col, _options_to_str(options)) return Column(jc) @@ -2464,7 +2469,7 @@ def to_csv(col, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.to_csv(_to_java_column(col), options) + jc = sc._jvm.functions.to_csv(_to_java_column(col), _options_to_str(options)) return Column(jc) @@ -2775,6 +2780,11 @@ def from_csv(col, schema, options={}): >>> value = data[0][0] >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() [Row(csv=Row(_c0=1, _c1=2, _c2=3))] + >>> data = [(" abc",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> options = {'ignoreLeadingWhiteSpace': True} + >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect() + [Row(csv=Row(s=u'abc'))] """ sc = SparkContext._active_spark_context @@ -2785,7 +2795,7 @@ def from_csv(col, schema, options={}): else: raise TypeError("schema argument should be a column or string") - jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options) + jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, _options_to_str(options)) return Column(jc) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aa5bf635d187..f9bc2ff72a50 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -27,23 +27,11 @@ from pyspark.sql.column import _to_seq from pyspark.sql.types import * from pyspark.sql import utils +from pyspark.sql.utils import to_str __all__ = ["DataFrameReader", "DataFrameWriter"] -def to_str(value): - """ - A wrapper over str(), but converts bool values to lower case strings. - If None is given, just returns None, instead of converting it to string "None". - """ - if isinstance(value, bool): - return str(value).lower() - elif value is None: - return value - else: - return str(value) - - class OptionUtils(object): def _set_opts(self, schema=None, **options): @@ -757,7 +745,7 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): self._jwrite.save(path) @since(1.4) - def insertInto(self, tableName, overwrite=False): + def insertInto(self, tableName, overwrite=None): """Inserts the content of the :class:`DataFrame` to the specified table. It requires that the schema of the class:`DataFrame` is the same as the @@ -765,7 +753,9 @@ def insertInto(self, tableName, overwrite=False): Optionally overwriting any existing data. """ - self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) + if overwrite is not None: + self.mode("overwrite" if overwrite else "append") + self._jwrite.insertInto(tableName) @since(1.4) def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options): diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index e5c3b48d548e..5550a093bf80 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -750,6 +750,7 @@ def test_query_execution_listener_on_collect(self): self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should not be called before 'collect'") self.spark.sql("SELECT * FROM range(1)").collect() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000) self.assertTrue( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should be called after 'collect'") @@ -764,6 +765,7 @@ def test_query_execution_listener_on_collect_with_arrow(self): "The callback from the query execution listener should not be " "called before 'toPandas'") self.spark.sql("SELECT * FROM range(1)").toPandas() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000) self.assertTrue( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should be called after 'toPandas'") diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index a70807248960..2530cc2ebf22 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -141,6 +141,27 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + def test_insert_into(self): + df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"]) + with self.table("test_table"): + df.write.saveAsTable("test_table") + self.assertEqual(2, self.spark.sql("select * from test_table").count()) + + df.write.insertInto("test_table") + self.assertEqual(4, self.spark.sql("select * from test_table").count()) + + df.write.mode("overwrite").insertInto("test_table") + self.assertEqual(2, self.spark.sql("select * from test_table").count()) + + df.write.insertInto("test_table", True) + self.assertEqual(2, self.spark.sql("select * from test_table").count()) + + df.write.insertInto("test_table", False) + self.assertEqual(4, self.spark.sql("select * from test_table").count()) + + df.write.mode("overwrite").insertInto("test_table", False) + self.assertEqual(6, self.spark.sql("select * from test_table").count()) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index f9bed7604b13..ea2a686cddaa 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -128,10 +128,6 @@ def test_BinaryType_serialization(self): def test_int_array_serialization(self): # Note that this test seems dependent on parallelism. - # This issue is because internal object map in Pyrolite is not cleared after op code - # STOP. If we use protocol 4 to pickle Python objects, op code MEMOIZE will store - # objects in the map. We need to clear up it to make sure next unpickling works on - # clear map. data = self.spark.sparkContext.parallelize([[1, 2, 3, 4]] * 100, numSlices=12) df = self.spark.createDataFrame(data, "array") self.assertEqual(len(list(filter(lambda r: None in r.value, df.collect()))), 0) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index ca5e85bb3a9b..c30cc1482750 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -207,3 +207,16 @@ def call(self, jdf, batch_id): class Java: implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] + + +def to_str(value): + """ + A wrapper over str(), but converts bool values to lower case strings. + If None is given, just returns None, instead of converting it to string "None". + """ + if isinstance(value, bool): + return str(value).lower() + elif value is None: + return value + else: + return str(value) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 3ff68348be7b..2fd13a590324 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -28,6 +28,7 @@ import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest._ import org.apache.spark.internal.config +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.util.Utils @@ -97,6 +98,7 @@ private[mesos] class MesosSubmitRequestServlet( // Optional fields val sparkProperties = request.sparkProperties + val driverDefaultJavaOptions = sparkProperties.get(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS) val driverExtraJavaOptions = sparkProperties.get(config.DRIVER_JAVA_OPTIONS.key) val driverExtraClassPath = sparkProperties.get(config.DRIVER_CLASS_PATH.key) val driverExtraLibraryPath = sparkProperties.get(config.DRIVER_LIBRARY_PATH.key) @@ -110,9 +112,11 @@ private[mesos] class MesosSubmitRequestServlet( val conf = new SparkConf(false).setAll(sparkProperties) val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val defaultJavaOpts = driverDefaultJavaOptions.map(Utils.splitCommandString) + .getOrElse(Seq.empty) val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) - val javaOpts = sparkJavaOpts ++ extraJavaOpts + val javaOpts = sparkJavaOpts ++ defaultJavaOpts ++ extraJavaOpts val command = new Command( mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5b361d17c01a..651e706021fc 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -51,7 +51,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.api.python.PythonUtils import org.apache.spark.deploy.{SparkApplication, SparkHadoopUtil} import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.deploy.yarn.ResourceRequestHelper._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -241,12 +241,12 @@ private[spark] class Client( newApp: YarnClientApplication, containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { - val yarnAMResources = - if (isClusterMode) { - sparkConf.getAllWithPrefix(config.YARN_DRIVER_RESOURCE_TYPES_PREFIX).toMap - } else { - sparkConf.getAllWithPrefix(config.YARN_AM_RESOURCE_TYPES_PREFIX).toMap - } + val componentName = if (isClusterMode) { + config.YARN_DRIVER_RESOURCE_TYPES_PREFIX + } else { + config.YARN_AM_RESOURCE_TYPES_PREFIX + } + val yarnAMResources = getYarnResourcesAndAmounts(sparkConf, componentName) val amResources = yarnAMResources ++ getYarnResourcesFromSparkResources(SPARK_DRIVER_PREFIX, sparkConf) logDebug(s"AM resources: $amResources") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala index cb0c68d1d346..522c16b3a108 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala @@ -26,11 +26,11 @@ import scala.util.Try import org.apache.hadoop.yarn.api.records.Resource import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.resource.ResourceID +import org.apache.spark.resource.ResourceUtils.{AMOUNT, FPGA, GPU} import org.apache.spark.util.{CausedBy, Utils} /** @@ -40,6 +40,45 @@ import org.apache.spark.util.{CausedBy, Utils} private object ResourceRequestHelper extends Logging { private val AMOUNT_AND_UNIT_REGEX = "([0-9]+)([A-Za-z]*)".r private val RESOURCE_INFO_CLASS = "org.apache.hadoop.yarn.api.records.ResourceInformation" + val YARN_GPU_RESOURCE_CONFIG = "yarn.io/gpu" + val YARN_FPGA_RESOURCE_CONFIG = "yarn.io/fpga" + + private[yarn] def getYarnResourcesAndAmounts( + sparkConf: SparkConf, + componentName: String): Map[String, String] = { + sparkConf.getAllWithPrefix(s"$componentName").map { case (key, value) => + val splitIndex = key.lastIndexOf('.') + if (splitIndex == -1) { + val errorMessage = s"Missing suffix for ${componentName}${key}, you must specify" + + s" a suffix - $AMOUNT is currently the only supported suffix." + throw new IllegalArgumentException(errorMessage.toString()) + } + val resourceName = key.substring(0, splitIndex) + val resourceSuffix = key.substring(splitIndex + 1) + if (!AMOUNT.equals(resourceSuffix)) { + val errorMessage = s"Unsupported suffix: $resourceSuffix in: ${componentName}${key}, " + + s"only .$AMOUNT is supported." + throw new IllegalArgumentException(errorMessage.toString()) + } + (resourceName, value) + }.toMap + } + + /** + * Convert Spark resources into YARN resources. + * The only resources we know how to map from spark configs to yarn configs are + * gpus and fpgas, everything else the user has to specify them in both the + * spark.yarn.*.resource and the spark.*.resource configs. + */ + private[yarn] def getYarnResourcesFromSparkResources( + confPrefix: String, + sparkConf: SparkConf + ): Map[String, String] = { + Map(GPU -> YARN_GPU_RESOURCE_CONFIG, FPGA -> YARN_FPGA_RESOURCE_CONFIG).map { + case (rName, yarnName) => + (yarnName -> sparkConf.get(ResourceID(confPrefix, rName).amountConf, "0")) + }.filter { case (_, count) => count.toLong > 0 } + } /** * Validates sparkConf and throws a SparkException if any of standard resources (memory or cores) @@ -81,8 +120,9 @@ private object ResourceRequestHelper extends Logging { val errorMessage = new mutable.StringBuilder() resourceDefinitions.foreach { case (sparkName, resourceRequest) => - if (sparkConf.contains(resourceRequest)) { - errorMessage.append(s"Error: Do not use $resourceRequest, " + + val resourceRequestAmount = s"${resourceRequest}.${AMOUNT}" + if (sparkConf.contains(resourceRequestAmount)) { + errorMessage.append(s"Error: Do not use $resourceRequestAmount, " + s"please use $sparkName instead!\n") } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 6e634b921fcd..8ec7bd66b250 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.deploy.yarn.ResourceRequestHelper._ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging @@ -142,8 +143,8 @@ private[yarn] class YarnAllocator( protected val executorCores = sparkConf.get(EXECUTOR_CORES) private val executorResourceRequests = - sparkConf.getAllWithPrefix(config.YARN_EXECUTOR_RESOURCE_TYPES_PREFIX).toMap ++ - getYarnResourcesFromSparkResources(SPARK_EXECUTOR_PREFIX, sparkConf) + getYarnResourcesAndAmounts(sparkConf, config.YARN_EXECUTOR_RESOURCE_TYPES_PREFIX) ++ + getYarnResourcesFromSparkResources(SPARK_EXECUTOR_PREFIX, sparkConf) // Resource capability requested for each executor private[yarn] val resource: Resource = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 6b87eec795f9..11035520ae18 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -41,24 +41,6 @@ object YarnSparkHadoopUtil { val MEMORY_OVERHEAD_MIN = 384L val ANY_HOST = "*" - val YARN_GPU_RESOURCE_CONFIG = "yarn.io/gpu" - val YARN_FPGA_RESOURCE_CONFIG = "yarn.io/fpga" - - /** - * Convert Spark resources into YARN resources. - * The only resources we know how to map from spark configs to yarn configs are - * gpus and fpgas, everything else the user has to specify them in both the - * spark.yarn.*.resource and the spark.*.resource configs. - */ - private[yarn] def getYarnResourcesFromSparkResources( - confPrefix: String, - sparkConf: SparkConf - ): Map[String, String] = { - Map(GPU -> YARN_GPU_RESOURCE_CONFIG, FPGA -> YARN_FPGA_RESOURCE_CONFIG).map { - case (rName, yarnName) => - (yarnName -> sparkConf.get(ResourceID(confPrefix, rName).amountConf, "0")) - }.filter { case (_, count) => count.toLong > 0 } - } // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index d5f1992a09f5..847fc3773de5 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -38,10 +38,11 @@ import org.mockito.Mockito.{spy, verify} import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TestUtils} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.deploy.yarn.ResourceRequestHelper._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.config._ import org.apache.spark.resource.ResourceID +import org.apache.spark.resource.ResourceUtils.AMOUNT import org.apache.spark.util.{SparkConfWithEnv, Utils} class ClientSuite extends SparkFunSuite with Matchers { @@ -372,7 +373,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val conf = new SparkConf().set(SUBMIT_DEPLOY_MODE, deployMode) resources.foreach { case (name, v) => - conf.set(prefix + name, v.toString) + conf.set(s"${prefix}${name}.${AMOUNT}", v.toString) } val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) @@ -397,7 +398,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val conf = new SparkConf().set(SUBMIT_DEPLOY_MODE, "cluster") resources.keys.foreach { yarnName => - conf.set(s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${yarnName}", "2") + conf.set(s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${yarnName}.${AMOUNT}", "2") } resources.values.foreach { rName => conf.set(ResourceID(SPARK_DRIVER_PREFIX, rName).amountConf, "3") @@ -407,9 +408,9 @@ class ClientSuite extends SparkFunSuite with Matchers { ResourceRequestHelper.validateResources(conf) }.getMessage() - assert(error.contains("Do not use spark.yarn.driver.resource.yarn.io/fpga," + + assert(error.contains("Do not use spark.yarn.driver.resource.yarn.io/fpga.amount," + " please use spark.driver.resource.fpga.amount")) - assert(error.contains("Do not use spark.yarn.driver.resource.yarn.io/gpu," + + assert(error.contains("Do not use spark.yarn.driver.resource.yarn.io/gpu.amount," + " please use spark.driver.resource.gpu.amount")) } @@ -420,7 +421,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val conf = new SparkConf().set(SUBMIT_DEPLOY_MODE, "cluster") resources.keys.foreach { yarnName => - conf.set(s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${yarnName}", "2") + conf.set(s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${yarnName}.${AMOUNT}", "2") } resources.values.foreach { rName => conf.set(ResourceID(SPARK_EXECUTOR_PREFIX, rName).amountConf, "3") @@ -430,9 +431,9 @@ class ClientSuite extends SparkFunSuite with Matchers { ResourceRequestHelper.validateResources(conf) }.getMessage() - assert(error.contains("Do not use spark.yarn.executor.resource.yarn.io/fpga," + + assert(error.contains("Do not use spark.yarn.executor.resource.yarn.io/fpga.amount," + " please use spark.executor.resource.fpga.amount")) - assert(error.contains("Do not use spark.yarn.executor.resource.yarn.io/gpu," + + assert(error.contains("Do not use spark.yarn.executor.resource.yarn.io/gpu.amount," + " please use spark.executor.resource.gpu.amount")) } @@ -450,7 +451,7 @@ class ClientSuite extends SparkFunSuite with Matchers { conf.set(ResourceID(SPARK_DRIVER_PREFIX, rName).amountConf, "3") } // also just set yarn one that we don't convert - conf.set(YARN_DRIVER_RESOURCE_TYPES_PREFIX + yarnMadeupResource, "5") + conf.set(s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${yarnMadeupResource}.${AMOUNT}", "5") val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext]) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala index 9e3cc6ec01df..f5ec531e26e0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala @@ -22,9 +22,11 @@ import org.apache.hadoop.yarn.util.Records import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.deploy.yarn.ResourceRequestHelper._ import org.apache.spark.deploy.yarn.ResourceRequestTestHelper.ResourceInformation import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.config.{DRIVER_CORES, DRIVER_MEMORY, EXECUTOR_CORES, EXECUTOR_MEMORY} +import org.apache.spark.resource.ResourceUtils.AMOUNT class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { @@ -32,16 +34,18 @@ class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { private val CUSTOM_RES_2 = "custom-resource-type-2" private val MEMORY = "memory" private val CORES = "cores" - private val NEW_CONFIG_EXECUTOR_MEMORY = YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + MEMORY - private val NEW_CONFIG_EXECUTOR_CORES = YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + CORES - private val NEW_CONFIG_AM_MEMORY = YARN_AM_RESOURCE_TYPES_PREFIX + MEMORY - private val NEW_CONFIG_AM_CORES = YARN_AM_RESOURCE_TYPES_PREFIX + CORES - private val NEW_CONFIG_DRIVER_MEMORY = YARN_DRIVER_RESOURCE_TYPES_PREFIX + MEMORY - private val NEW_CONFIG_DRIVER_CORES = YARN_DRIVER_RESOURCE_TYPES_PREFIX + CORES + private val NEW_CONFIG_EXECUTOR_MEMORY = + s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${MEMORY}.${AMOUNT}" + private val NEW_CONFIG_EXECUTOR_CORES = + s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${CORES}.${AMOUNT}" + private val NEW_CONFIG_AM_MEMORY = s"${YARN_AM_RESOURCE_TYPES_PREFIX}${MEMORY}.${AMOUNT}" + private val NEW_CONFIG_AM_CORES = s"${YARN_AM_RESOURCE_TYPES_PREFIX}${CORES}.${AMOUNT}" + private val NEW_CONFIG_DRIVER_MEMORY = s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${MEMORY}.${AMOUNT}" + private val NEW_CONFIG_DRIVER_CORES = s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${CORES}.${AMOUNT}" test("empty SparkConf should be valid") { val sparkConf = new SparkConf() - ResourceRequestHelper.validateResources(sparkConf) + validateResources(sparkConf) } test("just normal resources are defined") { @@ -50,7 +54,44 @@ class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { sparkConf.set(DRIVER_CORES.key, "4") sparkConf.set(EXECUTOR_MEMORY.key, "4G") sparkConf.set(EXECUTOR_CORES.key, "2") - ResourceRequestHelper.validateResources(sparkConf) + validateResources(sparkConf) + } + + test("get yarn resources from configs") { + val sparkConf = new SparkConf() + val resources = Map(YARN_GPU_RESOURCE_CONFIG -> "2G", + YARN_FPGA_RESOURCE_CONFIG -> "3G", "custom" -> "4") + resources.foreach { case (name, value) => + sparkConf.set(s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${name}.${AMOUNT}", value) + sparkConf.set(s"${YARN_DRIVER_RESOURCE_TYPES_PREFIX}${name}.${AMOUNT}", value) + sparkConf.set(s"${YARN_AM_RESOURCE_TYPES_PREFIX}${name}.${AMOUNT}", value) + } + var parsedResources = getYarnResourcesAndAmounts(sparkConf, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX) + assert(parsedResources === resources) + parsedResources = getYarnResourcesAndAmounts(sparkConf, YARN_DRIVER_RESOURCE_TYPES_PREFIX) + assert(parsedResources === resources) + parsedResources = getYarnResourcesAndAmounts(sparkConf, YARN_AM_RESOURCE_TYPES_PREFIX) + assert(parsedResources === resources) + } + + test("get invalid yarn resources from configs") { + val sparkConf = new SparkConf() + + val missingAmountConfig = s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}missingAmount" + // missing .amount + sparkConf.set(missingAmountConfig, "2g") + var thrown = intercept[IllegalArgumentException] { + getYarnResourcesAndAmounts(sparkConf, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX) + } + thrown.getMessage should include("Missing suffix for") + + sparkConf.remove(missingAmountConfig) + sparkConf.set(s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}customResource.invalid", "2g") + + thrown = intercept[IllegalArgumentException] { + getYarnResourcesAndAmounts(sparkConf, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX) + } + thrown.getMessage should include("Unsupported suffix") } Seq( @@ -60,14 +101,14 @@ class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { ResourceInformation(CUSTOM_RES_2, 10, "G")) ).foreach { case (name, resources) => test(s"valid request: $name") { - assume(ResourceRequestHelper.isYarnResourceTypesAvailable()) + assume(isYarnResourceTypesAvailable()) val resourceDefs = resources.map { r => r.name } val requests = resources.map { r => (r.name, r.value.toString + r.unit) }.toMap ResourceRequestTestHelper.initializeResourceTypes(resourceDefs) val resource = createResource() - ResourceRequestHelper.setResourceRequests(requests, resource) + setResourceRequests(requests, resource) resources.foreach { r => val requested = ResourceRequestTestHelper.getResourceInformationByName(resource, r.name) @@ -82,12 +123,12 @@ class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { ("invalid unit", CUSTOM_RES_1, "123ppp") ).foreach { case (name, key, value) => test(s"invalid request: $name") { - assume(ResourceRequestHelper.isYarnResourceTypesAvailable()) + assume(isYarnResourceTypesAvailable()) ResourceRequestTestHelper.initializeResourceTypes(Seq(key)) val resource = createResource() val thrown = intercept[IllegalArgumentException] { - ResourceRequestHelper.setResourceRequests(Map(key -> value), resource) + setResourceRequests(Map(key -> value), resource) } thrown.getMessage should include (key) } @@ -95,20 +136,20 @@ class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { Seq( NEW_CONFIG_EXECUTOR_MEMORY -> "30G", - YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "memory-mb" -> "30G", - YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "mb" -> "30G", + s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}memory-mb.$AMOUNT" -> "30G", + s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}mb.$AMOUNT" -> "30G", NEW_CONFIG_EXECUTOR_CORES -> "5", - YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "vcores" -> "5", + s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}vcores.$AMOUNT" -> "5", NEW_CONFIG_AM_MEMORY -> "1G", NEW_CONFIG_DRIVER_MEMORY -> "1G", NEW_CONFIG_AM_CORES -> "3", NEW_CONFIG_DRIVER_CORES -> "1G" ).foreach { case (key, value) => test(s"disallowed resource request: $key") { - assume(ResourceRequestHelper.isYarnResourceTypesAvailable()) + assume(isYarnResourceTypesAvailable()) val conf = new SparkConf(false).set(key, value) val thrown = intercept[SparkException] { - ResourceRequestHelper.validateResources(conf) + validateResources(conf) } thrown.getMessage should include (key) } @@ -126,7 +167,7 @@ class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { sparkConf.set(NEW_CONFIG_DRIVER_MEMORY, "2G") val thrown = intercept[SparkException] { - ResourceRequestHelper.validateResources(sparkConf) + validateResources(sparkConf) } thrown.getMessage should ( include(NEW_CONFIG_EXECUTOR_MEMORY) and diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index ca89af26230f..4ac27ede6483 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -31,9 +31,11 @@ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.yarn.ResourceRequestHelper._ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.config._ +import org.apache.spark.resource.ResourceUtils.{AMOUNT, GPU} import org.apache.spark.resource.TestResourceIDs._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo @@ -160,12 +162,12 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter } test("custom resource requested from yarn") { - assume(ResourceRequestHelper.isYarnResourceTypesAvailable()) + assume(isYarnResourceTypesAvailable()) ResourceRequestTestHelper.initializeResourceTypes(List("gpu")) val mockAmClient = mock(classOf[AMRMClient[ContainerRequest]]) val handler = createAllocator(1, mockAmClient, - Map(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "gpu" -> "2G")) + Map(s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${GPU}.${AMOUNT}" -> "2G")) handler.updateResourceRequests() val container = createContainer("host1", resource = handler.resource) @@ -174,7 +176,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter // get amount of memory and vcores from resource, so effectively skipping their validation val expectedResources = Resource.newInstance(handler.resource.getMemory(), handler.resource.getVirtualCores) - ResourceRequestHelper.setResourceRequests(Map("gpu" -> "2G"), expectedResources) + setResourceRequests(Map("gpu" -> "2G"), expectedResources) val captor = ArgumentCaptor.forClass(classOf[ContainerRequest]) verify(mockAmClient).addContainerRequest(captor.capture()) @@ -183,15 +185,16 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter } test("custom spark resource mapped to yarn resource configs") { - assume(ResourceRequestHelper.isYarnResourceTypesAvailable()) + assume(isYarnResourceTypesAvailable()) val yarnMadeupResource = "yarn.io/madeup" val yarnResources = Seq(YARN_GPU_RESOURCE_CONFIG, YARN_FPGA_RESOURCE_CONFIG, yarnMadeupResource) ResourceRequestTestHelper.initializeResourceTypes(yarnResources) val mockAmClient = mock(classOf[AMRMClient[ContainerRequest]]) + val madeupConfigName = s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${yarnMadeupResource}.${AMOUNT}" val sparkResources = Map(EXECUTOR_GPU_ID.amountConf -> "3", EXECUTOR_FPGA_ID.amountConf -> "2", - s"${YARN_EXECUTOR_RESOURCE_TYPES_PREFIX}${yarnMadeupResource}" -> "5") + madeupConfigName -> "5") val handler = createAllocator(1, mockAmClient, sparkResources) handler.updateResourceRequests() diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java index 9b87e676d9b2..7eef6aea8812 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java @@ -227,6 +227,10 @@ public String property() { } } + interface ColumnChange extends TableChange { + String[] fieldNames(); + } + /** * A TableChange to add a field. *

@@ -234,7 +238,7 @@ public String property() { * If the new field is nested and its parent does not exist or is not a struct, the change must * result in an {@link IllegalArgumentException}. */ - final class AddColumn implements TableChange { + final class AddColumn implements ColumnChange { private final String[] fieldNames; private final DataType dataType; private final boolean isNullable; @@ -247,6 +251,7 @@ private AddColumn(String[] fieldNames, DataType dataType, boolean isNullable, St this.comment = comment; } + @Override public String[] fieldNames() { return fieldNames; } @@ -272,7 +277,7 @@ public String comment() { *

* If the field does not exist, the change must result in an {@link IllegalArgumentException}. */ - final class RenameColumn implements TableChange { + final class RenameColumn implements ColumnChange { private final String[] fieldNames; private final String newName; @@ -281,6 +286,7 @@ private RenameColumn(String[] fieldNames, String newName) { this.newName = newName; } + @Override public String[] fieldNames() { return fieldNames; } @@ -297,7 +303,7 @@ public String newName() { *

* If the field does not exist, the change must result in an {@link IllegalArgumentException}. */ - final class UpdateColumnType implements TableChange { + final class UpdateColumnType implements ColumnChange { private final String[] fieldNames; private final DataType newDataType; private final boolean isNullable; @@ -308,6 +314,7 @@ private UpdateColumnType(String[] fieldNames, DataType newDataType, boolean isNu this.isNullable = isNullable; } + @Override public String[] fieldNames() { return fieldNames; } @@ -328,7 +335,7 @@ public boolean isNullable() { *

* If the field does not exist, the change must result in an {@link IllegalArgumentException}. */ - final class UpdateColumnComment implements TableChange { + final class UpdateColumnComment implements ColumnChange { private final String[] fieldNames; private final String newComment; @@ -337,6 +344,7 @@ private UpdateColumnComment(String[] fieldNames, String newComment) { this.newComment = newComment; } + @Override public String[] fieldNames() { return fieldNames; } @@ -351,13 +359,14 @@ public String newComment() { *

* If the field does not exist, the change must result in an {@link IllegalArgumentException}. */ - final class DeleteColumn implements TableChange { + final class DeleteColumn implements ColumnChange { private final String[] fieldNames; private DeleteColumn(String[] fieldNames) { this.fieldNames = fieldNames; } + @Override public String[] fieldNames() { return fieldNames; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala index 6de1ef5660e5..7cc80c41a901 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, TableChange} import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, DeleteColumn, RemoveProperty, RenameColumn, SetProperty, UpdateColumnComment, UpdateColumnType} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.sources.v2.Table -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} object CatalogV2Util { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ @@ -132,16 +132,45 @@ object CatalogV2Util { val pos = struct.getFieldIndex(fieldNames.head) .getOrElse(throw new IllegalArgumentException(s"Cannot find field: ${fieldNames.head}")) val field = struct.fields(pos) - val replacement: Option[StructField] = if (fieldNames.tail.isEmpty) { - update(field) - } else { - field.dataType match { - case nestedStruct: StructType => - val updatedType: StructType = replace(nestedStruct, fieldNames.tail, update) - Some(StructField(field.name, updatedType, field.nullable, field.metadata)) - case _ => - throw new IllegalArgumentException(s"Not a struct: ${fieldNames.head}") - } + val replacement: Option[StructField] = (fieldNames.tail, field.dataType) match { + case (Seq(), _) => + update(field) + + case (names, struct: StructType) => + val updatedType: StructType = replace(struct, names, update) + Some(StructField(field.name, updatedType, field.nullable, field.metadata)) + + case (Seq("key"), map @ MapType(keyType, _, _)) => + val updated = update(StructField("key", keyType, nullable = false)) + .getOrElse(throw new IllegalArgumentException(s"Cannot delete map key")) + Some(field.copy(dataType = map.copy(keyType = updated.dataType))) + + case (Seq("key", names @ _*), map @ MapType(keyStruct: StructType, _, _)) => + Some(field.copy(dataType = map.copy(keyType = replace(keyStruct, names, update)))) + + case (Seq("value"), map @ MapType(_, mapValueType, isNullable)) => + val updated = update(StructField("value", mapValueType, nullable = isNullable)) + .getOrElse(throw new IllegalArgumentException(s"Cannot delete map value")) + Some(field.copy(dataType = map.copy( + valueType = updated.dataType, + valueContainsNull = updated.nullable))) + + case (Seq("value", names @ _*), map @ MapType(_, valueStruct: StructType, _)) => + Some(field.copy(dataType = map.copy(valueType = replace(valueStruct, names, update)))) + + case (Seq("element"), array @ ArrayType(elementType, isNullable)) => + val updated = update(StructField("element", elementType, nullable = isNullable)) + .getOrElse(throw new IllegalArgumentException(s"Cannot delete array element")) + Some(field.copy(dataType = array.copy( + elementType = updated.dataType, + containsNull = updated.nullable))) + + case (Seq("element", names @ _*), array @ ArrayType(elementStruct: StructType, _)) => + Some(field.copy(dataType = array.copy(elementType = replace(elementStruct, names, update)))) + + case (names, dataType) => + throw new IllegalArgumentException( + s"Cannot find field: ${names.head} in ${dataType.simpleString}") } val newFields = struct.fields.zipWithIndex.flatMap { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1d0dba262c10..e55cdfedd323 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -165,6 +166,7 @@ class Analyzer( new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: + ResolveAlterTable :: ResolveTables :: ResolveRelations :: ResolveReferences :: @@ -211,38 +213,6 @@ class Analyzer( CleanupAliases) ) - /** - * Analyze cte definitions and substitute child plan with analyzed cte definitions. - */ - object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case With(child, relations) => - // substitute CTE expressions right-to-left to resolve references to previous CTEs: - // with a as (select * from t), b as (select * from a) select * from b - relations.foldRight(child) { - case ((cteName, ctePlan), currentPlan) => - substituteCTE(currentPlan, cteName, ctePlan) - } - case other => other - } - - private def substituteCTE( - plan: LogicalPlan, - cteName: String, - ctePlan: LogicalPlan): LogicalPlan = { - plan resolveOperatorsUp { - case UnresolvedRelation(Seq(table)) if resolver(cteName, table) => - ctePlan - case other => - // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. - other transformExpressions { - case e: SubqueryExpression => - e.withNewPlan(substituteCTE(e.plan, cteName, ctePlan)) - } - } - } - } - /** * Substitute child plan with WindowSpecDefinitions. */ @@ -787,6 +757,86 @@ class Analyzer( } } + /** + * Resolve ALTER TABLE statements that use a DSv2 catalog. + * + * This rule converts unresolved ALTER TABLE statements to v2 when a v2 catalog is responsible + * for the table identifier. A v2 catalog is responsible for an identifier when the identifier + * has a catalog specified, like prod_catalog.db.table, or when a default v2 catalog is set and + * the table identifier does not include a catalog. + */ + object ResolveAlterTable extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case alter @ AlterTableAddColumnsStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), cols) => + val changes = cols.map { col => + TableChange.addColumn(col.name.toArray, col.dataType, true, col.comment.orNull) + } + + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + changes) + + case alter @ AlterTableAlterColumnStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), colName, dataType, comment) => + val typeChange = dataType.map { newDataType => + TableChange.updateColumnType(colName.toArray, newDataType, true) + } + + val commentChange = comment.map { newComment => + TableChange.updateColumnComment(colName.toArray, newComment) + } + + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + typeChange.toSeq ++ commentChange.toSeq) + + case alter @ AlterTableRenameColumnStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), col, newName) => + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + Seq(TableChange.renameColumn(col.toArray, newName))) + + case alter @ AlterTableDropColumnsStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), cols) => + val changes = cols.map(col => TableChange.deleteColumn(col.toArray)) + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + changes) + + case alter @ AlterTableSetPropertiesStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), props) => + val changes = props.map { + case (key, value) => + TableChange.setProperty(key, value) + } + + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + changes.toSeq) + + case alter @ AlterTableUnsetPropertiesStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), keys, _) => + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + keys.map(key => TableChange.removeProperty(key))) + + case alter @ AlterTableSetLocationStatement( + CatalogObjectIdentifier(Some(v2Catalog), ident), newLoc) => + AlterTable( + v2Catalog.asTableCatalog, ident, + UnresolvedRelation(alter.tableName), + Seq(TableChange.setProperty("location", newLoc))) + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala new file mode 100644 index 000000000000..60e6bf8db06d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, With} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LEGACY_CTE_PRECEDENCE_ENABLED + +/** + * Analyze WITH nodes and substitute child plan with CTE definitions. + */ +object CTESubstitution extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.getConf(LEGACY_CTE_PRECEDENCE_ENABLED)) { + legacyTraverseAndSubstituteCTE(plan) + } else { + traverseAndSubstituteCTE(plan, false) + } + } + + private def legacyTraverseAndSubstituteCTE(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsUp { + case With(child, relations) => + // substitute CTE expressions right-to-left to resolve references to previous CTEs: + // with a as (select * from t), b as (select * from a) select * from b + relations.foldRight(child) { + case ((cteName, ctePlan), currentPlan) => substituteCTE(currentPlan, cteName, ctePlan) + } + } + } + + /** + * Traverse the plan and expression nodes as a tree and replace matching references to CTE + * definitions. + * - If the rule encounters a WITH node then it substitutes the child of the node with CTE + * definitions of the node right-to-left order as a definition can reference to a previous + * one. + * For example the following query is valid: + * WITH + * t AS (SELECT 1), + * t2 AS (SELECT * FROM t) + * SELECT * FROM t2 + * - If a CTE definition contains an inner WITH node then substitution of inner should take + * precedence because it can shadow an outer CTE definition. + * For example the following query should return 2: + * WITH + * t AS (SELECT 1), + * t2 AS ( + * WITH t AS (SELECT 2) + * SELECT * FROM t + * ) + * SELECT * FROM t2 + * - If a CTE definition contains a subquery that contains an inner WITH node then substitution + * of inner should take precedence because it can shadow an outer CTE definition. + * For example the following query should return 2: + * WITH t AS (SELECT 1 AS c) + * SELECT max(c) FROM ( + * WITH t AS (SELECT 2 AS c) + * SELECT * FROM t + * ) + * - If a CTE definition contains a subquery expression that contains an inner WITH node then + * substitution of inner should take precedence because it can shadow an outer CTE + * definition. + * For example the following query should return 2: + * WITH t AS (SELECT 1) + * SELECT ( + * WITH t AS (SELECT 2) + * SELECT * FROM t + * ) + * @param plan the plan to be traversed + * @param inTraverse whether the current traverse is called from another traverse, only in this + * case name collision can occur + * @return the plan where CTE substitution is applied + */ + private def traverseAndSubstituteCTE(plan: LogicalPlan, inTraverse: Boolean): LogicalPlan = { + plan.resolveOperatorsUp { + case With(child: LogicalPlan, relations) => + // child might contain an inner CTE that has priority so traverse and substitute inner CTEs + // in child first + val traversedChild: LogicalPlan = child transformExpressions { + case e: SubqueryExpression => e.withNewPlan(traverseAndSubstituteCTE(e.plan, true)) + } + + // Substitute CTE definitions from last to first as a CTE definition can reference a + // previous one + relations.foldRight(traversedChild) { + case ((cteName, ctePlan), currentPlan) => + // A CTE definition might contain an inner CTE that has priority, so traverse and + // substitute CTE defined in ctePlan. + // A CTE definition might not be used at all or might be used multiple times. To avoid + // computation if it is not used and to avoid multiple recomputation if it is used + // multiple times we use a lazy construct with call-by-name parameter passing. + lazy val substitutedCTEPlan = traverseAndSubstituteCTE(ctePlan, true) + substituteCTE(currentPlan, cteName, substitutedCTEPlan) + } + + // CTE name collision can occur only when inTraverse is true, it helps to avoid eager CTE + // substitution in a subquery expression. + case other if inTraverse => + other.transformExpressions { + case e: SubqueryExpression => e.withNewPlan(traverseAndSubstituteCTE(e.plan, true)) + } + } + } + + private def substituteCTE( + plan: LogicalPlan, + cteName: String, + ctePlan: => LogicalPlan): LogicalPlan = + plan resolveOperatorsUp { + case UnresolvedRelation(Seq(table)) if plan.conf.resolver(cteName, table) => ctePlan + + case other => + // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. + other transformExpressions { + case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.plan, cteName, ctePlan)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 02031e758d83..ae19d02e4475 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -353,6 +354,59 @@ trait CheckAnalysis extends PredicateHelper { case _ => } + case alter: AlterTable if alter.childrenResolved => + val table = alter.table + def findField(operation: String, fieldName: Array[String]): StructField = { + // include collections because structs nested in maps and arrays may be altered + val field = table.schema.findNestedField(fieldName, includeCollections = true) + if (field.isEmpty) { + throw new AnalysisException( + s"Cannot $operation missing field in ${table.name} schema: ${fieldName.quoted}") + } + field.get + } + + alter.changes.foreach { + case add: AddColumn => + val parent = add.fieldNames.init + if (parent.nonEmpty) { + findField("add to", parent) + } + case update: UpdateColumnType => + val field = findField("update", update.fieldNames) + val fieldName = update.fieldNames.quoted + update.newDataType match { + case _: StructType => + throw new AnalysisException( + s"Cannot update ${table.name} field $fieldName type: " + + s"update a struct by adding, deleting, or updating its fields") + case _: MapType => + throw new AnalysisException( + s"Cannot update ${table.name} field $fieldName type: " + + s"update a map by updating $fieldName.key or $fieldName.value") + case _: ArrayType => + throw new AnalysisException( + s"Cannot update ${table.name} field $fieldName type: " + + s"update the element by updating $fieldName.element") + case _: AtomicType => + // update is okay + } + if (!Cast.canUpCast(field.dataType, update.newDataType)) { + throw new AnalysisException( + s"Cannot update ${table.name} field $fieldName: " + + s"${field.dataType.simpleString} cannot be cast to " + + s"${update.newDataType.simpleString}") + } + case rename: RenameColumn => + findField("rename", rename.fieldNames) + case update: UpdateColumnComment => + findField("update", update.fieldNames) + case delete: DeleteColumn => + findField("delete", delete.fieldNames) + case _ => + // no validation needed for set and remove property + } + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9fe95671cda0..c72400a8b72c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -222,9 +222,12 @@ object FunctionRegistry { // math functions expression[Acos]("acos"), + expression[Acosh]("acosh"), expression[Asin]("asin"), + expression[Asinh]("asinh"), expression[Atan]("atan"), expression[Atan2]("atan2"), + expression[Atanh]("atanh"), expression[Bin]("bin"), expression[BRound]("bround"), expression[Cbrt]("cbrt"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1fdec89e258a..3125f8cb732d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { CaseWhenCoercion :: IfCoercion :: StackCoercion :: - Division :: + Division(conf) :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -666,7 +666,7 @@ object TypeCoercion { * Hive only performs integral division with the DIV operator. The arguments to / are always * converted to fractional types. */ - object Division extends TypeCoercionRule { + case class Division(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, @@ -677,7 +677,12 @@ object TypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => - Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + (left.dataType, right.dataType) match { + case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision => + IntegralDivide(left, right) + case _ => + Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + } } private def isNumericOrNull(ex: Expression): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b700c336e6ae..9e0e0d528a96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -40,12 +40,15 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str * * @param multipartIdentifier table name */ -case class UnresolvedRelation(multipartIdentifier: Seq[String]) extends LeafNode { +case class UnresolvedRelation( + multipartIdentifier: Seq[String]) extends LeafNode with NamedRelation { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ /** Returns a `.` separated name for this relation. */ def tableName: String = multipartIdentifier.quoted + override def name: String = tableName + override def output: Seq[Attribute] = Nil override lazy val resolved = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8477e63135e3..f671ede21782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2641,8 +2641,8 @@ object Sequence { while (t < exclusiveItem ^ stepSign < 0) { arr(i) = fromLong(t / scale) - t = timestampAddInterval(t, stepMonths, stepMicros, timeZone) i += 1 + t = timestampAddInterval(startMicros, i * stepMonths, i * stepMicros, timeZone) } // truncate array to the correct length @@ -2676,12 +2676,6 @@ object Sequence { |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} """.stripMargin - val timestampAddIntervalCode = - s""" - |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( - | $t, $stepMonths, $stepMicros, $genTimeZone); - """.stripMargin - s""" |final int $stepMonths = $step.months; |final long $stepMicros = $step.microseconds; @@ -2705,8 +2699,9 @@ object Sequence { | | while ($t < $exclusiveItem ^ $stepSign < 0) { | $arr[$i] = ($elemType) ($t / ${scale}L); - | $timestampAddIntervalCode | $i += 1; + | $t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $startMicros, $i * $stepMonths, $i * $stepMicros, $genTimeZone); | } | | if ($arr.length > $i) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bdeb9ed29e0a..e873f8ed1a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -287,6 +287,29 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") """) case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns inverse hyperbolic cosine of `expr`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, + examples = """ + Examples: + > SELECT _FUNC_(1); + 0.0 + > SELECT _FUNC_(0); + NaN + """, + since = "3.0.0") +case class Acosh(child: Expression) + extends UnaryMathExpression((x: Double) => math.log(x + math.sqrt(x * x - 1.0)), "ACOSH") { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => s"java.lang.Math.log($c + java.lang.Math.sqrt($c * $c - 1.0))") + } +} + /** * Convert a num from one base to another * @@ -557,6 +580,31 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") """) case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns inverse hyperbolic sine of `expr`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, + examples = """ + Examples: + > SELECT _FUNC_(0); + 0.0 + """, + since = "3.0.0") +case class Asinh(child: Expression) + extends UnaryMathExpression((x: Double) => x match { + case Double.NegativeInfinity => Double.NegativeInfinity + case _ => math.log(x + math.sqrt(x * x + 1.0)) }, "ASINH") { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => + s"$c == Double.NEGATIVE_INFINITY ? Double.NEGATIVE_INFINITY : " + + s"java.lang.Math.log($c + java.lang.Math.sqrt($c * $c + 1.0))") + } +} + @ExpressionDescription( usage = "_FUNC_(expr) - Returns the square root of `expr`.", examples = """ @@ -617,6 +665,29 @@ case class Cot(child: Expression) """) case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns inverse hyperbolic tangent of `expr`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, + examples = """ + Examples: + > SELECT _FUNC_(0); + 0.0 + > SELECT _FUNC_(2); + NaN + """, + since = "3.0.0") +case class Atanh(child: Expression) + extends UnaryMathExpression((x: Double) => 0.5 * math.log((1.0 + x) / (1.0 - x)), "ATANH") { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => s"0.5 * java.lang.Math.log((1.0 + $c)/(1.0 - $c))") + } +} + @ExpressionDescription( usage = "_FUNC_(expr) - Converts radians to degrees.", arguments = """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06985ac85b70..02d5a1f27aa7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -116,6 +116,10 @@ trait PredicateHelper { // non-correlated subquery will be replaced as literal e.children.isEmpty case a: AttributeReference => true + // PythonUDF will be executed by dedicated physical operator later. + // For PythonUDFs that can't be evaluated in join condition, `PullOutPythonUDFInJoinCondition` + // will pull them out later. + case _: PythonUDF => true case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a816922f49ae..51d2a73ea97b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -118,19 +118,23 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Replace null with default value for joining key, then those rows with null in it could // be joined together case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => - Some((Coalesce(Seq(l, Literal.default(l.dataType))), - Coalesce(Seq(r, Literal.default(r.dataType))))) + Seq((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType)))), + (IsNull(l), IsNull(r)) + ) case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => - Some((Coalesce(Seq(r, Literal.default(r.dataType))), - Coalesce(Seq(l, Literal.default(l.dataType))))) + Seq((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType)))), + (IsNull(r), IsNull(l)) + ) case other => None } val otherPredicates = predicates.filterNot { case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false - case EqualTo(l, r) => + case Equality(l, r) => canEvaluate(l, left) && canEvaluate(r, right) || canEvaluate(l, right) && canEvaluate(r, left) - case other => false + case _ => false } if (joinKeys.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 72f098354776..2cb04c9ec70c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TableChange} +import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, ColumnChange} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} @@ -507,6 +508,40 @@ case class DropTable( ident: Identifier, ifExists: Boolean) extends Command +/** + * Alter a table. + */ +case class AlterTable( + catalog: TableCatalog, + ident: Identifier, + table: NamedRelation, + changes: Seq[TableChange]) extends Command { + + override def children: Seq[LogicalPlan] = Seq(table) + + override lazy val resolved: Boolean = childrenResolved && { + changes.forall { + case add: AddColumn => + add.fieldNames match { + case Array(_) => + // a top-level field can always be added + true + case _ => + // the parent field must exist + table.schema.findNestedField(add.fieldNames.init, includeCollections = true).isDefined + } + + case colChange: ColumnChange => + // the column that will be changed must exist + table.schema.findNestedField(colChange.fieldNames, includeCollections = true).isDefined + + case _ => + // property changes require no resolution checks + true + } + } +} + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 63e778af889a..1daf65a0c560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -45,17 +45,19 @@ object DateTimeUtils { // it's 2440587.5, rounding up to compatible with Hive final val JULIAN_DAY_OF_EPOCH = 2440588 - final val NANOS_PER_MICROS = MICROSECONDS.toNanos(1) - final val NANOS_PER_MILLIS = MILLISECONDS.toNanos(1) - final val NANOS_PER_SECOND = SECONDS.toNanos(1) - final val MICROS_PER_MILLIS = MILLISECONDS.toMicros(1) - final val MICROS_PER_SECOND = SECONDS.toMicros(1) - final val MICROS_PER_DAY = DAYS.toMicros(1) - final val MILLIS_PER_SECOND = SECONDS.toMillis(1) - final val MILLIS_PER_MINUTE = MINUTES.toMillis(1) - final val MILLIS_PER_HOUR = HOURS.toMillis(1) - final val MILLIS_PER_DAY = DAYS.toMillis(1) - final val SECONDS_PER_DAY = DAYS.toSeconds(1) + // Pre-calculated values can provide an opportunity of additional optimizations + // to the compiler like constants propagation and folding. + final val NANOS_PER_MICROS: Long = 1000 + final val MICROS_PER_MILLIS: Long = 1000 + final val MILLIS_PER_SECOND: Long = 1000 + final val SECONDS_PER_DAY: Long = 24 * 60 * 60 + final val MICROS_PER_SECOND: Long = MILLIS_PER_SECOND * MICROS_PER_MILLIS + final val NANOS_PER_MILLIS: Long = NANOS_PER_MICROS * MICROS_PER_MILLIS + final val NANOS_PER_SECOND: Long = NANOS_PER_MICROS * MICROS_PER_SECOND + final val MICROS_PER_DAY: Long = SECONDS_PER_DAY * MICROS_PER_SECOND + final val MILLIS_PER_MINUTE: Long = 60 * MILLIS_PER_SECOND + final val MILLIS_PER_HOUR: Long = 60 * MILLIS_PER_MINUTE + final val MILLIS_PER_DAY: Long = SECONDS_PER_DAY * MILLIS_PER_SECOND // number of days between 1.1.1970 and 1.1.2001 final val to2001 = -11323 @@ -503,60 +505,12 @@ object DateTimeUtils { LocalDate.ofEpochDay(date).getDayOfMonth } - /** - * The number of days for each month (not leap year) - */ - private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) - - /** - * Returns the date value for the first day of the given month. - * The month is expressed in months since year zero (17999 BC), starting from 0. - */ - private def firstDayOfMonth(absoluteMonth: Int): SQLDate = { - val absoluteYear = absoluteMonth / 12 - var monthInYear = absoluteMonth - absoluteYear * 12 - var date = getDateFromYear(absoluteYear) - if (monthInYear >= 2 && isLeap(absoluteYear + YearZero)) { - date += 1 - } - while (monthInYear > 0) { - date += monthDays(monthInYear - 1) - monthInYear -= 1 - } - date - } - - /** - * Returns the date value for January 1 of the given year. - * The year is expressed in years since year zero (17999 BC), starting from 0. - */ - private def getDateFromYear(absoluteYear: Int): SQLDate = { - val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100 - + absoluteYear / 4) - absoluteDays - toYearZero - } - /** * Add date and year-month interval. * Returns a date value, expressed in days since 1.1.1970. */ def dateAddMonths(days: SQLDate, months: Int): SQLDate = { - val (year, monthInYear, dayOfMonth, daysToMonthEnd) = splitDate(days) - val absoluteMonth = (year - YearZero) * 12 + monthInYear - 1 + months - val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0 - val currentMonthInYear = nonNegativeMonth % 12 - val currentYear = nonNegativeMonth / 12 - - val leapDay = if (currentMonthInYear == 1 && isLeap(currentYear + YearZero)) 1 else 0 - val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay - - val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) { - // last day of the month - lastDayOfMonth - } else { - dayOfMonth - } - firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1 + LocalDate.ofEpochDay(days).plusMonths(months).toEpochDay.toInt } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index feb3b46df0cd..57f5128fd4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -234,6 +234,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED = + buildConf("spark.sql.inMemoryTableScanStatistics.enable") + .internal() + .doc("When true, enable in-memory table scan accumulators.") + .booleanConf + .createWithDefault(false) + val CACHE_VECTORIZED_READER_ENABLED = buildConf("spark.sql.inMemoryColumnarStorage.enableVectorizedReader") .doc("Enables vectorized reader for columnar caching.") @@ -1024,6 +1031,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ENABLE_VECTORIZED_HASH_MAP = + buildConf("spark.sql.codegen.aggregate.map.vectorized.enable") + .internal() + .doc("Enable vectorized aggregate hash map. This is for testing/benchmarking only.") + .booleanConf + .createWithDefault(false) + val MAX_NESTED_VIEW_DEPTH = buildConf("spark.sql.view.maxNestedViewDepth") .internal() @@ -1510,6 +1524,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision") + .doc("When true, will perform integral division with the / operator " + + "if both sides are integral types.") + .booleanConf + .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation") .internal() @@ -1837,6 +1857,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_CTE_PRECEDENCE_ENABLED = buildConf("spark.sql.legacy.ctePrecedence.enabled") + .internal() + .doc("When true, outer CTE definitions takes precedence over inner definitions.") + .booleanConf + .createWithDefault(false) + val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC = buildConf("spark.sql.legacy.arrayExistsFollowsThreeValuedLogic") .doc("When true, the ArrayExists will follow the three-valued boolean logic.") @@ -2109,6 +2135,8 @@ class SQLConf extends Serializable with Logging { def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + def inMemoryTableScanStatisticsEnabled: Boolean = getConf(IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED) + def offHeapColumnVectorEnabled: Boolean = getConf(COLUMN_VECTOR_OFFHEAP_ENABLED) def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) @@ -2148,6 +2176,8 @@ class SQLConf extends Serializable with Logging { def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) + def enableVectorizedHashMap: Boolean = getConf(ENABLE_VECTORIZED_HASH_MAP) + def useObjectHashAggregation: Boolean = getConf(USE_OBJECT_HASH_AGG) def objectAggSortBasedFallbackThreshold: Int = getConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD) @@ -2270,6 +2300,8 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index edf8d2c1b31a..236f73ba3832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -310,20 +310,46 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * Returns a field in this struct and its child structs. * - * This does not support finding fields nested in maps or arrays. + * If includeCollections is true, this will return fields that are nested in maps and arrays. */ - private[sql] def findNestedField(fieldNames: Seq[String]): Option[StructField] = { + private[sql] def findNestedField( + fieldNames: Seq[String], + includeCollections: Boolean = false): Option[StructField] = { fieldNames.headOption.flatMap(nameToField.get) match { case Some(field) => - if (fieldNames.tail.isEmpty) { - Some(field) - } else { - field.dataType match { - case struct: StructType => - struct.findNestedField(fieldNames.tail) - case _ => - None - } + (fieldNames.tail, field.dataType, includeCollections) match { + case (Seq(), _, _) => + Some(field) + + case (names, struct: StructType, _) => + struct.findNestedField(names, includeCollections) + + case (_, _, false) => + None // types nested in maps and arrays are not used + + case (Seq("key"), MapType(keyType, _, _), true) => + // return the key type as a struct field to include nullability + Some(StructField("key", keyType, nullable = false)) + + case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) => + struct.findNestedField(names, includeCollections) + + case (Seq("value"), MapType(_, valueType, isNullable), true) => + // return the value type as a struct field to include nullability + Some(StructField("value", valueType, nullable = isNullable)) + + case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) => + struct.findNestedField(names, includeCollections) + + case (Seq("element"), ArrayType(elementType, isNullable), true) => + // return the element type as a struct field to include nullability + Some(StructField("element", elementType, nullable = isNullable)) + + case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) => + struct.findNestedField(names, includeCollections) + + case _ => + None } case _ => None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2c3ba1b0daf4..949bb30d1550 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1126,14 +1126,14 @@ class TypeCoercionSuite extends AnalysisTest { Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), Cast(Literal(new Timestamp(0)), StringType)))) - withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") { + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "true") { ruleTest(rule, Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), Concat(Seq(Cast(Literal("123".getBytes), StringType), Cast(Literal("456".getBytes), StringType)))) } - withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") { + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { ruleTest(rule, Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) @@ -1180,14 +1180,14 @@ class TypeCoercionSuite extends AnalysisTest { Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), Cast(Literal(new Timestamp(0)), StringType)))) - withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "true") { ruleTest(rule, Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), Cast(Literal("456".getBytes), StringType)))) } - withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "false") { ruleTest(rule, Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) @@ -1456,7 +1456,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(FunctionArgumentConversion, Division(conf)) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1475,12 +1475,35 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division(conf), ImplicitTypeCasts) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) } + test("SPARK-28395 Division operator support integral division") { + val rules = Seq(FunctionArgumentConversion, Division(conf)) + Seq(true, false).foreach { preferIntegralDivision => + withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") { + val result1 = if (preferIntegralDivision) { + IntegralDivide(1L, 1L) + } else { + Divide(Cast(1L, DoubleType), Cast(1L, DoubleType)) + } + ruleTest(rules, Divide(1L, 1L), result1) + val result2 = if (preferIntegralDivision) { + IntegralDivide(1, Cast(1, ShortType)) + } else { + Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType)) + } + ruleTest(rules, Divide(1, Cast(1, ShortType)), result2) + + ruleTest(rules, Divide(1L, 1D), Divide(Cast(1L, DoubleType), Cast(1D, DoubleType))) + ruleTest(rules, Divide(Decimal(1.1), 1L), Divide(Decimal(1.1), 1L)) + } + } + } + test("binary comparison with string promotion") { val rule = TypeCoercion.PromoteStrings(conf) ruleTest(rule, @@ -1498,7 +1521,7 @@ class TypeCoercionSuite extends AnalysisTest { DoubleType))) Seq(true, false).foreach { convertToTS => withSQLConf( - "spark.sql.legacy.compareDateTimestampInTimestamp" -> convertToTS.toString) { + SQLConf.COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP.key -> convertToTS.toString) { val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 04bb61a7486e..4e8322d3c55d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -462,13 +462,19 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), null) + // Valid range of DateType is [0001-01-01, 9999-12-31] + val maxMonthInterval = 10000 * 12 checkEvaluation( - AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498) + AddMonths(Literal(Date.valueOf("0001-01-01")), Literal(maxMonthInterval)), 2933261) checkEvaluation( - AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213) - checkEvaluation( - AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528) - checkConsistencyBetweenInterpretedAndCodegen(AddMonths, DateType, IntegerType) + AddMonths(Literal(Date.valueOf("9999-12-31")), Literal(-1 * maxMonthInterval)), -719529) + // Test evaluation results between Interpreted mode and Codegen mode + forAll ( + LiteralGenerator.randomGen(DateType), + LiteralGenerator.monthIntervalLiterGen + ) { (l1: Literal, l2: Literal) => + cmpInterpretWithCodegen(EmptyRow, AddMonths(l1, l2)) + } } test("months_between") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1c91adab7137..a2c0ce35df23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -398,7 +398,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } - private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { + def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try { evaluateWithoutCodegen(expr, inputRow) } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index be5fdb5b42ea..b111797c3588 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate} +import java.util.concurrent.TimeUnit import org.scalacheck.{Arbitrary, Gen} @@ -100,23 +102,44 @@ object LiteralGenerator { lazy val booleanLiteralGen: Gen[Literal] = for { b <- Arbitrary.arbBool.arbitrary } yield Literal.create(b, BooleanType) - lazy val dateLiteralGen: Gen[Literal] = - for { d <- Arbitrary.arbInt.arbitrary } yield Literal.create(new Date(d), DateType) + lazy val dateLiteralGen: Gen[Literal] = { + // Valid range for DateType is [0001-01-01, 9999-12-31] + val minDay = LocalDate.of(1, 1, 1).toEpochDay + val maxDay = LocalDate.of(9999, 12, 31).toEpochDay + for { day <- Gen.choose(minDay, maxDay) } + yield Literal.create(new Date(day * DateTimeUtils.MILLIS_PER_DAY), DateType) + } lazy val timestampLiteralGen: Gen[Literal] = { // Catalyst's Timestamp type stores number of microseconds since epoch in // a variable of Long type. To prevent arithmetic overflow of Long on // conversion from milliseconds to microseconds, the range of random milliseconds // since epoch is restricted here. - val maxMillis = Long.MaxValue / DateTimeUtils.MICROS_PER_MILLIS - val minMillis = Long.MinValue / DateTimeUtils.MICROS_PER_MILLIS + // Valid range for TimestampType is [0001-01-01T00:00:00.000000Z, 9999-12-31T23:59:59.999999Z] + val minMillis = Instant.parse("0001-01-01T00:00:00.000000Z").toEpochMilli + val maxMillis = Instant.parse("9999-12-31T23:59:59.999999Z").toEpochMilli for { millis <- Gen.choose(minMillis, maxMillis) } yield Literal.create(new Timestamp(millis), TimestampType) } - lazy val calendarIntervalLiterGen: Gen[Literal] = - for { m <- Arbitrary.arbInt.arbitrary; s <- Arbitrary.arbLong.arbitrary} - yield Literal.create(new CalendarInterval(m, s), CalendarIntervalType) + // Valid range for DateType and TimestampType is [0001-01-01, 9999-12-31] + private val maxIntervalInMonths: Int = 10000 * 12 + + lazy val monthIntervalLiterGen: Gen[Literal] = { + for { months <- Gen.choose(-1 * maxIntervalInMonths, maxIntervalInMonths) } + yield Literal.create(months, IntegerType) + } + + lazy val calendarIntervalLiterGen: Gen[Literal] = { + val maxDurationInSec = Duration.between( + Instant.parse("0001-01-01T00:00:00.000000Z"), + Instant.parse("9999-12-31T23:59:59.999999Z")).getSeconds + val maxMicros = TimeUnit.SECONDS.toMicros(maxDurationInSec) + for { + months <- Gen.choose(-1 * maxIntervalInMonths, maxIntervalInMonths) + micros <- Gen.choose(-1 * maxMicros, maxMicros) + } yield Literal.create(new CalendarInterval(months, micros), CalendarIntervalType) + } // Sometimes, it would be quite expensive when unlimited value is used, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 48105571b279..4c048f79741b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -199,6 +199,18 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) } + test("asinh") { + testUnary(Asinh, (x: Double) => math.log(x + math.sqrt(x * x + 1.0))) + checkConsistencyBetweenInterpretedAndCodegen(Asinh, DoubleType) + + checkEvaluation(Asinh(Double.NegativeInfinity), Double.NegativeInfinity) + + val nullLit = Literal.create(null, NullType) + val doubleNullLit = Literal.create(null, DoubleType) + checkEvaluation(checkDataTypeAndCast(Asinh(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Asinh(doubleNullLit)), null, EmptyRow) + } + test("cos") { testUnary(Cos, math.cos) checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) @@ -215,6 +227,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) } + test("acosh") { + testUnary(Acosh, (x: Double) => math.log(x + math.sqrt(x * x - 1.0))) + checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) + + val nullLit = Literal.create(null, NullType) + val doubleNullLit = Literal.create(null, DoubleType) + checkEvaluation(checkDataTypeAndCast(Acosh(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Acosh(doubleNullLit)), null, EmptyRow) + } + test("tan") { testUnary(Tan, math.tan) checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) @@ -244,6 +266,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) } + test("atanh") { + testUnary(Atanh, (x: Double) => 0.5 * math.log((1.0 + x) / (1.0 - x))) + checkConsistencyBetweenInterpretedAndCodegen(Atanh, DoubleType) + + val nullLit = Literal.create(null, NullType) + val doubleNullLit = Literal.create(null, DoubleType) + checkEvaluation(checkDataTypeAndCast(Atanh(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Atanh(doubleNullLit)), null, EmptyRow) + } + test("toDegrees") { testUnary(ToDegrees, math.toDegrees) checkConsistencyBetweenInterpretedAndCodegen(ToDegrees, DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 2db4667fd056..3ec8d18bc871 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { @@ -41,9 +42,14 @@ class FilterPushdownSuite extends PlanTest { CollapseProject) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val attrA = 'a.int + val attrB = 'b.int + val attrC = 'c.int + val attrD = 'd.int - val testRelation1 = LocalRelation('d.int) + val testRelation = LocalRelation(attrA, attrB, attrC) + + val testRelation1 = LocalRelation(attrD) // This test already passes. test("eliminate subqueries") { @@ -1202,4 +1208,26 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) } + + test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") { + val pythonUDFJoinCond = { + val pythonUDF = PythonUDF("pythonUDF", null, + IntegerType, + Seq(attrA), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + pythonUDF === attrD + } + + val query = testRelation.join( + testRelation1, + joinType = Cross).where(pythonUDFJoinCond) + + val expected = testRelation.join( + testRelation1, + joinType = Cross, + condition = Some(pythonUDFJoinCond)).analyze + + comparePlans(Optimize.execute(query.analyze), expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index 5f616da2978b..f5af416602c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized +import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -78,5 +78,18 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { comparePlans(doubleOptimized, correctAnswer) } + + test("normalize floating points in join keys (equal null safe) - idempotence") { + val query = testRelation1.join(testRelation2, condition = Some(a <=> b)) + + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + val joinCond = IsNull(a) === IsNull(b) && + KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) === + KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0))) + val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) + + comparePlans(doubleOptimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index c77c9aec6887..4f8353922319 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -359,18 +359,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { test("date add months") { val input = days(1997, 2, 28, 10, 30) - assert(dateAddMonths(input, 36) === days(2000, 2, 29)) - assert(dateAddMonths(input, -13) === days(1996, 1, 31)) + assert(dateAddMonths(input, 36) === days(2000, 2, 28)) + assert(dateAddMonths(input, -13) === days(1996, 1, 28)) } test("timestamp add months") { val ts1 = date(1997, 2, 28, 10, 30, 0) - val ts2 = date(2000, 2, 29, 10, 30, 0, 123000) + val ts2 = date(2000, 2, 28, 10, 30, 0, 123000) assert(timestampAddInterval(ts1, 36, 123000, defaultTz) === ts2) val ts3 = date(1997, 2, 27, 16, 0, 0, 0, TimeZonePST) val ts4 = date(2000, 2, 27, 16, 0, 0, 123000, TimeZonePST) - val ts5 = date(2000, 2, 29, 0, 0, 0, 123000, TimeZoneGMT) + val ts5 = date(2000, 2, 28, 0, 0, 0, 123000, TimeZoneGMT) assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST) === ts4) assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT) === ts5) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index fd6f7be2abc5..1bd7b825328d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -20,9 +20,10 @@ import java.util.concurrent.TimeUnit; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger; import scala.concurrent.duration.Duration; -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger; +import org.apache.spark.sql.execution.streaming.ContinuousTrigger; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; /** @@ -40,7 +41,7 @@ public class Trigger { * @since 2.2.0 */ public static Trigger ProcessingTime(long intervalMs) { - return ProcessingTime.create(intervalMs, TimeUnit.MILLISECONDS); + return ProcessingTimeTrigger.create(intervalMs, TimeUnit.MILLISECONDS); } /** @@ -56,7 +57,7 @@ public static Trigger ProcessingTime(long intervalMs) { * @since 2.2.0 */ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { - return ProcessingTime.create(interval, timeUnit); + return ProcessingTimeTrigger.create(interval, timeUnit); } /** @@ -71,7 +72,7 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { * @since 2.2.0 */ public static Trigger ProcessingTime(Duration interval) { - return ProcessingTime.apply(interval); + return ProcessingTimeTrigger.apply(interval); } /** @@ -84,7 +85,7 @@ public static Trigger ProcessingTime(Duration interval) { * @since 2.2.0 */ public static Trigger ProcessingTime(String interval) { - return ProcessingTime.apply(interval); + return ProcessingTimeTrigger.apply(interval); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a3cbea9021f2..0da52d432d25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -520,6 +520,71 @@ class KeyValueGroupedDataset[K, V] private[sql]( col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + aggUntyped(col1, col2, col3, col4, col5, col6) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] + /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f0ef6e19b0aa..bb05c76cfee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -142,16 +142,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") val version = if (i == 0) "2.3.0" else "1.3.0" - val funcCall = if (i == 0) "() => func" else "func" + val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { - | val func = f$anyCast.call($anyParams) + | val func = $funcCall | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name)) + | ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -717,9 +717,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { - val func = f.asInstanceOf[UDF0[Any]].call() + val func = () => f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 315eba6635aa..4385843d9011 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -53,8 +53,8 @@ class ColumnarRule { * Provides a common executor to translate an [[RDD]] of [[ColumnarBatch]] into an [[RDD]] of * [[InternalRow]]. This is inserted whenever such a transition is determined to be needed. * - * The implementation is based off of similar implementations in [[ColumnarBatchScan]], - * [[org.apache.spark.sql.execution.python.ArrowEvalPythonExec]], and + * The implementation is based off of similar implementations in + * [[org.apache.spark.sql.execution.python.ArrowEvalPythonExec]] and * [[MapPartitionsInRWithArrowExec]]. Eventually this should replace those implementations. */ case class ColumnarToRowExec(child: SparkPlan) @@ -96,9 +96,6 @@ case class ColumnarToRowExec(child: SparkPlan) /** * Generate [[ColumnVector]] expressions for our parent to consume as rows. * This is called once per [[ColumnVector]] in the batch. - * - * This code came unchanged from [[ColumnarBatchScan]] and will hopefully replace it - * at some point. */ private def genCodeColumnVector( ctx: CodegenContext, @@ -130,9 +127,6 @@ case class ColumnarToRowExec(child: SparkPlan) * Produce code to process the input iterator as [[ColumnarBatch]]es. * This produces an [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] for each row in * each batch. - * - * This code came almost completely unchanged from [[ColumnarBatchScan]] and will - * hopefully replace it at some point. */ override protected def doProduce(ctx: CodegenContext): String = { // PhysicalRDD always just has one input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala deleted file mode 100644 index b2e9f760d27c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} - - -/** - * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. - */ -private[sql] trait ColumnarBatchScan extends CodegenSupport { - - protected def supportsBatch: Boolean = true - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - - /** - * Generate [[ColumnVector]] expressions for our parent to consume as rows. - * This is called once per [[ColumnarBatch]]. - */ - private def genCodeColumnVector( - ctx: CodegenContext, - columnVar: String, - ordinal: String, - dataType: DataType, - nullable: Boolean): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) - val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { - JavaCode.isNullVariable(ctx.freshName("isNull")) - } else { - FalseLiteral - } - val valueVar = ctx.freshName("value") - val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = code"${ctx.registerComment(str)}" + (if (nullable) { - code""" - boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); - """ - } else { - code"$javaType $valueVar = $value;" - }) - ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) - } - - /** - * Produce code to process the input iterator as [[ColumnarBatch]]es. - * This produces an [[UnsafeRow]] for each row in each batch. - */ - // TODO: return ColumnarBatch.Rows instead - override protected def doProduce(ctx: CodegenContext): String = { - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", - v => s"$v = inputs[0];") - if (supportsBatch) { - produceBatches(ctx, input) - } else { - produceRows(ctx, input) - } - } - - private def produceBatches(ctx: CodegenContext, input: String): String = { - // metrics - val numOutputRows = metricTerm(ctx, "numOutputRows") - val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = - ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 - - val columnarBatchClz = classOf[ColumnarBatch].getName - val batch = ctx.addMutableState(columnarBatchClz, "batch") - - val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 - val columnVectorClzs = vectorTypes.getOrElse( - Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) - val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { - case (columnVectorClz, i) => - val name = ctx.addMutableState(columnVectorClz, s"colInstance$i") - (name, s"$name = ($columnVectorClz) $batch.column($i);") - }.unzip - - val nextBatch = ctx.freshName("nextBatch") - val nextBatchFuncName = ctx.addNewFunction(nextBatch, - s""" - |private void $nextBatch() throws java.io.IOException { - | long getBatchStart = System.nanoTime(); - | if ($input.hasNext()) { - | $batch = ($columnarBatchClz)$input.next(); - | $numOutputRows.add($batch.numRows()); - | $idx = 0; - | ${columnAssigns.mkString("", "\n", "\n")} - | } - | $scanTimeTotalNs += System.nanoTime() - getBatchStart; - |}""".stripMargin) - - ctx.currentVars = null - val rowidx = ctx.freshName("rowIdx") - val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => - genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) - } - val localIdx = ctx.freshName("localIdx") - val localEnd = ctx.freshName("localEnd") - val numRows = ctx.freshName("numRows") - val shouldStop = if (parent.needStopCheck) { - s"if (shouldStop()) { $idx = $rowidx + 1; return; }" - } else { - "// shouldStop check is eliminated" - } - s""" - |if ($batch == null) { - | $nextBatchFuncName(); - |} - |while ($limitNotReachedCond $batch != null) { - | int $numRows = $batch.numRows(); - | int $localEnd = $numRows - $idx; - | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; - | ${consume(ctx, columnsBatchInput).trim} - | $shouldStop - | } - | $idx = $numRows; - | $batch = null; - | $nextBatchFuncName(); - |} - |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); - |$scanTimeTotalNs = 0; - """.stripMargin - } - - private def produceRows(ctx: CodegenContext, input: String): String = { - val numOutputRows = metricTerm(ctx, "numOutputRows") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - s""" - |while ($limitNotReachedCond $input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, null, row).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 518460d98f05..728ac3a466fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -37,10 +37,11 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet -trait DataSourceScanExec extends LeafExecNode with CodegenSupport { +trait DataSourceScanExec extends LeafExecNode { val relation: BaseRelation val tableIdentifier: Option[TableIdentifier] @@ -69,6 +70,12 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { private def redact(text: String): String = { Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } + + /** + * The data being read in. This is to provide input to the tests in a way compatible with + * [[InputRDDCodegen]] which all implementations used to extend. + */ + def inputRDDs(): Seq[RDD[InternalRow]] } /** Physical plan node for scanning data from a relation. */ @@ -141,11 +148,11 @@ case class FileSourceScanExec( optionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) - extends DataSourceScanExec with ColumnarBatchScan { + extends DataSourceScanExec { // Note that some vals referring the file-based relation are lazy intentionally // so that this plan can be canonicalized on executor side too. See SPARK-23731. - override lazy val supportsBatch: Boolean = { + override lazy val supportsColumnar: Boolean = { relation.fileFormat.supportBatch(relation.sparkSession, schema) } @@ -275,7 +282,7 @@ case class FileSourceScanExec( Map( "Format" -> relation.fileFormat.toString, "ReadSchema" -> requiredSchema.catalogString, - "Batched" -> supportsBatch.toString, + "Batched" -> supportsColumnar.toString, "PartitionFilters" -> seqToString(partitionFilters), "PushedFilters" -> seqToString(pushedDownFilters), "DataFilters" -> seqToString(dataFilters), @@ -302,7 +309,7 @@ case class FileSourceScanExec( withSelectedBucketsCount } - private lazy val inputRDD: RDD[InternalRow] = { + lazy val inputRDD: RDD[InternalRow] = { val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -334,29 +341,30 @@ case class FileSourceScanExec( "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { - if (supportsBatch) { - // in the case of fallback, this batched scan should never fail because of: - // 1) only primitive types are supported - // 2) the number of columns should be smaller than spark.sql.codegen.maxFields - WholeStageCodegenExec(this)(codegenStageId = 0).execute() - } else { - val numOutputRows = longMetric("numOutputRows") - - if (needsUnsafeRowConversion) { - inputRDD.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map( r => { - numOutputRows += 1 - proj(r) - }) - } - } else { - inputRDD.map { r => + val numOutputRows = longMetric("numOutputRows") + + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map( r => { numOutputRows += 1 - r - } + proj(r) + }) } + } else { + inputRDD.map { r => + numOutputRows += 1 + r + } + } + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].map { batch => + numOutputRows += batch.numRows() + batch } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 981ecae80a72..1ab183fe843f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -69,9 +69,8 @@ case class ExternalRDDScanExec[T]( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val outputDataType = outputObjAttr.dataType rdd.mapPartitionsInternal { iter => - val outputObject = ObjectOperator.wrapObjectToRow(outputDataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) iter.map { value => numOutputRows += 1 outputObject(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4d5a2b9b3f0..550094193644 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -703,7 +703,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), + planLater(child), canChangeNumPartitions = false) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -736,7 +737,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => - exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil + exchange.ShuffleExchangeExec( + r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 94a5ede75145..a0afa9a26fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -709,11 +709,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") - child match { - // The fallback solution of batch file source scan still uses WholeStageCodegenExec - case f: FileSourceScanExec if f.supportsBatch => // do nothing - case _ => return child.execute() - } + return child.execute() } val references = ctx.references.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 0708878ece46..61dbc5829738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -84,6 +84,8 @@ case class AdaptiveSparkPlanExec( // optimizations should be stage-independent. @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( ReduceNumShufflePartitions(conf), + ApplyColumnarRulesAndInsertTransitions(session.sessionState.conf, + session.sessionState.columnarRules), CollapseCodegenStages(conf) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index d93eb76b9fbc..78923433eaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -61,12 +61,18 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // If not all leaf nodes are query stages, it's not safe to reduce the number of // shuffle partitions, because we may break the assumption that all children of a spark plan // have same number of output partitions. + return plan + } + + val shuffleStages = plan.collect { + case stage: ShuffleQueryStageExec => stage + case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage + } + // ShuffleExchanges introduced by repartition do not support changing the number of partitions. + // We change the number of partitions in the stage only if all the ShuffleExchanges support it. + if (!shuffleStages.forall(_.plan.canChangeNumPartitions)) { plan } else { - val shuffleStages = plan.collect { - case stage: ShuffleQueryStageExec => stage - case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage - } val shuffleMetrics = shuffleStages.map { stage => val metricsFuture = stage.mapOutputStatisticsFuture assert(metricsFuture.isCompleted, "ShuffleQueryStageExec should already be ready") @@ -76,12 +82,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // `ShuffleQueryStageExec` gives null mapOutputStatistics when the input RDD has 0 partitions, // we should skip it when calculating the `partitionStartIndices`. val validMetrics = shuffleMetrics.filter(_ != null) - // We may get different pre-shuffle partition number if user calls repartition manually. - // We don't reduce shuffle partition number in that case. - val distinctNumPreShufflePartitions = - validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { + if (validMetrics.nonEmpty) { val partitionStartIndices = estimatePartitionStartIndices(validMetrics.toArray) // This transformation adds new nodes, so we must use `transformUp` here. plan.transformUp { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 25ff6584360e..4a95f7638133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -559,7 +560,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { - logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" + logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } } else { @@ -567,8 +568,7 @@ case class HashAggregateExec( // This is for testing/benchmarking only. // We enforce to first level to be a vectorized hashmap, instead of the default row-based one. - isVectorizedHashMapEnabled = sqlContext.getConf( - "spark.sql.codegen.aggregate.map.vectorized.enable", "false") == "true" + isVectorizedHashMapEnabled = sqlContext.conf.enableVectorizedHashMap } } @@ -576,12 +576,8 @@ case class HashAggregateExec( val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) - } else { - sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { - case "true" => - logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") - case _ => - } + } else if (sqlContext.conf.enableVectorizedHashMap) { + logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") } val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 06634c13ec43..3566ab1aa5a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -34,7 +35,10 @@ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode with ColumnarBatchScan { + extends LeafExecNode { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override val nodeName: String = { relation.cacheBuilder.tableName match { @@ -65,7 +69,7 @@ case class InMemoryTableScanExec( * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. * If false, get data from UnsafeRow build from CachedBatch */ - override val supportsBatch: Boolean = { + override val supportsColumnar: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields conf.cacheVectorizedReaderEnabled && relation.schema.fields.forall(f => f.dataType match { @@ -75,9 +79,6 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } - // TODO: revisit this. Shall we always turn off whole stage codegen if the output data are rows? - override def supportCodegen: Boolean = supportsBatch - private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -108,59 +109,61 @@ case class InMemoryTableScanExec( columnarBatch } + private lazy val columnarInputRDD: RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .map(b => { + numOutputRows += b.numRows() + b + }) + } + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled - if (supportsBatch) { - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers - .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) - .asInstanceOf[RDD[InternalRow]] - } else { - val numOutputRows = longMetric("numOutputRows") + val numOutputRows = longMetric("numOutputRows") - if (enableAccumulatorsForTest) { - readPartitions.setValue(0) - readBatches.setValue(0) - } + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } - // Using these variables here to avoid serialization of entire objects (if referenced - // directly) within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulatorsForTest) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) } + numOutputRows += batch.numRows + batch + } - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulatorsForTest && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) } + columnarIterator } } - override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { @@ -294,8 +297,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulatorsForTest: Boolean = - sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + lazy val enableAccumulatorsForTest: Boolean = sqlContext.conf.inMemoryTableScanStatisticsEnabled // Accumulators used for testing purposes lazy val readPartitions = sparkContext.longAccumulator @@ -339,10 +341,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - if (supportsBatch) { - WholeStageCodegenExec(this)(codegenStageId = 0).execute() - } else { - inputRDD - } + inputRDD + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + columnarInputRDD } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 8508322f54e8..b9b86adb438e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -613,7 +613,7 @@ class ParquetFilters( } override def keep(value: Binary): Boolean = { - UTF8String.fromBytes(value.getBytes).startsWith( + value != null && UTF8String.fromBytes(value.getBytes).startsWith( UTF8String.fromBytes(strToBinary.getBytes)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala new file mode 100644 index 000000000000..a3fa82b12e93 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TableChange} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.LeafExecNode + +/** + * Physical plan node for altering a table. + */ +case class AlterTableExec( + catalog: TableCatalog, + ident: Identifier, + changes: Seq[TableChange]) extends LeafExecNode { + + override def output: Seq[Attribute] = Seq.empty + + override protected def doExecute(): RDD[InternalRow] = { + try { + catalog.alterTable(ident, changes: _*) + } catch { + case e: IllegalArgumentException => + throw new SparkException(s"Unsupported table change: ${e.getMessage}", e) + } + + sqlContext.sparkContext.parallelize(Seq.empty, 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 3276ab506750..c3cbb9d2af4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -45,7 +45,7 @@ case class BatchScanExec( override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() override lazy val inputRDD: RDD[InternalRow] = { - new DataSourceRDD(sparkContext, partitions, readerFactory, supportsBatch) + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar) } override def doCanonicalize(): BatchScanExec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 9ad683fbe1df..c5c902ffc410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -23,11 +23,16 @@ import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, Scan, SupportsReportPartitioning} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils -trait DataSourceV2ScanExecBase extends LeafExecNode with ColumnarBatchScan { +trait DataSourceV2ScanExecBase extends LeafExecNode { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) def scan: Scan @@ -52,7 +57,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode with ColumnarBatchScan { case _ => super.outputPartitioning } - override def supportsBatch: Boolean = { + override def supportsColumnar: Boolean = { require(partitions.forall(readerFactory.supportColumnarReads) || !partitions.exists(readerFactory.supportColumnarReads), "Cannot mix row-based and columnar input partitions.") @@ -62,17 +67,22 @@ trait DataSourceV2ScanExecBase extends LeafExecNode with ColumnarBatchScan { def inputRDD: RDD[InternalRow] - override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) override def doExecute(): RDD[InternalRow] = { - if (supportsBatch) { - WholeStageCodegenExec(this)(codegenStageId = 0).execute() - } else { - val numOutputRows = longMetric("numOutputRows") - inputRDD.map { r => - numOutputRows += 1 - r - } + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].map { + b => + numOutputRows += b.numRows() + b } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 27d87960edb3..4f8507da3924 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -202,6 +202,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case DropTable(catalog, ident, ifExists) => DropTableExec(catalog, ident, ifExists) :: Nil + case AlterTable(catalog, ident, _, changes) => + AlterTableExec(catalog, ident, changes) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala index d2e33d4fa77c..a9b0f5bce1b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala @@ -46,6 +46,6 @@ case class MicroBatchScanExec( override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory() override lazy val inputRDD: RDD[InternalRow] = { - new DataSourceRDD(sparkContext, partitions, readerFactory, supportsBatch) + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8184baf50b04..c56a5c015f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -94,7 +93,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) child match { // If child is an exchange, we replace it with a new one having defaultPartitioning. - case ShuffleExchangeExec(_, c) => ShuffleExchangeExec(defaultPartitioning, c) + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) case _ => ShuffleExchangeExec(defaultPartitioning, child) } } @@ -117,25 +116,41 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } private def reorder( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], + leftKeys: IndexedSeq[Expression], + rightKeys: IndexedSeq[Expression], expectedOrderOfKeys: Seq[Expression], currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() - val pickedIndexes = mutable.Set[Int]() - val keysAndIndexes = currentOrderOfKeys.zipWithIndex - - expectedOrderOfKeys.foreach(expression => { - val index = keysAndIndexes.find { case (e, idx) => - // As we may have the same key used many times, we need to filter out its occurrence we - // have already used. - e.semanticEquals(expression) && !pickedIndexes.contains(idx) - }.map(_._2).get - pickedIndexes += index - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) + if (expectedOrderOfKeys.size != currentOrderOfKeys.size) { + return (leftKeys, rightKeys) + } + + // Build a lookup between an expression and the positions its holds in the current key seq. + val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet] + currentOrderOfKeys.zipWithIndex.foreach { + case (key, index) => + keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet.empty).add(index) + } + + // Reorder the keys. + val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size) + val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size) + val iterator = expectedOrderOfKeys.iterator + while (iterator.hasNext) { + // Lookup the current index of this key. + keyToIndexMap.get(iterator.next().canonicalized) match { + case Some(indices) if indices.nonEmpty => + // Take the first available index from the map. + val index = indices.firstKey + indices.remove(index) + + // Add the keys for that index to the reordered keys. + leftKeysBuffer += leftKeys(index) + rightKeysBuffer += rightKeys(index) + case _ => + // The expression cannot be found, or we have exhausted all indices for that expression. + return (leftKeys, rightKeys) + } + } (leftKeysBuffer, rightKeysBuffer) } @@ -145,20 +160,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftPartitioning: Partitioning, rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftKeys, rightKeys, leftExpressions, leftKeys) - - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(leftKeys, rightKeys, rightExpressions, rightKeys) - - case _ => (leftKeys, rightKeys) - } + (leftPartitioning, rightPartitioning) match { + case (HashPartitioning(leftExpressions, _), _) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) + case (_, HashPartitioning(rightExpressions, _)) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) + case _ => + (leftKeys, rightKeys) } } else { (leftKeys, rightKeys) @@ -191,7 +199,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) => + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => child.outputPartitioning match { case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5d0208f1ecc4..fec05a76b451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -43,7 +43,8 @@ import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordCo */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, - child: SparkPlan) extends Exchange { + child: SparkPlan, + canChangeNumPartitions: Boolean = true) extends Exchange { // NOTE: coordinator can be null after serialization/deserialization, // e.g. it can be null on the Executor side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 9a76e144b885..d05113431df4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -191,7 +191,7 @@ case class MapPartitionsExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) func(iter.map(getObject)).map(outputObject) } } @@ -278,10 +278,10 @@ case class MapElementsExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) + case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) - val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) + val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output) val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) @@ -296,7 +296,7 @@ case class MapElementsExec( child.execute().mapPartitionsInternal { iter => val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) iter.map(row => outputObject(callFunc(getObject(row)))) } } @@ -395,7 +395,7 @@ case class MapGroupsExec( val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) grouped.flatMap { case (key, rowIter) => val result = func( @@ -447,8 +447,6 @@ case class FlatMapGroupsInRExec( outputObjAttr: Attribute, child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = @@ -473,7 +471,7 @@ case class FlatMapGroupsInRExec( val grouped = GroupedIterator(iter, groupingAttributes, child.output) val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]]( func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars, isDataFrame = true, colNames = inputSchema.fieldNames, @@ -606,7 +604,7 @@ case class CoGroupExec( val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup) val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr) val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr) - val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 4f352782067c..02bfbc4949b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -81,10 +81,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) - // `Opcodes.MEMOIZE` of Protocol 4 (Python 3.4+) will store objects in internal map - // of `Unpickler`. This map is cleared when calling `Unpickler.close()`. Pyrolite - // doesn't clear it up, so we manually clear it. - unpickle.close() unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => if (udfs.length == 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index f4c2d02ee942..41521bfae1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -32,7 +32,7 @@ case class UserDefinedPythonFunction( pythonEvalType: Int, udfDeterministic: Boolean) { - def builder(e: Seq[Expression]): PythonUDF = { + def builder(e: Seq[Expression]): Expression = { PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 6b6eb78404e3..fe91d2491222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.util.CompletionIterator /** - * Physical operator for executing `FlatMapGroupsWithState.` + * Physical operator for executing `FlatMapGroupsWithState` * * @param func function called on each group * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. * @param groupingAttributes used to group the data * @param dataAttributes used to read the data - * @param outputObjAttr used to define the output object + * @param outputObjAttr Defines the output object * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` * @param timeoutConf used to timeout groups that have not received data in a while @@ -154,7 +154,7 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) private val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index fd2638f30469..e7eb2cb558cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming.sources.{RateControlMicroBatchSt import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2, SparkDataStream} -import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.util.Clock class MicroBatchExecution( @@ -51,7 +51,7 @@ class MicroBatchExecution( @volatile protected var sources: Seq[SparkDataStream] = Seq.empty private val triggerExecutor = trigger match { - case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) + case t: ProcessingTimeTrigger => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index d188566f822b..088471053b6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging -import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{Clock, SystemClock} trait TriggerExecutor { @@ -43,10 +42,12 @@ case class OneTimeExecutor() extends TriggerExecutor { /** * A trigger executor that runs a batch every `intervalMs` milliseconds. */ -case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock()) +case class ProcessingTimeExecutor( + processingTimeTrigger: ProcessingTimeTrigger, + clock: Clock = new SystemClock()) extends TriggerExecutor with Logging { - private val intervalMs = processingTime.intervalMs + private val intervalMs = processingTimeTrigger.intervalMs require(intervalMs >= 0) override def execute(triggerHandler: () => Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 4c0db3cb42a8..aede08820503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -17,8 +17,31 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.unsafe.types.CalendarInterval + +private object Triggers { + def validate(intervalMs: Long): Unit = { + require(intervalMs >= 0, "the interval of trigger should not be negative") + } + + def convert(interval: String): Long = { + val cal = CalendarInterval.fromCaseInsensitiveString(interval) + if (cal.months > 0) { + throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") + } + TimeUnit.MICROSECONDS.toMillis(cal.microseconds) + } + + def convert(interval: Duration): Long = interval.toMillis + + def convert(interval: Long, unit: TimeUnit): Long = unit.toMillis(interval) +} /** * A [[Trigger]] that processes only one batch of data in a streaming query then terminates @@ -26,4 +49,62 @@ import org.apache.spark.sql.streaming.Trigger */ @Experimental @Evolving -case object OneTimeTrigger extends Trigger +private[sql] case object OneTimeTrigger extends Trigger + +/** + * A [[Trigger]] that runs a query periodically based on the processing time. If `interval` is 0, + * the query will run as fast as possible. + */ +@Evolving +private[sql] case class ProcessingTimeTrigger(intervalMs: Long) extends Trigger { + Triggers.validate(intervalMs) +} + +private[sql] object ProcessingTimeTrigger { + import Triggers._ + + def apply(interval: String): ProcessingTimeTrigger = { + ProcessingTimeTrigger(convert(interval)) + } + + def apply(interval: Duration): ProcessingTimeTrigger = { + ProcessingTimeTrigger(convert(interval)) + } + + def create(interval: String): ProcessingTimeTrigger = { + apply(interval) + } + + def create(interval: Long, unit: TimeUnit): ProcessingTimeTrigger = { + ProcessingTimeTrigger(convert(interval, unit)) + } +} + +/** + * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + */ +@Evolving +private[sql] case class ContinuousTrigger(intervalMs: Long) extends Trigger { + Triggers.validate(intervalMs) +} + +private[sql] object ContinuousTrigger { + import Triggers._ + + def apply(interval: String): ContinuousTrigger = { + ContinuousTrigger(convert(interval)) + } + + def apply(interval: Duration): ContinuousTrigger = { + ContinuousTrigger(convert(interval)) + } + + def create(interval: String): ContinuousTrigger = { + apply(interval) + } + + def create(interval: Long, unit: TimeUnit): ContinuousTrigger = { + ContinuousTrigger(convert(interval, unit)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 509b103faa0d..f6d156ded766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} -import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.util.Clock class ContinuousExecution( @@ -93,7 +93,7 @@ class ContinuousExecution( } private val triggerExecutor = trigger match { - case ContinuousTrigger(t) => ProcessingTimeExecutor(ProcessingTime(t), triggerClock) + case ContinuousTrigger(t) => ProcessingTimeExecutor(ProcessingTimeTrigger(t), triggerClock) case _ => throw new IllegalStateException(s"Unsupported type of trigger: $trigger") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index d55f71c7be83..e1b7a8fc283d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -134,8 +134,10 @@ class RateStreamContinuousPartitionReader( nextReadTime += readTimeIncrement try { - while (System.currentTimeMillis < nextReadTime) { - Thread.sleep(nextReadTime - System.currentTimeMillis) + var toWaitMs = nextReadTime - System.currentTimeMillis + while (toWaitMs > 0) { + Thread.sleep(toWaitMs) + toWaitMs = nextReadTime - System.currentTimeMillis } } catch { case _: InterruptedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala deleted file mode 100644 index bd343f380603..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - */ -@Evolving -case class ContinuousTrigger(intervalMs: Long) extends Trigger { - require(intervalMs >= 0, "the interval of trigger should not be negative") -} - -private[sql] object ContinuousTrigger { - def apply(interval: String): ContinuousTrigger = { - val cal = CalendarInterval.fromCaseInsensitiveString(interval) - if (cal.months > 0) { - throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") - } - new ContinuousTrigger(TimeUnit.MICROSECONDS.toMillis(cal.microseconds)) - } - - def apply(interval: Duration): ContinuousTrigger = { - ContinuousTrigger(interval.toMillis) - } - - def create(interval: String): ContinuousTrigger = { - apply(interval) - } - - def create(interval: Long, unit: TimeUnit): ContinuousTrigger = { - ContinuousTrigger(unit.toMillis(interval)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5fa3fd0a37a6..72a197bdbcfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3932,7 +3932,7 @@ object functions { val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") - val funcCall = if (i == 0) "() => func" else "func" + val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Defines a Java UDF$i instance as user-defined function (UDF). @@ -3944,8 +3944,8 @@ object functions { | * @since 2.3.0 | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { - | val func = f$anyCast.call($anyParams) - | SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None)) + | val func = $funcCall + | SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None)) |}""".stripMargin) } @@ -4145,8 +4145,8 @@ object functions { * @since 2.3.0 */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None)) + val func = () => f.asInstanceOf[UDF0[Any]].call() + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 29500cf2afbc..805f73dee141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -30,7 +30,11 @@ private object MsSqlServerDialect extends JdbcDialect { // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients Option(StringType) } else { - None + sqlType match { + case java.sql.Types.SMALLINT => Some(ShortType) + case java.sql.Types.REAL => Some(FloatType) + case _ => None + } } } @@ -39,6 +43,7 @@ private object MsSqlServerDialect extends JdbcDialect { case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR)) case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT)) case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY)) + case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 5be45c973a5f..2645e4c9d528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -73,14 +73,13 @@ private object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) - case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) + case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) - case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d051cf9c1d4a..36104d7a7044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.sources.v2.{SupportsWrite, TableProvider} import org.apache.spark.sql.sources.v2.TableCapability._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala deleted file mode 100644 index 417d698bdbb0..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.apache.spark.annotation.Evolving -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * A trigger that runs a query periodically based on the processing time. If `interval` is 0, - * the query will run as fast as possible. - * - * Scala Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.writeStream.trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 2.0.0 - */ -@Evolving -@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") -case class ProcessingTime(intervalMs: Long) extends Trigger { - require(intervalMs >= 0, "the interval of trigger should not be negative") -} - -/** - * Used to create [[ProcessingTime]] triggers for [[StreamingQuery]]s. - * - * @since 2.0.0 - */ -@Evolving -@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") -object ProcessingTime { - - /** - * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. - * - * Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * }}} - * - * @since 2.0.0 - * @deprecated use Trigger.ProcessingTime(interval) - */ - @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") - def apply(interval: String): ProcessingTime = { - val cal = CalendarInterval.fromCaseInsensitiveString(interval) - if (cal.months > 0) { - throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") - } - new ProcessingTime(TimeUnit.MICROSECONDS.toMillis(cal.microseconds)) - } - - /** - * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. - * - * Example: - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * @since 2.0.0 - * @deprecated use Trigger.ProcessingTime(interval) - */ - @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") - def apply(interval: Duration): ProcessingTime = { - new ProcessingTime(interval.toMillis) - } - - /** - * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. - * - * Example: - * {{{ - * df.writeStream.trigger(ProcessingTime.create("10 seconds")) - * }}} - * - * @since 2.0.0 - * @deprecated use Trigger.ProcessingTime(interval) - */ - @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") - def create(interval: String): ProcessingTime = { - apply(interval) - } - - /** - * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. - * - * Example: - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 2.0.0 - * @deprecated use Trigger.ProcessingTime(interval, unit) - */ - @deprecated("use Trigger.ProcessingTime(interval, unit)", "2.2.0") - def create(interval: Long, unit: TimeUnit): ProcessingTime = { - new ProcessingTime(unit.toMillis(interval)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 1705d5624409..abee5f6017df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-legacy.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-legacy.sql new file mode 100644 index 000000000000..2f2606d44d91 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-legacy.sql @@ -0,0 +1,115 @@ +create temporary view t as select * from values 0, 1, 2 as t(id); +create temporary view t2 as select * from values 0, 1 as t(id); + +-- CTE legacy substitution +SET spark.sql.legacy.ctePrecedence.enabled=true; + +-- CTE in CTE definition +WITH t as ( + WITH t2 AS (SELECT 1) + SELECT * FROM t2 +) +SELECT * FROM t; + +-- CTE in subquery +SELECT max(c) FROM ( + WITH t(c) AS (SELECT 1) + SELECT * FROM t +); + +-- CTE in subquery expression +SELECT ( + WITH t AS (SELECT 1) + SELECT * FROM t +); + +-- CTE in CTE definition shadows outer +WITH + t AS (SELECT 1), + t2 AS ( + WITH t AS (SELECT 2) + SELECT * FROM t + ) +SELECT * FROM t2; + +-- CTE in CTE definition shadows outer 2 +WITH + t(c) AS (SELECT 1), + t2 AS ( + SELECT ( + SELECT max(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t + ) + ) + ) +SELECT * FROM t2; + +-- CTE in CTE definition shadows outer 3 +WITH + t AS (SELECT 1), + t2 AS ( + WITH t AS (SELECT 2), + t2 AS ( + WITH t AS (SELECT 3) + SELECT * FROM t + ) + SELECT * FROM t2 + ) +SELECT * FROM t2; + +-- CTE in subquery shadows outer +WITH t(c) AS (SELECT 1) +SELECT max(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t +); + +-- CTE in subquery shadows outer 2 +WITH t(c) AS (SELECT 1) +SELECT sum(c) FROM ( + SELECT max(c) AS c FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t + ) +); + +-- CTE in subquery shadows outer 3 +WITH t(c) AS (SELECT 1) +SELECT sum(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT max(c) AS c FROM ( + WITH t(c) AS (SELECT 3) + SELECT * FROM t + ) +); + +-- CTE in subquery expression shadows outer +WITH t AS (SELECT 1) +SELECT ( + WITH t AS (SELECT 2) + SELECT * FROM t +); + +-- CTE in subquery expression shadows outer 2 +WITH t AS (SELECT 1) +SELECT ( + SELECT ( + WITH t AS (SELECT 2) + SELECT * FROM t + ) +); + +-- CTE in subquery expression shadows outer 3 +WITH t AS (SELECT 1) +SELECT ( + WITH t AS (SELECT 2) + SELECT ( + WITH t AS (SELECT 3) + SELECT * FROM t + ) +); + +-- Clean up +DROP VIEW IF EXISTS t; +DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/boolean.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/boolean.sql index 4e621c68e1ec..fd0d299d7b0f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/boolean.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/boolean.sql @@ -17,19 +17,20 @@ SELECT 1 AS one; SELECT true AS true; -SELECT false AS false; +-- [SPARK-28349] We do not need to follow PostgreSQL to support reserved words in column alias +SELECT false AS `false`; SELECT boolean('t') AS true; -- [SPARK-27931] Trim the string when cast string type to boolean type -SELECT boolean(' f ') AS false; +SELECT boolean(' f ') AS `false`; SELECT boolean('true') AS true; -- [SPARK-27923] PostgreSQL does not accept 'test' but Spark SQL accepts it and sets it to NULL SELECT boolean('test') AS error; -SELECT boolean('false') AS false; +SELECT boolean('false') AS `false`; -- [SPARK-27923] PostgreSQL does not accept 'foo' but Spark SQL accepts it and sets it to NULL SELECT boolean('foo') AS error; @@ -41,9 +42,9 @@ SELECT boolean('yes') AS true; -- [SPARK-27923] PostgreSQL does not accept 'yeah' but Spark SQL accepts it and sets it to NULL SELECT boolean('yeah') AS error; -SELECT boolean('n') AS false; +SELECT boolean('n') AS `false`; -SELECT boolean('no') AS false; +SELECT boolean('no') AS `false`; -- [SPARK-27923] PostgreSQL does not accept 'nay' but Spark SQL accepts it and sets it to NULL SELECT boolean('nay') AS error; @@ -51,10 +52,10 @@ SELECT boolean('nay') AS error; -- [SPARK-27931] Accept 'on' and 'off' as input for boolean data type SELECT boolean('on') AS true; -SELECT boolean('off') AS false; +SELECT boolean('off') AS `false`; -- [SPARK-27931] Accept unique prefixes thereof -SELECT boolean('of') AS false; +SELECT boolean('of') AS `false`; -- [SPARK-27923] PostgreSQL does not accept 'o' but Spark SQL accepts it and sets it to NULL SELECT boolean('o') AS error; @@ -70,7 +71,7 @@ SELECT boolean('1') AS true; -- [SPARK-27923] PostgreSQL does not accept '11' but Spark SQL accepts it and sets it to NULL SELECT boolean('11') AS error; -SELECT boolean('0') AS false; +SELECT boolean('0') AS `false`; -- [SPARK-27923] PostgreSQL does not accept '000' but Spark SQL accepts it and sets it to NULL SELECT boolean('000') AS error; @@ -82,11 +83,11 @@ SELECT boolean('') AS error; SELECT boolean('t') or boolean('f') AS true; -SELECT boolean('t') and boolean('f') AS false; +SELECT boolean('t') and boolean('f') AS `false`; SELECT not boolean('f') AS true; -SELECT boolean('t') = boolean('f') AS false; +SELECT boolean('t') = boolean('f') AS `false`; SELECT boolean('t') <> boolean('f') AS true; @@ -99,11 +100,11 @@ SELECT boolean('f') < boolean('t') AS true; SELECT boolean('f') <= boolean('t') AS true; -- explicit casts to/from text -SELECT boolean(string('TrUe')) AS true, boolean(string('fAlse')) AS false; +SELECT boolean(string('TrUe')) AS true, boolean(string('fAlse')) AS `false`; -- [SPARK-27931] Trim the string when cast to boolean type SELECT boolean(string(' true ')) AS true, - boolean(string(' FALSE')) AS false; -SELECT string(boolean(true)) AS true, string(boolean(false)) AS false; + boolean(string(' FALSE')) AS `false`; +SELECT string(boolean(true)) AS true, string(boolean(false)) AS `false`; -- [SPARK-27923] PostgreSQL does not accept ' tru e ' but Spark SQL accepts it and sets it to NULL SELECT boolean(string(' tru e ')) AS invalid; -- error diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/case.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/case.sql index 7bb425d3fbe8..6d9c44c67a96 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/case.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/case.sql @@ -6,9 +6,6 @@ -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/case.sql -- Test the CASE statement -- --- This test suite contains two Cartesian products without using explicit CROSS JOIN syntax. --- Thus, we set spark.sql.crossJoin.enabled to true. -set spark.sql.crossJoin.enabled=true; CREATE TABLE CASE_TBL ( i integer, f double @@ -264,4 +261,3 @@ SELECT CASE DROP TABLE CASE_TBL; DROP TABLE CASE2_TBL; -set spark.sql.crossJoin.enabled=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql new file mode 100644 index 000000000000..6f8e3b596e60 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql @@ -0,0 +1,500 @@ +-- +-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group +-- +-- +-- FLOAT8 +-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/float8.sql + +CREATE TABLE FLOAT8_TBL(f1 double) USING parquet; + +INSERT INTO FLOAT8_TBL VALUES (' 0.0 '); +INSERT INTO FLOAT8_TBL VALUES ('1004.30 '); +INSERT INTO FLOAT8_TBL VALUES (' -34.84'); +INSERT INTO FLOAT8_TBL VALUES ('1.2345678901234e+200'); +INSERT INTO FLOAT8_TBL VALUES ('1.2345678901234e-200'); + +-- [SPARK-28024] Incorrect numeric values when out of range +-- test for underflow and overflow handling +SELECT double('10e400'); +SELECT double('-10e400'); +SELECT double('10e-400'); +SELECT double('-10e-400'); + +-- [SPARK-28061] Support for converting float to binary format +-- test smallest normalized input +-- SELECT float8send('2.2250738585072014E-308'::float8); + +-- [SPARK-27923] Spark SQL insert there bad inputs to NULL +-- bad input +-- INSERT INTO FLOAT8_TBL VALUES (''); +-- INSERT INTO FLOAT8_TBL VALUES (' '); +-- INSERT INTO FLOAT8_TBL VALUES ('xyz'); +-- INSERT INTO FLOAT8_TBL VALUES ('5.0.0'); +-- INSERT INTO FLOAT8_TBL VALUES ('5 . 0'); +-- INSERT INTO FLOAT8_TBL VALUES ('5. 0'); +-- INSERT INTO FLOAT8_TBL VALUES (' - 3'); +-- INSERT INTO FLOAT8_TBL VALUES ('123 5'); + +-- special inputs +SELECT double('NaN'); +-- [SPARK-28060] Double type can not accept some special inputs +SELECT double('nan'); +SELECT double(' NAN '); +SELECT double('infinity'); +SELECT double(' -INFINiTY '); +-- [SPARK-27923] Spark SQL insert there bad special inputs to NULL +-- bad special inputs +SELECT double('N A N'); +SELECT double('NaN x'); +SELECT double(' INFINITY x'); + +SELECT double('Infinity') + 100.0; +-- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner +SELECT double('Infinity') / double('Infinity'); +SELECT double('NaN') / double('NaN'); +-- [SPARK-28315] Decimal can not accept NaN as input +SELECT double(decimal('nan')); + +SELECT '' AS five, * FROM FLOAT8_TBL; + +SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <> '1004.3'; + +SELECT '' AS one, f.* FROM FLOAT8_TBL f WHERE f.f1 = '1004.3'; + +SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE '1004.3' > f.f1; + +SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE f.f1 < '1004.3'; + +SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE '1004.3' >= f.f1; + +SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <= '1004.3'; + +SELECT '' AS three, f.f1, f.f1 * '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0'; + +SELECT '' AS three, f.f1, f.f1 + '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0'; + +SELECT '' AS three, f.f1, f.f1 / '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0'; + +SELECT '' AS three, f.f1, f.f1 - '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0'; +-- [SPARK-28007] Caret operator (^) means bitwise XOR in Spark/Hive and exponentiation in Postgres +-- SELECT '' AS one, f.f1 ^ '2.0' AS square_f1 +-- FROM FLOAT8_TBL f where f.f1 = '1004.3'; + +-- [SPARK-28027] Spark SQL does not support prefix operator @ +-- absolute value +-- SELECT '' AS five, f.f1, @f.f1 AS abs_f1 +-- FROM FLOAT8_TBL f; + +-- [SPARK-23906] Support Truncate number +-- truncate +-- SELECT '' AS five, f.f1, trunc(f.f1) AS trunc_f1 +-- FROM FLOAT8_TBL f; + +-- round +SELECT '' AS five, f.f1, round(f.f1) AS round_f1 + FROM FLOAT8_TBL f; + +-- [SPARK-28135] ceil/ceiling/floor returns incorrect values +-- ceil / ceiling +select ceil(f1) as ceil_f1 from float8_tbl f; +select ceiling(f1) as ceiling_f1 from float8_tbl f; + +-- floor +select floor(f1) as floor_f1 from float8_tbl f; + +-- sign +select sign(f1) as sign_f1 from float8_tbl f; + +-- avoid bit-exact output here because operations may not be bit-exact. +-- SET extra_float_digits = 0; + +-- square root +SELECT sqrt(double('64')) AS eight; + +-- [SPARK-28027] Spark SQL does not support prefix operator |/ +-- SELECT |/ float8 '64' AS eight; + +-- SELECT '' AS three, f.f1, |/f.f1 AS sqrt_f1 +-- FROM FLOAT8_TBL f +-- WHERE f.f1 > '0.0'; + +-- power +SELECT power(double('144'), double('0.5')); +SELECT power(double('NaN'), double('0.5')); +SELECT power(double('144'), double('NaN')); +SELECT power(double('NaN'), double('NaN')); +SELECT power(double('-1'), double('NaN')); +-- [SPARK-28135] power returns incorrect values +SELECT power(double('1'), double('NaN')); +SELECT power(double('NaN'), double('0')); + +-- take exp of ln(f.f1) +SELECT '' AS three, f.f1, exp(ln(f.f1)) AS exp_ln_f1 + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0'; + +-- [SPARK-28027] Spark SQL does not support prefix operator ||/ +-- cube root +-- SELECT ||/ float8 '27' AS three; + +-- SELECT '' AS five, f.f1, ||/f.f1 AS cbrt_f1 FROM FLOAT8_TBL f; + + +SELECT '' AS five, * FROM FLOAT8_TBL; + +-- UPDATE FLOAT8_TBL +-- SET f1 = FLOAT8_TBL.f1 * '-1' +-- WHERE FLOAT8_TBL.f1 > '0.0'; +-- Update the FLOAT8_TBL to UPDATED_FLOAT8_TBL +CREATE TEMPORARY VIEW UPDATED_FLOAT8_TBL as +SELECT + CASE WHEN FLOAT8_TBL.f1 > '0.0' THEN FLOAT8_TBL.f1 * '-1' ELSE FLOAT8_TBL.f1 END AS f1 +FROM FLOAT8_TBL; + +-- [SPARK-27923] Out of range, Spark SQL returns Infinity +SELECT '' AS bad, f.f1 * '1e200' from UPDATED_FLOAT8_TBL f; + +-- [SPARK-28007] Caret operator (^) means bitwise XOR in Spark/Hive and exponentiation in Postgres +-- SELECT '' AS bad, f.f1 ^ '1e200' from UPDATED_FLOAT8_TBL f; + +-- SELECT 0 ^ 0 + 0 ^ 1 + 0 ^ 0.0 + 0 ^ 0.5; + +-- [SPARK-27923] Cannot take logarithm of zero +-- SELECT '' AS bad, ln(f.f1) from UPDATED_FLOAT8_TBL f where f.f1 = '0.0' ; + +-- [SPARK-27923] Cannot take logarithm of a negative number +-- SELECT '' AS bad, ln(f.f1) from UPDATED_FLOAT8_TBL f where f.f1 < '0.0' ; + +-- [SPARK-28024] Incorrect numeric values when out of range +-- SELECT '' AS bad, exp(f.f1) from UPDATED_FLOAT8_TBL f; + +-- [SPARK-27923] Divide by zero, Spark SQL returns NULL +-- SELECT '' AS bad, f.f1 / '0.0' from UPDATED_FLOAT8_TBL f; + +SELECT '' AS five, * FROM UPDATED_FLOAT8_TBL; + +-- hyperbolic functions +-- we run these with extra_float_digits = 0 too, since different platforms +-- tend to produce results that vary in the last place. +SELECT sinh(double('1')); +SELECT cosh(double('1')); +SELECT tanh(double('1')); +SELECT asinh(double('1')); +SELECT acosh(double('2')); +SELECT atanh(double('0.5')); +-- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner +-- test Inf/NaN cases for hyperbolic functions +SELECT sinh(double('Infinity')); +SELECT sinh(double('-Infinity')); +SELECT sinh(double('NaN')); +SELECT cosh(double('Infinity')); +SELECT cosh(double('-Infinity')); +SELECT cosh(double('NaN')); +SELECT tanh(double('Infinity')); +SELECT tanh(double('-Infinity')); +SELECT tanh(double('NaN')); +SELECT asinh(double('Infinity')); +SELECT asinh(double('-Infinity')); +SELECT asinh(double('NaN')); +-- acosh(Inf) should be Inf, but some mingw versions produce NaN, so skip test +SELECT acosh(double('Infinity')); +SELECT acosh(double('-Infinity')); +SELECT acosh(double('NaN')); +SELECT atanh(double('Infinity')); +SELECT atanh(double('-Infinity')); +SELECT atanh(double('NaN')); + +-- RESET extra_float_digits; + +-- [SPARK-28024] Incorrect numeric values when out of range +-- test for over- and underflow +-- INSERT INTO FLOAT8_TBL VALUES ('10e400'); + +-- INSERT INTO FLOAT8_TBL VALUES ('-10e400'); + +-- INSERT INTO FLOAT8_TBL VALUES ('10e-400'); + +-- INSERT INTO FLOAT8_TBL VALUES ('-10e-400'); + +-- maintain external table consistency across platforms +-- delete all values and reinsert well-behaved ones + +TRUNCATE TABLE FLOAT8_TBL; + +INSERT INTO FLOAT8_TBL VALUES ('0.0'); + +INSERT INTO FLOAT8_TBL VALUES ('-34.84'); + +INSERT INTO FLOAT8_TBL VALUES ('-1004.30'); + +INSERT INTO FLOAT8_TBL VALUES ('-1.2345678901234e+200'); + +INSERT INTO FLOAT8_TBL VALUES ('-1.2345678901234e-200'); + +SELECT '' AS five, * FROM FLOAT8_TBL; + +-- [SPARK-28028] Cast numeric to integral type need round +-- [SPARK-28024] Incorrect numeric values when out of range +-- test edge-case coercions to integer +SELECT smallint(double('32767.4')); +SELECT smallint(double('32767.6')); +SELECT smallint(double('-32768.4')); +SELECT smallint(double('-32768.6')); +SELECT int(double('2147483647.4')); +SELECT int(double('2147483647.6')); +SELECT int(double('-2147483648.4')); +SELECT int(double('-2147483648.6')); +SELECT bigint(double('9223372036854773760')); +SELECT bigint(double('9223372036854775807')); +SELECT bigint(double('-9223372036854775808.5')); +SELECT bigint(double('-9223372036854780000')); + +-- [SPARK-28134] Missing Trigonometric Functions +-- test exact cases for trigonometric functions in degrees + +-- SELECT x, +-- sind(x), +-- sind(x) IN (-1,-0.5,0,0.5,1) AS sind_exact +-- FROM (VALUES (0), (30), (90), (150), (180), +-- (210), (270), (330), (360)) AS t(x); + +-- SELECT x, +-- cosd(x), +-- cosd(x) IN (-1,-0.5,0,0.5,1) AS cosd_exact +-- FROM (VALUES (0), (60), (90), (120), (180), +-- (240), (270), (300), (360)) AS t(x); + +-- SELECT x, +-- tand(x), +-- tand(x) IN ('-Infinity'::float8,-1,0, +-- 1,'Infinity'::float8) AS tand_exact, +-- cotd(x), +-- cotd(x) IN ('-Infinity'::float8,-1,0, +-- 1,'Infinity'::float8) AS cotd_exact +-- FROM (VALUES (0), (45), (90), (135), (180), +-- (225), (270), (315), (360)) AS t(x); + +-- SELECT x, +-- asind(x), +-- asind(x) IN (-90,-30,0,30,90) AS asind_exact, +-- acosd(x), +-- acosd(x) IN (0,60,90,120,180) AS acosd_exact +-- FROM (VALUES (-1), (-0.5), (0), (0.5), (1)) AS t(x); + +-- SELECT x, +-- atand(x), +-- atand(x) IN (-90,-45,0,45,90) AS atand_exact +-- FROM (VALUES ('-Infinity'::float8), (-1), (0), (1), +-- ('Infinity'::float8)) AS t(x); + +-- SELECT x, y, +-- atan2d(y, x), +-- atan2d(y, x) IN (-90,0,90,180) AS atan2d_exact +-- FROM (SELECT 10*cosd(a), 10*sind(a) +-- FROM generate_series(0, 360, 90) AS t(a)) AS t(x,y); + +-- We do not support creating types, skip the test below +-- +-- test output (and round-trip safety) of various values. +-- To ensure we're testing what we think we're testing, start with +-- float values specified by bit patterns (as a useful side effect, +-- this means we'll fail on non-IEEE platforms). + +-- create type xfloat8; +-- create function xfloat8in(cstring) returns xfloat8 immutable strict +-- language internal as 'int8in'; +-- create function xfloat8out(xfloat8) returns cstring immutable strict +-- language internal as 'int8out'; +-- create type xfloat8 (input = xfloat8in, output = xfloat8out, like = float8); +-- create cast (xfloat8 as float8) without function; +-- create cast (float8 as xfloat8) without function; +-- create cast (xfloat8 as bigint) without function; +-- create cast (bigint as xfloat8) without function; + +-- float8: seeeeeee eeeeeeee eeeeeeee mmmmmmmm mmmmmmmm(x4) + +-- we don't care to assume the platform's strtod() handles subnormals +-- correctly; those are "use at your own risk". However we do test +-- subnormal outputs, since those are under our control. + +-- with testdata(bits) as (values +-- -- small subnormals +-- (x'0000000000000001'), +-- (x'0000000000000002'), (x'0000000000000003'), +-- (x'0000000000001000'), (x'0000000100000000'), +-- (x'0000010000000000'), (x'0000010100000000'), +-- (x'0000400000000000'), (x'0000400100000000'), +-- (x'0000800000000000'), (x'0000800000000001'), +-- -- these values taken from upstream testsuite +-- (x'00000000000f4240'), +-- (x'00000000016e3600'), +-- (x'0000008cdcdea440'), +-- -- borderline between subnormal and normal +-- (x'000ffffffffffff0'), (x'000ffffffffffff1'), +-- (x'000ffffffffffffe'), (x'000fffffffffffff')) +-- select float8send(flt) as ibits, +-- flt +-- from (select bits::bigint::xfloat8::float8 as flt +-- from testdata +-- offset 0) s; + +-- round-trip tests + +-- with testdata(bits) as (values +-- (x'0000000000000000'), +-- -- smallest normal values +-- (x'0010000000000000'), (x'0010000000000001'), +-- (x'0010000000000002'), (x'0018000000000000'), +-- -- +-- (x'3ddb7cdfd9d7bdba'), (x'3ddb7cdfd9d7bdbb'), (x'3ddb7cdfd9d7bdbc'), +-- (x'3e112e0be826d694'), (x'3e112e0be826d695'), (x'3e112e0be826d696'), +-- (x'3e45798ee2308c39'), (x'3e45798ee2308c3a'), (x'3e45798ee2308c3b'), +-- (x'3e7ad7f29abcaf47'), (x'3e7ad7f29abcaf48'), (x'3e7ad7f29abcaf49'), +-- (x'3eb0c6f7a0b5ed8c'), (x'3eb0c6f7a0b5ed8d'), (x'3eb0c6f7a0b5ed8e'), +-- (x'3ee4f8b588e368ef'), (x'3ee4f8b588e368f0'), (x'3ee4f8b588e368f1'), +-- (x'3f1a36e2eb1c432c'), (x'3f1a36e2eb1c432d'), (x'3f1a36e2eb1c432e'), +-- (x'3f50624dd2f1a9fb'), (x'3f50624dd2f1a9fc'), (x'3f50624dd2f1a9fd'), +-- (x'3f847ae147ae147a'), (x'3f847ae147ae147b'), (x'3f847ae147ae147c'), +-- (x'3fb9999999999999'), (x'3fb999999999999a'), (x'3fb999999999999b'), +-- -- values very close to 1 +-- (x'3feffffffffffff0'), (x'3feffffffffffff1'), (x'3feffffffffffff2'), +-- (x'3feffffffffffff3'), (x'3feffffffffffff4'), (x'3feffffffffffff5'), +-- (x'3feffffffffffff6'), (x'3feffffffffffff7'), (x'3feffffffffffff8'), +-- (x'3feffffffffffff9'), (x'3feffffffffffffa'), (x'3feffffffffffffb'), +-- (x'3feffffffffffffc'), (x'3feffffffffffffd'), (x'3feffffffffffffe'), +-- (x'3fefffffffffffff'), +-- (x'3ff0000000000000'), +-- (x'3ff0000000000001'), (x'3ff0000000000002'), (x'3ff0000000000003'), +-- (x'3ff0000000000004'), (x'3ff0000000000005'), (x'3ff0000000000006'), +-- (x'3ff0000000000007'), (x'3ff0000000000008'), (x'3ff0000000000009'), +-- -- +-- (x'3ff921fb54442d18'), +-- (x'4005bf0a8b14576a'), +-- (x'400921fb54442d18'), +-- -- +-- (x'4023ffffffffffff'), (x'4024000000000000'), (x'4024000000000001'), +-- (x'4058ffffffffffff'), (x'4059000000000000'), (x'4059000000000001'), +-- (x'408f3fffffffffff'), (x'408f400000000000'), (x'408f400000000001'), +-- (x'40c387ffffffffff'), (x'40c3880000000000'), (x'40c3880000000001'), +-- (x'40f869ffffffffff'), (x'40f86a0000000000'), (x'40f86a0000000001'), +-- (x'412e847fffffffff'), (x'412e848000000000'), (x'412e848000000001'), +-- (x'416312cfffffffff'), (x'416312d000000000'), (x'416312d000000001'), +-- (x'4197d783ffffffff'), (x'4197d78400000000'), (x'4197d78400000001'), +-- (x'41cdcd64ffffffff'), (x'41cdcd6500000000'), (x'41cdcd6500000001'), +-- (x'4202a05f1fffffff'), (x'4202a05f20000000'), (x'4202a05f20000001'), +-- (x'42374876e7ffffff'), (x'42374876e8000000'), (x'42374876e8000001'), +-- (x'426d1a94a1ffffff'), (x'426d1a94a2000000'), (x'426d1a94a2000001'), +-- (x'42a2309ce53fffff'), (x'42a2309ce5400000'), (x'42a2309ce5400001'), +-- (x'42d6bcc41e8fffff'), (x'42d6bcc41e900000'), (x'42d6bcc41e900001'), +-- (x'430c6bf52633ffff'), (x'430c6bf526340000'), (x'430c6bf526340001'), +-- (x'4341c37937e07fff'), (x'4341c37937e08000'), (x'4341c37937e08001'), +-- (x'4376345785d89fff'), (x'4376345785d8a000'), (x'4376345785d8a001'), +-- (x'43abc16d674ec7ff'), (x'43abc16d674ec800'), (x'43abc16d674ec801'), +-- (x'43e158e460913cff'), (x'43e158e460913d00'), (x'43e158e460913d01'), +-- (x'4415af1d78b58c3f'), (x'4415af1d78b58c40'), (x'4415af1d78b58c41'), +-- (x'444b1ae4d6e2ef4f'), (x'444b1ae4d6e2ef50'), (x'444b1ae4d6e2ef51'), +-- (x'4480f0cf064dd591'), (x'4480f0cf064dd592'), (x'4480f0cf064dd593'), +-- (x'44b52d02c7e14af5'), (x'44b52d02c7e14af6'), (x'44b52d02c7e14af7'), +-- (x'44ea784379d99db3'), (x'44ea784379d99db4'), (x'44ea784379d99db5'), +-- (x'45208b2a2c280290'), (x'45208b2a2c280291'), (x'45208b2a2c280292'), +-- -- +-- (x'7feffffffffffffe'), (x'7fefffffffffffff'), +-- -- round to even tests (+ve) +-- (x'4350000000000002'), +-- (x'4350000000002e06'), +-- (x'4352000000000003'), +-- (x'4352000000000004'), +-- (x'4358000000000003'), +-- (x'4358000000000004'), +-- (x'435f000000000020'), +-- -- round to even tests (-ve) +-- (x'c350000000000002'), +-- (x'c350000000002e06'), +-- (x'c352000000000003'), +-- (x'c352000000000004'), +-- (x'c358000000000003'), +-- (x'c358000000000004'), +-- (x'c35f000000000020'), +-- -- exercise fixed-point memmoves +-- (x'42dc12218377de66'), +-- (x'42a674e79c5fe51f'), +-- (x'4271f71fb04cb74c'), +-- (x'423cbe991a145879'), +-- (x'4206fee0e1a9e061'), +-- (x'41d26580b487e6b4'), +-- (x'419d6f34540ca453'), +-- (x'41678c29dcd6e9dc'), +-- (x'4132d687e3df217d'), +-- (x'40fe240c9fcb68c8'), +-- (x'40c81cd6e63c53d3'), +-- (x'40934a4584fd0fdc'), +-- (x'405edd3c07fb4c93'), +-- (x'4028b0fcd32f7076'), +-- (x'3ff3c0ca428c59f8'), +-- -- these cases come from the upstream's testsuite +-- -- LotsOfTrailingZeros) +-- (x'3e60000000000000'), +-- -- Regression +-- (x'c352bd2668e077c4'), +-- (x'434018601510c000'), +-- (x'43d055dc36f24000'), +-- (x'43e052961c6f8000'), +-- (x'3ff3c0ca2a5b1d5d'), +-- -- LooksLikePow5 +-- (x'4830f0cf064dd592'), +-- (x'4840f0cf064dd592'), +-- (x'4850f0cf064dd592'), +-- -- OutputLength +-- (x'3ff3333333333333'), +-- (x'3ff3ae147ae147ae'), +-- (x'3ff3be76c8b43958'), +-- (x'3ff3c083126e978d'), +-- (x'3ff3c0c1fc8f3238'), +-- (x'3ff3c0c9539b8887'), +-- (x'3ff3c0ca2a5b1d5d'), +-- (x'3ff3c0ca4283de1b'), +-- (x'3ff3c0ca43db770a'), +-- (x'3ff3c0ca428abd53'), +-- (x'3ff3c0ca428c1d2b'), +-- (x'3ff3c0ca428c51f2'), +-- (x'3ff3c0ca428c58fc'), +-- (x'3ff3c0ca428c59dd'), +-- (x'3ff3c0ca428c59f8'), +-- (x'3ff3c0ca428c59fb'), +-- -- 32-bit chunking +-- (x'40112e0be8047a7d'), +-- (x'40112e0be815a889'), +-- (x'40112e0be826d695'), +-- (x'40112e0be83804a1'), +-- (x'40112e0be84932ad'), +-- -- MinMaxShift +-- (x'0040000000000000'), +-- (x'007fffffffffffff'), +-- (x'0290000000000000'), +-- (x'029fffffffffffff'), +-- (x'4350000000000000'), +-- (x'435fffffffffffff'), +-- (x'1330000000000000'), +-- (x'133fffffffffffff'), +-- (x'3a6fa7161a4d6e0c') +-- ) +-- select float8send(flt) as ibits, +-- flt, +-- flt::text::float8 as r_flt, +-- float8send(flt::text::float8) as obits, +-- float8send(flt::text::float8) = float8send(flt) as correct +-- from (select bits::bigint::xfloat8::float8 as flt +-- from testdata +-- offset 0) s; + +-- clean up, lest opr_sanity complain +-- drop type xfloat8 cascade; +DROP TABLE FLOAT8_TBL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int2.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int2.sql index 61f350d3e3f4..f64ec5d75afc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int2.sql @@ -88,11 +88,9 @@ WHERE f1 > -32767; SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT2_TBL i; --- PostgreSQL `/` is the same with Spark `div` since SPARK-2659. -SELECT '' AS five, i.f1, i.f1 div smallint('2') AS x FROM INT2_TBL i; +SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT2_TBL i; --- PostgreSQL `/` is the same with Spark `div` since SPARK-2659. -SELECT '' AS five, i.f1, i.f1 div int('2') AS x FROM INT2_TBL i; +SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT2_TBL i; -- corner cases SELECT string(shiftleft(smallint(-1), 15)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql index cbd587889273..86432a845b6e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql @@ -125,7 +125,8 @@ SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true; SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true; -SELECT int('1000') < int('999') AS false; +-- [SPARK-28349] We do not need to follow PostgreSQL to support reserved words in column alias +SELECT int('1000') < int('999') AS `false`; -- [SPARK-28027] Our ! and !! has different meanings -- SELECT 4! AS twenty_four; @@ -134,7 +135,6 @@ SELECT int('1000') < int('999') AS false; SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten; --- [SPARK-2659] HiveQL: Division operator should always perform fractional division SELECT 2 + 2 / 2 AS three; SELECT (2 + 2) / 2 AS two; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int8.sql index 31eef6f34b1d..d29bf3bfad4c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int8.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int8.sql @@ -67,10 +67,11 @@ SELECT * FROM INT8_TBL WHERE smallint('123') <= q1; SELECT * FROM INT8_TBL WHERE smallint('123') >= q1; -SELECT '' AS five, q1 AS plus, -q1 AS minus FROM INT8_TBL; +-- [SPARK-28349] We do not need to follow PostgreSQL to support reserved words in column alias +SELECT '' AS five, q1 AS plus, -q1 AS `minus` FROM INT8_TBL; SELECT '' AS five, q1, q2, q1 + q2 AS plus FROM INT8_TBL; -SELECT '' AS five, q1, q2, q1 - q2 AS minus FROM INT8_TBL; +SELECT '' AS five, q1, q2, q1 - q2 AS `minus` FROM INT8_TBL; SELECT '' AS three, q1, q2, q1 * q2 AS multiply FROM INT8_TBL; SELECT '' AS three, q1, q2, q1 * q2 AS multiply FROM INT8_TBL WHERE q1 < 1000 or (q2 > 0 and q2 < 1000); @@ -84,7 +85,6 @@ SELECT 37 - q1 AS minus4 FROM INT8_TBL; SELECT '' AS five, 2 * q1 AS `twice int4` FROM INT8_TBL; SELECT '' AS five, q1 * 2 AS `twice int4` FROM INT8_TBL; --- [SPARK-2659] HiveQL: Division operator should always perform fractional division -- int8 op int4 SELECT q1 + int(42) AS `8plus4`, q1 - int(42) AS `8minus4`, q1 * int(42) AS `8mul4`, q1 / int(42) AS `8div4` FROM INT8_TBL; -- int4 op int8 diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select.sql new file mode 100644 index 000000000000..1f83d6c41661 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select.sql @@ -0,0 +1,285 @@ +-- +-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group +-- +-- +-- SELECT +-- Test int8 64-bit integers. +-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select.sql +-- +create or replace temporary view onek2 as select * from onek; +create or replace temporary view INT8_TBL as select * from values + (cast(trim(' 123 ') as bigint), cast(trim(' 456') as bigint)), + (cast(trim('123 ') as bigint),cast('4567890123456789' as bigint)), + (cast('4567890123456789' as bigint),cast('123' as bigint)), + (cast(+4567890123456789 as bigint),cast('4567890123456789' as bigint)), + (cast('+4567890123456789' as bigint),cast('-4567890123456789' as bigint)) + as INT8_TBL(q1, q2); + +-- btree index +-- awk '{if($1<10){print;}else{next;}}' onek.data | sort +0n -1 +-- +SELECT * FROM onek + WHERE onek.unique1 < 10 + ORDER BY onek.unique1; + +-- [SPARK-28010] Support ORDER BY ... USING syntax +-- +-- awk '{if($1<20){print $1,$14;}else{next;}}' onek.data | sort +0nr -1 +-- +SELECT onek.unique1, onek.stringu1 FROM onek + WHERE onek.unique1 < 20 + ORDER BY unique1 DESC; + +-- +-- awk '{if($1>980){print $1,$14;}else{next;}}' onek.data | sort +1d -2 +-- +SELECT onek.unique1, onek.stringu1 FROM onek + WHERE onek.unique1 > 980 + ORDER BY stringu1 ASC; + +-- +-- awk '{if($1>980){print $1,$16;}else{next;}}' onek.data | +-- sort +1d -2 +0nr -1 +-- +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 > 980 + ORDER BY string4 ASC, unique1 DESC; + +-- +-- awk '{if($1>980){print $1,$16;}else{next;}}' onek.data | +-- sort +1dr -2 +0n -1 +-- +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 > 980 + ORDER BY string4 DESC, unique1 ASC; + +-- +-- awk '{if($1<20){print $1,$16;}else{next;}}' onek.data | +-- sort +0nr -1 +1d -2 +-- +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 < 20 + ORDER BY unique1 DESC, string4 ASC; + +-- +-- awk '{if($1<20){print $1,$16;}else{next;}}' onek.data | +-- sort +0n -1 +1dr -2 +-- +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 < 20 + ORDER BY unique1 ASC, string4 DESC; + +-- +-- test partial btree indexes +-- +-- As of 7.2, planner probably won't pick an indexscan without stats, +-- so ANALYZE first. Also, we want to prevent it from picking a bitmapscan +-- followed by sort, because that could hide index ordering problems. +-- +-- ANALYZE onek2; + +-- SET enable_seqscan TO off; +-- SET enable_bitmapscan TO off; +-- SET enable_sort TO off; + +-- +-- awk '{if($1<10){print $0;}else{next;}}' onek.data | sort +0n -1 +-- +SELECT onek2.* FROM onek2 WHERE onek2.unique1 < 10; + +-- +-- awk '{if($1<20){print $1,$14;}else{next;}}' onek.data | sort +0nr -1 +-- +SELECT onek2.unique1, onek2.stringu1 FROM onek2 + WHERE onek2.unique1 < 20 + ORDER BY unique1 DESC; + +-- +-- awk '{if($1>980){print $1,$14;}else{next;}}' onek.data | sort +1d -2 +-- +SELECT onek2.unique1, onek2.stringu1 FROM onek2 + WHERE onek2.unique1 > 980; + +-- RESET enable_seqscan; +-- RESET enable_bitmapscan; +-- RESET enable_sort; + +-- [SPARK-28329] SELECT INTO syntax +-- SELECT two, stringu1, ten, string4 +-- INTO TABLE tmp +-- FROM onek; +CREATE TABLE tmp USING parquet AS +SELECT two, stringu1, ten, string4 +FROM onek; + +-- Skip the person table because there is a point data type that we don't support. +-- +-- awk '{print $1,$2;}' person.data | +-- awk '{if(NF!=2){print $3,$2;}else{print;}}' - emp.data | +-- awk '{if(NF!=2){print $3,$2;}else{print;}}' - student.data | +-- awk 'BEGIN{FS=" ";}{if(NF!=2){print $4,$5;}else{print;}}' - stud_emp.data +-- +-- SELECT name, age FROM person*; ??? check if different +-- SELECT p.name, p.age FROM person* p; + +-- +-- awk '{print $1,$2;}' person.data | +-- awk '{if(NF!=2){print $3,$2;}else{print;}}' - emp.data | +-- awk '{if(NF!=2){print $3,$2;}else{print;}}' - student.data | +-- awk 'BEGIN{FS=" ";}{if(NF!=1){print $4,$5;}else{print;}}' - stud_emp.data | +-- sort +1nr -2 +-- +-- SELECT p.name, p.age FROM person* p ORDER BY age DESC, name; + +-- [SPARK-28330] Enhance query limit +-- +-- Test some cases involving whole-row Var referencing a subquery +-- +select foo.* from (select 1) as foo; +select foo.* from (select null) as foo; +select foo.* from (select 'xyzzy',1,null) as foo; + +-- +-- Test VALUES lists +-- +select * from onek, values(147, 'RFAAAA'), (931, 'VJAAAA') as v (i, j) + WHERE onek.unique1 = v.i and onek.stringu1 = v.j; + +-- [SPARK-28296] Improved VALUES support +-- a more complex case +-- looks like we're coding lisp :-) +-- select * from onek, +-- (values ((select i from +-- (values(10000), (2), (389), (1000), (2000), ((select 10029))) as foo(i) +-- order by i asc limit 1))) bar (i) +-- where onek.unique1 = bar.i; + +-- try VALUES in a subquery +-- select * from onek +-- where (unique1,ten) in (values (1,1), (20,0), (99,9), (17,99)) +-- order by unique1; + +-- VALUES is also legal as a standalone query or a set-operation member +VALUES (1,2), (3,4+4), (7,77.7); + +VALUES (1,2), (3,4+4), (7,77.7) +UNION ALL +SELECT 2+2, 57 +UNION ALL +TABLE int8_tbl; + +-- +-- Test ORDER BY options +-- + +CREATE OR REPLACE TEMPORARY VIEW foo AS +SELECT * FROM (values(42),(3),(10),(7),(null),(null),(1)) as foo (f1); + +-- [SPARK-28333] NULLS FIRST for DESC and NULLS LAST for ASC +SELECT * FROM foo ORDER BY f1; +SELECT * FROM foo ORDER BY f1 ASC; -- same thing +SELECT * FROM foo ORDER BY f1 NULLS FIRST; +SELECT * FROM foo ORDER BY f1 DESC; +SELECT * FROM foo ORDER BY f1 DESC NULLS LAST; + +-- check if indexscans do the right things +-- CREATE INDEX fooi ON foo (f1); +-- SET enable_sort = false; + +-- SELECT * FROM foo ORDER BY f1; +-- SELECT * FROM foo ORDER BY f1 NULLS FIRST; +-- SELECT * FROM foo ORDER BY f1 DESC; +-- SELECT * FROM foo ORDER BY f1 DESC NULLS LAST; + +-- DROP INDEX fooi; +-- CREATE INDEX fooi ON foo (f1 DESC); + +-- SELECT * FROM foo ORDER BY f1; +-- SELECT * FROM foo ORDER BY f1 NULLS FIRST; +-- SELECT * FROM foo ORDER BY f1 DESC; +-- SELECT * FROM foo ORDER BY f1 DESC NULLS LAST; + +-- DROP INDEX fooi; +-- CREATE INDEX fooi ON foo (f1 DESC NULLS LAST); + +-- SELECT * FROM foo ORDER BY f1; +-- SELECT * FROM foo ORDER BY f1 NULLS FIRST; +-- SELECT * FROM foo ORDER BY f1 DESC; +-- SELECT * FROM foo ORDER BY f1 DESC NULLS LAST; + +-- +-- Test planning of some cases with partial indexes +-- + +-- partial index is usable +-- explain (costs off) +-- select * from onek2 where unique2 = 11 and stringu1 = 'ATAAAA'; +select * from onek2 where unique2 = 11 and stringu1 = 'ATAAAA'; +-- actually run the query with an analyze to use the partial index +-- explain (costs off, analyze on, timing off, summary off) +-- select * from onek2 where unique2 = 11 and stringu1 = 'ATAAAA'; +-- explain (costs off) +-- select unique2 from onek2 where unique2 = 11 and stringu1 = 'ATAAAA'; +select unique2 from onek2 where unique2 = 11 and stringu1 = 'ATAAAA'; +-- partial index predicate implies clause, so no need for retest +-- explain (costs off) +-- select * from onek2 where unique2 = 11 and stringu1 < 'B'; +select * from onek2 where unique2 = 11 and stringu1 < 'B'; +-- explain (costs off) +-- select unique2 from onek2 where unique2 = 11 and stringu1 < 'B'; +select unique2 from onek2 where unique2 = 11 and stringu1 < 'B'; +-- but if it's an update target, must retest anyway +-- explain (costs off) +-- select unique2 from onek2 where unique2 = 11 and stringu1 < 'B' for update; +-- select unique2 from onek2 where unique2 = 11 and stringu1 < 'B' for update; +-- partial index is not applicable +-- explain (costs off) +-- select unique2 from onek2 where unique2 = 11 and stringu1 < 'C'; +select unique2 from onek2 where unique2 = 11 and stringu1 < 'C'; +-- partial index implies clause, but bitmap scan must recheck predicate anyway +-- SET enable_indexscan TO off; +-- explain (costs off) +-- select unique2 from onek2 where unique2 = 11 and stringu1 < 'B'; +select unique2 from onek2 where unique2 = 11 and stringu1 < 'B'; +-- RESET enable_indexscan; +-- check multi-index cases too +-- explain (costs off) +-- select unique1, unique2 from onek2 +-- where (unique2 = 11 or unique1 = 0) and stringu1 < 'B'; +select unique1, unique2 from onek2 + where (unique2 = 11 or unique1 = 0) and stringu1 < 'B'; +-- explain (costs off) +-- select unique1, unique2 from onek2 +-- where (unique2 = 11 and stringu1 < 'B') or unique1 = 0; +select unique1, unique2 from onek2 + where (unique2 = 11 and stringu1 < 'B') or unique1 = 0; + +-- +-- Test some corner cases that have been known to confuse the planner +-- + +-- ORDER BY on a constant doesn't really need any sorting +SELECT 1 AS x ORDER BY x; + +-- But ORDER BY on a set-valued expression does +-- create function sillysrf(int) returns setof int as +-- 'values (1),(10),(2),($1)' language sql immutable; + +-- select sillysrf(42); +-- select sillysrf(-1) order by 1; + +-- drop function sillysrf(int); + +-- X = X isn't a no-op, it's effectively X IS NOT NULL assuming = is strict +-- (see bug #5084) +select * from (values (2),(null),(1)) v(k) where k = k order by k; +select * from (values (2),(null),(1)) v(k) where k = k; + +-- Test partitioned tables with no partitions, which should be handled the +-- same as the non-inheritance case when expanding its RTE. +-- create table list_parted_tbl (a int,b int) partition by list (a); +-- create table list_parted_tbl1 partition of list_parted_tbl +-- for values in (1) partition by list(b); +-- explain (costs off) select * from list_parted_tbl; +-- drop table list_parted_tbl; +drop table tmp; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select_distinct.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select_distinct.sql new file mode 100644 index 000000000000..5306028e5bd7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select_distinct.sql @@ -0,0 +1,86 @@ +-- +-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group +-- +-- +-- SELECT_DISTINCT +-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_distinct.sql +-- + +CREATE OR REPLACE TEMPORARY VIEW tmp AS +SELECT two, stringu1, ten, string4 +FROM onek; + +-- +-- awk '{print $3;}' onek.data | sort -n | uniq +-- +SELECT DISTINCT two FROM tmp ORDER BY 1; + +-- +-- awk '{print $5;}' onek.data | sort -n | uniq +-- +SELECT DISTINCT ten FROM tmp ORDER BY 1; + +-- +-- awk '{print $16;}' onek.data | sort -d | uniq +-- +SELECT DISTINCT string4 FROM tmp ORDER BY 1; + +-- [SPARK-28010] Support ORDER BY ... USING syntax +-- +-- awk '{print $3,$16,$5;}' onek.data | sort -d | uniq | +-- sort +0n -1 +1d -2 +2n -3 +-- +-- SELECT DISTINCT two, string4, ten +-- FROM tmp +-- ORDER BY two using <, string4 using <, ten using <; +SELECT DISTINCT two, string4, ten + FROM tmp + ORDER BY two ASC, string4 ASC, ten ASC; + +-- Skip the person table because there is a point data type that we don't support. +-- +-- awk '{print $2;}' person.data | +-- awk '{if(NF!=1){print $2;}else{print;}}' - emp.data | +-- awk '{if(NF!=1){print $2;}else{print;}}' - student.data | +-- awk 'BEGIN{FS=" ";}{if(NF!=1){print $5;}else{print;}}' - stud_emp.data | +-- sort -n -r | uniq +-- +-- SELECT DISTINCT p.age FROM person* p ORDER BY age using >; + +-- +-- Check mentioning same column more than once +-- + +-- EXPLAIN (VERBOSE, COSTS OFF) +-- SELECT count(*) FROM +-- (SELECT DISTINCT two, four, two FROM tenk1) ss; + +SELECT count(*) FROM + (SELECT DISTINCT two, four, two FROM tenk1) ss; + +-- +-- Also, some tests of IS DISTINCT FROM, which doesn't quite deserve its +-- very own regression file. +-- + +CREATE OR REPLACE TEMPORARY VIEW disttable AS SELECT * FROM + (VALUES (1), (2), (3), (NULL)) + AS v(f1); + +-- basic cases +SELECT f1, f1 IS DISTINCT FROM 2 as `not 2` FROM disttable; +SELECT f1, f1 IS DISTINCT FROM NULL as `not null` FROM disttable; +SELECT f1, f1 IS DISTINCT FROM f1 as `false` FROM disttable; +SELECT f1, f1 IS DISTINCT FROM f1+1 as `not null` FROM disttable; + +-- check that optimizer constant-folds it properly +SELECT 1 IS DISTINCT FROM 2 as `yes`; +SELECT 2 IS DISTINCT FROM 2 as `no`; +SELECT 2 IS DISTINCT FROM null as `yes`; +SELECT null IS DISTINCT FROM null as `no`; + +-- negated form +SELECT 1 IS NOT DISTINCT FROM 2 as `no`; +SELECT 2 IS NOT DISTINCT FROM 2 as `yes`; +SELECT 2 IS NOT DISTINCT FROM null as `no`; +SELECT null IS NOT DISTINCT FROM null as `yes`; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select_having.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select_having.sql new file mode 100644 index 000000000000..2edde8df0804 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/select_having.sql @@ -0,0 +1,55 @@ +-- +-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group +-- +-- +-- SELECT_HAVING +-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql +-- + +-- load test data +CREATE TABLE test_having (a int, b int, c string, d string) USING parquet; +INSERT INTO test_having VALUES (0, 1, 'XXXX', 'A'); +INSERT INTO test_having VALUES (1, 2, 'AAAA', 'b'); +INSERT INTO test_having VALUES (2, 2, 'AAAA', 'c'); +INSERT INTO test_having VALUES (3, 3, 'BBBB', 'D'); +INSERT INTO test_having VALUES (4, 3, 'BBBB', 'e'); +INSERT INTO test_having VALUES (5, 3, 'bbbb', 'F'); +INSERT INTO test_having VALUES (6, 4, 'cccc', 'g'); +INSERT INTO test_having VALUES (7, 4, 'cccc', 'h'); +INSERT INTO test_having VALUES (8, 4, 'CCCC', 'I'); +INSERT INTO test_having VALUES (9, 4, 'CCCC', 'j'); + +SELECT b, c FROM test_having + GROUP BY b, c HAVING count(*) = 1 ORDER BY b, c; + +-- HAVING is effectively equivalent to WHERE in this case +SELECT b, c FROM test_having + GROUP BY b, c HAVING b = 3 ORDER BY b, c; + +-- [SPARK-28386] Cannot resolve ORDER BY columns with GROUP BY and HAVING +-- SELECT lower(c), count(c) FROM test_having +-- GROUP BY lower(c) HAVING count(*) > 2 OR min(a) = max(a) +-- ORDER BY lower(c); + +SELECT c, max(a) FROM test_having + GROUP BY c HAVING count(*) > 2 OR min(a) = max(a) + ORDER BY c; + +-- test degenerate cases involving HAVING without GROUP BY +-- Per SQL spec, these should generate 0 or 1 row, even without aggregates + +SELECT min(a), max(a) FROM test_having HAVING min(a) = max(a); +SELECT min(a), max(a) FROM test_having HAVING min(a) < max(a); + +-- errors: ungrouped column references +SELECT a FROM test_having HAVING min(a) < max(a); +SELECT 1 AS one FROM test_having HAVING a > 1; + +-- the really degenerate case: need not scan table at all +SELECT 1 AS one FROM test_having HAVING 1 > 2; +SELECT 1 AS one FROM test_having HAVING 1 < 2; + +-- and just to prove that we aren't scanning the table: +SELECT 1 AS one FROM test_having WHERE 1/a = 1 HAVING 1 < 2; + +DROP TABLE test_having; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/with.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/with.sql new file mode 100644 index 000000000000..83c6fd8cbac9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/with.sql @@ -0,0 +1,1208 @@ +-- +-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group +-- +-- +-- WITH +-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/with.sql +-- +-- This test uses the generate_series(...) function which is rewritten to EXPLODE(SEQUENCE(...)) as +-- it's feature tracking ticket SPARK-27767 is closed as Won't Do. + +-- +-- Tests for common table expressions (WITH query, ... SELECT ...) +-- + +-- Basic WITH +WITH q1(x,y) AS (SELECT 1,2) +SELECT * FROM q1, q1 AS q2; + +-- Multiple uses are evaluated only once +-- [SPARK-28299] Evaluation of multiple CTE uses +-- [ORIGINAL SQL] +--SELECT count(*) FROM ( +-- WITH q1(x) AS (SELECT random() FROM generate_series(1, 5)) +-- SELECT * FROM q1 +-- UNION +-- SELECT * FROM q1 +--) ss; +SELECT count(*) FROM ( + WITH q1(x) AS (SELECT rand() FROM (SELECT EXPLODE(SEQUENCE(1, 5)))) + SELECT * FROM q1 + UNION + SELECT * FROM q1 +) ss; + +-- WITH RECURSIVE + +-- sum of 1..100 +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(n) AS ( +-- VALUES (1) +--UNION ALL +-- SELECT n+1 FROM t WHERE n < 100 +--) +--SELECT sum(n) FROM t; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(n) AS ( +-- SELECT (VALUES(1)) +--UNION ALL +-- SELECT n+1 FROM t WHERE n < 5 +--) +--SELECT * FROM t; + +-- recursive view +-- [SPARK-24497] Support recursive SQL query +--CREATE RECURSIVE VIEW nums (n) AS +-- VALUES (1) +--UNION ALL +-- SELECT n+1 FROM nums WHERE n < 5; +-- +--SELECT * FROM nums; + +-- [SPARK-24497] Support recursive SQL query +--CREATE OR REPLACE RECURSIVE VIEW nums (n) AS +-- VALUES (1) +--UNION ALL +-- SELECT n+1 FROM nums WHERE n < 6; +-- +--SELECT * FROM nums; + +-- This is an infinite loop with UNION ALL, but not with UNION +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(n) AS ( +-- SELECT 1 +--UNION +-- SELECT 10-n FROM t) +--SELECT * FROM t; + +-- This'd be an infinite loop, but outside query reads only as much as needed +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(n) AS ( +-- VALUES (1) +--UNION ALL +-- SELECT n+1 FROM t) +--SELECT * FROM t LIMIT 10; + +-- UNION case should have same property +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(n) AS ( +-- SELECT 1 +--UNION +-- SELECT n+1 FROM t) +--SELECT * FROM t LIMIT 10; + +-- Test behavior with an unknown-type literal in the WITH +-- [SPARK-28146] Support IS OF type predicate +--WITH q AS (SELECT 'foo' AS x) +--SELECT x, x IS OF (text) AS is_text FROM q; + +-- [SPARK-24497] Support recursive SQL query +-- [SPARK-28146] Support IS OF type predicate +--WITH RECURSIVE t(n) AS ( +-- SELECT 'foo' +--UNION ALL +-- SELECT n || ' bar' FROM t WHERE length(n) < 20 +--) +--SELECT n, n IS OF (text) AS is_text FROM t; + +-- In a perfect world, this would work and resolve the literal as int ... +-- but for now, we have to be content with resolving to text too soon. +-- [SPARK-24497] Support recursive SQL query +-- [SPARK-28146] Support IS OF type predicate +--WITH RECURSIVE t(n) AS ( +-- SELECT '7' +--UNION ALL +-- SELECT n+1 FROM t WHERE n < 10 +--) +--SELECT n, n IS OF (int) AS is_int FROM t; + +-- +-- Some examples with a tree +-- +-- department structure represented here is as follows: +-- +-- ROOT-+->A-+->B-+->C +-- | | +-- | +->D-+->F +-- +->E-+->G + + +-- [ORIGINAL SQL] +--CREATE TEMP TABLE department ( +-- id INTEGER PRIMARY KEY, -- department ID +-- parent_department INTEGER REFERENCES department, -- upper department ID +-- name string -- department name +--); +CREATE TABLE department ( + id INTEGER, -- department ID + parent_department INTEGER, -- upper department ID + name string -- department name +) USING parquet; + +INSERT INTO department VALUES (0, NULL, 'ROOT'); +INSERT INTO department VALUES (1, 0, 'A'); +INSERT INTO department VALUES (2, 1, 'B'); +INSERT INTO department VALUES (3, 2, 'C'); +INSERT INTO department VALUES (4, 2, 'D'); +INSERT INTO department VALUES (5, 0, 'E'); +INSERT INTO department VALUES (6, 4, 'F'); +INSERT INTO department VALUES (7, 5, 'G'); + + +-- extract all departments under 'A'. Result should be A, B, C, D and F +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE subdepartment AS +--( +-- -- non recursive term +-- SELECT name as root_name, * FROM department WHERE name = 'A' +-- +-- UNION ALL +-- +-- -- recursive term +-- SELECT sd.root_name, d.* FROM department AS d, subdepartment AS sd +-- WHERE d.parent_department = sd.id +--) +--SELECT * FROM subdepartment ORDER BY name; + +-- extract all departments under 'A' with "level" number +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE subdepartment(level, id, parent_department, name) AS +--( +-- -- non recursive term +-- SELECT 1, * FROM department WHERE name = 'A' +-- +-- UNION ALL +-- +-- -- recursive term +-- SELECT sd.level + 1, d.* FROM department AS d, subdepartment AS sd +-- WHERE d.parent_department = sd.id +--) +--SELECT * FROM subdepartment ORDER BY name; + +-- extract all departments under 'A' with "level" number. +-- Only shows level 2 or more +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE subdepartment(level, id, parent_department, name) AS +--( +-- -- non recursive term +-- SELECT 1, * FROM department WHERE name = 'A' +-- +-- UNION ALL +-- +-- -- recursive term +-- SELECT sd.level + 1, d.* FROM department AS d, subdepartment AS sd +-- WHERE d.parent_department = sd.id +--) +--SELECT * FROM subdepartment WHERE level >= 2 ORDER BY name; + +-- "RECURSIVE" is ignored if the query has no self-reference +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE subdepartment AS +--( +-- -- note lack of recursive UNION structure +-- SELECT * FROM department WHERE name = 'A' +--) +--SELECT * FROM subdepartment ORDER BY name; + +-- inside subqueries +-- [SPARK-24497] Support recursive SQL query +--SELECT count(*) FROM ( +-- WITH RECURSIVE t(n) AS ( +-- SELECT 1 UNION ALL SELECT n + 1 FROM t WHERE n < 500 +-- ) +-- SELECT * FROM t) AS t WHERE n < ( +-- SELECT count(*) FROM ( +-- WITH RECURSIVE t(n) AS ( +-- SELECT 1 UNION ALL SELECT n + 1 FROM t WHERE n < 100 +-- ) +-- SELECT * FROM t WHERE n < 50000 +-- ) AS t WHERE n < 100); + +-- use same CTE twice at different subquery levels +-- [SPARK-24497] Support recursive SQL query +--WITH q1(x,y) AS ( +-- SELECT hundred, sum(ten) FROM tenk1 GROUP BY hundred +-- ) +--SELECT count(*) FROM q1 WHERE y > (SELECT sum(y)/100 FROM q1 qsub); + +-- via a VIEW +-- [SPARK-24497] Support recursive SQL query +--CREATE TEMPORARY VIEW vsubdepartment AS +-- WITH RECURSIVE subdepartment AS +-- ( +-- -- non recursive term +-- SELECT * FROM department WHERE name = 'A' +-- UNION ALL +-- -- recursive term +-- SELECT d.* FROM department AS d, subdepartment AS sd +-- WHERE d.parent_department = sd.id +-- ) +-- SELECT * FROM subdepartment; +-- +--SELECT * FROM vsubdepartment ORDER BY name; +-- +---- Check reverse listing +--SELECT pg_get_viewdef('vsubdepartment'::regclass); +--SELECT pg_get_viewdef('vsubdepartment'::regclass, true); + +-- Another reverse-listing example +-- [SPARK-24497] Support recursive SQL query +--CREATE VIEW sums_1_100 AS +--WITH RECURSIVE t(n) AS ( +-- VALUES (1) +--UNION ALL +-- SELECT n+1 FROM t WHERE n < 100 +--) +--SELECT sum(n) FROM t; +-- +--\d+ sums_1_100 + +-- corner case in which sub-WITH gets initialized first +-- [SPARK-24497] Support recursive SQL query +--with recursive q as ( +-- select * from department +-- union all +-- (with x as (select * from q) +-- select * from x) +-- ) +--select * from q limit 24; + +-- [SPARK-24497] Support recursive SQL query +--with recursive q as ( +-- select * from department +-- union all +-- (with recursive x as ( +-- select * from department +-- union all +-- (select * from q union all select * from x) +-- ) +-- select * from x) +-- ) +--select * from q limit 32; + +-- recursive term has sub-UNION +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(i,j) AS ( +-- VALUES (1,2) +-- UNION ALL +-- SELECT t2.i, t.j+1 FROM +-- (SELECT 2 AS i UNION ALL SELECT 3 AS i) AS t2 +-- JOIN t ON (t2.i = t.i+1)) +-- +-- SELECT * FROM t; + +-- +-- different tree example +-- +-- [ORIGINAL SQL] +--CREATE TEMPORARY TABLE tree( +-- id INTEGER PRIMARY KEY, +-- parent_id INTEGER REFERENCES tree(id) +--); +CREATE TABLE tree( + id INTEGER, + parent_id INTEGER +) USING parquet; + +INSERT INTO tree +VALUES (1, NULL), (2, 1), (3,1), (4,2), (5,2), (6,2), (7,3), (8,3), + (9,4), (10,4), (11,7), (12,7), (13,7), (14, 9), (15,11), (16,11); + +-- +-- get all paths from "second level" nodes to leaf nodes +-- +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(id, path) AS ( +-- VALUES(1,ARRAY[]::integer[]) +--UNION ALL +-- SELECT tree.id, t.path || tree.id +-- FROM tree JOIN t ON (tree.parent_id = t.id) +--) +--SELECT t1.*, t2.* FROM t AS t1 JOIN t AS t2 ON +-- (t1.path[1] = t2.path[1] AND +-- array_upper(t1.path,1) = 1 AND +-- array_upper(t2.path,1) > 1) +-- ORDER BY t1.id, t2.id; + +-- just count 'em +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(id, path) AS ( +-- VALUES(1,ARRAY[]::integer[]) +--UNION ALL +-- SELECT tree.id, t.path || tree.id +-- FROM tree JOIN t ON (tree.parent_id = t.id) +--) +--SELECT t1.id, count(t2.*) FROM t AS t1 JOIN t AS t2 ON +-- (t1.path[1] = t2.path[1] AND +-- array_upper(t1.path,1) = 1 AND +-- array_upper(t2.path,1) > 1) +-- GROUP BY t1.id +-- ORDER BY t1.id; + +-- this variant tickled a whole-row-variable bug in 8.4devel +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(id, path) AS ( +-- VALUES(1,ARRAY[]::integer[]) +--UNION ALL +-- SELECT tree.id, t.path || tree.id +-- FROM tree JOIN t ON (tree.parent_id = t.id) +--) +--SELECT t1.id, t2.path, t2 FROM t AS t1 JOIN t AS t2 ON +--(t1.id=t2.id); + +-- +-- test cycle detection +-- +-- [ORIGINAL SQL] +--create temp table graph( f int, t int, label text ); +create table graph( f int, t int, label string ) USING parquet; + +insert into graph values + (1, 2, 'arc 1 -> 2'), + (1, 3, 'arc 1 -> 3'), + (2, 3, 'arc 2 -> 3'), + (1, 4, 'arc 1 -> 4'), + (4, 5, 'arc 4 -> 5'), + (5, 1, 'arc 5 -> 1'); + +-- [SPARK-24497] Support recursive SQL query +--with recursive search_graph(f, t, label, path, cycle) as ( +-- select *, array[row(g.f, g.t)], false from graph g +-- union all +-- select g.*, path || row(g.f, g.t), row(g.f, g.t) = any(path) +-- from graph g, search_graph sg +-- where g.f = sg.t and not cycle +--) +--select * from search_graph; + +-- ordering by the path column has same effect as SEARCH DEPTH FIRST +-- [SPARK-24497] Support recursive SQL query +--with recursive search_graph(f, t, label, path, cycle) as ( +-- select *, array[row(g.f, g.t)], false from graph g +-- union all +-- select g.*, path || row(g.f, g.t), row(g.f, g.t) = any(path) +-- from graph g, search_graph sg +-- where g.f = sg.t and not cycle +--) +--select * from search_graph order by path; + +-- +-- test multiple WITH queries +-- +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- y (id) AS (VALUES (1)), +-- x (id) AS (SELECT * FROM y UNION ALL SELECT id+1 FROM x WHERE id < 5) +--SELECT * FROM x; + +-- forward reference OK +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- x(id) AS (SELECT * FROM y UNION ALL SELECT id+1 FROM x WHERE id < 5), +-- y(id) AS (values (1)) +-- SELECT * FROM x; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- x(id) AS +-- (VALUES (1) UNION ALL SELECT id+1 FROM x WHERE id < 5), +-- y(id) AS +-- (VALUES (1) UNION ALL SELECT id+1 FROM y WHERE id < 10) +-- SELECT y.*, x.* FROM y LEFT JOIN x USING (id); + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- x(id) AS +-- (VALUES (1) UNION ALL SELECT id+1 FROM x WHERE id < 5), +-- y(id) AS +-- (VALUES (1) UNION ALL SELECT id+1 FROM x WHERE id < 10) +-- SELECT y.*, x.* FROM y LEFT JOIN x USING (id); + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- x(id) AS +-- (SELECT 1 UNION ALL SELECT id+1 FROM x WHERE id < 3 ), +-- y(id) AS +-- (SELECT * FROM x UNION ALL SELECT * FROM x), +-- z(id) AS +-- (SELECT * FROM x UNION ALL SELECT id+1 FROM z WHERE id < 10) +-- SELECT * FROM z; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- x(id) AS +-- (SELECT 1 UNION ALL SELECT id+1 FROM x WHERE id < 3 ), +-- y(id) AS +-- (SELECT * FROM x UNION ALL SELECT * FROM x), +-- z(id) AS +-- (SELECT * FROM y UNION ALL SELECT id+1 FROM z WHERE id < 10) +-- SELECT * FROM z; + +-- +-- Test WITH attached to a data-modifying statement +-- + +-- [ORIGINAL SQL] +--CREATE TEMPORARY TABLE y (a INTEGER); +CREATE TABLE y (a INTEGER) USING parquet; +-- [ORIGINAL SQL] +--INSERT INTO y SELECT generate_series(1, 10); +INSERT INTO y SELECT EXPLODE(SEQUENCE(1, 10)); + +-- [SPARK-28147] Support RETURNING clause +--WITH t AS ( +-- SELECT a FROM y +--) +--INSERT INTO y +--SELECT a+20 FROM t RETURNING *; +-- +--SELECT * FROM y; + +-- [NOTE] Spark SQL doesn't support UPDATE statement +--WITH t AS ( +-- SELECT a FROM y +--) +--UPDATE y SET a = y.a-10 FROM t WHERE y.a > 20 AND t.a = y.a RETURNING y.a; +-- +--SELECT * FROM y; + +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH RECURSIVE t(a) AS ( +-- SELECT 11 +-- UNION ALL +-- SELECT a+1 FROM t WHERE a < 50 +--) +--DELETE FROM y USING t WHERE t.a = y.a RETURNING y.a; +-- +--SELECT * FROM y; + +DROP TABLE y; + +-- +-- error cases +-- + +-- INTERSECT +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 INTERSECT SELECT n+1 FROM x) +-- SELECT * FROM x; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 INTERSECT ALL SELECT n+1 FROM x) +-- SELECT * FROM x; + +-- EXCEPT +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 EXCEPT SELECT n+1 FROM x) +-- SELECT * FROM x; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 EXCEPT ALL SELECT n+1 FROM x) +-- SELECT * FROM x; + +-- no non-recursive term +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT n FROM x) +-- SELECT * FROM x; + +-- recursive term in the left hand side (strictly speaking, should allow this) +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT n FROM x UNION ALL SELECT 1) +-- SELECT * FROM x; + +-- [ORIGINAL SQL] +--CREATE TEMPORARY TABLE y (a INTEGER); +CREATE TABLE y (a INTEGER) USING parquet; +-- [ORIGINAL SQL] +--INSERT INTO y SELECT generate_series(1, 10); +INSERT INTO y SELECT EXPLODE(SEQUENCE(1, 10)); + +-- LEFT JOIN + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT a FROM y WHERE a = 1 +-- UNION ALL +-- SELECT x.n+1 FROM y LEFT JOIN x ON x.n = y.a WHERE n < 10) +--SELECT * FROM x; + +-- RIGHT JOIN +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT a FROM y WHERE a = 1 +-- UNION ALL +-- SELECT x.n+1 FROM x RIGHT JOIN y ON x.n = y.a WHERE n < 10) +--SELECT * FROM x; + +-- FULL JOIN +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT a FROM y WHERE a = 1 +-- UNION ALL +-- SELECT x.n+1 FROM x FULL JOIN y ON x.n = y.a WHERE n < 10) +--SELECT * FROM x; + +-- subquery +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM x +-- WHERE n IN (SELECT * FROM x)) +-- SELECT * FROM x; + +-- aggregate functions +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 UNION ALL SELECT count(*) FROM x) +-- SELECT * FROM x; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 UNION ALL SELECT sum(n) FROM x) +-- SELECT * FROM x; + +-- ORDER BY +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM x ORDER BY 1) +-- SELECT * FROM x; + +-- LIMIT/OFFSET +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM x LIMIT 10 OFFSET 1) +-- SELECT * FROM x; + +-- FOR UPDATE +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(n) AS (SELECT 1 UNION ALL SELECT n+1 FROM x FOR UPDATE) +-- SELECT * FROM x; + +-- target list has a recursive query name +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE x(id) AS (values (1) +-- UNION ALL +-- SELECT (SELECT * FROM x) FROM x WHERE id < 5 +--) SELECT * FROM x; + +-- mutual recursive query (not implemented) +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- x (id) AS (SELECT 1 UNION ALL SELECT id+1 FROM y WHERE id < 5), +-- y (id) AS (SELECT 1 UNION ALL SELECT id+1 FROM x WHERE id < 5) +--SELECT * FROM x; + +-- non-linear recursion is not allowed +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE foo(i) AS +-- (values (1) +-- UNION ALL +-- (SELECT i+1 FROM foo WHERE i < 10 +-- UNION ALL +-- SELECT i+1 FROM foo WHERE i < 5) +--) SELECT * FROM foo; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE foo(i) AS +-- (values (1) +-- UNION ALL +-- SELECT * FROM +-- (SELECT i+1 FROM foo WHERE i < 10 +-- UNION ALL +-- SELECT i+1 FROM foo WHERE i < 5) AS t +--) SELECT * FROM foo; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE foo(i) AS +-- (values (1) +-- UNION ALL +-- (SELECT i+1 FROM foo WHERE i < 10 +-- EXCEPT +-- SELECT i+1 FROM foo WHERE i < 5) +--) SELECT * FROM foo; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE foo(i) AS +-- (values (1) +-- UNION ALL +-- (SELECT i+1 FROM foo WHERE i < 10 +-- INTERSECT +-- SELECT i+1 FROM foo WHERE i < 5) +--) SELECT * FROM foo; + +-- Wrong type induced from non-recursive term +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE foo(i) AS +-- (SELECT i FROM (VALUES(1),(2)) t(i) +-- UNION ALL +-- SELECT (i+1)::numeric(10,0) FROM foo WHERE i < 10) +--SELECT * FROM foo; + +-- rejects different typmod, too (should we allow this?) +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE foo(i) AS +-- (SELECT i::numeric(3,0) FROM (VALUES(1),(2)) t(i) +-- UNION ALL +-- SELECT (i+1)::numeric(10,0) FROM foo WHERE i < 10) +--SELECT * FROM foo; + +-- [NOTE] Spark SQL doesn't support RULEs +-- disallow OLD/NEW reference in CTE +--CREATE TABLE x (n integer) USING parquet; +--CREATE RULE r2 AS ON UPDATE TO x DO INSTEAD +-- WITH t AS (SELECT OLD.*) UPDATE y SET a = t.n FROM t; + +-- +-- test for bug #4902 +-- +-- [SPARK-28296] Improved VALUES support +--with cte(foo) as ( values(42) ) values((select foo from cte)); +with cte(foo) as ( select 42 ) select * from ((select foo from cte)) q; + +-- test CTE referencing an outer-level variable (to see that changed-parameter +-- signaling still works properly after fixing this bug) +-- [SPARK-28296] Improved VALUES support +-- [SPARK-28297] Handling outer links in CTE subquery expressions +--select ( with cte(foo) as ( values(f1) ) +-- select (select foo from cte) ) +--from int4_tbl; + +-- [SPARK-28296] Improved VALUES support +-- [SPARK-28297] Handling outer links in CTE subquery expressions +--select ( with cte(foo) as ( values(f1) ) +-- values((select foo from cte)) ) +--from int4_tbl; + +-- +-- test for nested-recursive-WITH bug +-- +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(j) AS ( +-- WITH RECURSIVE s(i) AS ( +-- VALUES (1) +-- UNION ALL +-- SELECT i+1 FROM s WHERE i < 10 +-- ) +-- SELECT i FROM s +-- UNION ALL +-- SELECT j+1 FROM t WHERE j < 10 +--) +--SELECT * FROM t; + +-- +-- test WITH attached to intermediate-level set operation +-- + +WITH outermost(x) AS ( + SELECT 1 + UNION (WITH innermost as (SELECT 2) + SELECT * FROM innermost + UNION SELECT 3) +) +SELECT * FROM outermost ORDER BY 1; + +WITH outermost(x) AS ( + SELECT 1 + UNION (WITH innermost as (SELECT 2) + SELECT * FROM outermost -- fail + UNION SELECT * FROM innermost) +) +SELECT * FROM outermost ORDER BY 1; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE outermost(x) AS ( +-- SELECT 1 +-- UNION (WITH innermost as (SELECT 2) +-- SELECT * FROM outermost +-- UNION SELECT * FROM innermost) +--) +--SELECT * FROM outermost ORDER BY 1; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE outermost(x) AS ( +-- WITH innermost as (SELECT 2 FROM outermost) -- fail +-- SELECT * FROM innermost +-- UNION SELECT * from outermost +--) +--SELECT * FROM outermost ORDER BY 1; + +-- +-- This test will fail with the old implementation of PARAM_EXEC parameter +-- assignment, because the "q1" Var passed down to A's targetlist subselect +-- looks exactly like the "A.id" Var passed down to C's subselect, causing +-- the old code to give them the same runtime PARAM_EXEC slot. But the +-- lifespans of the two parameters overlap, thanks to B also reading A. +-- + +-- [SPARK-27878] Support ARRAY(sub-SELECT) expressions +--with +--A as ( select q2 as id, (select q1) as x from int8_tbl ), +--B as ( select id, row_number() over (partition by id) as r from A ), +--C as ( select A.id, array(select B.id from B where B.id = A.id) from A ) +--select * from C; + +-- +-- Test CTEs read in non-initialization orders +-- + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- tab(id_key,link) AS (VALUES (1,17), (2,17), (3,17), (4,17), (6,17), (5,17)), +-- iter (id_key, row_type, link) AS ( +-- SELECT 0, 'base', 17 +-- UNION ALL ( +-- WITH remaining(id_key, row_type, link, min) AS ( +-- SELECT tab.id_key, 'true'::text, iter.link, MIN(tab.id_key) OVER () +-- FROM tab INNER JOIN iter USING (link) +-- WHERE tab.id_key > iter.id_key +-- ), +-- first_remaining AS ( +-- SELECT id_key, row_type, link +-- FROM remaining +-- WHERE id_key=min +-- ), +-- effect AS ( +-- SELECT tab.id_key, 'new'::text, tab.link +-- FROM first_remaining e INNER JOIN tab ON e.id_key=tab.id_key +-- WHERE e.row_type = 'false' +-- ) +-- SELECT * FROM first_remaining +-- UNION ALL SELECT * FROM effect +-- ) +-- ) +--SELECT * FROM iter; + +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE +-- tab(id_key,link) AS (VALUES (1,17), (2,17), (3,17), (4,17), (6,17), (5,17)), +-- iter (id_key, row_type, link) AS ( +-- SELECT 0, 'base', 17 +-- UNION ( +-- WITH remaining(id_key, row_type, link, min) AS ( +-- SELECT tab.id_key, 'true'::text, iter.link, MIN(tab.id_key) OVER () +-- FROM tab INNER JOIN iter USING (link) +-- WHERE tab.id_key > iter.id_key +-- ), +-- first_remaining AS ( +-- SELECT id_key, row_type, link +-- FROM remaining +-- WHERE id_key=min +-- ), +-- effect AS ( +-- SELECT tab.id_key, 'new'::text, tab.link +-- FROM first_remaining e INNER JOIN tab ON e.id_key=tab.id_key +-- WHERE e.row_type = 'false' +-- ) +-- SELECT * FROM first_remaining +-- UNION ALL SELECT * FROM effect +-- ) +-- ) +--SELECT * FROM iter; + +-- +-- Data-modifying statements in WITH +-- + +-- INSERT ... RETURNING +-- [SPARK-28147] Support RETURNING clause +--WITH t AS ( +-- INSERT INTO y +-- VALUES +-- (11), +-- (12), +-- (13), +-- (14), +-- (15), +-- (16), +-- (17), +-- (18), +-- (19), +-- (20) +-- RETURNING * +--) +--SELECT * FROM t; +-- +--SELECT * FROM y; + +-- UPDATE ... RETURNING +-- [NOTE] Spark SQL doesn't support UPDATE statement +--WITH t AS ( +-- UPDATE y +-- SET a=a+1 +-- RETURNING * +--) +--SELECT * FROM t; +-- +--SELECT * FROM y; + +-- DELETE ... RETURNING +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH t AS ( +-- DELETE FROM y +-- WHERE a <= 10 +-- RETURNING * +--) +--SELECT * FROM t; +-- +--SELECT * FROM y; + +-- forward reference +-- [NOTE] Spark SQL doesn't support UPDATE statement +--WITH RECURSIVE t AS ( +-- INSERT INTO y +-- SELECT a+5 FROM t2 WHERE a > 5 +-- RETURNING * +--), t2 AS ( +-- UPDATE y SET a=a-11 RETURNING * +--) +--SELECT * FROM t +--UNION ALL +--SELECT * FROM t2; +-- +--SELECT * FROM y; + +-- unconditional DO INSTEAD rule +-- [NOTE] Spark SQL doesn't support RULEs +--CREATE RULE y_rule AS ON DELETE TO y DO INSTEAD +-- INSERT INTO y VALUES(42) RETURNING *; + +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH t AS ( +-- DELETE FROM y RETURNING * +--) +--SELECT * FROM t; +-- +--SELECT * FROM y; + +--DROP RULE y_rule ON y; + +-- check merging of outer CTE with CTE in a rule action +--CREATE TEMP TABLE bug6051 AS +-- select i from generate_series(1,3) as t(i); + +--SELECT * FROM bug6051; + +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH t1 AS ( DELETE FROM bug6051 RETURNING * ) +--INSERT INTO bug6051 SELECT * FROM t1; +-- +--SELECT * FROM bug6051; + +-- [NOTE] Spark SQL doesn't support RULEs +--CREATE TEMP TABLE bug6051_2 (i int); +-- +--CREATE RULE bug6051_ins AS ON INSERT TO bug6051 DO INSTEAD +-- INSERT INTO bug6051_2 +-- SELECT NEW.i; + +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH t1 AS ( DELETE FROM bug6051 RETURNING * ) +--INSERT INTO bug6051 SELECT * FROM t1; +-- +--SELECT * FROM bug6051; +--SELECT * FROM bug6051_2; + +-- a truly recursive CTE in the same list +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t(a) AS ( +-- SELECT 0 +-- UNION ALL +-- SELECT a+1 FROM t WHERE a+1 < 5 +--), t2 as ( +-- INSERT INTO y +-- SELECT * FROM t RETURNING * +--) +--SELECT * FROM t2 JOIN y USING (a) ORDER BY a; +-- +--SELECT * FROM y; + +-- data-modifying WITH in a modifying statement +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH t AS ( +-- DELETE FROM y +-- WHERE a <= 10 +-- RETURNING * +--) +--INSERT INTO y SELECT -a FROM t RETURNING *; +-- +--SELECT * FROM y; + +-- check that WITH query is run to completion even if outer query isn't +-- [NOTE] Spark SQL doesn't support UPDATE statement +--WITH t AS ( +-- UPDATE y SET a = a * 100 RETURNING * +--) +--SELECT * FROM t LIMIT 10; +-- +--SELECT * FROM y; + +-- data-modifying WITH containing INSERT...ON CONFLICT DO UPDATE +-- [ORIGINAL SQL] +--CREATE TABLE withz AS SELECT i AS k, (i || ' v')::text v FROM generate_series(1, 16, 3) i; +CREATE TABLE withz USING parquet AS SELECT i AS k, CAST(i || ' v' AS string) v FROM (SELECT EXPLODE(SEQUENCE(1, 16, 3)) i); +-- [NOTE] Spark SQL doesn't support UNIQUE constraints +--ALTER TABLE withz ADD UNIQUE (k); + +-- [NOTE] Spark SQL doesn't support UPDATE statement +--WITH t AS ( +-- INSERT INTO withz SELECT i, 'insert' +-- FROM generate_series(0, 16) i +-- ON CONFLICT (k) DO UPDATE SET v = withz.v || ', now update' +-- RETURNING * +--) +--SELECT * FROM t JOIN y ON t.k = y.a ORDER BY a, k; + +-- Test EXCLUDED.* reference within CTE +-- [NOTE] Spark SQL doesn't support ON CONFLICT clause +--WITH aa AS ( +-- INSERT INTO withz VALUES(1, 5) ON CONFLICT (k) DO UPDATE SET v = EXCLUDED.v +-- WHERE withz.k != EXCLUDED.k +-- RETURNING * +--) +--SELECT * FROM aa; + +-- New query/snapshot demonstrates side-effects of previous query. +SELECT * FROM withz ORDER BY k; + +-- +-- Ensure subqueries within the update clause work, even if they +-- reference outside values +-- +-- [NOTE] Spark SQL doesn't support ON CONFLICT clause +--WITH aa AS (SELECT 1 a, 2 b) +--INSERT INTO withz VALUES(1, 'insert') +--ON CONFLICT (k) DO UPDATE SET v = (SELECT b || ' update' FROM aa WHERE a = 1 LIMIT 1); +--WITH aa AS (SELECT 1 a, 2 b) +--INSERT INTO withz VALUES(1, 'insert') +--ON CONFLICT (k) DO UPDATE SET v = ' update' WHERE withz.k = (SELECT a FROM aa); +--WITH aa AS (SELECT 1 a, 2 b) +--INSERT INTO withz VALUES(1, 'insert') +--ON CONFLICT (k) DO UPDATE SET v = (SELECT b || ' update' FROM aa WHERE a = 1 LIMIT 1); +--WITH aa AS (SELECT 'a' a, 'b' b UNION ALL SELECT 'a' a, 'b' b) +--INSERT INTO withz VALUES(1, 'insert') +--ON CONFLICT (k) DO UPDATE SET v = (SELECT b || ' update' FROM aa WHERE a = 'a' LIMIT 1); +--WITH aa AS (SELECT 1 a, 2 b) +--INSERT INTO withz VALUES(1, (SELECT b || ' insert' FROM aa WHERE a = 1 )) +--ON CONFLICT (k) DO UPDATE SET v = (SELECT b || ' update' FROM aa WHERE a = 1 LIMIT 1); + +-- Update a row more than once, in different parts of a wCTE. That is +-- an allowed, presumably very rare, edge case, but since it was +-- broken in the past, having a test seems worthwhile. +-- [NOTE] Spark SQL doesn't support ON CONFLICT clause +--WITH simpletup AS ( +-- SELECT 2 k, 'Green' v), +--upsert_cte AS ( +-- INSERT INTO withz VALUES(2, 'Blue') ON CONFLICT (k) DO +-- UPDATE SET (k, v) = (SELECT k, v FROM simpletup WHERE simpletup.k = withz.k) +-- RETURNING k, v) +--INSERT INTO withz VALUES(2, 'Red') ON CONFLICT (k) DO +--UPDATE SET (k, v) = (SELECT k, v FROM upsert_cte WHERE upsert_cte.k = withz.k) +--RETURNING k, v; + +DROP TABLE withz; + +-- check that run to completion happens in proper ordering + +TRUNCATE TABLE y; +-- [ORIGINAL SQL] +--INSERT INTO y SELECT generate_series(1, 3); +INSERT INTO y SELECT EXPLODE(SEQUENCE(1, 3)); +-- [ORIGINAL SQL] +--CREATE TEMPORARY TABLE yy (a INTEGER); +CREATE TABLE yy (a INTEGER) USING parquet; + +-- [SPARK-24497] Support recursive SQL query +-- [SPARK-28147] Support RETURNING clause +--WITH RECURSIVE t1 AS ( +-- INSERT INTO y SELECT * FROM y RETURNING * +--), t2 AS ( +-- INSERT INTO yy SELECT * FROM t1 RETURNING * +--) +--SELECT 1; + +SELECT * FROM y; +SELECT * FROM yy; + +-- [SPARK-24497] Support recursive SQL query +-- [SPARK-28147] Support RETURNING clause +--WITH RECURSIVE t1 AS ( +-- INSERT INTO yy SELECT * FROM t2 RETURNING * +--), t2 AS ( +-- INSERT INTO y SELECT * FROM y RETURNING * +--) +--SELECT 1; + +SELECT * FROM y; +SELECT * FROM yy; + +-- [NOTE] Spark SQL doesn't support TRIGGERs +-- triggers +-- +--TRUNCATE TABLE y; +--INSERT INTO y SELECT generate_series(1, 10); +-- +--CREATE FUNCTION y_trigger() RETURNS trigger AS $$ +--begin +-- raise notice 'y_trigger: a = %', new.a; +-- return new; +--end; +--$$ LANGUAGE plpgsql; +-- +-- +--CREATE TRIGGER y_trig BEFORE INSERT ON y FOR EACH ROW +-- EXECUTE PROCEDURE y_trigger(); +-- +--WITH t AS ( +-- INSERT INTO y +-- VALUES +-- (21), +-- (22), +-- (23) +-- RETURNING * +--) +--SELECT * FROM t; +-- +--SELECT * FROM y; +-- +--DROP TRIGGER y_trig ON y; +-- +--CREATE TRIGGER y_trig AFTER INSERT ON y FOR EACH ROW +-- EXECUTE PROCEDURE y_trigger(); +-- +--WITH t AS ( +-- INSERT INTO y +-- VALUES +-- (31), +-- (32), +-- (33) +-- RETURNING * +--) +--SELECT * FROM t LIMIT 1; +-- +--SELECT * FROM y; +-- +--DROP TRIGGER y_trig ON y; +-- +--CREATE OR REPLACE FUNCTION y_trigger() RETURNS trigger AS $$ +--begin +-- raise notice 'y_trigger'; +-- return null; +--end; +--$$ LANGUAGE plpgsql; +-- +--CREATE TRIGGER y_trig AFTER INSERT ON y FOR EACH STATEMENT +-- EXECUTE PROCEDURE y_trigger(); +-- +--WITH t AS ( +-- INSERT INTO y +-- VALUES +-- (41), +-- (42), +-- (43) +-- RETURNING * +--) +--SELECT * FROM t; +-- +--SELECT * FROM y; +-- +--DROP TRIGGER y_trig ON y; +--DROP FUNCTION y_trigger(); + +-- WITH attached to inherited UPDATE or DELETE + +-- [ORIGINAL SQL] +--CREATE TEMP TABLE parent ( id int, val text ); +CREATE TABLE parent ( id int, val string ) USING parquet; +-- [NOTE] Spark SQL doesn't support INHERITS clause +--CREATE TEMP TABLE child1 ( ) INHERITS ( parent ); +-- [NOTE] Spark SQL doesn't support INHERITS clause +--CREATE TEMP TABLE child2 ( ) INHERITS ( parent ); + +INSERT INTO parent VALUES ( 1, 'p1' ); +--INSERT INTO child1 VALUES ( 11, 'c11' ),( 12, 'c12' ); +--INSERT INTO child2 VALUES ( 23, 'c21' ),( 24, 'c22' ); + +-- [NOTE] Spark SQL doesn't support UPDATE statement +--WITH rcte AS ( SELECT sum(id) AS totalid FROM parent ) +--UPDATE parent SET id = id + totalid FROM rcte; + +SELECT * FROM parent; + +-- [SPARK-28147] Support RETURNING clause +--WITH wcte AS ( INSERT INTO child1 VALUES ( 42, 'new' ) RETURNING id AS newid ) +--UPDATE parent SET id = id + newid FROM wcte; +-- +--SELECT * FROM parent; + +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH rcte AS ( SELECT max(id) AS maxid FROM parent ) +--DELETE FROM parent USING rcte WHERE id = maxid; + +SELECT * FROM parent; + +-- [NOTE] Spark SQL doesn't support DELETE statement +--WITH wcte AS ( INSERT INTO child2 VALUES ( 42, 'new2' ) RETURNING id AS newid ) +--DELETE FROM parent USING wcte WHERE id = newid; +-- +--SELECT * FROM parent; + +-- check EXPLAIN VERBOSE for a wCTE with RETURNING + +-- [NOTE] Spark SQL doesn't support DELETE statement +--EXPLAIN (VERBOSE, COSTS OFF) +--WITH wcte AS ( INSERT INTO int8_tbl VALUES ( 42, 47 ) RETURNING q2 ) +--DELETE FROM a USING wcte WHERE aa = q2; + +-- error cases + +-- data-modifying WITH tries to use its own output +-- [SPARK-24497] Support recursive SQL query +--WITH RECURSIVE t AS ( +-- INSERT INTO y +-- SELECT * FROM t +--) +--VALUES(FALSE); + +-- no RETURNING in a referenced data-modifying WITH +-- [SPARK-24497] Support recursive SQL query +--WITH t AS ( +-- INSERT INTO y VALUES(0) +--) +--SELECT * FROM t; + +-- data-modifying WITH allowed only at the top level +-- [SPARK-28147] Support RETURNING clause +--SELECT * FROM ( +-- WITH t AS (UPDATE y SET a=a+1 RETURNING *) +-- SELECT * FROM t +--) ss; + +-- most variants of rules aren't allowed +-- [NOTE] Spark SQL doesn't support RULEs +--CREATE RULE y_rule AS ON INSERT TO y WHERE a=0 DO INSTEAD DELETE FROM y; +--WITH t AS ( +-- INSERT INTO y VALUES(0) +--) +--VALUES(FALSE); +--DROP RULE y_rule ON y; + +-- check that parser lookahead for WITH doesn't cause any odd behavior +create table foo (with baz); -- fail, WITH is a reserved word +create table foo (with ordinality); -- fail, WITH is a reserved word +with ordinality as (select 1 as x) select * from ordinality; + +-- check sane response to attempt to modify CTE relation +WITH test AS (SELECT 42) INSERT INTO test VALUES (1); + +-- check response to attempt to modify table with same name as a CTE (perhaps +-- surprisingly it works, because CTEs don't hide tables from data-modifying +-- statements) +-- [ORIGINAL SQL] +--create temp table test (i int); +create table test (i int) USING parquet; +with test as (select 42) insert into test select * from test; +select * from test; +drop table test; + +-- +-- Clean up +-- + +DROP TABLE department; +DROP TABLE tree; +DROP TABLE graph; +DROP TABLE y; +DROP TABLE yy; +DROP TABLE parent; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part1.sql index 33b61666ca4d..d829a5c1159f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part1.sql @@ -9,8 +9,6 @@ -- SET extra_float_digits = 0; -- This test file was converted from pgSQL/aggregates_part1.sql. --- Note that currently registered UDF returns a string. So there are some differences, for instance --- in string cast within UDF in Scala and Python. SELECT avg(udf(four)) AS avg_1 FROM onek; @@ -31,13 +29,13 @@ SELECT udf(udf(sum(b))) AS avg_431_773 FROM aggtest; SELECT udf(max(four)) AS max_3 FROM onek; SELECT max(udf(a)) AS max_100 FROM aggtest; -SELECT CAST(udf(udf(max(aggtest.b))) AS int) AS max_324_78 FROM aggtest; +SELECT udf(udf(max(aggtest.b))) AS max_324_78 FROM aggtest; -- `student` has a column with data type POINT, which is not supported by Spark [SPARK-27766] -- SELECT max(student.gpa) AS max_3_7 FROM student; -SELECT CAST(stddev_pop(udf(b)) AS int) FROM aggtest; +SELECT stddev_pop(udf(b)) FROM aggtest; SELECT udf(stddev_samp(b)) FROM aggtest; -SELECT CAST(var_pop(udf(b)) as int) FROM aggtest; +SELECT var_pop(udf(b)) FROM aggtest; SELECT udf(var_samp(b)) FROM aggtest; SELECT udf(stddev_pop(CAST(b AS Decimal(38,0)))) FROM aggtest; @@ -89,7 +87,7 @@ FROM (VALUES (7000000000005), (7000000000007)) v(x); -- SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; -- SELECT regr_r2(b, a) FROM aggtest; -- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; -SELECT CAST(udf(covar_pop(b, udf(a))) AS int), CAST(covar_samp(udf(b), a) as int) FROM aggtest; +SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest; SELECT corr(b, udf(a)) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part2.sql index 57491a32c48f..5636537398a8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-aggregates_part2.sql @@ -6,8 +6,6 @@ -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L145-L350 -- -- This test file was converted from pgSQL/aggregates_part2.sql. --- Note that currently registered UDF returns a string. So there are some differences, for instance --- in string cast within UDF in Scala and Python. create temporary view int4_tbl as select * from values (0), diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-case.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-case.sql index b05c21d24b36..1865ee94ec1f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-case.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/pgSQL/udf-case.sql @@ -6,14 +6,8 @@ -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/case.sql -- Test the CASE statement -- --- This test suite contains two Cartesian products without using explicit CROSS JOIN syntax. --- Thus, we set spark.sql.crossJoin.enabled to true. - -- This test file was converted from pgSQL/case.sql. --- Note that currently registered UDF returns a string. So there are some differences, for instance --- in string cast within UDF in Scala and Python. -set spark.sql.crossJoin.enabled=true; CREATE TABLE CASE_TBL ( i integer, f double @@ -42,7 +36,7 @@ INSERT INTO CASE2_TBL VALUES (NULL, -6); SELECT '3' AS `One`, CASE - WHEN CAST(udf(1 < 2) AS boolean) THEN 3 + WHEN udf(1 < 2) THEN 3 END AS `Simple WHEN`; SELECT '' AS `One`, @@ -64,7 +58,7 @@ SELECT udf('4') AS `One`, SELECT udf('6') AS `One`, CASE - WHEN CAST(udf(1 > 2) AS boolean) THEN 3 + WHEN udf(1 > 2) THEN 3 WHEN udf(4) < 5 THEN 6 ELSE 7 END AS `Two WHEN with default`; @@ -74,7 +68,7 @@ SELECT '7' AS `None`, END AS `NULL on no matches`; -- Constant-expression folding shouldn't evaluate unreachable subexpressions -SELECT CASE WHEN CAST(udf(1=0) AS boolean) THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END; +SELECT CASE WHEN udf(1=0) THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END; SELECT CASE 1 WHEN 0 THEN 1/udf(0) WHEN 1 THEN 1 ELSE 2/0 END; -- [SPARK-27923] PostgreSQL throws an exception but Spark SQL is NULL @@ -146,7 +140,7 @@ SELECT udf('') AS Five, NULLIF(a.i,b.i) AS `NULLIF(a.i,b.i)`, SELECT '' AS `Two`, * FROM CASE_TBL a, CASE2_TBL b - WHERE CAST(udf(COALESCE(f,b.i) = 2) AS boolean); + WHERE udf(COALESCE(f,b.i) = 2); -- We don't support update now. -- @@ -269,4 +263,3 @@ SELECT CASE DROP TABLE CASE_TBL; DROP TABLE CASE2_TBL; -set spark.sql.crossJoin.enabled=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-having.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-having.sql index 6ae34ae589fa..ff8573ad7e56 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-having.sql @@ -1,6 +1,4 @@ -- This test file was converted from having.sql. --- Note that currently registered UDF returns a string. So there are some differences, for instance --- in string cast within UDF in Scala and Python. create temporary view hav as select * from values ("one", 1), diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql index 686268317800..e5eb812d69a1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql @@ -4,8 +4,6 @@ --SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false -- This test file was converted from natural-join.sql. --- Note that currently registered UDF returns a string. So there are some differences, for instance --- in string cast within UDF in Scala and Python. create temporary view nt1 as select * from values ("one", 1), diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-special-values.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-special-values.sql new file mode 100644 index 000000000000..9cd15369bb16 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-special-values.sql @@ -0,0 +1,8 @@ +-- This file tests special values such as NaN, Infinity and NULL. + +SELECT udf(x) FROM (VALUES (1), (2), (NULL)) v(x); +SELECT udf(x) FROM (VALUES ('A'), ('B'), (NULL)) v(x); +SELECT udf(x) FROM (VALUES ('NaN'), ('1'), ('2')) v(x); +SELECT udf(x) FROM (VALUES ('Infinity'), ('1'), ('2')) v(x); +SELECT udf(x) FROM (VALUES ('-Infinity'), ('1'), ('2')) v(x); +SELECT udf(x) FROM (VALUES 0.00000001, 0.00000002, 0.00000003) v(x); diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-udaf.sql new file mode 100644 index 000000000000..7f5c2237499d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-udaf.sql @@ -0,0 +1,18 @@ +-- This test file was converted from udaf.sql. + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1); + +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg'; + +SELECT default.myDoubleAvg(udf(int_col1)) as my_avg, udf(default.myDoubleAvg(udf(int_col1))) as my_avg2, udf(default.myDoubleAvg(int_col1)) as my_avg3 from t1; + +SELECT default.myDoubleAvg(udf(int_col1), udf(3)) as my_avg from t1; + +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; + +SELECT default.udaf1(udf(int_col1)) as udaf1 from t1; + +DROP FUNCTION myDoubleAvg; +DROP FUNCTION udaf1; diff --git a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out new file mode 100644 index 000000000000..5193e2536c0c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out @@ -0,0 +1,208 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +create temporary view t as select * from values 0, 1, 2 as t(id) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values 0, 1 as t(id) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SET spark.sql.legacy.ctePrecedence.enabled=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.legacy.ctePrecedence.enabled true + + +-- !query 3 +WITH t as ( + WITH t2 AS (SELECT 1) + SELECT * FROM t2 +) +SELECT * FROM t +-- !query 3 schema +struct<1:int> +-- !query 3 output +1 + + +-- !query 4 +SELECT max(c) FROM ( + WITH t(c) AS (SELECT 1) + SELECT * FROM t +) +-- !query 4 schema +struct +-- !query 4 output +1 + + +-- !query 5 +SELECT ( + WITH t AS (SELECT 1) + SELECT * FROM t +) +-- !query 5 schema +struct +-- !query 5 output +1 + + +-- !query 6 +WITH + t AS (SELECT 1), + t2 AS ( + WITH t AS (SELECT 2) + SELECT * FROM t + ) +SELECT * FROM t2 +-- !query 6 schema +struct<1:int> +-- !query 6 output +1 + + +-- !query 7 +WITH + t(c) AS (SELECT 1), + t2 AS ( + SELECT ( + SELECT max(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t + ) + ) + ) +SELECT * FROM t2 +-- !query 7 schema +struct +-- !query 7 output +1 + + +-- !query 8 +WITH + t AS (SELECT 1), + t2 AS ( + WITH t AS (SELECT 2), + t2 AS ( + WITH t AS (SELECT 3) + SELECT * FROM t + ) + SELECT * FROM t2 + ) +SELECT * FROM t2 +-- !query 8 schema +struct<2:int> +-- !query 8 output +2 + + +-- !query 9 +WITH t(c) AS (SELECT 1) +SELECT max(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t +) +-- !query 9 schema +struct +-- !query 9 output +2 + + +-- !query 10 +WITH t(c) AS (SELECT 1) +SELECT sum(c) FROM ( + SELECT max(c) AS c FROM ( + WITH t(c) AS (SELECT 2) + SELECT * FROM t + ) +) +-- !query 10 schema +struct +-- !query 10 output +2 + + +-- !query 11 +WITH t(c) AS (SELECT 1) +SELECT sum(c) FROM ( + WITH t(c) AS (SELECT 2) + SELECT max(c) AS c FROM ( + WITH t(c) AS (SELECT 3) + SELECT * FROM t + ) +) +-- !query 11 schema +struct +-- !query 11 output +3 + + +-- !query 12 +WITH t AS (SELECT 1) +SELECT ( + WITH t AS (SELECT 2) + SELECT * FROM t +) +-- !query 12 schema +struct +-- !query 12 output +1 + + +-- !query 13 +WITH t AS (SELECT 1) +SELECT ( + SELECT ( + WITH t AS (SELECT 2) + SELECT * FROM t + ) +) +-- !query 13 schema +struct +-- !query 13 output +1 + + +-- !query 14 +WITH t AS (SELECT 1) +SELECT ( + WITH t AS (SELECT 2) + SELECT ( + WITH t AS (SELECT 3) + SELECT * FROM t + ) +) +-- !query 14 schema +struct +-- !query 14 output +1 + + +-- !query 15 +DROP VIEW IF EXISTS t +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +DROP VIEW IF EXISTS t2 +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out index 9e90908d92fa..b7dd76c72520 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out @@ -204,9 +204,9 @@ WITH ) SELECT * FROM t2 -- !query 16 schema -struct<1:int> +struct<2:int> -- !query 16 output -1 +2 -- !query 17 @@ -224,7 +224,7 @@ SELECT * FROM t2 -- !query 17 schema struct -- !query 17 output -1 +2 -- !query 18 @@ -240,9 +240,9 @@ WITH ) SELECT * FROM t2 -- !query 18 schema -struct<2:int> +struct<3:int> -- !query 18 output -2 +3 -- !query 19 @@ -295,7 +295,7 @@ SELECT ( -- !query 22 schema struct -- !query 22 output -1 +2 -- !query 23 @@ -309,7 +309,7 @@ SELECT ( -- !query 23 schema struct -- !query 23 output -1 +2 -- !query 24 @@ -324,7 +324,7 @@ SELECT ( -- !query 24 schema struct -- !query 24 output -1 +3 -- !query 25 diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out index 99c42ec2eb6c..b7cf3a9f1ad8 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/boolean.sql.out @@ -19,7 +19,7 @@ true -- !query 2 -SELECT false AS false +SELECT false AS `false` -- !query 2 schema struct -- !query 2 output @@ -35,7 +35,7 @@ true -- !query 4 -SELECT boolean(' f ') AS false +SELECT boolean(' f ') AS `false` -- !query 4 schema struct -- !query 4 output @@ -59,7 +59,7 @@ NULL -- !query 7 -SELECT boolean('false') AS false +SELECT boolean('false') AS `false` -- !query 7 schema struct -- !query 7 output @@ -99,7 +99,7 @@ NULL -- !query 12 -SELECT boolean('n') AS false +SELECT boolean('n') AS `false` -- !query 12 schema struct -- !query 12 output @@ -107,7 +107,7 @@ false -- !query 13 -SELECT boolean('no') AS false +SELECT boolean('no') AS `false` -- !query 13 schema struct -- !query 13 output @@ -131,7 +131,7 @@ NULL -- !query 16 -SELECT boolean('off') AS false +SELECT boolean('off') AS `false` -- !query 16 schema struct -- !query 16 output @@ -139,7 +139,7 @@ NULL -- !query 17 -SELECT boolean('of') AS false +SELECT boolean('of') AS `false` -- !query 17 schema struct -- !query 17 output @@ -187,7 +187,7 @@ NULL -- !query 23 -SELECT boolean('0') AS false +SELECT boolean('0') AS `false` -- !query 23 schema struct -- !query 23 output @@ -219,7 +219,7 @@ true -- !query 27 -SELECT boolean('t') and boolean('f') AS false +SELECT boolean('t') and boolean('f') AS `false` -- !query 27 schema struct -- !query 27 output @@ -235,7 +235,7 @@ true -- !query 29 -SELECT boolean('t') = boolean('f') AS false +SELECT boolean('t') = boolean('f') AS `false` -- !query 29 schema struct -- !query 29 output @@ -283,7 +283,7 @@ true -- !query 35 -SELECT boolean(string('TrUe')) AS true, boolean(string('fAlse')) AS false +SELECT boolean(string('TrUe')) AS true, boolean(string('fAlse')) AS `false` -- !query 35 schema struct -- !query 35 output @@ -292,7 +292,7 @@ true false -- !query 36 SELECT boolean(string(' true ')) AS true, - boolean(string(' FALSE')) AS false + boolean(string(' FALSE')) AS `false` -- !query 36 schema struct -- !query 36 output @@ -300,7 +300,7 @@ NULL NULL -- !query 37 -SELECT string(boolean(true)) AS true, string(boolean(false)) AS false +SELECT string(boolean(true)) AS true, string(boolean(false)) AS `false` -- !query 37 schema struct -- !query 37 output diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/case.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/case.sql.out index dbd775e5ebba..f95adcde81b3 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/case.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/case.sql.out @@ -1,19 +1,22 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 37 +-- Number of queries: 35 -- !query 0 -set spark.sql.crossJoin.enabled=true +CREATE TABLE CASE_TBL ( + i integer, + f double +) USING parquet -- !query 0 schema -struct +struct<> -- !query 0 output -spark.sql.crossJoin.enabled true + -- !query 1 -CREATE TABLE CASE_TBL ( +CREATE TABLE CASE2_TBL ( i integer, - f double + j integer ) USING parquet -- !query 1 schema struct<> @@ -22,10 +25,7 @@ struct<> -- !query 2 -CREATE TABLE CASE2_TBL ( - i integer, - j integer -) USING parquet +INSERT INTO CASE_TBL VALUES (1, 10.1) -- !query 2 schema struct<> -- !query 2 output @@ -33,7 +33,7 @@ struct<> -- !query 3 -INSERT INTO CASE_TBL VALUES (1, 10.1) +INSERT INTO CASE_TBL VALUES (2, 20.2) -- !query 3 schema struct<> -- !query 3 output @@ -41,7 +41,7 @@ struct<> -- !query 4 -INSERT INTO CASE_TBL VALUES (2, 20.2) +INSERT INTO CASE_TBL VALUES (3, -30.3) -- !query 4 schema struct<> -- !query 4 output @@ -49,7 +49,7 @@ struct<> -- !query 5 -INSERT INTO CASE_TBL VALUES (3, -30.3) +INSERT INTO CASE_TBL VALUES (4, NULL) -- !query 5 schema struct<> -- !query 5 output @@ -57,7 +57,7 @@ struct<> -- !query 6 -INSERT INTO CASE_TBL VALUES (4, NULL) +INSERT INTO CASE2_TBL VALUES (1, -1) -- !query 6 schema struct<> -- !query 6 output @@ -65,7 +65,7 @@ struct<> -- !query 7 -INSERT INTO CASE2_TBL VALUES (1, -1) +INSERT INTO CASE2_TBL VALUES (2, -2) -- !query 7 schema struct<> -- !query 7 output @@ -73,7 +73,7 @@ struct<> -- !query 8 -INSERT INTO CASE2_TBL VALUES (2, -2) +INSERT INTO CASE2_TBL VALUES (3, -3) -- !query 8 schema struct<> -- !query 8 output @@ -81,7 +81,7 @@ struct<> -- !query 9 -INSERT INTO CASE2_TBL VALUES (3, -3) +INSERT INTO CASE2_TBL VALUES (2, -4) -- !query 9 schema struct<> -- !query 9 output @@ -89,7 +89,7 @@ struct<> -- !query 10 -INSERT INTO CASE2_TBL VALUES (2, -4) +INSERT INTO CASE2_TBL VALUES (1, NULL) -- !query 10 schema struct<> -- !query 10 output @@ -97,7 +97,7 @@ struct<> -- !query 11 -INSERT INTO CASE2_TBL VALUES (1, NULL) +INSERT INTO CASE2_TBL VALUES (NULL, -6) -- !query 11 schema struct<> -- !query 11 output @@ -105,148 +105,140 @@ struct<> -- !query 12 -INSERT INTO CASE2_TBL VALUES (NULL, -6) --- !query 12 schema -struct<> --- !query 12 output - - - --- !query 13 SELECT '3' AS `One`, CASE WHEN 1 < 2 THEN 3 END AS `Simple WHEN` --- !query 13 schema +-- !query 12 schema struct --- !query 13 output +-- !query 12 output 3 3 --- !query 14 +-- !query 13 SELECT '' AS `One`, CASE WHEN 1 > 2 THEN 3 END AS `Simple default` --- !query 14 schema +-- !query 13 schema struct --- !query 14 output +-- !query 13 output NULL --- !query 15 +-- !query 14 SELECT '3' AS `One`, CASE WHEN 1 < 2 THEN 3 ELSE 4 END AS `Simple ELSE` --- !query 15 schema +-- !query 14 schema struct --- !query 15 output +-- !query 14 output 3 3 --- !query 16 +-- !query 15 SELECT '4' AS `One`, CASE WHEN 1 > 2 THEN 3 ELSE 4 END AS `ELSE default` --- !query 16 schema +-- !query 15 schema struct --- !query 16 output +-- !query 15 output 4 4 --- !query 17 +-- !query 16 SELECT '6' AS `One`, CASE WHEN 1 > 2 THEN 3 WHEN 4 < 5 THEN 6 ELSE 7 END AS `Two WHEN with default` --- !query 17 schema +-- !query 16 schema struct --- !query 17 output +-- !query 16 output 6 6 --- !query 18 +-- !query 17 SELECT '7' AS `None`, CASE WHEN rand() < 0 THEN 1 END AS `NULL on no matches` --- !query 18 schema +-- !query 17 schema struct --- !query 18 output +-- !query 17 output 7 NULL --- !query 19 +-- !query 18 SELECT CASE WHEN 1=0 THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END +-- !query 18 schema +struct +-- !query 18 output +1 + + +-- !query 19 +SELECT CASE 1 WHEN 0 THEN 1/0 WHEN 1 THEN 1 ELSE 2/0 END -- !query 19 schema -struct +struct -- !query 19 output -1.0 +1 -- !query 20 -SELECT CASE 1 WHEN 0 THEN 1/0 WHEN 1 THEN 1 ELSE 2/0 END +SELECT CASE WHEN i > 100 THEN 1/0 ELSE 0 END FROM case_tbl -- !query 20 schema -struct +struct 100) THEN (1 div 0) ELSE 0 END:int> -- !query 20 output -1.0 +0 +0 +0 +0 -- !query 21 -SELECT CASE WHEN i > 100 THEN 1/0 ELSE 0 END FROM case_tbl --- !query 21 schema -struct 100) THEN (CAST(1 AS DOUBLE) / CAST(0 AS DOUBLE)) ELSE CAST(0 AS DOUBLE) END:double> --- !query 21 output -0.0 -0.0 -0.0 -0.0 - - --- !query 22 SELECT CASE 'a' WHEN 'a' THEN 1 ELSE 2 END --- !query 22 schema +-- !query 21 schema struct --- !query 22 output +-- !query 21 output 1 --- !query 23 +-- !query 22 SELECT '' AS `Five`, CASE WHEN i >= 3 THEN i END AS `>= 3 or Null` FROM CASE_TBL --- !query 23 schema +-- !query 22 schema struct= 3 or Null:int> --- !query 23 output +-- !query 22 output 3 4 NULL NULL --- !query 24 +-- !query 23 SELECT '' AS `Five`, CASE WHEN i >= 3 THEN (i + i) ELSE i END AS `Simplest Math` FROM CASE_TBL --- !query 24 schema +-- !query 23 schema struct --- !query 24 output +-- !query 23 output 1 2 6 8 --- !query 25 +-- !query 24 SELECT '' AS `Five`, i AS `Value`, CASE WHEN (i < 0) THEN 'small' WHEN (i = 0) THEN 'zero' @@ -255,16 +247,16 @@ SELECT '' AS `Five`, i AS `Value`, ELSE 'big' END AS `Category` FROM CASE_TBL --- !query 25 schema +-- !query 24 schema struct --- !query 25 output +-- !query 24 output 1 one 2 two 3 big 4 big --- !query 26 +-- !query 25 SELECT '' AS `Five`, CASE WHEN ((i < 0) or (i < 0)) THEN 'small' WHEN ((i = 0) or (i = 0)) THEN 'zero' @@ -273,37 +265,37 @@ SELECT '' AS `Five`, ELSE 'big' END AS `Category` FROM CASE_TBL --- !query 26 schema +-- !query 25 schema struct --- !query 26 output +-- !query 25 output big big one two --- !query 27 +-- !query 26 SELECT * FROM CASE_TBL WHERE COALESCE(f,i) = 4 --- !query 27 schema +-- !query 26 schema struct --- !query 27 output +-- !query 26 output 4 NULL --- !query 28 +-- !query 27 SELECT * FROM CASE_TBL WHERE NULLIF(f,i) = 2 --- !query 28 schema +-- !query 27 schema struct --- !query 28 output +-- !query 27 output --- !query 29 +-- !query 28 SELECT COALESCE(a.f, b.i, b.j) FROM CASE_TBL a, CASE2_TBL b --- !query 29 schema +-- !query 28 schema struct --- !query 29 output +-- !query 28 output -30.3 -30.3 -30.3 @@ -330,24 +322,24 @@ struct 3.0 --- !query 30 +-- !query 29 SELECT * FROM CASE_TBL a, CASE2_TBL b WHERE COALESCE(a.f, b.i, b.j) = 2 --- !query 30 schema +-- !query 29 schema struct --- !query 30 output +-- !query 29 output 4 NULL 2 -2 4 NULL 2 -4 --- !query 31 +-- !query 30 SELECT '' AS Five, NULLIF(a.i,b.i) AS `NULLIF(a.i,b.i)`, NULLIF(b.i, 4) AS `NULLIF(b.i,4)` FROM CASE_TBL a, CASE2_TBL b --- !query 31 schema +-- !query 30 schema struct --- !query 31 output +-- !query 30 output 1 2 1 2 1 3 @@ -374,18 +366,18 @@ struct NULL 3 --- !query 32 +-- !query 31 SELECT '' AS `Two`, * FROM CASE_TBL a, CASE2_TBL b WHERE COALESCE(f,b.i) = 2 --- !query 32 schema +-- !query 31 schema struct --- !query 32 output +-- !query 31 output 4 NULL 2 -2 4 NULL 2 -4 --- !query 33 +-- !query 32 SELECT CASE (CASE vol('bar') WHEN 'foo' THEN 'it was foo!' @@ -395,31 +387,23 @@ SELECT CASE WHEN 'it was foo!' THEN 'foo recognized' WHEN 'it was bar!' THEN 'bar recognized' ELSE 'unrecognized' END --- !query 33 schema +-- !query 32 schema struct --- !query 33 output +-- !query 32 output bar recognized --- !query 34 +-- !query 33 DROP TABLE CASE_TBL --- !query 34 schema +-- !query 33 schema struct<> --- !query 34 output +-- !query 33 output --- !query 35 +-- !query 34 DROP TABLE CASE2_TBL --- !query 35 schema +-- !query 34 schema struct<> --- !query 35 output - - +-- !query 34 output --- !query 36 -set spark.sql.crossJoin.enabled=false --- !query 36 schema -struct --- !query 36 output -spark.sql.crossJoin.enabled false diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out new file mode 100644 index 000000000000..3e3f24d603ff --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out @@ -0,0 +1,839 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 95 + + +-- !query 0 +CREATE TABLE FLOAT8_TBL(f1 double) USING parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +INSERT INTO FLOAT8_TBL VALUES (' 0.0 ') +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +INSERT INTO FLOAT8_TBL VALUES ('1004.30 ') +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +INSERT INTO FLOAT8_TBL VALUES (' -34.84') +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +INSERT INTO FLOAT8_TBL VALUES ('1.2345678901234e+200') +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +INSERT INTO FLOAT8_TBL VALUES ('1.2345678901234e-200') +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SELECT double('10e400') +-- !query 6 schema +struct +-- !query 6 output +Infinity + + +-- !query 7 +SELECT double('-10e400') +-- !query 7 schema +struct +-- !query 7 output +-Infinity + + +-- !query 8 +SELECT double('10e-400') +-- !query 8 schema +struct +-- !query 8 output +0.0 + + +-- !query 9 +SELECT double('-10e-400') +-- !query 9 schema +struct +-- !query 9 output +-0.0 + + +-- !query 10 +SELECT double('NaN') +-- !query 10 schema +struct +-- !query 10 output +NaN + + +-- !query 11 +SELECT double('nan') +-- !query 11 schema +struct +-- !query 11 output +NULL + + +-- !query 12 +SELECT double(' NAN ') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +SELECT double('infinity') +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +SELECT double(' -INFINiTY ') +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +SELECT double('N A N') +-- !query 15 schema +struct +-- !query 15 output +NULL + + +-- !query 16 +SELECT double('NaN x') +-- !query 16 schema +struct +-- !query 16 output +NULL + + +-- !query 17 +SELECT double(' INFINITY x') +-- !query 17 schema +struct +-- !query 17 output +NULL + + +-- !query 18 +SELECT double('Infinity') + 100.0 +-- !query 18 schema +struct<(CAST(Infinity AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> +-- !query 18 output +Infinity + + +-- !query 19 +SELECT double('Infinity') / double('Infinity') +-- !query 19 schema +struct<(CAST(Infinity AS DOUBLE) / CAST(Infinity AS DOUBLE)):double> +-- !query 19 output +NaN + + +-- !query 20 +SELECT double('NaN') / double('NaN') +-- !query 20 schema +struct<(CAST(NaN AS DOUBLE) / CAST(NaN AS DOUBLE)):double> +-- !query 20 output +NaN + + +-- !query 21 +SELECT double(decimal('nan')) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +SELECT '' AS five, * FROM FLOAT8_TBL +-- !query 22 schema +struct +-- !query 22 output +-34.84 + 0.0 + 1.2345678901234E-200 + 1.2345678901234E200 + 1004.3 + + +-- !query 23 +SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <> '1004.3' +-- !query 23 schema +struct +-- !query 23 output +-34.84 + 0.0 + 1.2345678901234E-200 + 1.2345678901234E200 + + +-- !query 24 +SELECT '' AS one, f.* FROM FLOAT8_TBL f WHERE f.f1 = '1004.3' +-- !query 24 schema +struct +-- !query 24 output +1004.3 + + +-- !query 25 +SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE '1004.3' > f.f1 +-- !query 25 schema +struct +-- !query 25 output +-34.84 + 0.0 + 1.2345678901234E-200 + + +-- !query 26 +SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE f.f1 < '1004.3' +-- !query 26 schema +struct +-- !query 26 output +-34.84 + 0.0 + 1.2345678901234E-200 + + +-- !query 27 +SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE '1004.3' >= f.f1 +-- !query 27 schema +struct +-- !query 27 output +-34.84 + 0.0 + 1.2345678901234E-200 + 1004.3 + + +-- !query 28 +SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <= '1004.3' +-- !query 28 schema +struct +-- !query 28 output +-34.84 + 0.0 + 1.2345678901234E-200 + 1004.3 + + +-- !query 29 +SELECT '' AS three, f.f1, f.f1 * '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0' +-- !query 29 schema +struct +-- !query 29 output +1.2345678901234E-200 -1.2345678901234E-199 + 1.2345678901234E200 -1.2345678901234E201 + 1004.3 -10043.0 + + +-- !query 30 +SELECT '' AS three, f.f1, f.f1 + '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0' +-- !query 30 schema +struct +-- !query 30 output +1.2345678901234E-200 -10.0 + 1.2345678901234E200 1.2345678901234E200 + 1004.3 994.3 + + +-- !query 31 +SELECT '' AS three, f.f1, f.f1 / '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0' +-- !query 31 schema +struct +-- !query 31 output +1.2345678901234E-200 -1.2345678901234E-201 + 1.2345678901234E200 -1.2345678901234E199 + 1004.3 -100.42999999999999 + + +-- !query 32 +SELECT '' AS three, f.f1, f.f1 - '-10' AS x + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0' +-- !query 32 schema +struct +-- !query 32 output +1.2345678901234E-200 10.0 + 1.2345678901234E200 1.2345678901234E200 + 1004.3 1014.3 + + +-- !query 33 +SELECT '' AS five, f.f1, round(f.f1) AS round_f1 + FROM FLOAT8_TBL f +-- !query 33 schema +struct +-- !query 33 output +-34.84 -35.0 + 0.0 0.0 + 1.2345678901234E-200 0.0 + 1.2345678901234E200 1.2345678901234E200 + 1004.3 1004.0 + + +-- !query 34 +select ceil(f1) as ceil_f1 from float8_tbl f +-- !query 34 schema +struct +-- !query 34 output +-34 +0 +1 +1005 +9223372036854775807 + + +-- !query 35 +select ceiling(f1) as ceiling_f1 from float8_tbl f +-- !query 35 schema +struct +-- !query 35 output +-34 +0 +1 +1005 +9223372036854775807 + + +-- !query 36 +select floor(f1) as floor_f1 from float8_tbl f +-- !query 36 schema +struct +-- !query 36 output +-35 +0 +0 +1004 +9223372036854775807 + + +-- !query 37 +select sign(f1) as sign_f1 from float8_tbl f +-- !query 37 schema +struct +-- !query 37 output +-1.0 +0.0 +1.0 +1.0 +1.0 + + +-- !query 38 +SELECT sqrt(double('64')) AS eight +-- !query 38 schema +struct +-- !query 38 output +8.0 + + +-- !query 39 +SELECT power(double('144'), double('0.5')) +-- !query 39 schema +struct +-- !query 39 output +12.0 + + +-- !query 40 +SELECT power(double('NaN'), double('0.5')) +-- !query 40 schema +struct +-- !query 40 output +NaN + + +-- !query 41 +SELECT power(double('144'), double('NaN')) +-- !query 41 schema +struct +-- !query 41 output +NaN + + +-- !query 42 +SELECT power(double('NaN'), double('NaN')) +-- !query 42 schema +struct +-- !query 42 output +NaN + + +-- !query 43 +SELECT power(double('-1'), double('NaN')) +-- !query 43 schema +struct +-- !query 43 output +NaN + + +-- !query 44 +SELECT power(double('1'), double('NaN')) +-- !query 44 schema +struct +-- !query 44 output +NaN + + +-- !query 45 +SELECT power(double('NaN'), double('0')) +-- !query 45 schema +struct +-- !query 45 output +1.0 + + +-- !query 46 +SELECT '' AS three, f.f1, exp(ln(f.f1)) AS exp_ln_f1 + FROM FLOAT8_TBL f + WHERE f.f1 > '0.0' +-- !query 46 schema +struct +-- !query 46 output +1.2345678901234E-200 1.2345678901233948E-200 + 1.2345678901234E200 1.234567890123379E200 + 1004.3 1004.3000000000004 + + +-- !query 47 +SELECT '' AS five, * FROM FLOAT8_TBL +-- !query 47 schema +struct +-- !query 47 output +-34.84 + 0.0 + 1.2345678901234E-200 + 1.2345678901234E200 + 1004.3 + + +-- !query 48 +CREATE TEMPORARY VIEW UPDATED_FLOAT8_TBL as +SELECT + CASE WHEN FLOAT8_TBL.f1 > '0.0' THEN FLOAT8_TBL.f1 * '-1' ELSE FLOAT8_TBL.f1 END AS f1 +FROM FLOAT8_TBL +-- !query 48 schema +struct<> +-- !query 48 output + + + +-- !query 49 +SELECT '' AS bad, f.f1 * '1e200' from UPDATED_FLOAT8_TBL f +-- !query 49 schema +struct +-- !query 49 output +-1.0042999999999999E203 + -1.2345678901234 + -3.484E201 + -Infinity + 0.0 + + +-- !query 50 +SELECT '' AS five, * FROM UPDATED_FLOAT8_TBL +-- !query 50 schema +struct +-- !query 50 output +-1.2345678901234E-200 + -1.2345678901234E200 + -1004.3 + -34.84 + 0.0 + + +-- !query 51 +SELECT sinh(double('1')) +-- !query 51 schema +struct +-- !query 51 output +1.1752011936438014 + + +-- !query 52 +SELECT cosh(double('1')) +-- !query 52 schema +struct +-- !query 52 output +1.543080634815244 + + +-- !query 53 +SELECT tanh(double('1')) +-- !query 53 schema +struct +-- !query 53 output +0.7615941559557649 + + +-- !query 54 +SELECT asinh(double('1')) +-- !query 54 schema +struct +-- !query 54 output +0.8813735870195429 + + +-- !query 55 +SELECT acosh(double('2')) +-- !query 55 schema +struct +-- !query 55 output +1.3169578969248166 + + +-- !query 56 +SELECT atanh(double('0.5')) +-- !query 56 schema +struct +-- !query 56 output +0.5493061443340549 + + +-- !query 57 +SELECT sinh(double('Infinity')) +-- !query 57 schema +struct +-- !query 57 output +Infinity + + +-- !query 58 +SELECT sinh(double('-Infinity')) +-- !query 58 schema +struct +-- !query 58 output +-Infinity + + +-- !query 59 +SELECT sinh(double('NaN')) +-- !query 59 schema +struct +-- !query 59 output +NaN + + +-- !query 60 +SELECT cosh(double('Infinity')) +-- !query 60 schema +struct +-- !query 60 output +Infinity + + +-- !query 61 +SELECT cosh(double('-Infinity')) +-- !query 61 schema +struct +-- !query 61 output +Infinity + + +-- !query 62 +SELECT cosh(double('NaN')) +-- !query 62 schema +struct +-- !query 62 output +NaN + + +-- !query 63 +SELECT tanh(double('Infinity')) +-- !query 63 schema +struct +-- !query 63 output +1.0 + + +-- !query 64 +SELECT tanh(double('-Infinity')) +-- !query 64 schema +struct +-- !query 64 output +-1.0 + + +-- !query 65 +SELECT tanh(double('NaN')) +-- !query 65 schema +struct +-- !query 65 output +NaN + + +-- !query 66 +SELECT asinh(double('Infinity')) +-- !query 66 schema +struct +-- !query 66 output +Infinity + + +-- !query 67 +SELECT asinh(double('-Infinity')) +-- !query 67 schema +struct +-- !query 67 output +-Infinity + + +-- !query 68 +SELECT asinh(double('NaN')) +-- !query 68 schema +struct +-- !query 68 output +NaN + + +-- !query 69 +SELECT acosh(double('Infinity')) +-- !query 69 schema +struct +-- !query 69 output +Infinity + + +-- !query 70 +SELECT acosh(double('-Infinity')) +-- !query 70 schema +struct +-- !query 70 output +NaN + + +-- !query 71 +SELECT acosh(double('NaN')) +-- !query 71 schema +struct +-- !query 71 output +NaN + + +-- !query 72 +SELECT atanh(double('Infinity')) +-- !query 72 schema +struct +-- !query 72 output +NaN + + +-- !query 73 +SELECT atanh(double('-Infinity')) +-- !query 73 schema +struct +-- !query 73 output +NaN + + +-- !query 74 +SELECT atanh(double('NaN')) +-- !query 74 schema +struct +-- !query 74 output +NaN + + +-- !query 75 +TRUNCATE TABLE FLOAT8_TBL +-- !query 75 schema +struct<> +-- !query 75 output + + + +-- !query 76 +INSERT INTO FLOAT8_TBL VALUES ('0.0') +-- !query 76 schema +struct<> +-- !query 76 output + + + +-- !query 77 +INSERT INTO FLOAT8_TBL VALUES ('-34.84') +-- !query 77 schema +struct<> +-- !query 77 output + + + +-- !query 78 +INSERT INTO FLOAT8_TBL VALUES ('-1004.30') +-- !query 78 schema +struct<> +-- !query 78 output + + + +-- !query 79 +INSERT INTO FLOAT8_TBL VALUES ('-1.2345678901234e+200') +-- !query 79 schema +struct<> +-- !query 79 output + + + +-- !query 80 +INSERT INTO FLOAT8_TBL VALUES ('-1.2345678901234e-200') +-- !query 80 schema +struct<> +-- !query 80 output + + + +-- !query 81 +SELECT '' AS five, * FROM FLOAT8_TBL +-- !query 81 schema +struct +-- !query 81 output +-1.2345678901234E-200 + -1.2345678901234E200 + -1004.3 + -34.84 + 0.0 + + +-- !query 82 +SELECT smallint(double('32767.4')) +-- !query 82 schema +struct +-- !query 82 output +32767 + + +-- !query 83 +SELECT smallint(double('32767.6')) +-- !query 83 schema +struct +-- !query 83 output +32767 + + +-- !query 84 +SELECT smallint(double('-32768.4')) +-- !query 84 schema +struct +-- !query 84 output +-32768 + + +-- !query 85 +SELECT smallint(double('-32768.6')) +-- !query 85 schema +struct +-- !query 85 output +-32768 + + +-- !query 86 +SELECT int(double('2147483647.4')) +-- !query 86 schema +struct +-- !query 86 output +2147483647 + + +-- !query 87 +SELECT int(double('2147483647.6')) +-- !query 87 schema +struct +-- !query 87 output +2147483647 + + +-- !query 88 +SELECT int(double('-2147483648.4')) +-- !query 88 schema +struct +-- !query 88 output +-2147483648 + + +-- !query 89 +SELECT int(double('-2147483648.6')) +-- !query 89 schema +struct +-- !query 89 output +-2147483648 + + +-- !query 90 +SELECT bigint(double('9223372036854773760')) +-- !query 90 schema +struct +-- !query 90 output +9223372036854773760 + + +-- !query 91 +SELECT bigint(double('9223372036854775807')) +-- !query 91 schema +struct +-- !query 91 output +9223372036854775807 + + +-- !query 92 +SELECT bigint(double('-9223372036854775808.5')) +-- !query 92 schema +struct +-- !query 92 output +-9223372036854775808 + + +-- !query 93 +SELECT bigint(double('-9223372036854780000')) +-- !query 93 schema +struct +-- !query 93 output +-9223372036854775808 + + +-- !query 94 +DROP TABLE FLOAT8_TBL +-- !query 94 schema +struct<> +-- !query 94 output + diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int2.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int2.sql.out index 6b9246fed0c6..7a7ce5f37dea 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int2.sql.out @@ -266,7 +266,7 @@ struct -- !query 27 -SELECT '' AS five, i.f1, i.f1 div smallint('2') AS x FROM INT2_TBL i +SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT2_TBL i -- !query 27 schema struct -- !query 27 output @@ -278,7 +278,7 @@ struct -- !query 28 -SELECT '' AS five, i.f1, i.f1 div int('2') AS x FROM INT2_TBL i +SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT2_TBL i -- !query 28 schema struct -- !query 28 output diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index 9c17e9a1a197..456b1ef962d4 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -339,25 +339,25 @@ struct -- !query 33 SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i -- !query 33 schema -struct +struct -- !query 33 output --123456 -61728.0 - -2147483647 -1.0737418235E9 - 0 0.0 - 123456 61728.0 - 2147483647 1.0737418235E9 +-123456 -61728 + -2147483647 -1073741823 + 0 0 + 123456 61728 + 2147483647 1073741823 -- !query 34 SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i -- !query 34 schema -struct +struct -- !query 34 output --123456 -61728.0 - -2147483647 -1.0737418235E9 - 0 0.0 - 123456 61728.0 - 2147483647 1.0737418235E9 +-123456 -61728 + -2147483647 -1073741823 + 0 0 + 123456 61728 + 2147483647 1073741823 -- !query 35 @@ -417,7 +417,7 @@ true -- !query 42 -SELECT int('1000') < int('999') AS false +SELECT int('1000') < int('999') AS `false` -- !query 42 schema struct -- !query 42 output @@ -435,17 +435,17 @@ struct -- !query 44 SELECT 2 + 2 / 2 AS three -- !query 44 schema -struct +struct -- !query 44 output -3.0 +3 -- !query 45 SELECT (2 + 2) / 2 AS two -- !query 45 schema -struct +struct -- !query 45 output -2.0 +2 -- !query 46 diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int8.sql.out index 13bc748dd2b4..6d7fae19aa7e 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int8.sql.out @@ -351,7 +351,7 @@ struct -- !query 37 -SELECT '' AS five, q1 AS plus, -q1 AS minus FROM INT8_TBL +SELECT '' AS five, q1 AS plus, -q1 AS `minus` FROM INT8_TBL -- !query 37 schema struct -- !query 37 output @@ -375,7 +375,7 @@ struct -- !query 39 -SELECT '' AS five, q1, q2, q1 - q2 AS minus FROM INT8_TBL +SELECT '' AS five, q1, q2, q1 - q2 AS `minus` FROM INT8_TBL -- !query 39 schema struct -- !query 39 output @@ -412,13 +412,13 @@ struct -- !query 42 SELECT '' AS five, q1, q2, q1 / q2 AS divide, q1 % q2 AS mod FROM INT8_TBL -- !query 42 schema -struct +struct -- !query 42 output -123 456 0.26973684210526316 123 - 123 4567890123456789 2.6927092525360204E-14 123 - 4567890123456789 -4567890123456789 -1.0 0 - 4567890123456789 123 3.713731807688446E13 57 - 4567890123456789 4567890123456789 1.0 0 +123 456 0 123 + 123 4567890123456789 0 123 + 4567890123456789 -4567890123456789 -1 0 + 4567890123456789 123 37137318076884 57 + 4567890123456789 4567890123456789 1 0 -- !query 43 @@ -496,49 +496,49 @@ struct -- !query 49 SELECT q1 + int(42) AS `8plus4`, q1 - int(42) AS `8minus4`, q1 * int(42) AS `8mul4`, q1 / int(42) AS `8div4` FROM INT8_TBL -- !query 49 schema -struct<8plus4:bigint,8minus4:bigint,8mul4:bigint,8div4:double> +struct<8plus4:bigint,8minus4:bigint,8mul4:bigint,8div4:bigint> -- !query 49 output -165 81 5166 2.9285714285714284 -165 81 5166 2.9285714285714284 -4567890123456831 4567890123456747 191851385185185138 1.0875928865373308E14 -4567890123456831 4567890123456747 191851385185185138 1.0875928865373308E14 -4567890123456831 4567890123456747 191851385185185138 1.0875928865373308E14 +165 81 5166 2 +165 81 5166 2 +4567890123456831 4567890123456747 191851385185185138 108759288653733 +4567890123456831 4567890123456747 191851385185185138 108759288653733 +4567890123456831 4567890123456747 191851385185185138 108759288653733 -- !query 50 SELECT int(246) + q1 AS `4plus8`, int(246) - q1 AS `4minus8`, int(246) * q1 AS `4mul8`, int(246) / q1 AS `4div8` FROM INT8_TBL -- !query 50 schema -struct<4plus8:bigint,4minus8:bigint,4mul8:bigint,4div8:double> +struct<4plus8:bigint,4minus8:bigint,4mul8:bigint,4div8:bigint> -- !query 50 output -369 123 30258 2.0 -369 123 30258 2.0 -4567890123457035 -4567890123456543 1123700970370370094 5.385418505072041E-14 -4567890123457035 -4567890123456543 1123700970370370094 5.385418505072041E-14 -4567890123457035 -4567890123456543 1123700970370370094 5.385418505072041E-14 +369 123 30258 2 +369 123 30258 2 +4567890123457035 -4567890123456543 1123700970370370094 0 +4567890123457035 -4567890123456543 1123700970370370094 0 +4567890123457035 -4567890123456543 1123700970370370094 0 -- !query 51 SELECT q1 + smallint(42) AS `8plus2`, q1 - smallint(42) AS `8minus2`, q1 * smallint(42) AS `8mul2`, q1 / smallint(42) AS `8div2` FROM INT8_TBL -- !query 51 schema -struct<8plus2:bigint,8minus2:bigint,8mul2:bigint,8div2:double> +struct<8plus2:bigint,8minus2:bigint,8mul2:bigint,8div2:bigint> -- !query 51 output -165 81 5166 2.9285714285714284 -165 81 5166 2.9285714285714284 -4567890123456831 4567890123456747 191851385185185138 1.0875928865373308E14 -4567890123456831 4567890123456747 191851385185185138 1.0875928865373308E14 -4567890123456831 4567890123456747 191851385185185138 1.0875928865373308E14 +165 81 5166 2 +165 81 5166 2 +4567890123456831 4567890123456747 191851385185185138 108759288653733 +4567890123456831 4567890123456747 191851385185185138 108759288653733 +4567890123456831 4567890123456747 191851385185185138 108759288653733 -- !query 52 SELECT smallint(246) + q1 AS `2plus8`, smallint(246) - q1 AS `2minus8`, smallint(246) * q1 AS `2mul8`, smallint(246) / q1 AS `2div8` FROM INT8_TBL -- !query 52 schema -struct<2plus8:bigint,2minus8:bigint,2mul8:bigint,2div8:double> +struct<2plus8:bigint,2minus8:bigint,2mul8:bigint,2div8:bigint> -- !query 52 output -369 123 30258 2.0 -369 123 30258 2.0 -4567890123457035 -4567890123456543 1123700970370370094 5.385418505072041E-14 -4567890123457035 -4567890123456543 1123700970370370094 5.385418505072041E-14 -4567890123457035 -4567890123456543 1123700970370370094 5.385418505072041E-14 +369 123 30258 2 +369 123 30258 2 +4567890123457035 -4567890123456543 1123700970370370094 0 +4567890123457035 -4567890123456543 1123700970370370094 0 +4567890123457035 -4567890123456543 1123700970370370094 0 -- !query 53 @@ -572,7 +572,7 @@ struct -- !query 56 select bigint('9223372036854775800') / bigint('0') -- !query 56 schema -struct<(CAST(CAST(9223372036854775800 AS BIGINT) AS DOUBLE) / CAST(CAST(0 AS BIGINT) AS DOUBLE)):double> +struct<(CAST(9223372036854775800 AS BIGINT) div CAST(0 AS BIGINT)):bigint> -- !query 56 output NULL @@ -580,7 +580,7 @@ NULL -- !query 57 select bigint('-9223372036854775808') / smallint('0') -- !query 57 schema -struct<(CAST(CAST(-9223372036854775808 AS BIGINT) AS DOUBLE) / CAST(CAST(0 AS SMALLINT) AS DOUBLE)):double> +struct<(CAST(-9223372036854775808 AS BIGINT) div CAST(CAST(0 AS SMALLINT) AS BIGINT)):bigint> -- !query 57 output NULL @@ -588,7 +588,7 @@ NULL -- !query 58 select smallint('100') / bigint('0') -- !query 58 schema -struct<(CAST(CAST(100 AS SMALLINT) AS DOUBLE) / CAST(CAST(0 AS BIGINT) AS DOUBLE)):double> +struct<(CAST(CAST(100 AS SMALLINT) AS BIGINT) div CAST(0 AS BIGINT)):bigint> -- !query 58 output NULL @@ -744,9 +744,9 @@ struct<(CAST(-9223372036854775808 AS BIGINT) * CAST(-1 AS BIGINT)):bigint> -- !query 74 SELECT bigint((-9223372036854775808)) / bigint((-1)) -- !query 74 schema -struct<(CAST(CAST(-9223372036854775808 AS BIGINT) AS DOUBLE) / CAST(CAST(-1 AS BIGINT) AS DOUBLE)):double> +struct<(CAST(-9223372036854775808 AS BIGINT) div CAST(-1 AS BIGINT)):bigint> -- !query 74 output -9.223372036854776E18 +-9223372036854775808 -- !query 75 @@ -768,9 +768,9 @@ struct<(CAST(-9223372036854775808 AS BIGINT) * CAST(CAST(-1 AS INT) AS BIGINT)): -- !query 77 SELECT bigint((-9223372036854775808)) / int((-1)) -- !query 77 schema -struct<(CAST(CAST(-9223372036854775808 AS BIGINT) AS DOUBLE) / CAST(CAST(-1 AS INT) AS DOUBLE)):double> +struct<(CAST(-9223372036854775808 AS BIGINT) div CAST(CAST(-1 AS INT) AS BIGINT)):bigint> -- !query 77 output -9.223372036854776E18 +-9223372036854775808 -- !query 78 @@ -792,9 +792,9 @@ struct<(CAST(-9223372036854775808 AS BIGINT) * CAST(CAST(-1 AS SMALLINT) AS BIGI -- !query 80 SELECT bigint((-9223372036854775808)) / smallint((-1)) -- !query 80 schema -struct<(CAST(CAST(-9223372036854775808 AS BIGINT) AS DOUBLE) / CAST(CAST(-1 AS SMALLINT) AS DOUBLE)):double> +struct<(CAST(-9223372036854775808 AS BIGINT) div CAST(CAST(-1 AS SMALLINT) AS BIGINT)):bigint> -- !query 80 output -9.223372036854776E18 +-9223372036854775808 -- !query 81 diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/select.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/select.sql.out new file mode 100644 index 000000000000..797f808dad11 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/select.sql.out @@ -0,0 +1,543 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 37 + + +-- !query 0 +create or replace temporary view onek2 as select * from onek +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create or replace temporary view INT8_TBL as select * from values + (cast(trim(' 123 ') as bigint), cast(trim(' 456') as bigint)), + (cast(trim('123 ') as bigint),cast('4567890123456789' as bigint)), + (cast('4567890123456789' as bigint),cast('123' as bigint)), + (cast(+4567890123456789 as bigint),cast('4567890123456789' as bigint)), + (cast('+4567890123456789' as bigint),cast('-4567890123456789' as bigint)) + as INT8_TBL(q1, q2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM onek + WHERE onek.unique1 < 10 + ORDER BY onek.unique1 +-- !query 2 schema +struct +-- !query 2 output +0 998 0 0 0 0 0 0 0 0 0 0 1 AAAAAA KMBAAA OOOOxx +1 214 1 1 1 1 1 1 1 1 1 2 3 BAAAAA GIAAAA OOOOxx +2 326 0 2 2 2 2 2 2 2 2 4 5 CAAAAA OMAAAA OOOOxx +3 431 1 3 3 3 3 3 3 3 3 6 7 DAAAAA PQAAAA VVVVxx +4 833 0 0 4 4 4 4 4 4 4 8 9 EAAAAA BGBAAA HHHHxx +5 541 1 1 5 5 5 5 5 5 5 10 11 FAAAAA VUAAAA HHHHxx +6 978 0 2 6 6 6 6 6 6 6 12 13 GAAAAA QLBAAA OOOOxx +7 647 1 3 7 7 7 7 7 7 7 14 15 HAAAAA XYAAAA VVVVxx +8 653 0 0 8 8 8 8 8 8 8 16 17 IAAAAA DZAAAA HHHHxx +9 49 1 1 9 9 9 9 9 9 9 18 19 JAAAAA XBAAAA HHHHxx + + +-- !query 3 +SELECT onek.unique1, onek.stringu1 FROM onek + WHERE onek.unique1 < 20 + ORDER BY unique1 DESC +-- !query 3 schema +struct +-- !query 3 output +19 TAAAAA +18 SAAAAA +17 RAAAAA +16 QAAAAA +15 PAAAAA +14 OAAAAA +13 NAAAAA +12 MAAAAA +11 LAAAAA +10 KAAAAA +9 JAAAAA +8 IAAAAA +7 HAAAAA +6 GAAAAA +5 FAAAAA +4 EAAAAA +3 DAAAAA +2 CAAAAA +1 BAAAAA +0 AAAAAA + + +-- !query 4 +SELECT onek.unique1, onek.stringu1 FROM onek + WHERE onek.unique1 > 980 + ORDER BY stringu1 ASC +-- !query 4 schema +struct +-- !query 4 output +988 AMAAAA +989 BMAAAA +990 CMAAAA +991 DMAAAA +992 EMAAAA +993 FMAAAA +994 GMAAAA +995 HMAAAA +996 IMAAAA +997 JMAAAA +998 KMAAAA +999 LMAAAA +981 TLAAAA +982 ULAAAA +983 VLAAAA +984 WLAAAA +985 XLAAAA +986 YLAAAA +987 ZLAAAA + + +-- !query 5 +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 > 980 + ORDER BY string4 ASC, unique1 DESC +-- !query 5 schema +struct +-- !query 5 output +999 AAAAxx +995 AAAAxx +983 AAAAxx +982 AAAAxx +981 AAAAxx +998 HHHHxx +997 HHHHxx +993 HHHHxx +990 HHHHxx +986 HHHHxx +996 OOOOxx +991 OOOOxx +988 OOOOxx +987 OOOOxx +985 OOOOxx +994 VVVVxx +992 VVVVxx +989 VVVVxx +984 VVVVxx + + +-- !query 6 +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 > 980 + ORDER BY string4 DESC, unique1 ASC +-- !query 6 schema +struct +-- !query 6 output +984 VVVVxx +989 VVVVxx +992 VVVVxx +994 VVVVxx +985 OOOOxx +987 OOOOxx +988 OOOOxx +991 OOOOxx +996 OOOOxx +986 HHHHxx +990 HHHHxx +993 HHHHxx +997 HHHHxx +998 HHHHxx +981 AAAAxx +982 AAAAxx +983 AAAAxx +995 AAAAxx +999 AAAAxx + + +-- !query 7 +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 < 20 + ORDER BY unique1 DESC, string4 ASC +-- !query 7 schema +struct +-- !query 7 output +19 OOOOxx +18 VVVVxx +17 HHHHxx +16 OOOOxx +15 VVVVxx +14 AAAAxx +13 OOOOxx +12 AAAAxx +11 OOOOxx +10 AAAAxx +9 HHHHxx +8 HHHHxx +7 VVVVxx +6 OOOOxx +5 HHHHxx +4 HHHHxx +3 VVVVxx +2 OOOOxx +1 OOOOxx +0 OOOOxx + + +-- !query 8 +SELECT onek.unique1, onek.string4 FROM onek + WHERE onek.unique1 < 20 + ORDER BY unique1 ASC, string4 DESC +-- !query 8 schema +struct +-- !query 8 output +0 OOOOxx +1 OOOOxx +2 OOOOxx +3 VVVVxx +4 HHHHxx +5 HHHHxx +6 OOOOxx +7 VVVVxx +8 HHHHxx +9 HHHHxx +10 AAAAxx +11 OOOOxx +12 AAAAxx +13 OOOOxx +14 AAAAxx +15 VVVVxx +16 OOOOxx +17 HHHHxx +18 VVVVxx +19 OOOOxx + + +-- !query 9 +SELECT onek2.* FROM onek2 WHERE onek2.unique1 < 10 +-- !query 9 schema +struct +-- !query 9 output +0 998 0 0 0 0 0 0 0 0 0 0 1 AAAAAA KMBAAA OOOOxx +1 214 1 1 1 1 1 1 1 1 1 2 3 BAAAAA GIAAAA OOOOxx +2 326 0 2 2 2 2 2 2 2 2 4 5 CAAAAA OMAAAA OOOOxx +3 431 1 3 3 3 3 3 3 3 3 6 7 DAAAAA PQAAAA VVVVxx +4 833 0 0 4 4 4 4 4 4 4 8 9 EAAAAA BGBAAA HHHHxx +5 541 1 1 5 5 5 5 5 5 5 10 11 FAAAAA VUAAAA HHHHxx +6 978 0 2 6 6 6 6 6 6 6 12 13 GAAAAA QLBAAA OOOOxx +7 647 1 3 7 7 7 7 7 7 7 14 15 HAAAAA XYAAAA VVVVxx +8 653 0 0 8 8 8 8 8 8 8 16 17 IAAAAA DZAAAA HHHHxx +9 49 1 1 9 9 9 9 9 9 9 18 19 JAAAAA XBAAAA HHHHxx + + +-- !query 10 +SELECT onek2.unique1, onek2.stringu1 FROM onek2 + WHERE onek2.unique1 < 20 + ORDER BY unique1 DESC +-- !query 10 schema +struct +-- !query 10 output +19 TAAAAA +18 SAAAAA +17 RAAAAA +16 QAAAAA +15 PAAAAA +14 OAAAAA +13 NAAAAA +12 MAAAAA +11 LAAAAA +10 KAAAAA +9 JAAAAA +8 IAAAAA +7 HAAAAA +6 GAAAAA +5 FAAAAA +4 EAAAAA +3 DAAAAA +2 CAAAAA +1 BAAAAA +0 AAAAAA + + +-- !query 11 +SELECT onek2.unique1, onek2.stringu1 FROM onek2 + WHERE onek2.unique1 > 980 +-- !query 11 schema +struct +-- !query 11 output +981 TLAAAA +982 ULAAAA +983 VLAAAA +984 WLAAAA +985 XLAAAA +986 YLAAAA +987 ZLAAAA +988 AMAAAA +989 BMAAAA +990 CMAAAA +991 DMAAAA +992 EMAAAA +993 FMAAAA +994 GMAAAA +995 HMAAAA +996 IMAAAA +997 JMAAAA +998 KMAAAA +999 LMAAAA + + +-- !query 12 +CREATE TABLE tmp USING parquet AS +SELECT two, stringu1, ten, string4 +FROM onek +-- !query 12 schema +struct<> +-- !query 12 output + + + +-- !query 13 +select foo.* from (select 1) as foo +-- !query 13 schema +struct<1:int> +-- !query 13 output +1 + + +-- !query 14 +select foo.* from (select null) as foo +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select foo.* from (select 'xyzzy',1,null) as foo +-- !query 15 schema +struct +-- !query 15 output +xyzzy 1 NULL + + +-- !query 16 +select * from onek, values(147, 'RFAAAA'), (931, 'VJAAAA') as v (i, j) + WHERE onek.unique1 = v.i and onek.stringu1 = v.j +-- !query 16 schema +struct +-- !query 16 output +147 0 1 3 7 7 7 47 147 147 147 14 15 RFAAAA AAAAAA AAAAxx 147 RFAAAA +931 1 1 3 1 11 1 31 131 431 931 2 3 VJAAAA BAAAAA HHHHxx 931 VJAAAA + + +-- !query 17 +VALUES (1,2), (3,4+4), (7,77.7) +-- !query 17 schema +struct +-- !query 17 output +1 2 +3 8 +7 77.7 + + +-- !query 18 +VALUES (1,2), (3,4+4), (7,77.7) +UNION ALL +SELECT 2+2, 57 +UNION ALL +TABLE int8_tbl +-- !query 18 schema +struct +-- !query 18 output +1 2 +123 456 +123 4567890123456789 +3 8 +4 57 +4567890123456789 -4567890123456789 +4567890123456789 123 +4567890123456789 4567890123456789 +7 77.7 + + +-- !query 19 +CREATE OR REPLACE TEMPORARY VIEW foo AS +SELECT * FROM (values(42),(3),(10),(7),(null),(null),(1)) as foo (f1) +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +SELECT * FROM foo ORDER BY f1 +-- !query 20 schema +struct +-- !query 20 output +NULL +NULL +1 +3 +7 +10 +42 + + +-- !query 21 +SELECT * FROM foo ORDER BY f1 ASC +-- !query 21 schema +struct +-- !query 21 output +NULL +NULL +1 +3 +7 +10 +42 + + +-- !query 22 +-- same thing +SELECT * FROM foo ORDER BY f1 NULLS FIRST +-- !query 22 schema +struct +-- !query 22 output +NULL +NULL +1 +3 +7 +10 +42 + + +-- !query 23 +SELECT * FROM foo ORDER BY f1 DESC +-- !query 23 schema +struct +-- !query 23 output +42 +10 +7 +3 +1 +NULL +NULL + + +-- !query 24 +SELECT * FROM foo ORDER BY f1 DESC NULLS LAST +-- !query 24 schema +struct +-- !query 24 output +42 +10 +7 +3 +1 +NULL +NULL + + +-- !query 25 +select * from onek2 where unique2 = 11 and stringu1 = 'ATAAAA' +-- !query 25 schema +struct +-- !query 25 output +494 11 0 2 4 14 4 94 94 494 494 8 9 ATAAAA LAAAAA VVVVxx + + +-- !query 26 +select unique2 from onek2 where unique2 = 11 and stringu1 = 'ATAAAA' +-- !query 26 schema +struct +-- !query 26 output +11 + + +-- !query 27 +select * from onek2 where unique2 = 11 and stringu1 < 'B' +-- !query 27 schema +struct +-- !query 27 output +494 11 0 2 4 14 4 94 94 494 494 8 9 ATAAAA LAAAAA VVVVxx + + +-- !query 28 +select unique2 from onek2 where unique2 = 11 and stringu1 < 'B' +-- !query 28 schema +struct +-- !query 28 output +11 + + +-- !query 29 +select unique2 from onek2 where unique2 = 11 and stringu1 < 'C' +-- !query 29 schema +struct +-- !query 29 output +11 + + +-- !query 30 +select unique2 from onek2 where unique2 = 11 and stringu1 < 'B' +-- !query 30 schema +struct +-- !query 30 output +11 + + +-- !query 31 +select unique1, unique2 from onek2 + where (unique2 = 11 or unique1 = 0) and stringu1 < 'B' +-- !query 31 schema +struct +-- !query 31 output +0 998 +494 11 + + +-- !query 32 +select unique1, unique2 from onek2 + where (unique2 = 11 and stringu1 < 'B') or unique1 = 0 +-- !query 32 schema +struct +-- !query 32 output +0 998 +494 11 + + +-- !query 33 +SELECT 1 AS x ORDER BY x +-- !query 33 schema +struct +-- !query 33 output +1 + + +-- !query 34 +select * from (values (2),(null),(1)) v(k) where k = k order by k +-- !query 34 schema +struct +-- !query 34 output +1 +2 + + +-- !query 35 +select * from (values (2),(null),(1)) v(k) where k = k +-- !query 35 schema +struct +-- !query 35 output +1 +2 + + +-- !query 36 +drop table tmp +-- !query 36 schema +struct<> +-- !query 36 output + diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/select_distinct.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/select_distinct.sql.out new file mode 100644 index 000000000000..38eae1739f55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/select_distinct.sql.out @@ -0,0 +1,225 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 19 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW tmp AS +SELECT two, stringu1, ten, string4 +FROM onek +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT DISTINCT two FROM tmp ORDER BY 1 +-- !query 1 schema +struct +-- !query 1 output +0 +1 + + +-- !query 2 +SELECT DISTINCT ten FROM tmp ORDER BY 1 +-- !query 2 schema +struct +-- !query 2 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 3 +SELECT DISTINCT string4 FROM tmp ORDER BY 1 +-- !query 3 schema +struct +-- !query 3 output +AAAAxx +HHHHxx +OOOOxx +VVVVxx + + +-- !query 4 +SELECT DISTINCT two, string4, ten + FROM tmp + ORDER BY two ASC, string4 ASC, ten ASC +-- !query 4 schema +struct +-- !query 4 output +0 AAAAxx 0 +0 AAAAxx 2 +0 AAAAxx 4 +0 AAAAxx 6 +0 AAAAxx 8 +0 HHHHxx 0 +0 HHHHxx 2 +0 HHHHxx 4 +0 HHHHxx 6 +0 HHHHxx 8 +0 OOOOxx 0 +0 OOOOxx 2 +0 OOOOxx 4 +0 OOOOxx 6 +0 OOOOxx 8 +0 VVVVxx 0 +0 VVVVxx 2 +0 VVVVxx 4 +0 VVVVxx 6 +0 VVVVxx 8 +1 AAAAxx 1 +1 AAAAxx 3 +1 AAAAxx 5 +1 AAAAxx 7 +1 AAAAxx 9 +1 HHHHxx 1 +1 HHHHxx 3 +1 HHHHxx 5 +1 HHHHxx 7 +1 HHHHxx 9 +1 OOOOxx 1 +1 OOOOxx 3 +1 OOOOxx 5 +1 OOOOxx 7 +1 OOOOxx 9 +1 VVVVxx 1 +1 VVVVxx 3 +1 VVVVxx 5 +1 VVVVxx 7 +1 VVVVxx 9 + + +-- !query 5 +SELECT count(*) FROM + (SELECT DISTINCT two, four, two FROM tenk1) ss +-- !query 5 schema +struct +-- !query 5 output +4 + + +-- !query 6 +CREATE OR REPLACE TEMPORARY VIEW disttable AS SELECT * FROM + (VALUES (1), (2), (3), (NULL)) + AS v(f1) +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +SELECT f1, f1 IS DISTINCT FROM 2 as `not 2` FROM disttable +-- !query 7 schema +struct +-- !query 7 output +1 true +2 false +3 true +NULL true + + +-- !query 8 +SELECT f1, f1 IS DISTINCT FROM NULL as `not null` FROM disttable +-- !query 8 schema +struct +-- !query 8 output +1 true +2 true +3 true +NULL false + + +-- !query 9 +SELECT f1, f1 IS DISTINCT FROM f1 as `false` FROM disttable +-- !query 9 schema +struct +-- !query 9 output +1 false +2 false +3 false +NULL false + + +-- !query 10 +SELECT f1, f1 IS DISTINCT FROM f1+1 as `not null` FROM disttable +-- !query 10 schema +struct +-- !query 10 output +1 true +2 true +3 true +NULL false + + +-- !query 11 +SELECT 1 IS DISTINCT FROM 2 as `yes` +-- !query 11 schema +struct +-- !query 11 output +true + + +-- !query 12 +SELECT 2 IS DISTINCT FROM 2 as `no` +-- !query 12 schema +struct +-- !query 12 output +false + + +-- !query 13 +SELECT 2 IS DISTINCT FROM null as `yes` +-- !query 13 schema +struct +-- !query 13 output +true + + +-- !query 14 +SELECT null IS DISTINCT FROM null as `no` +-- !query 14 schema +struct +-- !query 14 output +false + + +-- !query 15 +SELECT 1 IS NOT DISTINCT FROM 2 as `no` +-- !query 15 schema +struct +-- !query 15 output +false + + +-- !query 16 +SELECT 2 IS NOT DISTINCT FROM 2 as `yes` +-- !query 16 schema +struct +-- !query 16 output +true + + +-- !query 17 +SELECT 2 IS NOT DISTINCT FROM null as `no` +-- !query 17 schema +struct +-- !query 17 output +false + + +-- !query 18 +SELECT null IS NOT DISTINCT FROM null as `yes` +-- !query 18 schema +struct +-- !query 18 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/select_having.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/select_having.sql.out new file mode 100644 index 000000000000..02536ebd8ebe --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/select_having.sql.out @@ -0,0 +1,187 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 22 + + +-- !query 0 +CREATE TABLE test_having (a int, b int, c string, d string) USING parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +INSERT INTO test_having VALUES (0, 1, 'XXXX', 'A') +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +INSERT INTO test_having VALUES (1, 2, 'AAAA', 'b') +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +INSERT INTO test_having VALUES (2, 2, 'AAAA', 'c') +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +INSERT INTO test_having VALUES (3, 3, 'BBBB', 'D') +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +INSERT INTO test_having VALUES (4, 3, 'BBBB', 'e') +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +INSERT INTO test_having VALUES (5, 3, 'bbbb', 'F') +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +INSERT INTO test_having VALUES (6, 4, 'cccc', 'g') +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +INSERT INTO test_having VALUES (7, 4, 'cccc', 'h') +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +INSERT INTO test_having VALUES (8, 4, 'CCCC', 'I') +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +INSERT INTO test_having VALUES (9, 4, 'CCCC', 'j') +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +SELECT b, c FROM test_having + GROUP BY b, c HAVING count(*) = 1 ORDER BY b, c +-- !query 11 schema +struct +-- !query 11 output +1 XXXX +3 bbbb + + +-- !query 12 +SELECT b, c FROM test_having + GROUP BY b, c HAVING b = 3 ORDER BY b, c +-- !query 12 schema +struct +-- !query 12 output +3 BBBB +3 bbbb + + +-- !query 13 +SELECT c, max(a) FROM test_having + GROUP BY c HAVING count(*) > 2 OR min(a) = max(a) + ORDER BY c +-- !query 13 schema +struct +-- !query 13 output +XXXX 0 +bbbb 5 + + +-- !query 14 +SELECT min(a), max(a) FROM test_having HAVING min(a) = max(a) +-- !query 14 schema +struct +-- !query 14 output + + + +-- !query 15 +SELECT min(a), max(a) FROM test_having HAVING min(a) < max(a) +-- !query 15 schema +struct +-- !query 15 output +0 9 + + +-- !query 16 +SELECT a FROM test_having HAVING min(a) < max(a) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and 'default.test_having.`a`' is not an aggregate function. Wrap '(min(default.test_having.`a`) AS `min(a#x)`, max(default.test_having.`a`) AS `max(a#x)`)' in windowing function(s) or wrap 'default.test_having.`a`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 17 +SELECT 1 AS one FROM test_having HAVING a > 1 +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +cannot resolve '`a`' given input columns: [one]; line 1 pos 40 + + +-- !query 18 +SELECT 1 AS one FROM test_having HAVING 1 > 2 +-- !query 18 schema +struct +-- !query 18 output + + + +-- !query 19 +SELECT 1 AS one FROM test_having HAVING 1 < 2 +-- !query 19 schema +struct +-- !query 19 output +1 + + +-- !query 20 +SELECT 1 AS one FROM test_having WHERE 1/a = 1 HAVING 1 < 2 +-- !query 20 schema +struct +-- !query 20 output +1 + + +-- !query 21 +DROP TABLE test_having +-- !query 21 schema +struct<> +-- !query 21 output + diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/with.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/with.sql.out new file mode 100644 index 000000000000..366b65f3659c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/with.sql.out @@ -0,0 +1,471 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 51 + + +-- !query 0 +WITH q1(x,y) AS (SELECT 1,2) +SELECT * FROM q1, q1 AS q2 +-- !query 0 schema +struct +-- !query 0 output +1 2 1 2 + + +-- !query 1 +SELECT count(*) FROM ( + WITH q1(x) AS (SELECT rand() FROM (SELECT EXPLODE(SEQUENCE(1, 5)))) + SELECT * FROM q1 + UNION + SELECT * FROM q1 +) ss +-- !query 1 schema +struct +-- !query 1 output +10 + + +-- !query 2 +CREATE TABLE department ( + id INTEGER, -- department ID + parent_department INTEGER, -- upper department ID + name string -- department name +) USING parquet +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +INSERT INTO department VALUES (0, NULL, 'ROOT') +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +INSERT INTO department VALUES (1, 0, 'A') +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +INSERT INTO department VALUES (2, 1, 'B') +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +INSERT INTO department VALUES (3, 2, 'C') +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +INSERT INTO department VALUES (4, 2, 'D') +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +INSERT INTO department VALUES (5, 0, 'E') +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +INSERT INTO department VALUES (6, 4, 'F') +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +INSERT INTO department VALUES (7, 5, 'G') +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +CREATE TABLE tree( + id INTEGER, + parent_id INTEGER +) USING parquet +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +INSERT INTO tree +VALUES (1, NULL), (2, 1), (3,1), (4,2), (5,2), (6,2), (7,3), (8,3), + (9,4), (10,4), (11,7), (12,7), (13,7), (14, 9), (15,11), (16,11) +-- !query 12 schema +struct<> +-- !query 12 output + + + +-- !query 13 +create table graph( f int, t int, label string ) USING parquet +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +insert into graph values + (1, 2, 'arc 1 -> 2'), + (1, 3, 'arc 1 -> 3'), + (2, 3, 'arc 2 -> 3'), + (1, 4, 'arc 1 -> 4'), + (4, 5, 'arc 4 -> 5'), + (5, 1, 'arc 5 -> 1') +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +CREATE TABLE y (a INTEGER) USING parquet +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +INSERT INTO y SELECT EXPLODE(SEQUENCE(1, 10)) +-- !query 16 schema +struct<> +-- !query 16 output + + + +-- !query 17 +DROP TABLE y +-- !query 17 schema +struct<> +-- !query 17 output + + + +-- !query 18 +CREATE TABLE y (a INTEGER) USING parquet +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +INSERT INTO y SELECT EXPLODE(SEQUENCE(1, 10)) +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +with cte(foo) as ( select 42 ) select * from ((select foo from cte)) q +-- !query 20 schema +struct +-- !query 20 output +42 + + +-- !query 21 +WITH outermost(x) AS ( + SELECT 1 + UNION (WITH innermost as (SELECT 2) + SELECT * FROM innermost + UNION SELECT 3) +) +SELECT * FROM outermost ORDER BY 1 +-- !query 21 schema +struct +-- !query 21 output +1 +2 +3 + + +-- !query 22 +WITH outermost(x) AS ( + SELECT 1 + UNION (WITH innermost as (SELECT 2) + SELECT * FROM outermost -- fail + UNION SELECT * FROM innermost) +) +SELECT * FROM outermost ORDER BY 1 +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Table or view not found: outermost; line 4 pos 23 + + +-- !query 23 +CREATE TABLE withz USING parquet AS SELECT i AS k, CAST(i || ' v' AS string) v FROM (SELECT EXPLODE(SEQUENCE(1, 16, 3)) i) +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT * FROM withz ORDER BY k +-- !query 24 schema +struct +-- !query 24 output +1 1 v +4 4 v +7 7 v +10 10 v +13 13 v +16 16 v + + +-- !query 25 +DROP TABLE withz +-- !query 25 schema +struct<> +-- !query 25 output + + + +-- !query 26 +TRUNCATE TABLE y +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +INSERT INTO y SELECT EXPLODE(SEQUENCE(1, 3)) +-- !query 27 schema +struct<> +-- !query 27 output + + + +-- !query 28 +CREATE TABLE yy (a INTEGER) USING parquet +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +SELECT * FROM y +-- !query 29 schema +struct +-- !query 29 output +1 +2 +3 + + +-- !query 30 +SELECT * FROM yy +-- !query 30 schema +struct +-- !query 30 output + + + +-- !query 31 +SELECT * FROM y +-- !query 31 schema +struct +-- !query 31 output +1 +2 +3 + + +-- !query 32 +SELECT * FROM yy +-- !query 32 schema +struct +-- !query 32 output + + + +-- !query 33 +CREATE TABLE parent ( id int, val string ) USING parquet +-- !query 33 schema +struct<> +-- !query 33 output + + + +-- !query 34 +INSERT INTO parent VALUES ( 1, 'p1' ) +-- !query 34 schema +struct<> +-- !query 34 output + + + +-- !query 35 +SELECT * FROM parent +-- !query 35 schema +struct +-- !query 35 output +1 p1 + + +-- !query 36 +SELECT * FROM parent +-- !query 36 schema +struct +-- !query 36 output +1 p1 + + +-- !query 37 +create table foo (with baz) +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'with'(line 1, pos 18) + +== SQL == +create table foo (with baz) +------------------^^^ + + +-- !query 38 +-- fail, WITH is a reserved word +create table foo (with ordinality) +-- !query 38 schema +struct<> +-- !query 38 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'with'(line 2, pos 18) + +== SQL == +-- fail, WITH is a reserved word +create table foo (with ordinality) +------------------^^^ + + +-- !query 39 +-- fail, WITH is a reserved word +with ordinality as (select 1 as x) select * from ordinality +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +WITH test AS (SELECT 42) INSERT INTO test VALUES (1) +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +Table not found: test; + + +-- !query 41 +create table test (i int) USING parquet +-- !query 41 schema +struct<> +-- !query 41 output + + + +-- !query 42 +with test as (select 42) insert into test select * from test +-- !query 42 schema +struct<> +-- !query 42 output + + + +-- !query 43 +select * from test +-- !query 43 schema +struct +-- !query 43 output +42 + + +-- !query 44 +drop table test +-- !query 44 schema +struct<> +-- !query 44 output + + + +-- !query 45 +DROP TABLE department +-- !query 45 schema +struct<> +-- !query 45 output + + + +-- !query 46 +DROP TABLE tree +-- !query 46 schema +struct<> +-- !query 46 output + + + +-- !query 47 +DROP TABLE graph +-- !query 47 schema +struct<> +-- !query 47 output + + + +-- !query 48 +DROP TABLE y +-- !query 48 schema +struct<> +-- !query 48 output + + + +-- !query 49 +DROP TABLE yy +-- !query 49 schema +struct<> +-- !query 49 output + + + +-- !query 50 +DROP TABLE parent +-- !query 50 schema +struct<> +-- !query 50 output + diff --git a/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part1.sql.out index 32be362d87ca..a2f64717d73a 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part1.sql.out @@ -13,7 +13,7 @@ struct -- !query 1 SELECT udf(avg(a)) AS avg_32 FROM aggtest WHERE a < 100 -- !query 1 schema -struct +struct -- !query 1 output 32.666666666666664 @@ -29,15 +29,15 @@ struct -- !query 3 SELECT sum(udf(four)) AS sum_1500 FROM onek -- !query 3 schema -struct +struct -- !query 3 output -1500.0 +1500 -- !query 4 SELECT udf(sum(a)) AS sum_198 FROM aggtest -- !query 4 schema -struct +struct -- !query 4 output 198 @@ -45,7 +45,7 @@ struct -- !query 5 SELECT udf(udf(sum(b))) AS avg_431_773 FROM aggtest -- !query 5 schema -struct +struct -- !query 5 output 431.77260909229517 @@ -53,7 +53,7 @@ struct -- !query 6 SELECT udf(max(four)) AS max_3 FROM onek -- !query 6 schema -struct +struct -- !query 6 output 3 @@ -61,47 +61,47 @@ struct -- !query 7 SELECT max(udf(a)) AS max_100 FROM aggtest -- !query 7 schema -struct +struct -- !query 7 output -56 +100 -- !query 8 -SELECT CAST(udf(udf(max(aggtest.b))) AS int) AS max_324_78 FROM aggtest +SELECT udf(udf(max(aggtest.b))) AS max_324_78 FROM aggtest -- !query 8 schema -struct +struct -- !query 8 output -324 +324.78 -- !query 9 -SELECT CAST(stddev_pop(udf(b)) AS int) FROM aggtest +SELECT stddev_pop(udf(b)) FROM aggtest -- !query 9 schema -struct +struct -- !query 9 output -131 +131.10703231895047 -- !query 10 SELECT udf(stddev_samp(b)) FROM aggtest -- !query 10 schema -struct +struct -- !query 10 output 151.38936080399804 -- !query 11 -SELECT CAST(var_pop(udf(b)) as int) FROM aggtest +SELECT var_pop(udf(b)) FROM aggtest -- !query 11 schema -struct +struct -- !query 11 output -17189 +17189.053923482323 -- !query 12 SELECT udf(var_samp(b)) FROM aggtest -- !query 12 schema -struct +struct -- !query 12 output 22918.738564643096 @@ -109,7 +109,7 @@ struct -- !query 13 SELECT udf(stddev_pop(CAST(b AS Decimal(38,0)))) FROM aggtest -- !query 13 schema -struct +struct -- !query 13 output 131.18117242958306 @@ -117,7 +117,7 @@ struct -- !query 14 SELECT stddev_samp(CAST(udf(b) AS Decimal(38,0))) FROM aggtest -- !query 14 schema -struct +struct -- !query 14 output 151.47497042966097 @@ -125,7 +125,7 @@ struct -- !query 15 SELECT udf(var_pop(CAST(b AS Decimal(38,0)))) FROM aggtest -- !query 15 schema -struct +struct -- !query 15 output 17208.5 @@ -133,7 +133,7 @@ struct -- !query 16 SELECT var_samp(udf(CAST(b AS Decimal(38,0)))) FROM aggtest -- !query 16 schema -struct +struct -- !query 16 output 22944.666666666668 @@ -141,7 +141,7 @@ struct -- !query 17 SELECT udf(var_pop(1.0)), var_samp(udf(2.0)) -- !query 17 schema -struct +struct -- !query 17 output 0.0 NaN @@ -149,7 +149,7 @@ struct +struct -- !query 18 output 0.0 NaN @@ -157,7 +157,7 @@ struct +struct -- !query 19 output NULL @@ -165,7 +165,7 @@ NULL -- !query 20 select sum(udf(CAST(null AS long))) from range(1,4) -- !query 20 schema -struct +struct -- !query 20 output NULL @@ -173,7 +173,7 @@ NULL -- !query 21 select sum(udf(CAST(null AS Decimal(38,0)))) from range(1,4) -- !query 21 schema -struct +struct -- !query 21 output NULL @@ -181,7 +181,7 @@ NULL -- !query 22 select sum(udf(CAST(null AS DOUBLE))) from range(1,4) -- !query 22 schema -struct +struct -- !query 22 output NULL @@ -189,7 +189,7 @@ NULL -- !query 23 select avg(udf(CAST(null AS int))) from range(1,4) -- !query 23 schema -struct +struct -- !query 23 output NULL @@ -197,7 +197,7 @@ NULL -- !query 24 select avg(udf(CAST(null AS long))) from range(1,4) -- !query 24 schema -struct +struct -- !query 24 output NULL @@ -205,7 +205,7 @@ NULL -- !query 25 select avg(udf(CAST(null AS Decimal(38,0)))) from range(1,4) -- !query 25 schema -struct +struct -- !query 25 output NULL @@ -213,7 +213,7 @@ NULL -- !query 26 select avg(udf(CAST(null AS DOUBLE))) from range(1,4) -- !query 26 schema -struct +struct -- !query 26 output NULL @@ -221,7 +221,7 @@ NULL -- !query 27 select sum(CAST(udf('NaN') AS DOUBLE)) from range(1,4) -- !query 27 schema -struct +struct -- !query 27 output NaN @@ -229,7 +229,7 @@ NaN -- !query 28 select avg(CAST(udf('NaN') AS DOUBLE)) from range(1,4) -- !query 28 schema -struct +struct -- !query 28 output NaN @@ -238,7 +238,7 @@ NaN SELECT avg(CAST(udf(x) AS DOUBLE)), var_pop(CAST(udf(x) AS DOUBLE)) FROM (VALUES ('Infinity'), ('1')) v(x) -- !query 29 schema -struct +struct -- !query 29 output Infinity NaN @@ -247,7 +247,7 @@ Infinity NaN SELECT avg(CAST(udf(x) AS DOUBLE)), var_pop(CAST(udf(x) AS DOUBLE)) FROM (VALUES ('Infinity'), ('Infinity')) v(x) -- !query 30 schema -struct +struct -- !query 30 output Infinity NaN @@ -256,7 +256,7 @@ Infinity NaN SELECT avg(CAST(udf(x) AS DOUBLE)), var_pop(CAST(udf(x) AS DOUBLE)) FROM (VALUES ('-Infinity'), ('Infinity')) v(x) -- !query 31 schema -struct +struct -- !query 31 output NaN NaN @@ -265,7 +265,7 @@ NaN NaN SELECT avg(udf(CAST(x AS DOUBLE))), udf(var_pop(CAST(x AS DOUBLE))) FROM (VALUES (100000003), (100000004), (100000006), (100000007)) v(x) -- !query 32 schema -struct +struct -- !query 32 output 1.00000005E8 2.5 @@ -274,23 +274,23 @@ struct +struct -- !query 33 output 7.000000000006E12 1.0 -- !query 34 -SELECT CAST(udf(covar_pop(b, udf(a))) AS int), CAST(covar_samp(udf(b), a) as int) FROM aggtest +SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest -- !query 34 schema -struct +struct -- !query 34 output -653 871 +653.6289553875104 871.5052738500139 -- !query 35 SELECT corr(b, udf(a)) FROM aggtest -- !query 35 schema -struct +struct -- !query 35 output 0.1396345165178734 @@ -306,7 +306,7 @@ struct -- !query 37 SELECT udf(count(DISTINCT four)) AS cnt_4 FROM onek -- !query 37 schema -struct +struct -- !query 37 output 4 @@ -315,25 +315,25 @@ struct select ten, udf(count(*)), sum(udf(four)) from onek group by ten order by ten -- !query 38 schema -struct +struct -- !query 38 output -0 100 100.0 -1 100 200.0 -2 100 100.0 -3 100 200.0 -4 100 100.0 -5 100 200.0 -6 100 100.0 -7 100 200.0 -8 100 100.0 -9 100 200.0 +0 100 100 +1 100 200 +2 100 100 +3 100 200 +4 100 100 +5 100 200 +6 100 100 +7 100 200 +8 100 100 +9 100 200 -- !query 39 select ten, count(udf(four)), udf(sum(DISTINCT four)) from onek group by ten order by ten -- !query 39 schema -struct +struct -- !query 39 output 0 100 2 1 100 4 @@ -352,7 +352,7 @@ select ten, udf(sum(distinct four)) from onek a group by ten having exists (select 1 from onek b where udf(sum(distinct a.four)) = b.four) -- !query 40 schema -struct +struct -- !query 40 output 0 2 2 2 @@ -372,7 +372,7 @@ struct<> org.apache.spark.sql.AnalysisException Aggregate/Window/Generate expressions are not valid in where clause of the query. -Expression in where clause: [(sum(DISTINCT CAST((outer() + b.`four`) AS BIGINT)) = CAST(udf(four) AS BIGINT))] +Expression in where clause: [(sum(DISTINCT CAST((outer() + b.`four`) AS BIGINT)) = CAST(CAST(udf(cast(four as string)) AS INT) AS BIGINT))] Invalid expressions: [sum(DISTINCT CAST((outer() + b.`four`) AS BIGINT))]; diff --git a/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part2.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part2.sql.out index d90aa11fc6ef..9fe943874c3e 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-aggregates_part2.sql.out @@ -59,7 +59,7 @@ true false true false true true true true true -- !query 3 select min(udf(unique1)) from tenk1 -- !query 3 schema -struct +struct -- !query 3 output 0 @@ -67,7 +67,7 @@ struct -- !query 4 select udf(max(unique1)) from tenk1 -- !query 4 schema -struct +struct -- !query 4 output 9999 @@ -115,7 +115,7 @@ struct -- !query 10 select distinct max(udf(unique2)) from tenk1 -- !query 10 schema -struct +struct -- !query 10 output 9999 @@ -139,7 +139,7 @@ struct -- !query 13 select udf(max(udf(unique2))) from tenk1 order by udf(max(unique2))+1 -- !query 13 schema -struct +struct -- !query 13 output 9999 @@ -147,7 +147,7 @@ struct -- !query 14 select t1.max_unique2, udf(g) from (select max(udf(unique2)) as max_unique2 FROM tenk1) t1 LATERAL VIEW explode(array(1,2,3)) t2 AS g order by g desc -- !query 14 schema -struct +struct -- !query 14 output 9999 3 9999 2 @@ -157,6 +157,6 @@ struct -- !query 15 select udf(max(100)) from tenk1 -- !query 15 schema -struct +struct -- !query 15 output 100 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-case.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-case.sql.out index 55bef64338f4..d9a8ca86361f 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-case.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-case.sql.out @@ -1,19 +1,22 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 37 +-- Number of queries: 35 -- !query 0 -set spark.sql.crossJoin.enabled=true +CREATE TABLE CASE_TBL ( + i integer, + f double +) USING parquet -- !query 0 schema -struct +struct<> -- !query 0 output -spark.sql.crossJoin.enabled true + -- !query 1 -CREATE TABLE CASE_TBL ( +CREATE TABLE CASE2_TBL ( i integer, - f double + j integer ) USING parquet -- !query 1 schema struct<> @@ -22,10 +25,7 @@ struct<> -- !query 2 -CREATE TABLE CASE2_TBL ( - i integer, - j integer -) USING parquet +INSERT INTO CASE_TBL VALUES (1, 10.1) -- !query 2 schema struct<> -- !query 2 output @@ -33,7 +33,7 @@ struct<> -- !query 3 -INSERT INTO CASE_TBL VALUES (1, 10.1) +INSERT INTO CASE_TBL VALUES (2, 20.2) -- !query 3 schema struct<> -- !query 3 output @@ -41,7 +41,7 @@ struct<> -- !query 4 -INSERT INTO CASE_TBL VALUES (2, 20.2) +INSERT INTO CASE_TBL VALUES (3, -30.3) -- !query 4 schema struct<> -- !query 4 output @@ -49,7 +49,7 @@ struct<> -- !query 5 -INSERT INTO CASE_TBL VALUES (3, -30.3) +INSERT INTO CASE_TBL VALUES (4, NULL) -- !query 5 schema struct<> -- !query 5 output @@ -57,7 +57,7 @@ struct<> -- !query 6 -INSERT INTO CASE_TBL VALUES (4, NULL) +INSERT INTO CASE2_TBL VALUES (1, -1) -- !query 6 schema struct<> -- !query 6 output @@ -65,7 +65,7 @@ struct<> -- !query 7 -INSERT INTO CASE2_TBL VALUES (1, -1) +INSERT INTO CASE2_TBL VALUES (2, -2) -- !query 7 schema struct<> -- !query 7 output @@ -73,7 +73,7 @@ struct<> -- !query 8 -INSERT INTO CASE2_TBL VALUES (2, -2) +INSERT INTO CASE2_TBL VALUES (3, -3) -- !query 8 schema struct<> -- !query 8 output @@ -81,7 +81,7 @@ struct<> -- !query 9 -INSERT INTO CASE2_TBL VALUES (3, -3) +INSERT INTO CASE2_TBL VALUES (2, -4) -- !query 9 schema struct<> -- !query 9 output @@ -89,7 +89,7 @@ struct<> -- !query 10 -INSERT INTO CASE2_TBL VALUES (2, -4) +INSERT INTO CASE2_TBL VALUES (1, NULL) -- !query 10 schema struct<> -- !query 10 output @@ -97,7 +97,7 @@ struct<> -- !query 11 -INSERT INTO CASE2_TBL VALUES (1, NULL) +INSERT INTO CASE2_TBL VALUES (NULL, -6) -- !query 11 schema struct<> -- !query 11 output @@ -105,148 +105,140 @@ struct<> -- !query 12 -INSERT INTO CASE2_TBL VALUES (NULL, -6) --- !query 12 schema -struct<> --- !query 12 output - - - --- !query 13 SELECT '3' AS `One`, CASE - WHEN CAST(udf(1 < 2) AS boolean) THEN 3 + WHEN udf(1 < 2) THEN 3 END AS `Simple WHEN` --- !query 13 schema +-- !query 12 schema struct --- !query 13 output +-- !query 12 output 3 3 --- !query 14 +-- !query 13 SELECT '' AS `One`, CASE WHEN 1 > 2 THEN udf(3) END AS `Simple default` --- !query 14 schema -struct --- !query 14 output +-- !query 13 schema +struct +-- !query 13 output NULL --- !query 15 +-- !query 14 SELECT '3' AS `One`, CASE WHEN udf(1) < 2 THEN udf(3) ELSE udf(4) END AS `Simple ELSE` --- !query 15 schema -struct --- !query 15 output +-- !query 14 schema +struct +-- !query 14 output 3 3 --- !query 16 +-- !query 15 SELECT udf('4') AS `One`, CASE WHEN 1 > 2 THEN 3 ELSE 4 END AS `ELSE default` --- !query 16 schema +-- !query 15 schema struct --- !query 16 output +-- !query 15 output 4 4 --- !query 17 +-- !query 16 SELECT udf('6') AS `One`, CASE - WHEN CAST(udf(1 > 2) AS boolean) THEN 3 + WHEN udf(1 > 2) THEN 3 WHEN udf(4) < 5 THEN 6 ELSE 7 END AS `Two WHEN with default` --- !query 17 schema +-- !query 16 schema struct --- !query 17 output +-- !query 16 output 6 6 --- !query 18 +-- !query 17 SELECT '7' AS `None`, CASE WHEN rand() < udf(0) THEN 1 END AS `NULL on no matches` --- !query 18 schema +-- !query 17 schema struct --- !query 18 output +-- !query 17 output 7 NULL +-- !query 18 +SELECT CASE WHEN udf(1=0) THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END +-- !query 18 schema +struct +-- !query 18 output +1 + + -- !query 19 -SELECT CASE WHEN CAST(udf(1=0) AS boolean) THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END +SELECT CASE 1 WHEN 0 THEN 1/udf(0) WHEN 1 THEN 1 ELSE 2/0 END -- !query 19 schema -struct +struct -- !query 19 output -1.0 +1 -- !query 20 -SELECT CASE 1 WHEN 0 THEN 1/udf(0) WHEN 1 THEN 1 ELSE 2/0 END +SELECT CASE WHEN i > 100 THEN udf(1/0) ELSE udf(0) END FROM case_tbl -- !query 20 schema -struct +struct 100) THEN CAST(udf(cast((1 div 0) as string)) AS INT) ELSE CAST(udf(cast(0 as string)) AS INT) END:int> -- !query 20 output -1.0 - - --- !query 21 -SELECT CASE WHEN i > 100 THEN udf(1/0) ELSE udf(0) END FROM case_tbl --- !query 21 schema -struct 100) THEN udf((cast(1 as double) / cast(0 as double))) ELSE udf(0) END:string> --- !query 21 output 0 0 0 0 --- !query 22 +-- !query 21 SELECT CASE 'a' WHEN 'a' THEN udf(1) ELSE udf(2) END --- !query 22 schema -struct --- !query 22 output +-- !query 21 schema +struct +-- !query 21 output 1 --- !query 23 +-- !query 22 SELECT '' AS `Five`, CASE WHEN i >= 3 THEN i END AS `>= 3 or Null` FROM CASE_TBL --- !query 23 schema +-- !query 22 schema struct= 3 or Null:int> --- !query 23 output +-- !query 22 output 3 4 NULL NULL --- !query 24 +-- !query 23 SELECT '' AS `Five`, CASE WHEN i >= 3 THEN (i + i) ELSE i END AS `Simplest Math` FROM CASE_TBL --- !query 24 schema +-- !query 23 schema struct --- !query 24 output +-- !query 23 output 1 2 6 8 --- !query 25 +-- !query 24 SELECT '' AS `Five`, i AS `Value`, CASE WHEN (i < 0) THEN 'small' WHEN (i = 0) THEN 'zero' @@ -255,16 +247,16 @@ SELECT '' AS `Five`, i AS `Value`, ELSE 'big' END AS `Category` FROM CASE_TBL --- !query 25 schema +-- !query 24 schema struct --- !query 25 output +-- !query 24 output 1 one 2 two 3 big 4 big --- !query 26 +-- !query 25 SELECT '' AS `Five`, CASE WHEN ((i < 0) or (i < 0)) THEN 'small' WHEN ((i = 0) or (i = 0)) THEN 'zero' @@ -273,37 +265,37 @@ SELECT '' AS `Five`, ELSE 'big' END AS `Category` FROM CASE_TBL --- !query 26 schema +-- !query 25 schema struct --- !query 26 output +-- !query 25 output big big one two --- !query 27 +-- !query 26 SELECT * FROM CASE_TBL WHERE udf(COALESCE(f,i)) = 4 --- !query 27 schema +-- !query 26 schema struct --- !query 27 output +-- !query 26 output 4 NULL --- !query 28 +-- !query 27 SELECT * FROM CASE_TBL WHERE udf(NULLIF(f,i)) = 2 --- !query 28 schema +-- !query 27 schema struct --- !query 28 output +-- !query 27 output --- !query 29 +-- !query 28 SELECT udf(COALESCE(a.f, b.i, b.j)) FROM CASE_TBL a, CASE2_TBL b --- !query 29 schema -struct --- !query 29 output +-- !query 28 schema +struct +-- !query 28 output -30.3 -30.3 -30.3 @@ -330,24 +322,24 @@ struct 3.0 --- !query 30 +-- !query 29 SELECT * FROM CASE_TBL a, CASE2_TBL b WHERE udf(COALESCE(a.f, b.i, b.j)) = 2 --- !query 30 schema +-- !query 29 schema struct --- !query 30 output +-- !query 29 output 4 NULL 2 -2 4 NULL 2 -4 --- !query 31 +-- !query 30 SELECT udf('') AS Five, NULLIF(a.i,b.i) AS `NULLIF(a.i,b.i)`, NULLIF(b.i, 4) AS `NULLIF(b.i,4)` FROM CASE_TBL a, CASE2_TBL b --- !query 31 schema +-- !query 30 schema struct --- !query 31 output +-- !query 30 output 1 2 1 2 1 3 @@ -374,18 +366,18 @@ struct NULL 3 --- !query 32 +-- !query 31 SELECT '' AS `Two`, * FROM CASE_TBL a, CASE2_TBL b - WHERE CAST(udf(COALESCE(f,b.i) = 2) AS boolean) --- !query 32 schema + WHERE udf(COALESCE(f,b.i) = 2) +-- !query 31 schema struct --- !query 32 output +-- !query 31 output 4 NULL 2 -2 4 NULL 2 -4 --- !query 33 +-- !query 32 SELECT CASE (CASE vol('bar') WHEN udf('foo') THEN 'it was foo!' @@ -395,31 +387,23 @@ SELECT CASE WHEN udf('it was foo!') THEN 'foo recognized' WHEN 'it was bar!' THEN udf('bar recognized') ELSE 'unrecognized' END AS col --- !query 33 schema +-- !query 32 schema struct --- !query 33 output +-- !query 32 output bar recognized --- !query 34 +-- !query 33 DROP TABLE CASE_TBL --- !query 34 schema +-- !query 33 schema struct<> --- !query 34 output +-- !query 33 output --- !query 35 +-- !query 34 DROP TABLE CASE2_TBL --- !query 35 schema +-- !query 34 schema struct<> --- !query 35 output - - +-- !query 34 output --- !query 36 -set spark.sql.crossJoin.enabled=false --- !query 36 schema -struct --- !query 36 output -spark.sql.crossJoin.enabled false diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-count.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-count.sql.out index 9476937abd9e..3d7c64054a6a 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-count.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-count.sql.out @@ -17,7 +17,7 @@ SELECT udf(count(*)), udf(count(1)), udf(count(null)), udf(count(a)), udf(count(b)), udf(count(a + b)), udf(count((a, b))) FROM testData -- !query 1 schema -struct +struct -- !query 1 output 7 7 0 5 5 4 7 @@ -32,7 +32,7 @@ SELECT udf(count(DISTINCT (a, b))) FROM testData -- !query 2 schema -struct +struct -- !query 2 output 1 0 2 2 2 6 @@ -40,7 +40,7 @@ struct +struct -- !query 3 output 4 4 4 @@ -50,6 +50,6 @@ SELECT udf(count(DISTINCT a, b)), udf(count(DISTINCT b, a)), udf(count(DISTINCT *)), udf(count(DISTINCT testData.*)) FROM testData -- !query 4 schema -struct +struct -- !query 4 output 3 3 3 3 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-having.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-having.sql.out index 7cea2e5128f8..1effcc8470e1 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-having.sql.out @@ -18,7 +18,7 @@ struct<> -- !query 1 SELECT udf(k) AS k, udf(sum(v)) FROM hav GROUP BY k HAVING udf(sum(v)) > 2 -- !query 1 schema -struct +struct -- !query 1 output one 6 three 3 @@ -27,7 +27,7 @@ three 3 -- !query 2 SELECT udf(count(udf(k))) FROM hav GROUP BY v + 1 HAVING v + 1 = udf(2) -- !query 2 schema -struct +struct -- !query 2 output 1 @@ -35,7 +35,7 @@ struct -- !query 3 SELECT udf(MIN(t.v)) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(udf(COUNT(udf(1))) > 0) -- !query 3 schema -struct +struct -- !query 3 output 1 @@ -43,7 +43,7 @@ struct -- !query 4 SELECT udf(a + b) FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > udf(1) -- !query 4 schema -struct +struct -- !query 4 output 3 7 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-inner-join.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-inner-join.sql.out index 10952cb21e4f..120f2d39f73d 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-inner-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-inner-join.sql.out @@ -59,7 +59,7 @@ struct<> -- !query 6 SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag -- !query 6 schema -struct +struct -- !query 6 output 1 a 1 a diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-natural-join.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-natural-join.sql.out index 53ef177db0bb..950809ddcaf2 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-natural-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-natural-join.sql.out @@ -59,6 +59,6 @@ two 2 22 -- !query 5 SELECT udf(count(*)) FROM nt1 natural full outer join nt2 -- !query 5 schema -struct +struct -- !query 5 output 4 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-special-values.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-special-values.sql.out new file mode 100644 index 000000000000..7b2b5dbe578c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-special-values.sql.out @@ -0,0 +1,62 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT udf(x) FROM (VALUES (1), (2), (NULL)) v(x) +-- !query 0 schema +struct +-- !query 0 output +1 +2 +NULL + + +-- !query 1 +SELECT udf(x) FROM (VALUES ('A'), ('B'), (NULL)) v(x) +-- !query 1 schema +struct +-- !query 1 output +A +B +NULL + + +-- !query 2 +SELECT udf(x) FROM (VALUES ('NaN'), ('1'), ('2')) v(x) +-- !query 2 schema +struct +-- !query 2 output +1 +2 +NaN + + +-- !query 3 +SELECT udf(x) FROM (VALUES ('Infinity'), ('1'), ('2')) v(x) +-- !query 3 schema +struct +-- !query 3 output +1 +2 +Infinity + + +-- !query 4 +SELECT udf(x) FROM (VALUES ('-Infinity'), ('1'), ('2')) v(x) +-- !query 4 schema +struct +-- !query 4 output +-Infinity +1 +2 + + +-- !query 5 +SELECT udf(x) FROM (VALUES 0.00000001, 0.00000002, 0.00000003) v(x) +-- !query 5 schema +struct +-- !query 5 output +0.00000001 +0.00000002 +0.00000003 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-udaf.sql.out new file mode 100644 index 000000000000..6cfeb8c17f55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-udaf.sql.out @@ -0,0 +1,70 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg' +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT default.myDoubleAvg(udf(int_col1)) as my_avg, udf(default.myDoubleAvg(udf(int_col1))) as my_avg2, udf(default.myDoubleAvg(int_col1)) as my_avg3 from t1 +-- !query 2 schema +struct +-- !query 2 output +102.5 102.5 102.5 + + +-- !query 3 +SELECT default.myDoubleAvg(udf(int_col1), udf(3)) as my_avg from t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function default.myDoubleAvg. Expected: 1; Found: 2; line 1 pos 7 + + +-- !query 4 +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf' +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +SELECT default.udaf1(udf(int_col1)) as udaf1 from t1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 + + +-- !query 6 +DROP FUNCTION myDoubleAvg +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +DROP FUNCTION udaf1 +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 938d76c9f083..b253c4a70bbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -20,33 +20,34 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter import org.apache.spark.SparkConf +import org.apache.spark.sql.internal.SQLConf class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key, "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { - assert(sparkConf.get("spark.sql.codegen.fallback") == "false", + assert(sparkConf.get(SQLConf.CODEGEN_FALLBACK.key) == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "false", + assert(sparkConf.get(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key) == "false", "configuration parameter changed in test body") } } class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key, "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { - assert(sparkConf.get("spark.sql.codegen.fallback") == "false", + assert(sparkConf.get(SQLConf.CODEGEN_FALLBACK.key) == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", + assert(sparkConf.get(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key) == "true", "configuration parameter changed in test body") } } @@ -56,18 +57,18 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") - .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key, "true") + .set(SQLConf.ENABLE_VECTORIZED_HASH_MAP.key, "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { - assert(sparkConf.get("spark.sql.codegen.fallback") == "false", + assert(sparkConf.get(SQLConf.CODEGEN_FALLBACK.key) == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", + assert(sparkConf.get(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key) == "true", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true", + assert(sparkConf.get(SQLConf.ENABLE_VECTORIZED_HASH_MAP.key) == "true", "configuration parameter changed in test body") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 6049e89c93cf..267f255a11e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -832,7 +832,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val df = spark.range(10).cache() df.queryExecution.executedPlan.foreach { case i: InMemoryTableScanExec => - assert(i.supportsBatch == vectorized && i.supportCodegen == vectorized) + assert(i.supportsColumnar == vectorized) case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 98936702a013..e8ddd4e1fd97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1672,7 +1672,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("reuse exchange") { - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index efd5db1c5b6c..ff6143162ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -603,6 +603,70 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } + test("typed aggregation: expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L), ("b", 3L, 5L, 2L, 1.5, 2L), ("c", 1L, 2L, 1L, 1.0, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long], + min("_2").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long], + min("_2").as[Long], + mean("_2").as[Double]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L, 15.0), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L, 1.5), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L, 1.0)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() @@ -1365,7 +1429,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchangeExec(_, _: RDDScanExec) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 3f91b91850e8..ff48ac8d7a6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -301,11 +301,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") checkAnswer( df.selectExpr(s"d - $i"), - Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + Seq(Row(Date.valueOf("2015-07-29")), Row(Date.valueOf("2015-12-28")))) checkAnswer( df.selectExpr(s"t - $i"), Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), - Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + Row(Timestamp.valueOf("2015-12-29 00:00:00")))) } test("function add_months") { @@ -314,10 +314,10 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((1, d1), (2, d2)).toDF("n", "d") checkAnswer( df.select(add_months(col("d"), 1)), - Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-28")))) checkAnswer( df.selectExpr("add_months(d, -1)"), - Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-28")))) } test("function months_between") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index fffe52d52dec..89195284a5b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -671,8 +671,8 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-22790,SPARK-27668: spark.sql.sources.compressionFactor takes effect") { Seq(1.0, 0.5).foreach { compressionFactor => - withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, - "spark.sql.autoBroadcastJoinThreshold" -> "250") { + withSQLConf(SQLConf.FILE_COMRESSION_FACTOR.key -> compressionFactor.toString, + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "250") { withTempPath { workDir => // the file size is 486 bytes val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index e379d6df867c..d62fe961117a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -26,6 +26,7 @@ import org.apache.spark.TestUtils import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.config.Tests +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction @@ -35,8 +36,12 @@ import org.apache.spark.sql.types.StringType * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and * Scalar Pandas UDFs can be tested in SBT & Maven tests. * - * The available UDFs cast input to strings, which take one column as input and return a string - * type column as output. + * The available UDFs are special. It defines an UDF wrapped by cast. So, the input column is + * casted into string, UDF returns strings as are, and then output column is casted back to + * the input column. In this way, UDF is virtually no-op. + * + * Note that, due to this implementation limitation, complex types such as map, array and struct + * types do not work with this UDFs because they cannot be same after the cast roundtrip. * * To register Scala UDF in SQL: * {{{ @@ -59,8 +64,9 @@ import org.apache.spark.sql.types.StringType * To use it in Scala API and SQL: * {{{ * sql("SELECT udf_name(1)") - * spark.range(10).select(expr("udf_name(id)") - * spark.range(10).select(pandasTestUDF($"id")) + * val df = spark.range(10) + * df.select(expr("udf_name(id)") + * df.select(pandasTestUDF(df("id"))) * }}} */ object IntegratedUDFTestUtils extends SQLHelper { @@ -137,7 +143,8 @@ object IntegratedUDFTestUtils extends SQLHelper { "from pyspark.sql.types import StringType; " + "from pyspark.serializers import CloudPickleSerializer; " + s"f = open('$path', 'wb');" + - s"f.write(CloudPickleSerializer().dumps((lambda x: str(x), StringType())))"), + "f.write(CloudPickleSerializer().dumps((" + + "lambda x: None if x is None else str(x), StringType())))"), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! binaryPythonFunc = Files.readAllBytes(path.toPath) @@ -158,7 +165,9 @@ object IntegratedUDFTestUtils extends SQLHelper { "from pyspark.sql.types import StringType; " + "from pyspark.serializers import CloudPickleSerializer; " + s"f = open('$path', 'wb');" + - s"f.write(CloudPickleSerializer().dumps((lambda x: x.apply(str), StringType())))"), + "f.write(CloudPickleSerializer().dumps((" + + "lambda x: x.apply(" + + "lambda v: None if v is None else str(v)), StringType())))"), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! binaryPandasFunc = Files.readAllBytes(path.toPath) @@ -198,11 +207,22 @@ object IntegratedUDFTestUtils extends SQLHelper { } /** - * A Python UDF that takes one column and returns a string column. - * Equivalent to `udf(lambda x: str(x), "string")` + * A Python UDF that takes one column, casts into string, executes the Python native function, + * and casts back to the type of input column. + * + * Virtually equivalent to: + * + * {{{ + * from pyspark.sql.functions import udf + * + * df = spark.range(3).toDF("col") + * python_udf = udf(lambda x: str(x), "string") + * casted_col = python_udf(df.col.cast("string")) + * casted_col.cast(df.schema["col"].dataType) + * }}} */ case class TestPythonUDF(name: String) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction( + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = PythonFunction( command = pythonFunc, @@ -214,7 +234,16 @@ object IntegratedUDFTestUtils extends SQLHelper { accumulator = null), dataType = StringType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, - udfDeterministic = true) + udfDeterministic = true) { + + override def builder(e: Seq[Expression]): Expression = { + assert(e.length == 1, "Defined UDF only has one column") + val expr = e.head + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) + } + } def apply(exprs: Column*): Column = udf(exprs: _*) @@ -222,11 +251,22 @@ object IntegratedUDFTestUtils extends SQLHelper { } /** - * A Scalar Pandas UDF that takes one column and returns a string column. - * Equivalent to `pandas_udf(lambda x: x.apply(str), "string", PandasUDFType.SCALAR)`. + * A Scalar Pandas UDF that takes one column, casts into string, executes the + * Python native function, and casts back to the type of input column. + * + * Virtually equivalent to: + * + * {{{ + * from pyspark.sql.functions import pandas_udf + * + * df = spark.range(3).toDF("col") + * scalar_udf = pandas_udf(lambda x: x.apply(lambda v: str(v)), "string") + * casted_col = scalar_udf(df.col.cast("string")) + * casted_col.cast(df.schema["col"].dataType) + * }}} */ case class TestScalarPandasUDF(name: String) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction( + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = PythonFunction( command = pandasFunc, @@ -238,7 +278,16 @@ object IntegratedUDFTestUtils extends SQLHelper { accumulator = null), dataType = StringType, pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, - udfDeterministic = true) + udfDeterministic = true) { + + override def builder(e: Seq[Expression]): Expression = { + assert(e.length == 1, "Defined UDF only has one column") + val expr = e.head + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) + } + } def apply(exprs: Column*): Column = udf(exprs: _*) @@ -246,15 +295,39 @@ object IntegratedUDFTestUtils extends SQLHelper { } /** - * A Scala UDF that takes one column and returns a string column. - * Equivalent to `udf((input: Any) => String.valueOf(input)`. + * A Scala UDF that takes one column, casts into string, executes the + * Scala native function, and casts back to the type of input column. + * + * Virtually equivalent to: + * + * {{{ + * import org.apache.spark.sql.functions.udf + * + * val df = spark.range(3).toDF("col") + * val scala_udf = udf((input: Any) => input.toString) + * val casted_col = scala_udf(df.col("col").cast("string")) + * casted_col.cast(df.schema("col").dataType) + * }}} */ case class TestScalaUDF(name: String) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = SparkUserDefinedFunction( - (input: Any) => String.valueOf(input), + private[IntegratedUDFTestUtils] lazy val udf = new SparkUserDefinedFunction( + (input: Any) => if (input == null) { + null + } else { + input.toString + }, StringType, inputSchemas = Seq.fill(1)(None), - name = Some(name)) + name = Some(name)) { + + override def apply(exprs: Column*): Column = { + assert(exprs.length == 1, "Defined UDF only has one column") + val expr = exprs.head.expr + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType)) + } + } def apply(exprs: Column*): Column = udf(exprs: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 32cddc94166b..059dbf892c65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -72,7 +73,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("join operator selection") { spark.sharedState.cacheManager.clearCache() - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0", + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", @@ -651,7 +652,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("test SortMergeJoin (without spill)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", - "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) { + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> Int.MaxValue.toString) { assertNotSpilled(sparkContext, "inner join") { checkAnswer( @@ -708,8 +709,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("test SortMergeJoin (with spill)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", - "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0", - "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") { + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { assertSpilled(sparkContext, "inner join") { checkAnswer( @@ -897,6 +898,26 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-27485: EnsureRequirements should not fail join with duplicate keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val tbl_a = spark.range(40) + .select($"id" as "x", $"id" % 10 as "y") + .repartition(2, $"x", $"y", $"x") + .as("tbl_a") + + val tbl_b = spark.range(20) + .select($"id" as "x", $"id" % 2 as "y1", $"id" % 20 as "y2") + .as("tbl_b") + + val res = tbl_a + .join(tbl_b, + $"tbl_a.x" === $"tbl_b.x" && $"tbl_a.y" === $"tbl_b.y1" && $"tbl_a.y" === $"tbl_b.y2") + .select($"tbl_a.x") + checkAnswer(res, Row(0L) :: Row(1L) :: Nil) + } + } + test("SPARK-26352: join reordering should not change the order of columns") { withTable("tab1", "tab2", "tab3") { spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1") @@ -980,7 +1001,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val left = Seq((1, 2), (2, 3)).toDF("a", "b") val right = Seq((1, 2), (3, 4)).toDF("c", "d") - val df = left.join(right, pythonTestUDF($"a") === pythonTestUDF($"c")) + val df = left.join(right, pythonTestUDF(left("a")) === pythonTestUDF(right.col("c"))) val joinNode = df.queryExecution.executedPlan.find(_.isInstanceOf[BroadcastHashJoinExec]) assert(joinNode.isDefined) @@ -994,4 +1015,26 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } + + test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") { + import IntegratedUDFTestUtils._ + + assume(shouldTestPythonUDFs) + + val pythonTestUDF = TestPythonUDF(name = "udf") + + val left = Seq((1, 2), (2, 3)).toDF("a", "b") + val right = Seq((1, 2), (3, 4)).toDF("c", "d") + val df = left.crossJoin(right).where(pythonTestUDF(left("a")) === right.col("c")) + + // Before optimization, there is a logical Filter operator. + val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter]) + assert(filterInAnalysis.isDefined) + + // Filter predicate was pushdown as join condition. So there is no Filter exec operator. + val filterExec = df.queryExecution.executedPlan.find(_.isInstanceOf[FilterExec]) + assert(filterExec.isEmpty) + + checkAnswer(df, Row(1, 2, 1, 2) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala index 623a1b6f854c..e33870d4e1af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -22,12 +22,15 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} +import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger +import org.apache.spark.sql.streaming.Trigger class ProcessingTimeSuite extends SparkFunSuite { test("create") { - def getIntervalMs(trigger: Trigger): Long = trigger.asInstanceOf[ProcessingTime].intervalMs + def getIntervalMs(trigger: Trigger): Long = { + trigger.asInstanceOf[ProcessingTimeTrigger].intervalMs + } assert(getIntervalMs(Trigger.ProcessingTime(10.seconds)) === 10 * 1000) assert(getIntervalMs(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) === 10 * 1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index 328423160696..720d570ca838 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config +import org.apache.spark.sql.internal.SQLConf.CHECKPOINT_LOCATION +import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD class RuntimeConfigSuite extends SparkFunSuite { @@ -60,8 +62,8 @@ class RuntimeConfigSuite extends SparkFunSuite { val conf = newConf() // SQL configs - assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) - assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + assert(!conf.isModifiable(SCHEMA_STRING_LENGTH_THRESHOLD.key)) + assert(conf.isModifiable(CHECKPOINT_LOCATION.key)) // Core configs assert(!conf.isModifiable(config.CPUS_PER_TASK.key)) assert(!conf.isModifiable("spark.executor.cores")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2cc1be9fdda2..972950669198 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1896,7 +1896,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Star Expansion - group by") { - withSQLConf("spark.sql.retainGroupColumns" -> "false") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { checkAnswer( testData2.groupBy($"a", $"b").agg($"*"), sql("SELECT * FROM testData2 group by a, b")) @@ -1936,7 +1936,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Common subexpression elimination") { // TODO: support subexpression elimination in whole stage codegen - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { // select from a table to prevent constant folding. val df = sql("SELECT a, b from testData2 limit 1") checkAnswer(df, Row(1, 1)) @@ -1985,9 +1985,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. - spark.conf.set("spark.sql.subexpressionElimination.enabled", "false") + spark.conf.set(SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key, "false") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - spark.conf.set("spark.sql.subexpressionElimination.enabled", "true") + spark.conf.set(SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key, "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index c8a187b57a61..e4052b7ed3ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -151,17 +151,37 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val resultFile: String } + /** + * traits that indicate UDF or PgSQL to trigger the code path specific to each. For instance, + * PgSQL tests require to register some UDF functions. + */ + private trait PgSQLTest + + private trait UDFTest { + val udf: TestUDF + } + /** A regular test case. */ private case class RegularTestCase( name: String, inputFile: String, resultFile: String) extends TestCase /** A PostgreSQL test case. */ private case class PgSQLTestCase( - name: String, inputFile: String, resultFile: String) extends TestCase + name: String, inputFile: String, resultFile: String) extends TestCase with PgSQLTest /** A UDF test case. */ private case class UDFTestCase( - name: String, inputFile: String, resultFile: String, udf: TestUDF) extends TestCase + name: String, + inputFile: String, + resultFile: String, + udf: TestUDF) extends TestCase with UDFTest + + /** A UDF PostgreSQL test case. */ + private case class UDFPgSQLTestCase( + name: String, + inputFile: String, + resultFile: String, + udf: TestUDF) extends TestCase with UDFTest with PgSQLTest private def createScalaTestCase(testCase: TestCase): Unit = { if (blackList.exists(t => @@ -169,12 +189,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else testCase match { - case UDFTestCase(_, _, _, udf: TestPythonUDF) if !shouldTestPythonUDFs => + case udfTestCase: UDFTest + if udfTestCase.udf.isInstanceOf[TestPythonUDF] && !shouldTestPythonUDFs => ignore(s"${testCase.name} is skipped because " + s"[$pythonExec] and/or pyspark were not available.") { /* Do nothing */ } - case UDFTestCase(_, _, _, udf: TestScalarPandasUDF) if !shouldTestScalarPandasUDFs => + case udfTestCase: UDFTest + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs => ignore(s"${testCase.name} is skipped because pyspark," + s"pandas and/or pyarrow were not available in [$pythonExec].") { /* Do nothing */ @@ -254,18 +276,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + testCase match { - case udfTestCase: UDFTestCase => - // vol used by udf-case.sql. - localSparkSession.udf.register("vol", (s: String) => s) + case udfTestCase: UDFTest => registerTestUDF(udfTestCase.udf, localSparkSession) - case _: PgSQLTestCase => + case _ => + } + + testCase match { + case _: PgSQLTest => // booleq/boolne used by boolean.sql localSparkSession.udf.register("booleq", (b1: Boolean, b2: Boolean) => b1 == b2) localSparkSession.udf.register("boolne", (b1: Boolean, b2: Boolean) => b1 != b2) // vol used by boolean.sql and case.sql. localSparkSession.udf.register("vol", (s: String) => s) - case _ => // Don't add UDFs in Regular tests. + // PostgreSQL enabled cartesian product by default. + localSparkSession.conf.set(SQLConf.CROSS_JOINS_ENABLED.key, true) + localSparkSession.conf.set(SQLConf.ANSI_SQL_PARSER.key, true) + localSparkSession.conf.set(SQLConf.PREFER_INTEGRAL_DIVISION.key, true) + case _ => } if (configSet.isDefined) { @@ -385,13 +414,16 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val absPath = file.getAbsolutePath val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator) - if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) { + if (file.getAbsolutePath.startsWith( + s"$inputFilePath${File.separator}udf${File.separator}pgSQL")) { + Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf => + UDFPgSQLTestCase( + s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf) + } + } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) { Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf => UDFTestCase( - s"$testCaseName - ${udf.prettyName}", - absPath, - resultFile, - udf) + s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf) } } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) { PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 2e2e61b43896..74341f93dd5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, Metadata, StructType} import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch, ColumnarMap, ColumnVector} import org.apache.spark.unsafe.types.UTF8String @@ -152,7 +153,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite { test("use custom class for extensions") { val session = SparkSession.builder() .master("local[1]") - .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .config(SPARK_SESSION_EXTENSIONS.key, classOf[MyExtensions].getCanonicalName) .getOrCreate() try { assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) @@ -173,7 +174,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite { test("use multiple custom class for extensions in the specified order") { val session = SparkSession.builder() .master("local[1]") - .config("spark.sql.extensions", Seq( + .config(SPARK_SESSION_EXTENSIONS.key, Seq( classOf[MyExtensions2].getCanonicalName, classOf[MyExtensions].getCanonicalName).mkString(",")) .getOrCreate() @@ -201,7 +202,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite { test("allow an extension to be duplicated") { val session = SparkSession.builder() .master("local[1]") - .config("spark.sql.extensions", Seq( + .config(SPARK_SESSION_EXTENSIONS.key, Seq( classOf[MyExtensions].getCanonicalName, classOf[MyExtensions].getCanonicalName).mkString(",")) .getOrCreate() @@ -228,7 +229,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite { test("use the last registered function name when there are duplicates") { val session = SparkSession.builder() .master("local[1]") - .config("spark.sql.extensions", Seq( + .config(SPARK_SESSION_EXTENSIONS.key, Seq( classOf[MyExtensions2].getCanonicalName, classOf[MyExtensions2Duplicate].getCanonicalName).mkString(",")) .getOrCreate() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index fddc4f6bb350..b2c38684071d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} -import org.apache.spark.sql.execution.{ExecSubqueryExpression, FileSourceScanExec, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.datasources.FileScanRDD import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -1293,7 +1293,8 @@ class SubquerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(0, 0), Row(2, 0))) // need to execute the query before we can examine fs.inputRDDs() assert(df.queryExecution.executedPlan match { - case WholeStageCodegenExec(fs @ FileSourceScanExec(_, _, _, partitionFilters, _, _, _)) => + case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter( + fs @ FileSourceScanExec(_, _, _, partitionFilters, _, _, _), _))) => partitionFilters.exists(ExecSubqueryExpression.hasSubquery) && fs.inputRDDs().forall( _.asInstanceOf[FileScanRDD].filePartitions.forall( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index f155b5dc80cf..058c5ba7e50b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -514,4 +514,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(df.collect().toSeq === Seq(Row(expected))) } } + + test("SPARK-28321 0-args Java UDF should not be called only once") { + val nonDeterministicJavaUDF = udf( + new UDF0[Int] { + override def call(): Int = scala.util.Random.nextInt() + }, IntegerType).asNondeterministic() + + assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index b35348b4ea3b..b1143484a85e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -44,9 +44,14 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { } // A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. - private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { - case p: ProjectExec => isScanPlanTree(p.child) - case f: FilterExec => isScanPlanTree(f.child) + // Because of how codegen and columnar to row transitions work, we may have InputAdaptors + // and ColumnarToRow transformations in the middle of it, but they will not have the tag + // we want, so skip them if they are the first thing we see + private def isScanPlanTree(plan: SparkPlan, first: Boolean): Boolean = plan match { + case i: InputAdapter if !first => isScanPlanTree(i.child, false) + case c: ColumnarToRowExec if !first => isScanPlanTree(c.child, false) + case p: ProjectExec => isScanPlanTree(p.child, false) + case f: FilterExec => isScanPlanTree(f.child, false) case _: LeafExecNode => true case _ => false } @@ -87,7 +92,7 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { case _: SubqueryExec | _: ReusedSubqueryExec => assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) - case _ if isScanPlanTree(plan) => + case _ if isScanPlanTree(plan, true) => // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, // so it's not simple to check. Instead, we only check that the origin LogicalPlan // contains the corresponding leaf node of the SparkPlan. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c2d9e5498192..e30fb13d10df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -696,6 +696,32 @@ class PlannerSuite extends SharedSQLContext { } } + test("SPARK-27485: EnsureRequirements.reorder should handle duplicate expressions") { + val plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: exprB :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan() + val smjExec = SortMergeJoinExec( + leftKeys = exprA :: exprB :: exprB :: Nil, + rightKeys = exprA :: exprC :: exprC :: Nil, + joinType = Inner, + condition = None, + left = plan1, + right = plan2) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _)) => + assert(leftKeys === smjExec.leftKeys) + assert(rightKeys === smjExec.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitioningExpressions) + case _ => fail(outputPlan.toString) + } + } + test("SPARK-24500: create union with stream of children") { val df = Union(Stream( Range(1, 1, 1, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 20fed07d3872..35c33a7157d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -574,22 +574,17 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA withSparkSession(test, 4, None) } - test("Union two datasets with different pre-shuffle partition number") { + test("Do not reduce the number of shuffle partition for repartition") { val test: SparkSession => Unit = { spark: SparkSession => - val dataset1 = spark.range(3) - val dataset2 = spark.range(3) - - val resultDf = dataset1.repartition(2, dataset1.col("id")) - .union(dataset2.repartition(3, dataset2.col("id"))).toDF() + val ds = spark.range(3) + val resultDf = ds.repartition(2, ds.col("id")).toDF() checkAnswer(resultDf, - Seq((0), (0), (1), (1), (2), (2)).map(i => Row(i))) + Seq(0, 1, 2).map(i => Row(i))) val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - // As the pre-shuffle partition number are different, we will skip reducing - // the shuffle partition numbers. assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 0) } - withSparkSession(test, 100, None) + withSparkSession(test, 200, None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8269d4d3a285..64e305cd5c37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.internal.SQLConf.MAX_NESTED_VIEW_DEPTH import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} class SimpleSQLViewSuite extends SQLViewSuite with SharedSQLContext @@ -665,17 +666,17 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql(s"CREATE VIEW view${idx + 1} AS SELECT * FROM view$idx") } - withSQLConf("spark.sql.view.maxNestedViewDepth" -> "10") { + withSQLConf(MAX_NESTED_VIEW_DEPTH.key -> "10") { val e = intercept[AnalysisException] { sql("SELECT * FROM view10") }.getMessage assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + - "of spark.sql.view.maxNestedViewDepth to work around this.")) + s"of ${MAX_NESTED_VIEW_DEPTH.key} to work around this.")) } val e = intercept[IllegalArgumentException] { - withSQLConf("spark.sql.view.maxNestedViewDepth" -> "0") {} + withSQLConf(MAX_NESTED_VIEW_DEPTH.key -> "0") {} }.getMessage assert(e.contains("The maximum depth of a view reference in a nested view must be " + "positive.")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index 1c6fc3530cbe..971fd842f046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.TestUtils.assertSpilled import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf.{WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD, WINDOW_EXEC_BUFFER_SPILL_THRESHOLD} import org.apache.spark.sql.test.SharedSQLContext case class WindowData(month: Int, area: String, product: Int) @@ -477,8 +478,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) """.stripMargin) - withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1", - "spark.sql.windowExec.buffer.spill.threshold" -> "2") { + withSQLConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { assertSpilled(sparkContext, "test with low buffer spill threshold") { checkAnswer(actual, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9462ee190a31..483a04610338 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.metrics.source.CodegenMetrics -import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -121,29 +119,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } - test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { - import testImplicits._ - - val dsInt = spark.range(3).cache() - dsInt.count() - val dsIntFilter = dsInt.filter(_ > 0) - val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.collect { - case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () - }.length == 1) - assert(dsIntFilter.collect() === Array(1, 2)) - - // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache() - dsString.count() - val dsStringFilter = dsString.filter(_ == "1") - val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.collect { - case i: InMemoryTableScanExec if !i.supportsBatch => () - }.length == 1) - assert(dsStringFilter.collect() === Array("1")) - } - test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) @@ -168,10 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { .select("int") val plan = df.queryExecution.executedPlan - assert(!plan.find(p => + assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.children(0) - .isInstanceOf[SortMergeJoinExec]).isDefined) + .isInstanceOf[SortMergeJoinExec]).isEmpty) assert(df.collect() === Array(Row(1), Row(2))) } } @@ -204,6 +179,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } + def genCode(ds: Dataset[_]): Seq[CodeAndComment] = { + val plan = ds.queryExecution.executedPlan + val wholeStageCodeGenExecs = plan.collect { case p: WholeStageCodegenExec => p } + assert(wholeStageCodeGenExecs.nonEmpty, "WholeStageCodegenExec is expected") + wholeStageCodeGenExecs.map(_.doCodeGen()._2) + } + ignore("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) @@ -213,25 +195,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } - ignore("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { - import testImplicits._ - withTempPath { dir => - val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) - df.write.mode(SaveMode.Overwrite).parquet(path) - - withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", - SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "2000") { - // wide table batch scan causes the byte code of codegen exceeds the limit of - // WHOLESTAGE_HUGE_METHOD_LIMIT - val df2 = spark.read.parquet(path) - val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get - assert(fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch) - checkAnswer(df2, df) - } - } - } - test("Control splitting consume function by operators with config") { import testImplicits._ val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) @@ -283,9 +246,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val df = spark.range(100) val join = df.join(df, "id") val plan = join.queryExecution.executedPlan - assert(!plan.find(p => + assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined, + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty, "codegen stage IDs should be preserved through ReuseExchange") checkAnswer(join, df.toDF) } @@ -295,18 +258,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { import testImplicits._ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { - val bytecodeSizeHisto = CodegenMetrics.METRIC_COMPILATION_TIME - - // the same query run twice should hit the codegen cache - spark.range(3).select('id + 2).collect - val after1 = bytecodeSizeHisto.getCount - spark.range(3).select('id + 2).collect - val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately - // bytecodeSizeHisto's count is always monotonically increasing if new compilation to - // bytecode had occurred. If the count stayed the same that means we've got a cache hit. - assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected") - - // a different query can result in codegen cache miss, that's by design + // the same query run twice should produce identical code, which would imply a hit in + // the generated code cache. + val ds1 = spark.range(3).select('id + 2) + val code1 = genCode(ds1) + val ds2 = spark.range(3).select('id + 2) + val code2 = genCode(ds2) // same query shape as above, deliberately + assert(code1 == code2, "Should produce same code") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 86874b9817c2..67c3fa0d3bf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1191,7 +1191,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("max records in batch conf") { val totalRecords = 10 val maxRecordsPerBatch = 3 - spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + spark.conf.set(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key, maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") val arrowBatches = df.toArrowBatchRdd.collect() assert(arrowBatches.length >= 4) @@ -1206,7 +1206,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } assert(recordCount == totalRecords) allocator.close() - spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + spark.conf.unset(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key) } testQuietly("unsupported types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 81158d9e5424..2776bc310fef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -83,7 +83,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { f() } } @@ -92,7 +92,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "true") { f() } } @@ -119,7 +119,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { f() } } @@ -128,7 +128,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "true") { f() } } @@ -154,7 +154,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { f() } } @@ -163,7 +163,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "true") { f() } } @@ -189,7 +189,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { f() } } @@ -198,7 +198,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "true") { f() } } @@ -234,7 +234,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { f() } } @@ -243,7 +243,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", - "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "true") { f() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala index cd97324c997f..6925bdd72674 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.benchmark +import org.apache.spark.sql.internal.SQLConf + /** * Benchmark to measure built-in data sources write performance. * To run this benchmark: @@ -45,8 +47,8 @@ object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark { mainArgs } - spark.conf.set("spark.sql.parquet.compression.codec", "snappy") - spark.conf.set("spark.sql.orc.compression.codec", "snappy") + spark.conf.set(SQLConf.PARQUET_COMPRESSION.key, "snappy") + spark.conf.set(SQLConf.ORC_COMPRESSION.key, "snappy") formats.foreach { format => runBenchmark(s"$format writer benchmark") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index d31e49cf8cd4..711ecf1738ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, FilterExec, InputAdapter, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -437,7 +437,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") { - withSQLConf("spark.sql.shuffle.partitions" -> "200") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "200") { val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group") val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id") val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct() @@ -486,15 +486,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val df2 = df1.where("y = 3") val planBeforeFilter = df2.queryExecution.executedPlan.collect { - case f: FilterExec => f.child + case FilterExec(_, c: ColumnarToRowExec) => c.child + case WholeStageCodegenExec(FilterExec(_, ColumnarToRowExec(i: InputAdapter))) => i.child } assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) - val execPlan = if (codegenEnabled == "true") { - WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) - } else { - planBeforeFilter.head - } + val execPlan = planBeforeFilter.head assert(execPlan.executeCollectPublic().length == 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index b3a5c687f775..e74099202a1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -50,7 +50,7 @@ class PartitionBatchPruningSuite // Enable in-memory partition pruning spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, true) // Enable in-memory table scan accumulators - spark.conf.set("spark.sql.inMemoryTableScanStatistics.enable", "true") + spark.conf.set(SQLConf.IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED.key, "true") } override protected def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 7df0dabd67f8..ce209666024d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution} import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.internal.SQLConf.DEFAULT_V2_CATALOG import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -77,7 +78,7 @@ class PlanResolutionSuite extends AnalysisTest { def parseAndResolve(query: String, withDefault: Boolean = false): LogicalPlan = { val newConf = conf.copy() - newConf.setConfString("spark.sql.default.catalog", "testcat") + newConf.setConfString(DEFAULT_V2_CATALOG.key, "testcat") DataSourceResolution(newConf, if (withDefault) lookupWithDefault else lookupWithoutDefault) .apply(parsePlan(query)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index af524c7ca025..eaff5a2352a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -201,7 +201,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } test("partitioned table - case insensitive") { - withSQLConf("spark.sql.caseSensitive" -> "false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val table = createTable( files = Seq( @@ -437,7 +437,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } test("[SPARK-16818] exchange reuse respects differences in partition pruning") { - spark.conf.set("spark.sql.exchange.reuse", true) + spark.conf.set(SQLConf.EXCHANGE_REUSE_ENABLED.key, true) withTempPath { path => val tempDir = path.getCanonicalPath spark.range(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2e7d682a3bbc..fdb50a6dd929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1399,8 +1399,8 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te // that whole test file is mapped to only one partition. This will guarantee // reliable sampling of the input file. withSQLConf( - "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, - "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + SQLConf.FILES_MAX_PARTITION_BYTES.key -> (128 * 1024 * 1024).toString, + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> (4 * 1024 * 1024).toString )(withTempPath { path => val ds = sampledTestData.coalesce(1) ds.write.text(path.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6316e89537ca..34b44be57689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2041,8 +2041,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // that whole test file is mapped to only one partition. This will guarantee // reliable sampling of the input file. withSQLConf( - "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, - "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + SQLConf.FILES_MAX_PARTITION_BYTES.key -> (128 * 1024 * 1024).toString, + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> (4 * 1024 * 1024).toString )(withTempPath { path => val ds = sampledTestData.coalesce(1) ds.write.text(path.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c5d12d618e05..577d1bc8d6a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1208,6 +1208,14 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } + // SPARK-28371: make sure filter is null-safe. + withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df => + checkFilterPredicate( + '_1.startsWith("blah").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq.empty[Row]) + } + import testImplicits._ // Test canDrop() has taken effect testStringStartsWith(spark.range(1024).map(_.toString).toDF(), "value like 'a%'") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 6b05b9c0f720..6f2218ba82dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -475,7 +475,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { val extraOptions = Map( SQLConf.OUTPUT_COMMITTER_CLASS.key -> classOf[ParquetOutputCommitter].getCanonicalName, - "spark.sql.parquet.output.committer.class" -> + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName ) withTempPath { dir => @@ -505,7 +505,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Using a output committer that always fail when committing a task, so that both // `commitTask()` and `abortTask()` are invoked. val extraOptions = Map[String, String]( - "spark.sql.parquet.output.committer.class" -> + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 6f3ed3d85e93..04ace0a236e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -953,7 +953,7 @@ abstract class ParquetPartitionDiscoverySuite withSQLConf( ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", - "spark.sql.sources.commitProtocolClass" -> + SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { spark.range(3).write.parquet(s"$path/p0=0/p1=0") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 7aa0ba7f4e0c..a6429bfc5292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -924,14 +924,14 @@ class ParquetV1QuerySuite extends ParquetQuerySuite { // donot return batch, because whole stage codegen is disabled for wide table (>200 columns) val df2 = spark.read.parquet(path) val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get - assert(!fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch) + assert(!fileScan2.asInstanceOf[FileSourceScanExec].supportsColumnar) checkAnswer(df2, df) // return batch val columns = Seq.tabulate(9) {i => s"c$i"} val df3 = df2.selectExpr(columns : _*) val fileScan3 = df3.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get - assert(fileScan3.asInstanceOf[FileSourceScanExec].supportsBatch) + assert(fileScan3.asInstanceOf[FileSourceScanExec].supportsColumnar) checkAnswer(df3, df.selectExpr(columns : _*)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a8d230870aeb..dc4a2998a908 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -277,9 +277,9 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } test("ShuffledHashJoin metrics") { - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40", - "spark.sql.shuffle.partitions" -> "2", - "spark.sql.join.preferSortMergeJoin" -> "false") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40", + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") // Assume the execution plan is @@ -584,19 +584,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared sql("CREATE TEMPORARY VIEW inMemoryTable AS SELECT 1 AS c1") sql("CACHE TABLE inMemoryTable") testSparkPlanMetrics(spark.table("inMemoryTable"), 1, - Map(0L -> (("Scan In-memory table `inMemoryTable`", Map.empty))) + Map(1L -> (("Scan In-memory table `inMemoryTable`", Map.empty))) ) sql("CREATE TEMPORARY VIEW ```a``b``` AS SELECT 2 AS c1") sql("CACHE TABLE ```a``b```") testSparkPlanMetrics(spark.table("```a``b```"), 1, - Map(0L -> (("Scan In-memory table ```a``b```", Map.empty))) + Map(1L -> (("Scan In-memory table ```a``b```", Map.empty))) ) } // Show InMemoryTableScan on UI testSparkPlanMetrics(spark.range(1).cache().select("id"), 1, - Map(0L -> (("InMemoryTableScan", Map.empty))) + Map(1L -> (("InMemoryTableScan", Map.empty))) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index f12eeaa58064..8f26c04307ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} +import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED import org.apache.spark.sql.test.SQLTestUtils @@ -154,7 +155,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { expectedNodeIds: Set[Long], enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { val previousExecutionIds = currentExecutionIds() - withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { + withSQLConf(WHOLESTAGE_CODEGEN_ENABLED.key -> enableWholeStage.toString) { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 723764c77727..c0fd3fe3ef7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -24,7 +24,6 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.sql.streaming.util.StreamManualClock class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { @@ -35,7 +34,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { val timeout = 10.seconds test("nextBatchTime") { - val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTimeTrigger(100)) assert(processingTimeExecutor.nextBatchTime(0) === 100) assert(processingTimeExecutor.nextBatchTime(1) === 100) assert(processingTimeExecutor.nextBatchTime(99) === 100) @@ -49,7 +48,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { val clock = new StreamManualClock() @volatile var continueExecuting = true @volatile var clockIncrementInTrigger = 0L - val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), clock) + val executor = ProcessingTimeExecutor(ProcessingTimeTrigger("1000 milliseconds"), clock) val executorThread = new Thread() { override def run(): Unit = { executor.execute(() => { @@ -97,7 +96,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { test("calling nextBatchTime with the result of a previous call should return the next interval") { val intervalMS = 100 - val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS)) + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTimeTrigger(intervalMS)) val ITERATION = 10 var nextBatchTime: Long = 0 @@ -111,7 +110,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { private def testBatchTermination(intervalMs: Long): Unit = { var batchCounts = 0 - val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTimeTrigger(intervalMs)) processingTimeExecutor.execute(() => { batchCounts += 1 // If the batch termination works correctly, batchCounts should be 3 after `execute` @@ -130,7 +129,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { @volatile var batchFallingBehindCalled = false val t = new Thread() { override def run(): Unit = { - val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { + val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTimeTrigger(100), clock) { override def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { batchFallingBehindCalled = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 2a1e7d615e5e..7bca225dfdd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -124,7 +125,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator implicit val sqlContext = spark.sqlContext - spark.conf.set("spark.sql.shuffle.partitions", "1") + spark.conf.set(SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af4369de800e..a84d107f2cbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -569,7 +569,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] val spark = SparkSession.builder().master("local[2]").getOrCreate() SparkSession.setActiveSession(spark) implicit val sqlContext = spark.sqlContext - spark.conf.set("spark.sql.shuffle.partitions", "1") + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") import spark.implicits._ val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index e3e5ddff9637..8edbb8770671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -647,7 +647,7 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { .setMaster("local") .setAppName("test") .set(config.TASK_MAX_FAILURES, 1) // Don't retry the tasks to run this test quickly - .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + .set(UI_RETAINED_EXECUTIONS.key, "50") // Set it to 50 to run this test quickly .set(ASYNC_TRACKING_ENABLED, false) withSpark(new SparkContext(conf)) { sc => quietly { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5f27e75addcf..89eaac8e5927 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -857,10 +857,7 @@ class JDBCSuite extends QueryTest Some(ArrayType(DecimalType.SYSTEM_DEFAULT))) assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") assert(Postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") - val errMsg = intercept[IllegalArgumentException] { - Postgres.getJDBCType(ByteType) - } - assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType") + assert(Postgres.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") } test("DerbyDialect jdbc type mapping") { @@ -895,6 +892,17 @@ class JDBCSuite extends QueryTest "BIT") assert(msSqlServerDialect.getJDBCType(BinaryType).map(_.databaseTypeDefinition).get == "VARBINARY(MAX)") + assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == + "SMALLINT") + } + + test("SPARK-28152 MsSqlServerDialect catalyst type mapping") { + val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver") + val metadata = new MetadataBuilder().putLong("scale", 1) + assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1, + metadata).get == ShortType) + assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, + metadata).get == FloatType) } test("table exists query by jdbc dialect") { @@ -1322,7 +1330,7 @@ class JDBCSuite extends QueryTest testJdbcParitionColumn("THEID", "THEID") testJdbcParitionColumn("\"THEID\"", "THEID") - withSQLConf("spark.sql.caseSensitive" -> "false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { testJdbcParitionColumn("ThEiD", "THEID") } testJdbcParitionColumn("THE ID", "THE ID") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index fc61050dc745..75f68dea96bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -63,7 +63,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val maxNrBuckets: Int = 200000 val catalog = spark.sessionState.catalog - withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + withSQLConf(SQLConf.BUCKETING_MAX_BUCKETS.key -> maxNrBuckets.toString) { // within the new limit Seq(100001, maxNrBuckets).foreach(numBuckets => { withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index d46029e84433..5f9856656ac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.internal.SQLConf.BUCKETING_MAX_BUCKETS import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -252,7 +253,7 @@ class CreateTableAsSelectSuite val maxNrBuckets: Int = 200000 val catalog = spark.sessionState.catalog - withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + withSQLConf(BUCKETING_MAX_BUCKETS.key -> maxNrBuckets.toString) { // Within the new limit Seq(100001, maxNrBuckets).foreach(numBuckets => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 01752125ac26..c90090aca3d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -21,13 +21,15 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.QueryTest +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, StringType, StructField, StructType, TimestampType} class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { @@ -38,7 +40,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn before { spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) spark.conf.set("spark.sql.catalog.testcat2", classOf[TestInMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.session", classOf[TestInMemoryTableCatalog].getName) + spark.conf.set(V2_SESSION_CATALOG.key, classOf[TestInMemoryTableCatalog].getName) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") @@ -280,7 +282,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn test("CreateTableAsSelect: v2 session catalog can load v1 source table") { val sparkSession = spark.newSession() - sparkSession.conf.set("spark.sql.catalog.session", classOf[V2SessionCatalog].getName) + sparkSession.conf.set(V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName) val df = sparkSession.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") @@ -366,4 +368,834 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn df_joined) } } + + test("AlterTable: table does not exist") { + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE testcat.ns1.table_name DROP COLUMN id") + } + + assert(exc.getMessage.contains("testcat.ns1.table_name")) + assert(exc.getMessage.contains("Table or view not found")) + } + + test("AlterTable: change rejected by implementation") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[SparkException] { + sql(s"ALTER TABLE $t DROP COLUMN id") + } + + assert(exc.getMessage.contains("Unsupported table change")) + assert(exc.getMessage.contains("Cannot drop all fields")) // from the implementation + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType().add("id", IntegerType)) + } + } + + test("AlterTable: add top-level column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN data string") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType().add("id", IntegerType).add("data", StringType)) + } + } + + test("AlterTable: add column with comment") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN data string COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType).withComment("doc")))) + } + } + + test("AlterTable: add multiple columns") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ADD COLUMNS data string COMMENT 'doc', ts timestamp") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType).withComment("doc"), + StructField("ts", TimestampType)))) + } + } + + test("AlterTable: add nested column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN point.z double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType), + StructField("z", DoubleType))))) + } + } + + test("AlterTable: add nested column to map key") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map, bigint>) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN points.key.z double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType), + StructField("z", DoubleType))), LongType))) + } + } + + test("AlterTable: add nested column to map value") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map>) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN points.value.z double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StringType, StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType), + StructField("z", DoubleType)))))) + } + } + + test("AlterTable: add nested column to array element") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array>) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN points.element.z double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType), + StructField("z", DoubleType)))))) + } + } + + test("AlterTable: add complex column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN points array>") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType)))))) + } + } + + test("AlterTable: add nested column with comment") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array>) USING foo") + sql(s"ALTER TABLE $t ADD COLUMN points.element.z double COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType), + StructField("z", DoubleType).withComment("doc")))))) + } + } + + test("AlterTable: add nested column parent must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ADD COLUMN point.z double") + } + + assert(exc.getMessage.contains("point")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: update column type int -> long") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN id TYPE bigint") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType().add("id", LongType)) + } + } + + test("AlterTable: update nested type float -> double") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN point.x TYPE double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))))) + } + } + + test("AlterTable: update column with struct type fails") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN point TYPE struct") + } + + assert(exc.getMessage.contains("point")) + assert(exc.getMessage.contains("update a struct by adding, deleting, or updating its fields")) + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))))) + } + } + + test("AlterTable: update column with array type fails") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN points TYPE array") + } + + assert(exc.getMessage.contains("update the element by updating points.element")) + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(IntegerType))) + } + } + + test("AlterTable: update column array element type") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.element TYPE long") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(LongType))) + } + } + + test("AlterTable: update column with map type fails") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, m map) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN m TYPE map") + } + + assert(exc.getMessage.contains("update a map by updating m.key or m.value")) + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("m", MapType(StringType, IntegerType))) + } + } + + test("AlterTable: update column map value type") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, m map) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN m.value TYPE long") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("m", MapType(StringType, LongType))) + } + } + + test("AlterTable: update nested type in map key") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map, bigint>) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.key.x TYPE double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))), LongType))) + } + } + + test("AlterTable: update nested type in map value") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map>) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.value.x TYPE double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StringType, StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType)))))) + } + } + + test("AlterTable: update nested type in array") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array>) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.element.x TYPE double") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType)))))) + } + } + + test("AlterTable: update column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN data TYPE string") + } + + assert(exc.getMessage.contains("data")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: nested update column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN point.x TYPE double") + } + + assert(exc.getMessage.contains("point.x")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: update column type must be compatible") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN id TYPE boolean") + } + + assert(exc.getMessage.contains("id")) + assert(exc.getMessage.contains("int cannot be cast to boolean")) + } + } + + test("AlterTable: update column comment") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN id COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == StructType(Seq(StructField("id", IntegerType).withComment("doc")))) + } + } + + test("AlterTable: update column type and comment") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN id TYPE bigint COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == StructType(Seq(StructField("id", LongType).withComment("doc")))) + } + } + + test("AlterTable: update nested column comment") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN point.y COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType).withComment("doc"))))) + } + } + + test("AlterTable: update nested column comment in map key") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map, bigint>) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.key.y COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType).withComment("doc"))), LongType))) + } + } + + test("AlterTable: update nested column comment in map value") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map>) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.value.y COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StringType, StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType).withComment("doc")))))) + } + } + + test("AlterTable: update nested column comment in array") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array>) USING foo") + sql(s"ALTER TABLE $t ALTER COLUMN points.element.y COMMENT 'doc'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType).withComment("doc")))))) + } + } + + test("AlterTable: comment update column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN data COMMENT 'doc'") + } + + assert(exc.getMessage.contains("data")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: nested comment update column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t ALTER COLUMN point.x COMMENT 'doc'") + } + + assert(exc.getMessage.contains("point.x")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: rename column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t RENAME COLUMN id TO user_id") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType().add("user_id", IntegerType)) + } + } + + test("AlterTable: rename nested column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING foo") + sql(s"ALTER TABLE $t RENAME COLUMN point.y TO t") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", StructType(Seq( + StructField("x", DoubleType), + StructField("t", DoubleType))))) + } + } + + test("AlterTable: rename nested column in map key") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point map, bigint>) USING foo") + sql(s"ALTER TABLE $t RENAME COLUMN point.key.y TO t") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", MapType(StructType(Seq( + StructField("x", DoubleType), + StructField("t", DoubleType))), LongType))) + } + } + + test("AlterTable: rename nested column in map value") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map>) USING foo") + sql(s"ALTER TABLE $t RENAME COLUMN points.value.y TO t") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StringType, StructType(Seq( + StructField("x", DoubleType), + StructField("t", DoubleType)))))) + } + } + + test("AlterTable: rename nested column in array element") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array>) USING foo") + sql(s"ALTER TABLE $t RENAME COLUMN points.element.y TO t") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType), + StructField("t", DoubleType)))))) + } + } + + test("AlterTable: rename column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t RENAME COLUMN data TO some_string") + } + + assert(exc.getMessage.contains("data")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: nested rename column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t RENAME COLUMN point.x TO z") + } + + assert(exc.getMessage.contains("point.x")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: drop column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"ALTER TABLE $t DROP COLUMN data") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType().add("id", IntegerType)) + } + } + + test("AlterTable: drop nested column") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING foo") + sql(s"ALTER TABLE $t DROP COLUMN point.t") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))))) + } + } + + test("AlterTable: drop nested column in map key") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point map, bigint>) USING foo") + sql(s"ALTER TABLE $t DROP COLUMN point.key.y") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("point", MapType(StructType(Seq( + StructField("x", DoubleType))), LongType))) + } + } + + test("AlterTable: drop nested column in map value") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points map>) USING foo") + sql(s"ALTER TABLE $t DROP COLUMN points.value.y") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", MapType(StringType, StructType(Seq( + StructField("x", DoubleType)))))) + } + } + + test("AlterTable: drop nested column in array element") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, points array>) USING foo") + sql(s"ALTER TABLE $t DROP COLUMN points.element.y") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.schema == new StructType() + .add("id", IntegerType) + .add("points", ArrayType(StructType(Seq( + StructField("x", DoubleType)))))) + } + } + + test("AlterTable: drop column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t DROP COLUMN data") + } + + assert(exc.getMessage.contains("data")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: nested drop column must exist") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + + val exc = intercept[AnalysisException] { + sql(s"ALTER TABLE $t DROP COLUMN point.x") + } + + assert(exc.getMessage.contains("point.x")) + assert(exc.getMessage.contains("missing field")) + } + } + + test("AlterTable: set location") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t SET LOCATION 's3://bucket/path'") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.properties == Map("provider" -> "foo", "location" -> "s3://bucket/path").asJava) + } + } + + test("AlterTable: set table property") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo") + sql(s"ALTER TABLE $t SET TBLPROPERTIES ('test'='34')") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.properties == Map("provider" -> "foo", "test" -> "34").asJava) + } + } + + test("AlterTable: remove table property") { + val t = "testcat.ns1.table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING foo TBLPROPERTIES('test' = '34')") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(table.name == "testcat.ns1.table_name") + assert(table.properties == Map("provider" -> "foo", "test" -> "34").asJava) + + sql(s"ALTER TABLE $t UNSET TBLPROPERTIES ('test')") + + val updated = testCatalog.loadTable(Identifier.of(Array("ns1"), "table_name")) + + assert(updated.name == "testcat.ns1.table_name") + assert(updated.properties == Map("provider" -> "foo").asJava) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 4e9f961016de..380df7a36596 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -88,6 +88,12 @@ class TestInMemoryTableCatalog extends TableCatalog { case Some(table) => val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) val schema = CatalogV2Util.applySchemaChanges(table.schema, changes) + + // fail if the last column in the schema was dropped + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + val newTable = new InMemoryTable(table.name, schema, properties, table.data) tables.put(ident, newTable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 4bf49ff4d5c6..92ec2a0c172e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -305,7 +305,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche test("update mode") { val inputData = MemoryStream[Int] - spark.conf.set("spark.sql.shuffle.partitions", "10") + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "10") val windowedAggregation = inputData.toDF() .withColumn("eventTime", $"value".cast("timestamp")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 2b8d77386925..72f893845172 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1310,7 +1310,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val start = startId.map(new FileStreamSourceOffset(_)) val end = FileStreamSourceOffset(endId) - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f2f5fad59eb2..1ed2599444c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -871,7 +871,7 @@ class StreamSuite extends StreamTest { testQuietly("specify custom state store provider") { val providerClassName = classOf[TestStateStoreProvider].getCanonicalName - withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName) { val input = MemoryStream[Int] val df = input.toDS().groupBy().count() val query = df.writeStream.outputMode("complete").format("memory").queryName("name").start() @@ -888,9 +888,9 @@ class StreamSuite extends StreamTest { testQuietly("custom state store provider read from offset log") { val input = MemoryStream[Int] val df = input.toDS().groupBy().count() - val providerConf1 = "spark.sql.streaming.stateStore.providerClass" -> + val providerConf1 = SQLConf.STATE_STORE_PROVIDER_CLASS.key -> "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider" - val providerConf2 = "spark.sql.streaming.stateStore.providerClass" -> + val providerConf2 = SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[TestStateStoreProvider].getCanonicalName def runQuery(queryName: String, checkpointLoc: String): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala index 88f510c726fa..da2f221aaf10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkConf import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -29,7 +30,7 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { import testImplicits._ override protected def sparkConf: SparkConf = - super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + super.sparkConf.set(STREAMING_QUERY_LISTENERS.key, "org.apache.spark.sql.streaming.TestListener") test("test if the configured query lister is loaded") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index a5cb25c49b86..e6b56e5f46f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -413,9 +413,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi sources.nonEmpty } // Disabled by default - assert(spark.conf.get("spark.sql.streaming.metricsEnabled").toBoolean === false) + assert(spark.conf.get(SQLConf.STREAMING_METRICS_ENABLED.key).toBoolean === false) - withSQLConf("spark.sql.streaming.metricsEnabled" -> "false") { + withSQLConf(SQLConf.STREAMING_METRICS_ENABLED.key -> "false") { testStream(inputData.toDF)( AssertOnQuery { q => !isMetricsRegistered(q) }, StopStream, @@ -424,7 +424,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // Registered when enabled - withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + withSQLConf(SQLConf.STREAMING_METRICS_ENABLED.key -> "true") { testStream(inputData.toDF)( AssertOnQuery { q => isMetricsRegistered(q) }, StopStream, @@ -434,7 +434,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } test("SPARK-22975: MetricsReporter defaults when there was no progress reported") { - withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + withSQLConf(SQLConf.STREAMING_METRICS_ENABLED.key -> "true") { BlockingSource.latch = new CountDownLatch(1) withTempDir { tempDir => val sq = spark.readStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index c5b95fa9b64a..3ec4750c59fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming.continuous import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED import org.apache.spark.sql.streaming.OutputMode class ContinuousAggregationSuite extends ContinuousSuiteBase { @@ -36,7 +37,7 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } test("basic") { - withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + withSQLConf((UNSUPPORTED_OPERATION_CHECK_ENABLED.key, "false")) { val input = ContinuousMemoryStream.singlePartition[Int] testStream(input.toDF().agg(max('value)), OutputMode.Complete)( @@ -112,7 +113,7 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } test("repeated restart") { - withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + withSQLConf((UNSUPPORTED_OPERATION_CHECK_ENABLED.key, "false")) { val input = ContinuousMemoryStream.singlePartition[Int] testStream(input.toDF().agg(max('value)), OutputMode.Complete)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9840c7f06678..c6921010a002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE +import org.apache.spark.sql.internal.SQLConf.{CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, MIN_BATCHES_TO_RETAIN} import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -37,18 +37,43 @@ class ContinuousSuiteBase extends StreamTest { "continuous-stream-test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) - protected def waitForRateSourceTriggers(query: StreamExecution, numTriggers: Int): Unit = { - query match { - case s: ContinuousExecution => - assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") - val reader = s.lastExecution.executedPlan.collectFirst { - case ContinuousScanExec(_, _, r: RateStreamContinuousStream, _) => r - }.get - - val deltaMs = numTriggers * 1000 + 300 - while (System.currentTimeMillis < reader.creationTime + deltaMs) { - Thread.sleep(reader.creationTime + deltaMs - System.currentTimeMillis) + protected def waitForRateSourceTriggers(query: ContinuousExecution, numTriggers: Int): Unit = { + query.awaitEpoch(0) + + // This is called after waiting first epoch to be committed, so we can just treat + // it as partition readers for rate source are already initialized. + val firstCommittedTime = System.nanoTime() + val deltaNs = (numTriggers * 1000 + 300) * 1000000L + var toWaitNs = firstCommittedTime + deltaNs - System.nanoTime() + while (toWaitNs > 0) { + Thread.sleep(toWaitNs / 1000000) + toWaitNs = firstCommittedTime + deltaNs - System.nanoTime() + } + } + + protected def waitForRateSourceCommittedValue( + query: ContinuousExecution, + desiredValue: Long, + maxWaitTimeMs: Long): Unit = { + def readHighestCommittedValue(c: ContinuousExecution): Option[Long] = { + c.committedOffsets.lastOption.map { case (_, offset) => + offset match { + case o: RateStreamOffset => + o.partitionToValueAndRunTimeMs.map { + case (_, ValueRunTimeMsPair(value, _)) => value + }.max } + } + } + + val maxWait = System.currentTimeMillis() + maxWaitTimeMs + while (System.currentTimeMillis() < maxWait && + readHighestCommittedValue(query).getOrElse(Long.MinValue) < desiredValue) { + Thread.sleep(100) + } + if (System.currentTimeMillis() > maxWait) { + logWarning(s"Couldn't reach desired value in $maxWaitTimeMs milliseconds!" + + s"Current highest committed value is ${readHighestCommittedValue(query)}") } } @@ -216,14 +241,16 @@ class ContinuousSuite extends ContinuousSuiteBase { .queryName("noharness") .trigger(Trigger.Continuous(100)) .start() + + val expected = Set(0, 1, 2, 3) val continuousExecution = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution] - continuousExecution.awaitEpoch(0) - waitForRateSourceTriggers(continuousExecution, 2) + waitForRateSourceCommittedValue(continuousExecution, expected.max, 20 * 1000) query.stop() val results = spark.read.table("noharness").collect() - assert(Set(0, 1, 2, 3).map(Row(_)).subsetOf(results.toSet)) + assert(expected.map(Row(_)).subsetOf(results.toSet), + s"Result set ${results.toSet} are not a superset of $expected!") } } @@ -241,7 +268,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { testStream(df)( StartStream(longContinuousTrigger), AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 10)), + Execute { exec => + waitForRateSourceTriggers(exec.asInstanceOf[ContinuousExecution], 5) + }, IncrementEpoch(), StopStream, CheckAnswerRowsContains(scala.Range(0, 2500).map(Row(_))) @@ -259,7 +288,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { testStream(df)( StartStream(Trigger.Continuous(2012)), AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 10)), + Execute { exec => + waitForRateSourceTriggers(exec.asInstanceOf[ContinuousExecution], 5) + }, IncrementEpoch(), StopStream, CheckAnswerRowsContains(scala.Range(0, 2500).map(Row(_)))) @@ -307,7 +338,7 @@ class ContinuousMetaSuite extends ContinuousSuiteBase { "local[10]", "continuous-stream-test-sql-context", sparkConf.set("spark.sql.testkey", "true") - .set("spark.sql.streaming.minBatchesToRetain", "2"))) + .set(MIN_BATCHES_TO_RETAIN.key, "2"))) test("SPARK-24351: check offsetLog/commitLog retained in the checkpoint directory") { withTempDir { checkpointDir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 7b2c1a56e8ba..4db605ee1b23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -24,8 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.execution.streaming.{ContinuousTrigger, RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ @@ -242,7 +241,7 @@ class StreamingDataSourceV2Suite extends StreamTest { override def beforeAll(): Unit = { super.beforeAll() val fakeCheckpoint = Utils.createTempDir() - spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) + spark.conf.set(SQLConf.CHECKPOINT_LOCATION.key, fakeCheckpoint.getCanonicalPath) } override def afterEach(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 8fb1400a9b5a..c630f1497a17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -203,7 +203,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .stop() assert(LastOptions.partitionColumns == Seq("a")) - withSQLConf("spark.sql.caseSensitive" -> "false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e9ab62800f84..126e23e6e592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -409,7 +409,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be test("write path implements onTaskCommit API correctly") { withSQLConf( - "spark.sql.sources.commitProtocolClass" -> + SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[MessageCapturingCommitProtocol].getCanonicalName) { withTempDir { dir => val path = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index da0e5535df5e..115536da8949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -392,7 +392,7 @@ private[sql] trait SQLTestUtilsBase */ protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema - val withoutFilters = df.queryExecution.sparkPlan.transform { + val withoutFilters = df.queryExecution.executedPlan.transform { case FilterExec(_, child) => child } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index d1de9f037992..b4d1d0d58aad 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -137,7 +137,7 @@ object HiveThriftServer2 extends Logging { } private[thriftserver] object ExecutionState extends Enumeration { - val STARTED, COMPILED, FAILED, FINISHED = Value + val STARTED, COMPILED, FAILED, FINISHED, CLOSED = Value type ExecutionState = Value } @@ -147,16 +147,17 @@ object HiveThriftServer2 extends Logging { val startTimestamp: Long, val userName: String) { var finishTimestamp: Long = 0L + var closeTimestamp: Long = 0L var executePlan: String = "" var detail: String = "" var state: ExecutionState.Value = ExecutionState.STARTED val jobId: ArrayBuffer[String] = ArrayBuffer[String]() var groupId: String = "" - def totalTime: Long = { - if (finishTimestamp == 0L) { + def totalTime(endTime: Long): Long = { + if (endTime == 0L) { System.currentTimeMillis - startTimestamp } else { - finishTimestamp - startTimestamp + endTime - startTimestamp } } } @@ -254,6 +255,11 @@ object HiveThriftServer2 extends Logging { trimExecutionIfNecessary() } + def onOperationClosed(id: String): Unit = synchronized { + executionList(id).closeTimestamp = System.currentTimeMillis + executionList(id).state = ExecutionState.CLOSED + } + private def trimExecutionIfNecessary() = { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 820f76db6db3..2f011c25fe2c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -70,11 +70,12 @@ private[hive] class SparkExecuteStatementOperation( } } - def close(): Unit = { + override def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. logDebug(s"CLOSING $statementId") cleanup(OperationState.CLOSED) sqlContext.sparkContext.clearJobGroup() + HiveThriftServer2.listener.onOperationClosed(statementId) } def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala index 99ba968e1ae8..89faff2f6f91 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala @@ -58,8 +58,15 @@ private[hive] class SparkGetColumnsOperation( val catalog: SessionCatalog = sqlContext.sessionState.catalog + private var statementId: String = _ + + override def close(): Unit = { + super.close() + HiveThriftServer2.listener.onOperationClosed(statementId) + } + override def runInternal(): Unit = { - val statementId = UUID.randomUUID().toString + statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName, tablePattern : $tableName" val logMsg = s"Listing columns '$cmdStr, columnName : $columnName'" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala index 3ecbbd036c87..87ef154bcc8a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -45,8 +45,15 @@ private[hive] class SparkGetSchemasOperation( schemaName: String) extends GetSchemasOperation(parentSession, catalogName, schemaName) with Logging { + private var statementId: String = _ + + override def close(): Unit = { + super.close() + HiveThriftServer2.listener.onOperationClosed(statementId) + } + override def runInternal(): Unit = { - val statementId = UUID.randomUUID().toString + statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val logMsg = s"Listing databases '$cmdStr'" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala index 878683692fb6..952de42083c4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -55,8 +55,15 @@ private[hive] class SparkGetTablesOperation( extends GetTablesOperation(parentSession, catalogName, schemaName, tableName, tableTypes) with Logging{ + private var statementId: String = _ + + override def close(): Unit = { + super.close() + HiveThriftServer2.listener.onOperationClosed(statementId) + } + override def runInternal(): Unit = { - val statementId = UUID.randomUUID().toString + statementId = UUID.randomUUID().toString // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val tableTypesStr = if (tableTypes == null) "null" else tableTypes.asScala.mkString(",") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 27d2c997ca3e..1747b5bafc93 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -70,8 +70,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { - val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", - "Statement", "State", "Detail") + val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time", + "Execution Time", "Duration", "Statement", "State", "Detail") val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { @@ -90,7 +90,9 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {info.groupId} {formatDate(info.startTimestamp)} {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} - {formatDurationOption(Some(info.totalTime))} + {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} + {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} + {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} {info.statement} {info.state} {errorMessageCell(detail)} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index fdc9bee5ed05..a45c6e363cbf 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -79,8 +79,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { - val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", - "Statement", "State", "Detail") + val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time", + "Execution Time", "Duration", "Statement", "State", "Detail") val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { @@ -99,7 +99,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) {info.groupId} {formatDate(info.startTimestamp)} {formatDate(info.finishTimestamp)} - {formatDurationOption(Some(info.totalTime))} + {formatDate(info.closeTimestamp)} + {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} + {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} {info.statement} {info.state} {errorMessageCell(detail)} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b06856b05479..dd18add53fde 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -44,6 +44,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.internal.StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.{ThreadUtils, Utils} @@ -536,9 +537,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } if (HiveUtils.isHive23) { - assert(conf.get("spark.sql.hive.version") === Some("2.3.5")) + assert(conf.get(HiveUtils.FAKE_HIVE_VERSION.key) === Some("2.3.5")) } else { - assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + assert(conf.get(HiveUtils.FAKE_HIVE_VERSION.key) === Some("1.2.1")) } } } @@ -553,9 +554,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } if (HiveUtils.isHive23) { - assert(conf.get("spark.sql.hive.version") === Some("2.3.5")) + assert(conf.get(HiveUtils.FAKE_HIVE_VERSION.key) === Some("2.3.5")) } else { - assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + assert(conf.get(HiveUtils.FAKE_HIVE_VERSION.key) === Some("1.2.1")) } } } @@ -659,7 +660,7 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary override protected def extraConf: Seq[String] = - "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil + s"--conf ${HIVE_THRIFT_SERVER_SINGLESESSION.key}=true" :: Nil test("share the temporary functions across JDBC connections") { withMultipleConnectionJdbcStatement()( diff --git a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index ad7a9a238f8a..8fce9d938343 100644 --- a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -344,6 +344,7 @@ SessionHandle getSessionHandle(TOpenSessionReq req, TOpenSessionResp res) String ipAddress = getIpAddress(); TProtocolVersion protocol = getMinVersion(CLIService.SERVER_VERSION, req.getClient_protocol()); + res.setServerProtocolVersion(protocol); SessionHandle sessionHandle; if (cliService.getHiveConf().getBoolVar(ConfVars.HIVE_SERVER2_ENABLE_DOAS) && (userName != null)) { @@ -354,7 +355,6 @@ SessionHandle getSessionHandle(TOpenSessionReq req, TOpenSessionResp res) sessionHandle = cliService.openSession(protocol, userName, req.getPassword(), ipAddress, req.getConfiguration()); } - res.setServerProtocolVersion(protocol); return sessionHandle; } diff --git a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 9552d9bd68cd..d41c3b493bb4 100644 --- a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -345,6 +345,7 @@ SessionHandle getSessionHandle(TOpenSessionReq req, TOpenSessionResp res) String ipAddress = getIpAddress(); TProtocolVersion protocol = getMinVersion(CLIService.SERVER_VERSION, req.getClient_protocol()); + res.setServerProtocolVersion(protocol); SessionHandle sessionHandle; if (cliService.getHiveConf().getBoolVar(ConfVars.HIVE_SERVER2_ENABLE_DOAS) && (userName != null)) { @@ -355,7 +356,6 @@ SessionHandle getSessionHandle(TOpenSessionReq req, TOpenSessionResp res) sessionHandle = cliService.openSession(protocol, userName, req.getPassword(), ipAddress, req.getConfiguration()); } - res.setServerProtocolVersion(protocol); return sessionHandle; } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 4351dc703684..9bc0be87be5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -27,9 +27,12 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.spark.{SecurityManager, SparkConf, TestUtils} +import org.apache.spark.internal.config.MASTER_REST_SERVER_ENABLED +import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -184,11 +187,11 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { val args = Seq( "--name", "prepare testing tables", "--master", "local[2]", - "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", - "--conf", "spark.sql.hive.metastore.version=1.2.1", - "--conf", "spark.sql.hive.metastore.jars=maven", - "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}", + "--conf", s"${UI_ENABLED.key}=false", + "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", + "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--conf", s"spark.sql.test.version.index=$index", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", tempPyFile.getCanonicalPath) @@ -203,11 +206,11 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), "--name", "HiveExternalCatalog backward compatibility test", "--master", "local[2]", - "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", - "--conf", "spark.sql.hive.metastore.version=1.2.1", - "--conf", "spark.sql.hive.metastore.jars=maven", - "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}", + "--conf", s"${UI_ENABLED.key}=false", + "--conf", s"${MASTER_REST_SERVER_ENABLED.key}=false", + "--conf", s"${HiveUtils.HIVE_METASTORE_VERSION.key}=1.2.1", + "--conf", s"${HiveUtils.HIVE_METASTORE_JARS.key}=maven", + "--conf", s"${WAREHOUSE_PATH.key}=${wareHousePath.getCanonicalPath}", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", unusedJar.toString) runSparkSubmit(args) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 0ff22150658b..e2ddec342766 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{HiveTestUtils, TestHiveContext} +import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS +import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -338,10 +340,10 @@ object SetMetastoreURLTest extends Logging { val builder = SparkSession.builder() .config(sparkConf) .config(UI_ENABLED.key, "false") - .config("spark.sql.hive.metastore.version", "0.13.1") + .config(HiveUtils.HIVE_METASTORE_VERSION.key, "0.13.1") // The issue described in SPARK-16901 only appear when // spark.sql.hive.metastore.jars is not set to builtin. - .config("spark.sql.hive.metastore.jars", "maven") + .config(HiveUtils.HIVE_METASTORE_JARS.key, "maven") .enableHiveSupport() val spark = builder.getOrCreate() @@ -392,16 +394,16 @@ object SetWarehouseLocationTest extends Logging { // We are expecting that the value of spark.sql.warehouse.dir will override the // value of hive.metastore.warehouse.dir. val session = new TestHiveContext(new SparkContext(sparkConf - .set("spark.sql.warehouse.dir", warehouseLocation.toString) + .set(WAREHOUSE_PATH.key, warehouseLocation.toString) .set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString))) .sparkSession (session, warehouseLocation.toString) } - if (sparkSession.conf.get("spark.sql.warehouse.dir") != expectedWarehouseLocation) { + if (sparkSession.conf.get(WAREHOUSE_PATH.key) != expectedWarehouseLocation) { throw new Exception( - "spark.sql.warehouse.dir is not set to the expected warehouse location " + + s"${WAREHOUSE_PATH.key} is not set to the expected warehouse location " + s"$expectedWarehouseLocation.") } @@ -564,7 +566,7 @@ object SparkSubmitClassLoaderTest extends Logging { val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set(UI_ENABLED, false) - conf.set("spark.sql.warehouse.dir", hiveWarehouseLocation.toString) + conf.set(WAREHOUSE_PATH.key, hiveWarehouseLocation.toString) val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") @@ -642,14 +644,14 @@ object SparkSQLConfTest extends Logging { val conf = new SparkConf() { override def getAll: Array[(String, String)] = { def isMetastoreSetting(conf: String): Boolean = { - conf == "spark.sql.hive.metastore.version" || conf == "spark.sql.hive.metastore.jars" + conf == HiveUtils.HIVE_METASTORE_VERSION.key || conf == HiveUtils.HIVE_METASTORE_JARS.key } // If there is any metastore settings, remove them. val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) // Always add these two metastore settings at the beginning. - ("spark.sql.hive.metastore.version" -> "0.12") +: - ("spark.sql.hive.metastore.jars" -> "maven") +: + (HiveUtils.HIVE_METASTORE_VERSION.key -> "0.12") +: + (HiveUtils.HIVE_METASTORE_JARS.key -> "maven") +: filteredSettings } @@ -676,10 +678,10 @@ object SPARK_9757 extends QueryTest { val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( new SparkConf() - .set("spark.sql.hive.metastore.version", "0.13.1") - .set("spark.sql.hive.metastore.jars", "maven") + .set(HiveUtils.HIVE_METASTORE_VERSION.key, "0.13.1") + .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven") .set(UI_ENABLED, false) - .set("spark.sql.warehouse.dir", hiveWarehouseLocation.toString)) + .set(WAREHOUSE_PATH.key, hiveWarehouseLocation.toString)) val hiveContext = new TestHiveContext(sparkContext) spark = hiveContext.sparkSession @@ -725,7 +727,7 @@ object SPARK_11009 extends QueryTest { val sparkContext = new SparkContext( new SparkConf() .set(UI_ENABLED, false) - .set("spark.sql.shuffle.partitions", "100")) + .set(SHUFFLE_PARTITIONS.key, "100")) val hiveContext = new TestHiveContext(sparkContext) spark = hiveContext.sparkSession @@ -756,7 +758,7 @@ object SPARK_14244 extends QueryTest { val sparkContext = new SparkContext( new SparkConf() .set(UI_ENABLED, false) - .set("spark.sql.shuffle.partitions", "100")) + .set(SHUFFLE_PARTITIONS.key, "100")) val hiveContext = new TestHiveContext(sparkContext) spark = hiveContext.sparkSession diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ecd428780c67..d06cc1c0a88a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1028,7 +1028,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" -> + withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> enableTwoLevelMaps) { (1 to 3).foreach { fallbackStartsAt => withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index 8cdb8dd84fb2..d68a47053f18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} -import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. @@ -57,9 +57,9 @@ object TestHive new SparkConf() .set("spark.sql.test", "") .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set("spark.sql.hive.metastore.barrierPrefixes", + .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.sql.warehouse.dir", TestHiveContext.makeWarehouseDir().toURI.getPath) + .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) // SPARK-8910 .set(UI_ENABLED, false) .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) @@ -534,7 +534,7 @@ private[hive] class TestHiveSparkSession( } // Clean out the Hive warehouse between each suite - val warehouseDir = new File(new URI(sparkContext.conf.get("spark.sql.warehouse.dir")).getPath) + val warehouseDir = new File(new URI(sparkContext.conf.get(WAREHOUSE_PATH.key)).getPath) Utils.deleteRecursively(warehouseDir) warehouseDir.mkdir()