diff --git a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java new file mode 100644 index 0000000000..0911901d21 --- /dev/null +++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet; + +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.crypto.DecryptionKeyRetriever; +import org.apache.parquet.crypto.DecryptionPropertiesFactory; +import org.apache.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; + +// spotless:off +/* + * Architecture Overview: + * + * JVM Side | Native Side + * ┌─────────────────────────────────────┐ | ┌─────────────────────────────────────┐ + * │ CometFileKeyUnwrapper │ | │ Parquet File Reading │ + * │ │ | │ │ + * │ ┌─────────────────────────────┐ │ | │ ┌─────────────────────────────┐ │ + * │ │ hadoopConf │ │ | │ │ file1.parquet │ │ + * │ │ (Configuration) │ │ | │ │ file2.parquet │ │ + * │ └─────────────────────────────┘ │ | │ │ file3.parquet │ │ + * │ │ │ | │ └─────────────────────────────┘ │ + * │ ▼ │ | │ │ │ + * │ ┌─────────────────────────────┐ │ | │ │ │ + * │ │ factoryCache │ │ | │ ▼ │ + * │ │ (many-to-one mapping) │ │ | │ ┌─────────────────────────────┐ │ + * │ │ │ │ | │ │ Parse file metadata & │ │ + * │ │ file1 ──┐ │ │ | │ │ extract keyMetadata │ │ + * │ │ file2 ──┼─► DecryptionProps │ │ | │ └─────────────────────────────┘ │ + * │ │ file3 ──┘ Factory │ │ | │ │ │ + * │ └─────────────────────────────┘ │ | │ │ │ + * │ │ │ | │ ▼ │ + * │ ▼ │ | │ ╔═════════════════════════════╗ │ + * │ ┌─────────────────────────────┐ │ | │ ║ JNI CALL: ║ │ + * │ │ retrieverCache │ │ | │ ║ getKey(filePath, ║ │ + * │ │ filePath -> KeyRetriever │◄───┼───┼───┼──║ keyMetadata) ║ │ + * │ └─────────────────────────────┘ │ | │ ╚═════════════════════════════╝ │ + * │ │ │ | │ │ + * │ ▼ │ | │ │ + * │ ┌─────────────────────────────┐ │ | │ │ + * │ │ DecryptionKeyRetriever │ │ | │ │ + * │ │ .getKey(keyMetadata) │ │ | │ │ + * │ └─────────────────────────────┘ │ | │ │ + * │ │ │ | │ │ + * │ ▼ │ | │ │ + * │ ┌─────────────────────────────┐ │ | │ ┌─────────────────────────────┐ │ + * │ │ return key bytes │────┼───┼───┼─►│ Use key for decryption │ │ + * │ └─────────────────────────────┘ │ | │ │ of parquet data │ │ + * └─────────────────────────────────────┘ | │ └─────────────────────────────┘ │ + * | └─────────────────────────────────────┘ + * | + * JNI Boundary + * + * Setup Phase (storeDecryptionKeyRetriever): + * 1. hadoopConf → DecryptionPropertiesFactory (cached in factoryCache) + * 2. Factory + filePath → DecryptionKeyRetriever (cached in retrieverCache) + * + * Runtime Phase (getKey): + * 3. Native code calls getKey(filePath, keyMetadata) ──► JVM + * 4. Retrieve cached DecryptionKeyRetriever for filePath + * 5. KeyRetriever.getKey(keyMetadata) → decrypted key bytes + * 6. Return key bytes ──► Native code for parquet decryption + */ +// spotless:on + +/** + * Helper class to access DecryptionKeyRetriever.getKey from native code via JNI. This class handles + * the complexity of creating and caching properly configured DecryptionKeyRetriever instances using + * DecryptionPropertiesFactory. The life of this object is meant to map to a single Comet plan, so + * associated with CometExecIterator. + */ +public class CometFileKeyUnwrapper { + + // Each file path gets a unique DecryptionKeyRetriever + private final ConcurrentHashMap retrieverCache = + new ConcurrentHashMap<>(); + + // Cache the factory since we should be using the same hadoopConf for every file in this scan. + private DecryptionPropertiesFactory factory = null; + // Cache the hadoopConf just to assert the assumption above. + private Configuration conf = null; + + /** + * Creates and stores a DecryptionKeyRetriever instance for the given file path. + * + * @param filePath The path to the Parquet file + * @param hadoopConf The Hadoop Configuration to use for this file path + */ + public void storeDecryptionKeyRetriever(final String filePath, final Configuration hadoopConf) { + // Use DecryptionPropertiesFactory.loadFactory to get the factory and then call + // getFileDecryptionProperties + if (factory == null) { + factory = DecryptionPropertiesFactory.loadFactory(hadoopConf); + conf = hadoopConf; + } else { + // Check the assumption that all files have the same hadoopConf and thus same Factory + assert (conf == hadoopConf); + } + Path path = new Path(filePath); + FileDecryptionProperties decryptionProperties = + factory.getFileDecryptionProperties(hadoopConf, path); + + DecryptionKeyRetriever keyRetriever = decryptionProperties.getKeyRetriever(); + retrieverCache.put(filePath, keyRetriever); + } + + /** + * Gets the decryption key for the given key metadata using the cached DecryptionKeyRetriever for + * the specified file path. + * + * @param filePath The path to the Parquet file + * @param keyMetadata The key metadata bytes from the Parquet file + * @return The decrypted key bytes + * @throws ParquetCryptoRuntimeException if key unwrapping fails + */ + public byte[] getKey(final String filePath, final byte[] keyMetadata) + throws ParquetCryptoRuntimeException { + DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath); + if (keyRetriever == null) { + throw new ParquetCryptoRuntimeException( + "Failed to find DecryptionKeyRetriever for path: " + filePath); + } + return keyRetriever.getKey(keyMetadata); + } +} diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java b/common/src/main/java/org/apache/comet/parquet/Native.java index cceb1085c3..dbddc3b743 100644 --- a/common/src/main/java/org/apache/comet/parquet/Native.java +++ b/common/src/main/java/org/apache/comet/parquet/Native.java @@ -267,9 +267,11 @@ public static native long initRecordBatchReader( String sessionTimezone, int batchSize, boolean caseSensitive, - Map objectStoreOptions); + Map objectStoreOptions, + CometFileKeyUnwrapper keyUnwrapper); // arrow native version of read batch + /** * Read the next batch of data into memory on native side * @@ -280,6 +282,7 @@ public static native long initRecordBatchReader( // arrow native equivalent of currentBatch. 'columnNum' is number of the column in the record // batch + /** * Load the column corresponding to columnNum in the currently loaded record batch into JVM * diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 67c2775400..84918d9335 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -80,7 +80,7 @@ import org.apache.comet.vector.CometVector; import org.apache.comet.vector.NativeUtil; -import static scala.jdk.javaapi.CollectionConverters.*; +import static scala.jdk.javaapi.CollectionConverters.asJava; /** * A vectorized Parquet reader that reads a Parquet file in a batched fashion. @@ -410,6 +410,15 @@ public void init() throws Throwable { } } + boolean encryptionEnabled = CometParquetUtils.encryptionEnabled(conf); + + // Create keyUnwrapper if encryption is enabled + CometFileKeyUnwrapper keyUnwrapper = null; + if (encryptionEnabled) { + keyUnwrapper = new CometFileKeyUnwrapper(); + keyUnwrapper.storeDecryptionKeyRetriever(file.filePath().toString(), conf); + } + int batchSize = conf.getInt( CometConf.COMET_BATCH_SIZE().key(), @@ -426,7 +435,8 @@ public void init() throws Throwable { timeZoneId, batchSize, caseSensitive, - objectStoreOptions); + objectStoreOptions, + keyUnwrapper); } isInitialized = true; } diff --git a/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala b/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala index b930aea17a..885b4686e7 100644 --- a/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala +++ b/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala @@ -58,7 +58,7 @@ object NativeConfig { def extractObjectStoreOptions(hadoopConf: Configuration, uri: URI): Map[String, String] = { val scheme = uri.getScheme.toLowerCase(Locale.ROOT) - import scala.collection.JavaConverters._ + import scala.jdk.CollectionConverters._ val options = scala.collection.mutable.Map[String, String]() // The schemes will use libhdfs diff --git a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala index a37ec7e66a..8bcf99dbd1 100644 --- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala +++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala @@ -20,6 +20,8 @@ package org.apache.comet.parquet import org.apache.hadoop.conf.Configuration +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory} import org.apache.spark.sql.internal.SQLConf object CometParquetUtils { @@ -27,6 +29,16 @@ object CometParquetUtils { private val PARQUET_FIELD_ID_READ_ENABLED = "spark.sql.parquet.fieldId.read.enabled" private val IGNORE_MISSING_PARQUET_FIELD_ID = "spark.sql.parquet.fieldId.read.ignoreMissing" + // Map of encryption configuration key-value pairs that, if present, are only supported with + // these specific values. Generally, these are the default values that won't be present, + // but if they are present we want to check them. + private val SUPPORTED_ENCRYPTION_CONFIGS: Map[String, Set[String]] = Map( + // https://github.com/apache/arrow-rs/blob/main/parquet/src/encryption/ciphers.rs#L21 + KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> Set(KeyToolkit.DATA_KEY_LENGTH_DEFAULT.toString), + KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> Set(KeyToolkit.KEK_LENGTH_DEFAULT.toString), + // https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L494 + PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME -> Set("AES_GCM_V1")) + def writeFieldId(conf: SQLConf): Boolean = conf.getConfString(PARQUET_FIELD_ID_WRITE_ENABLED, "false").toBoolean @@ -38,4 +50,36 @@ object CometParquetUtils { def ignoreMissingIds(conf: SQLConf): Boolean = conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean + + /** + * Checks if the given Hadoop configuration contains any unsupported encryption settings. + * + * @param hadoopConf + * The Hadoop configuration to check + * @return + * true if all encryption configurations are supported, false if any unsupported config is + * found + */ + def isEncryptionConfigSupported(hadoopConf: Configuration): Boolean = { + // Check configurations that, if present, can only have specific allowed values + val supportedListCheck = SUPPORTED_ENCRYPTION_CONFIGS.forall { + case (configKey, supportedValues) => + val configValue = Option(hadoopConf.get(configKey)) + configValue match { + case Some(value) => supportedValues.contains(value) + case None => true // Config not set, so it's supported + } + } + + supportedListCheck + } + + def encryptionEnabled(hadoopConf: Configuration): Boolean = { + // TODO: Are there any other properties to check? + val encryptionKeys = Seq( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME) + + encryptionKeys.exists(key => Option(hadoopConf.get(key)).exists(_.nonEmpty)) + } } diff --git a/native/Cargo.lock b/native/Cargo.lock index 483d2e0709..ad8c24d9dc 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1424,6 +1424,7 @@ dependencies = [ "datafusion-session", "datafusion-sql", "futures", + "hex", "itertools 0.14.0", "log", "object_store", @@ -1611,6 +1612,7 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", + "hex", "indexmap", "libc", "log", @@ -1738,6 +1740,7 @@ dependencies = [ "datafusion-pruning", "datafusion-session", "futures", + "hex", "itertools 0.14.0", "log", "object_store", @@ -1768,6 +1771,7 @@ dependencies = [ "log", "object_store", "parking_lot", + "parquet", "rand", "tempfile", "url", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index aa4425c96d..e6e4c6f3cf 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -59,7 +59,7 @@ bytes = { workspace = true } tempfile = "3.8.0" itertools = "0.14.0" paste = "1.0.14" -datafusion = { workspace = true } +datafusion = { workspace = true, features = ["parquet_encryption"] } datafusion-spark = { workspace = true } once_cell = "1.18.0" regex = { workspace = true } diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs index b3241477b8..ecac7af94e 100644 --- a/native/core/src/errors.rs +++ b/native/core/src/errors.rs @@ -185,6 +185,15 @@ impl From for DataFusionError { } } +impl From for ParquetError { + fn from(value: CometError) -> Self { + match value { + CometError::Parquet { source } => source, + _ => ParquetError::General(value.to_string()), + } + } +} + impl From for ExecutionError { fn from(value: CometError) -> Self { match value { diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 83dbd68e76..b76108ad96 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -78,6 +78,7 @@ use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; +use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use datafusion_comet_proto::spark_operator::operator::OpStruct; use log::info; use once_cell::sync::Lazy; @@ -171,6 +172,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native: jboolean, tracing_enabled: jboolean, max_temp_directory_size: jlong, + key_unwrapper_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { with_trace("createPlan", tracing_enabled != JNI_FALSE, || { @@ -247,6 +249,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( None }; + // Handle key unwrapper for encrypted files + if !key_unwrapper_obj.is_null() { + let encryption_factory = CometEncryptionFactory { + key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?, + }; + session.runtime_env().register_parquet_encryption_factory( + ENCRYPTION_FACTORY_ID, + Arc::new(encryption_factory), + ); + } + let exec_context = Box::new(ExecutionContext { id, task_attempt_id, diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 517c037e93..329edc5d2f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1358,6 +1358,8 @@ impl PhysicalPlanner { default_values, scan.session_timezone.as_str(), scan.case_sensitive, + self.session_ctx(), + scan.encryption_enabled, )?; Ok(( vec![], diff --git a/native/core/src/parquet/encryption_support.rs b/native/core/src/parquet/encryption_support.rs new file mode 100644 index 0000000000..ff67a3fcbd --- /dev/null +++ b/native/core/src/parquet/encryption_support.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::operators::ExecutionError; +use crate::jvm_bridge::{check_exception, JVMClasses}; +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion::common::extensions_options; +use datafusion::config::EncryptionFactoryOptions; +use datafusion::error::DataFusionError; +use datafusion::execution::parquet_encryption::EncryptionFactory; +use jni::objects::{GlobalRef, JMethodID}; +use object_store::path::Path; +use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever}; +use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::errors::ParquetError; +use std::sync::Arc; + +pub const ENCRYPTION_FACTORY_ID: &str = "comet.jni_kms_encryption"; + +extensions_options! { + pub struct CometEncryptionConfig { + // Native side strips file down to a path (not a URI) but Spark wants the full URI, + // so we cache the prefix to stick on the front before calling over JNI + pub uri_base: String, default = "file:///".into() + } +} + +#[derive(Debug)] +pub struct CometEncryptionFactory { + pub(crate) key_unwrapper: GlobalRef, +} + +/// `EncryptionFactory` is a DataFusion trait for types that generate +/// file encryption and decryption properties. +#[async_trait] +impl EncryptionFactory for CometEncryptionFactory { + async fn get_file_encryption_properties( + &self, + _options: &EncryptionFactoryOptions, + _schema: &SchemaRef, + _file_path: &Path, + ) -> Result, DataFusionError> { + Err(DataFusionError::NotImplemented( + "Comet does not support Parquet encryption yet." + .parse() + .unwrap(), + )) + } + + /// Generate file decryption properties to use when reading a Parquet file. + /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` + /// that can determine the keys using the encryption metadata. + async fn get_file_decryption_properties( + &self, + options: &EncryptionFactoryOptions, + file_path: &Path, + ) -> Result, DataFusionError> { + let config: CometEncryptionConfig = options.to_extension_options()?; + + let full_path: String = config.uri_base + file_path.as_ref(); + let key_retriever = CometKeyRetriever::new(&full_path, self.key_unwrapper.clone()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let decryption_properties = + FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?; + Ok(Some(decryption_properties)) + } +} + +pub struct CometKeyRetriever { + file_path: String, + key_unwrapper: GlobalRef, + get_key_method_id: JMethodID, +} + +impl CometKeyRetriever { + pub fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result { + let mut env = JVMClasses::get_env()?; + + Ok(CometKeyRetriever { + file_path: file_path.to_string(), + key_unwrapper, + get_key_method_id: env + .get_method_id( + "org/apache/comet/parquet/CometFileKeyUnwrapper", + "getKey", + "(Ljava/lang/String;[B)[B", + ) + .map_err(|e| { + ExecutionError::GeneralError(format!("Failed to get JNI method ID: {}", e)) + })?, + }) + } +} + +impl KeyRetriever for CometKeyRetriever { + /// Get a data encryption key using the metadata stored in the Parquet file. + fn retrieve_key(&self, key_metadata: &[u8]) -> datafusion::parquet::errors::Result> { + use jni::{objects::JObject, signature::ReturnType}; + + // Get JNI environment + let mut env = JVMClasses::get_env()?; + + // Get the key unwrapper instance from GlobalRef + let unwrapper_instance = self.key_unwrapper.as_obj(); + + let instance: JObject = unsafe { JObject::from_raw(unwrapper_instance.as_raw()) }; + + // Convert file path to JString + let file_path_jstring = env + .new_string(&self.file_path) + .map_err(|e| ParquetError::General(format!("Failed to create JString: {}", e)))?; + + // Convert key_metadata to JByteArray + let key_metadata_array = env + .byte_array_from_slice(key_metadata) + .map_err(|e| ParquetError::General(format!("Failed to create byte array: {}", e)))?; + + // Call instance method FileKeyUnwrapper.getKey(String, byte[]) -> byte[] + let result = unsafe { + env.call_method_unchecked( + instance, + self.get_key_method_id, + ReturnType::Array, + &[ + jni::objects::JValue::from(&file_path_jstring).as_jni(), + jni::objects::JValue::from(&key_metadata_array).as_jni(), + ], + ) + }; + + // Check for Java exceptions first, before processing the result + if let Some(exception) = check_exception(&mut env).map_err(|e| { + ParquetError::General(format!("Failed to check for Java exception: {}", e)) + })? { + return Err(ParquetError::General(format!( + "Java exception during key retrieval: {}", + exception + ))); + } + + let result = + result.map_err(|e| ParquetError::General(format!("JNI method call failed: {}", e)))?; + + // Extract the byte array from the result + let result_array = result + .l() + .map_err(|e| ParquetError::General(format!("Failed to extract result: {}", e)))?; + + // Convert JObject to JByteArray and then to Vec + let byte_array: jni::objects::JByteArray = result_array.into(); + + let result_vec = env + .convert_byte_array(&byte_array) + .map_err(|e| ParquetError::General(format!("Failed to convert byte array: {}", e)))?; + Ok(result_vec) + } +} diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index a6efe4ed53..ca70c2fc34 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod data_type; +pub mod encryption_support; pub mod mutable_vector; pub use mutable_vector::*; @@ -52,7 +53,9 @@ use crate::execution::operators::ExecutionError; use crate::execution::planner::PhysicalPlanner; use crate::execution::serde; use crate::execution::utils::SparkArrowConvert; +use crate::jvm_bridge::{jni_new_global_ref, JVMClasses}; use crate::parquet::data_type::AsBytes; +use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use crate::parquet::parquet_exec::init_datasource_exec; use crate::parquet::parquet_support::prepare_object_store_with_configs; use arrow::array::{Array, RecordBatch}; @@ -712,8 +715,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat batch_size: jint, case_sensitive: jboolean, object_store_options: jobject, + key_unwrapper_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |mut env| unsafe { + JVMClasses::init(&mut env); let session_config = SessionConfig::new().with_batch_size(batch_size as usize); let planner = PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)), 0); @@ -766,6 +771,22 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat .unwrap() .into(); + // Handle key unwrapper for encrypted files + let encryption_enabled = if !key_unwrapper_obj.is_null() { + let encryption_factory = CometEncryptionFactory { + key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?, + }; + session_ctx + .runtime_env() + .register_parquet_encryption_factory( + ENCRYPTION_FACTORY_ID, + Arc::new(encryption_factory), + ); + true + } else { + false + }; + let scan = init_datasource_exec( required_schema, Some(data_schema), @@ -778,6 +799,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat None, session_timezone.as_str(), case_sensitive != JNI_FALSE, + session_ctx, + encryption_enabled, )?; let partition_index: usize = 0; diff --git a/native/core/src/parquet/parquet_exec.rs b/native/core/src/parquet/parquet_exec.rs index 4b587b7ba6..0a95ec9996 100644 --- a/native/core/src/parquet/parquet_exec.rs +++ b/native/core/src/parquet/parquet_exec.rs @@ -16,6 +16,7 @@ // under the License. use crate::execution::operators::ExecutionError; +use crate::parquet::encryption_support::{CometEncryptionConfig, ENCRYPTION_FACTORY_ID}; use crate::parquet::parquet_support::SparkParquetOptions; use crate::parquet::schema_adapter::SparkSchemaAdapterFactory; use arrow::datatypes::{Field, SchemaRef}; @@ -28,6 +29,7 @@ use datafusion::datasource::source::DataSourceExec; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::physical_expr::expressions::BinaryExpr; use datafusion::physical_expr::PhysicalExpr; +use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_comet_spark_expr::EvalMode; use itertools::Itertools; @@ -66,9 +68,16 @@ pub(crate) fn init_datasource_exec( default_values: Option>, session_timezone: &str, case_sensitive: bool, + session_ctx: &Arc, + encryption_enabled: bool, ) -> Result, ExecutionError> { - let (table_parquet_options, spark_parquet_options) = - get_options(session_timezone, case_sensitive); + let (table_parquet_options, spark_parquet_options) = get_options( + session_timezone, + case_sensitive, + &object_store_url, + encryption_enabled, + ); + let mut parquet_source = ParquetSource::new(table_parquet_options); // Create a conjunctive form of the vector because ParquetExecBuilder takes @@ -87,6 +96,14 @@ pub(crate) fn init_datasource_exec( } } + if encryption_enabled { + parquet_source = parquet_source.with_encryption_factory( + session_ctx + .runtime_env() + .parquet_encryption_factory(ENCRYPTION_FACTORY_ID)?, + ); + } + let file_source = parquet_source.with_schema_adapter_factory(Arc::new( SparkSchemaAdapterFactory::new(spark_parquet_options, default_values), ))?; @@ -125,6 +142,8 @@ pub(crate) fn init_datasource_exec( fn get_options( session_timezone: &str, case_sensitive: bool, + object_store_url: &ObjectStoreUrl, + encryption_enabled: bool, ) -> (TableParquetOptions, SparkParquetOptions) { let mut table_parquet_options = TableParquetOptions::new(); table_parquet_options.global.pushdown_filters = true; @@ -134,6 +153,16 @@ fn get_options( SparkParquetOptions::new(EvalMode::Legacy, session_timezone, false); spark_parquet_options.allow_cast_unsigned_ints = true; spark_parquet_options.case_sensitive = case_sensitive; + + if encryption_enabled { + table_parquet_options.crypto.configure_factory( + ENCRYPTION_FACTORY_ID, + &CometEncryptionConfig { + uri_base: object_store_url.to_string(), + }, + ); + } + (table_parquet_options, spark_parquet_options) } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 57e012b369..a243ab6b03 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -104,6 +104,7 @@ message NativeScan { // configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in // the map. map object_store_options = 13; + bool encryption_enabled = 14; } message Projection { diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index a4e9494b69..8603a7b9a8 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -24,14 +24,18 @@ import java.lang.management.ManagementFactory import scala.util.matching.Regex +import org.apache.hadoop.conf.Configuration import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.comet.CometMetricNode import org.apache.spark.sql.vectorized._ +import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_METRICS_UPDATE_INTERVAL} +import org.apache.comet.CometConf._ import org.apache.comet.Tracing.withTrace +import org.apache.comet.parquet.CometFileKeyUnwrapper import org.apache.comet.serde.Config.ConfigMap import org.apache.comet.vector.NativeUtil @@ -52,6 +56,8 @@ import org.apache.comet.vector.NativeUtil * The number of partitions. * @param partitionIndex * The index of the partition. + * @param encryptedFilePaths + * Paths to encrypted Parquet files that need key unwrapping. */ class CometExecIterator( val id: Long, @@ -60,7 +66,9 @@ class CometExecIterator( protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode, numParts: Int, - partitionIndex: Int) + partitionIndex: Int, + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, + encryptedFilePaths: Seq[String] = Seq.empty) extends Iterator[ColumnarBatch] with Logging { @@ -73,6 +81,7 @@ class CometExecIterator( private val cometBatchIterators = inputs.map { iterator => new CometBatchIterator(iterator, nativeUtil) }.toArray + private val plan = { val conf = SparkEnv.get.conf val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs @@ -102,6 +111,19 @@ class CometExecIterator( getMemoryLimitPerTask(conf) } + // Create keyUnwrapper if encryption is enabled + val keyUnwrapper = if (encryptedFilePaths.nonEmpty) { + val unwrapper = new CometFileKeyUnwrapper() + val hadoopConf: Configuration = broadcastedHadoopConfForEncryption.get.value.value + + encryptedFilePaths.foreach(filePath => + unwrapper.storeDecryptionKeyRetriever(filePath, hadoopConf)) + + unwrapper + } else { + null + } + nativeLib.createPlan( id, cometBatchIterators, @@ -121,7 +143,8 @@ class CometExecIterator( debug = COMET_DEBUG_ENABLED.get(), explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), tracingEnabled, - maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get()) + maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get(), + keyUnwrapper) } private var nextBatch: Option[ColumnarBatch] = None @@ -145,6 +168,7 @@ class CometExecIterator( def convertToInt(threads: String): Int = { if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt } + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r val master = conf.get("spark.master") diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 13edf2997c..fb24dce0d3 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -24,6 +24,8 @@ import java.nio.ByteBuffer import org.apache.spark.CometTaskMemoryManager import org.apache.spark.sql.comet.CometMetricNode +import org.apache.comet.parquet.CometFileKeyUnwrapper + class Native extends NativeBase { // scalastyle:off @@ -69,7 +71,8 @@ class Native extends NativeBase { debug: Boolean, explain: Boolean, tracingEnabled: Boolean, - maxTempDirectorySize: Long): Long + maxTempDirectorySize: Long, + keyUnwrapper: CometFileKeyUnwrapper): Long // scalastyle:on /** diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index cbca7304d2..950d0e9d37 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -46,12 +46,14 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanE import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet} +import org.apache.comet.parquet.CometParquetUtils.{encryptionEnabled, isEncryptionConfigSupported} import org.apache.comet.shims.CometTypeShim /** * Spark physical optimizer rule for replacing Spark scans with Comet scans. */ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with CometTypeShim { + import CometScanRule._ private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() @@ -144,22 +146,14 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com return withInfos(scanExec, fallbackReasons.toSet) } - val encryptionEnabled: Boolean = - conf.getConfString("parquet.crypto.factory.class", "").nonEmpty && - conf.getConfString("parquet.encryption.kms.client.class", "").nonEmpty - var scanImpl = COMET_NATIVE_SCAN_IMPL.get() + val hadoopConf = scanExec.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scanExec.relation.options) + // if scan is auto then pick the best available scan if (scanImpl == SCAN_AUTO) { - if (encryptionEnabled) { - logInfo( - s"Auto scan mode falling back to $SCAN_NATIVE_COMET because " + - s"$SCAN_NATIVE_ICEBERG_COMPAT does not support reading encrypted Parquet files") - scanImpl = SCAN_NATIVE_COMET - } else { - scanImpl = selectScan(scanExec, r.partitionSchema) - } + scanImpl = selectScan(scanExec, r.partitionSchema, hadoopConf) } if (scanImpl == SCAN_NATIVE_DATAFUSION && !COMET_EXEC_ENABLED.get()) { @@ -206,10 +200,10 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com return withInfos(scanExec, fallbackReasons.toSet) } - if (scanImpl != CometConf.SCAN_NATIVE_COMET && encryptionEnabled) { - fallbackReasons += - "Full native scan disabled because encryption is not supported" - return withInfos(scanExec, fallbackReasons.toSet) + if (scanImpl != CometConf.SCAN_NATIVE_COMET && encryptionEnabled(hadoopConf)) { + if (!isEncryptionConfigSupported(hadoopConf)) { + return withInfos(scanExec, fallbackReasons.toSet) + } } val typeChecker = CometScanTypeChecker(scanImpl) @@ -303,7 +297,10 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com } } - private def selectScan(scanExec: FileSourceScanExec, partitionSchema: StructType): String = { + private def selectScan( + scanExec: FileSourceScanExec, + partitionSchema: StructType, + hadoopConf: Configuration): String = { val fallbackReasons = new ListBuffer[String]() @@ -313,10 +310,7 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com val filePath = scanExec.relation.inputFiles.headOption if (filePath.exists(_.startsWith("s3a://"))) { - validateObjectStoreConfig( - filePath.get, - session.sparkContext.hadoopConfiguration, - fallbackReasons) + validateObjectStoreConfig(filePath.get, hadoopConf, fallbackReasons) } } else { fallbackReasons += s"$SCAN_NATIVE_ICEBERG_COMPAT only supports local filesystem and S3" diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8fc7c2d63e..43f8b72935 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -48,6 +48,7 @@ import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig +import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} @@ -1161,6 +1162,9 @@ object QueryPlanSerde extends Logging with CometExprShim { // Collect S3/cloud storage configurations val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) + + nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf)) + firstPartition.foreach { partitionFile => val objectStoreOptions = NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) @@ -1702,6 +1706,7 @@ object QueryPlanSerde extends Logging with CometExprShim { } // scalastyle:off + /** * Align w/ Arrow's * [[https://github.com/apache/arrow-rs/blob/55.2.0/arrow-ord/src/rank.rs#L30-L40 can_rank]] and diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 3dfd1f8d03..43a1e5b9a0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -103,7 +103,9 @@ class CometNativeShuffleWriter[K, V]( nativePlan, nativeMetrics, numParts, - context.partitionId()) + context.partitionId(), + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a7cfacc475..de6892638a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,27 +25,29 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, HashPartitioningLike, Partitioning, PartitioningCollection, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils -import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException} +import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.OperatorOuterClass.Operator /** @@ -114,7 +116,9 @@ object CometExec { nativePlan, CometMetricNode(Map.empty), numParts, - partitionIdx) + partitionIdx, + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty) } def getCometIterator( @@ -123,7 +127,9 @@ object CometExec { nativePlan: Operator, nativeMetrics: CometMetricNode, numParts: Int, - partitionIdx: Int): CometExecIterator = { + partitionIdx: Int, + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]], + encryptedFilePaths: Seq[String]): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() @@ -135,7 +141,9 @@ object CometExec { bytes, nativeMetrics, numParts, - partitionIdx) + partitionIdx, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) } /** @@ -201,6 +209,39 @@ abstract class CometNativeExec extends CometExec { // TODO: support native metrics for all operators. val nativeMetrics = CometMetricNode.fromCometPlan(this) + // For each relation in a CometNativeScan generate a hadoopConf, + // for each file path in a relation associate with hadoopConf + val cometNativeScans: Seq[CometNativeScanExec] = this + .collectLeaves() + .filter(_.isInstanceOf[CometNativeScanExec]) + .map(_.asInstanceOf[CometNativeScanExec]) + assert( + cometNativeScans.size <= 1, + "We expect one native scan in a Comet plan since we will broadcast one hadoopConf.") + // If this assumption changes in the future, you can look at the commit history of #2447 + // to see how there used to be a map of relations to broadcasted confs in case multiple + // relations in a single plan. The example that came up was UNION. See discussion at: + // https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264 + val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = + cometNativeScans.headOption.fold( + (None: Option[Broadcast[SerializableConfiguration]], Seq.empty[String])) { scan => + // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and + // per-relation configs since different tables might have different decryption + // properties. + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) + if (encryptionEnabled) { + // hadoopConf isn't serializable, so we have to do a broadcasted config. + val broadcastedConf = + scan.relation.sparkSession.sparkContext + .broadcast(new SerializableConfiguration(hadoopConf)) + (Some(broadcastedConf), scan.relation.inputFiles.toSeq) + } else { + (None, Seq.empty) + } + } + def createCometExecIter( inputs: Seq[Iterator[ColumnarBatch]], numParts: Int, @@ -212,7 +253,9 @@ abstract class CometNativeExec extends CometExec { serializedPlanCopy, nativeMetrics, numParts, - partitionIndex) + partitionIndex, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) setSubqueries(it.id, this) @@ -429,6 +472,7 @@ abstract class CometBinaryExec extends CometNativeExec with BinaryExecNode */ case class SerializedPlan(plan: Option[Array[Byte]]) { def isDefined: Boolean = plan.isDefined + def isEmpty: Boolean = plan.isEmpty } @@ -442,6 +486,7 @@ case class CometProjectExec( extends CometUnaryExec with PartitioningPreservingUnaryExecNode { override def producedAttributes: AttributeSet = outputSet + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -474,6 +519,7 @@ case class CometFilterExec( extends CometUnaryExec { override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -551,7 +597,9 @@ case class CometLocalLimitExec( extends CometUnaryExec { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -585,7 +633,9 @@ case class CometGlobalLimitExec( extends CometUnaryExec { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -985,6 +1035,7 @@ case class CometScanWrapper(override val nativeOp: Operator, override val origin extends CometNativeExec with LeafExecNode { override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) + override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, originalPlan) } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala index 6e6c624910..1cbe27be91 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql.benchmark import java.io.File +import java.nio.charset.StandardCharsets +import java.util.Base64 import scala.util.Random +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory} +import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} @@ -120,6 +125,41 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table") } + protected def prepareEncryptedTable( + dir: File, + df: DataFrame, + partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsEncryptedParquetV1Table(testDf, dir.getCanonicalPath + "/parquetV1") + } + + protected def saveAsEncryptedParquetV1Table(df: DataFrameWriter[Row], dir: String): Unit = { + val encoder = Base64.getEncoder + val footerKey = + encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8)) + val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8)) + val cryptoFactoryClass = + "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory" + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + df.mode("overwrite") + .option("compression", "snappy") + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: id") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table") + } + } + protected def makeDecimalDataFrame( values: Int, decimal: DecimalType, diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala index 02b9ca5dce..a5db4f290d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala @@ -20,11 +20,16 @@ package org.apache.spark.sql.benchmark import java.io.File +import java.nio.charset.StandardCharsets +import java.util.Base64 import scala.jdk.CollectionConverters._ import scala.util.Random import org.apache.hadoop.fs.Path +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.KeyToolkit +import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.TestUtils import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} @@ -93,6 +98,94 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { } } + def encryptedScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = + new Benchmark(s"SQL Single ${dataType.sql} Encrypted Column Scan", values, output = output) + + val encoder = Base64.getEncoder + val footerKey = + encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8)) + val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8)) + val cryptoFactoryClass = + "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory" + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareEncryptedTable( + dir, + spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM $tbl")) + + val query = dataType match { + case BooleanType => "sum(cast(id as bigint))" + case _ => "sum(id)" + } + + sqlBenchmark.addCase("SQL Parquet - Spark") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_DATAFUSION, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native Iceberg Compat") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.run() + } + } + } + def decimalScanBenchmark(values: Int, precision: Int, scale: Int): Unit = { val sqlBenchmark = new Benchmark( s"SQL Single Decimal(precision: $precision, scale: $scale) Column Scan", @@ -552,13 +645,20 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { } } - runBenchmarkWithTable("SQL Single Numeric Column Scan", 1024 * 1024 * 15) { v => + runBenchmarkWithTable("SQL Single Numeric Column Scan", 1024 * 1024 * 128) { v => Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) .foreach { dataType => numericScanBenchmark(v, dataType) } } + runBenchmarkWithTable("SQL Single Numeric Encrypted Column Scan", 1024 * 1024 * 128) { v => + Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) + .foreach { dataType => + encryptedScanBenchmark(v, dataType) + } + } + runBenchmark("SQL Decimal Column Scan") { withTempTable(tbl) { import spark.implicits._ @@ -639,6 +739,7 @@ object CometReadHdfsBenchmark extends CometReadBaseBenchmark with WithHdfsCluste finally getFileSystem.delete(tempHdfsPath, true) } } + override protected def prepareTable( dir: File, df: DataFrame, diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index 8d2c3db727..cff21ecec3 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.comet -import java.io.File -import java.io.RandomAccessFile +import java.io.{File, RandomAccessFile} import java.nio.charset.StandardCharsets import java.util.Base64 @@ -29,12 +28,16 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.scalatestplus.junit.JUnitRunner +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory} +import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{CometTestBase, SQLContext} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.comet.{CometConf, IntegrationTestSuite} +import org.apache.comet.CometConf.{SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT} /** * A integration test suite that tests parquet modular encryption usage. @@ -47,90 +50,399 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8)) private val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8)) private val key2 = encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8)) + private val cryptoFactoryClass = + "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory" test("SPARK-34990: Write and read an encrypted parquet") { - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_ICEBERG_COMPAT) import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - "parquet.crypto.factory.class" -> factoryClass, - "parquet.encryption.kms.client.class" -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - "parquet.encryption.key.list" -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - // Make sure encryption works with multiple Parquet files - val inputDF = spark - .range(0, 2000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(10) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option("parquet.encryption.column.keys", "key1: a, b; key2: c") - .option("parquet.encryption.footer.key", "footerKey") - .parquet(parquetDir) - - verifyParquetEncrypted(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Make sure encryption works with multiple Parquet files + val inputDF = spark + .range(0, 2000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(10) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } test("SPARK-37117: Can't read files in Parquet encryption external key material mode") { - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_ICEBERG_COMPAT) import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - "parquet.crypto.factory.class" -> factoryClass, - "parquet.encryption.kms.client.class" -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - "parquet.encryption.key.material.store.internally" -> "false", - "parquet.encryption.key.list" -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - val inputDF = spark - .range(0, 2000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(10) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option("parquet.encryption.column.keys", "key1: a, b; key2: c") - .option("parquet.encryption.footer.key", "footerKey") - .parquet(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.KEY_MATERIAL_INTERNAL_PROPERTY_NAME -> "false", // default is true + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 2000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(10) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + + test("SPARK-42114: Test of uniform parquet encryption") { + + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"key1: ${key1}") { + + val inputDF = spark + .range(0, 2000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(10) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option("parquet.encryption.uniform.key", "key1") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + + test("Plain text footer mode") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + PropertiesDrivenCryptoFactory.PLAINTEXT_FOOTER_PROPERTY_NAME -> "true", // default is false + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetPlaintextFooter(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + + test("Change encryption algorithm") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + // default is AES_GCM_V1 + PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME -> "AES_GCM_CTR_V1", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + // native_datafusion and native_iceberg_compat fall back due to Arrow-rs + // https://github.com/apache/arrow-rs/blob/da9829728e2a9dffb8d4f47ffe7b103793851724/parquet/src/file/metadata/parser.rs#L494 + if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( + conf) == SCAN_NATIVE_COMET) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + + test("Test double wrapping disabled") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.DOUBLE_WRAPPING_PROPERTY_NAME -> "false", // default is true + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + + test("Join between files with different encryption keys") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Write first file + val inputDF1 = spark + .range(0, 100) + .map(i => (i, s"file1_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir1 = new File(dir, "parquet1").getCanonicalPath + inputDF1.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir1) + + // Write second file using different column key + val inputDF2 = spark + .range(0, 100) + .map(i => (i, s"file2_${i}", (i * 2).toFloat)) + .toDF("id", "description", "score") + val parquetDir2 = new File(dir, "parquet2").getCanonicalPath + inputDF2.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key2: id, description, score") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir2) + + verifyParquetEncrypted(parquetDir1) + verifyParquetEncrypted(parquetDir2) + + // Now perform a join between the two files with different encryption keys + // This tests that hadoopConf properties propagate correctly to each scan + val parquetDF1 = spark.read.parquet(parquetDir1).alias("f1") + val parquetDF2 = spark.read.parquet(parquetDir2).alias("f2") + + val joinedDF = parquetDF1 + .join(parquetDF2, parquetDF1("id") === parquetDF2("id"), "inner") + .select( + parquetDF1("id"), + parquetDF1("name"), + parquetDF2("description"), + parquetDF2("score")) + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(joinedDF) + } else { + checkSparkAnswer(joinedDF) + } + } + } + } + + // Union ends up with two scans in the same plan, so this ensures that Comet can distinguish + // between the hadoopConfs for each relation + test("Union between files with different encryption keys") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Write first file with key1 + val inputDF1 = spark + .range(0, 100) + .map(i => (i, s"file1_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir1 = new File(dir, "parquet1").getCanonicalPath + inputDF1.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir1) + + // Write second file with key2 - same schema, different encryption key + val inputDF2 = spark + .range(100, 200) + .map(i => (i, s"file2_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir2 = new File(dir, "parquet2").getCanonicalPath + inputDF2.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key2: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir2) + + verifyParquetEncrypted(parquetDir1) + verifyParquetEncrypted(parquetDir2) + + val parquetDF1 = spark.read.parquet(parquetDir1) + val parquetDF2 = spark.read.parquet(parquetDir2) + + val unionDF = parquetDF1.union(parquetDF2) + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(unionDF) + } else { + checkSparkAnswer(unionDF) } + } + } + } + + test("Test different key lengths") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> "256", + KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> "256", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + // native_datafusion and native_iceberg_compat fall back due to Arrow-rs not + // supporting other key lengths + if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( + conf) == SCAN_NATIVE_COMET) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } } } @@ -146,13 +458,29 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { + Seq("true", "false").foreach { cometEnabled => - super.test(testName + s" Comet($cometEnabled)", testTags: _*) { - withSQLConf( - CometConf.COMET_ENABLED.key -> cometEnabled, - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.ANSI_ENABLED.key -> "true") { - testFun + if (cometEnabled == "true") { + Seq(SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT).foreach { + scanImpl => + super.test(testName + s" Comet($cometEnabled)" + s" Scan($scanImpl)", testTags: _*) { + withSQLConf( + CometConf.COMET_ENABLED.key -> cometEnabled, + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl) { + testFun + } + } + } + } else { + super.test(testName + s" Comet($cometEnabled)", testTags: _*) { + withSQLConf( + CometConf.COMET_ENABLED.key -> cometEnabled, + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { + testFun + } } } } @@ -164,7 +492,9 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { } private var _spark: SparkSessionType = _ + protected implicit override def spark: SparkSessionType = _spark + protected implicit override def sqlContext: SQLContext = _spark.sqlContext /** @@ -182,12 +512,50 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { val byteArray = new Array[Byte](magicStringLength) val randomAccessFile = new RandomAccessFile(parquetFile, "r") try { + // Check first 4 bytes + randomAccessFile.read(byteArray, 0, magicStringLength) + val firstMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == firstMagicString) + + // Check last 4 bytes + randomAccessFile.seek(randomAccessFile.length() - magicStringLength) + randomAccessFile.read(byteArray, 0, magicStringLength) + val lastMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == lastMagicString) + } finally { + randomAccessFile.close() + } + } + } + + /** + * Verify that the directory contains an encrypted parquet in plaintext footer mode by means of + * checking for all the parquet part files in the parquet directory that their magic string is + * PAR1, as defined in the spec: + * https://github.com/apache/parquet-format/blob/master/Encryption.md#55-plaintext-footer-mode + */ + private def verifyParquetPlaintextFooter(parquetDir: String): Unit = { + val parquetPartitionFiles = getListOfParquetFiles(new File(parquetDir)) + assert(parquetPartitionFiles.size >= 1) + parquetPartitionFiles.foreach { parquetFile => + val magicString = "PAR1" + val magicStringLength = magicString.length() + val byteArray = new Array[Byte](magicStringLength) + val randomAccessFile = new RandomAccessFile(parquetFile, "r") + try { + // Check first 4 bytes + randomAccessFile.read(byteArray, 0, magicStringLength) + val firstMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == firstMagicString) + + // Check last 4 bytes + randomAccessFile.seek(randomAccessFile.length() - magicStringLength) randomAccessFile.read(byteArray, 0, magicStringLength) + val lastMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == lastMagicString) } finally { randomAccessFile.close() } - val stringRead = new String(byteArray, StandardCharsets.UTF_8) - assert(magicString == stringRead) } }